Aho-Corasick

Spaghetti Sourceの複数パターン検索 (Aho-Corasick)で紹介されているスライドを参考にして、書かれている疑似コードをほぼそのままC++で書き直してみたもの。

Trieのvalues_フィールドが保持しているのは、そのノードに到達した時点でマッチしている単語の集合。ここではvectorになっているが、速度が要求されるならばvectorにして添字を保存しても良いし、単語数が少ないならば情報をビットに詰め込んでしまっても良い。実際UVA 11019 - Matrix Matcherを解く際には、単語数が100以下と少なく、また速度が要求されていたため、long long [2]とした(その場合下記の和集合を求める処理が論理和で済んでしまう)。

スライドのp14にあるout(u) := out(u) ∪ out(f (u));に相当する処理は

          for (int i = 0; i < u->fail_->values_.size(); ++i) {
            u->values_.push_back(u->fail_->values_[i]);
          }

となっている。これは少々重たい気がする。実際この処理を飛ばしてMatchメソッドを

  void Match(const string& text) {
    Trie* current = root_;
    for (int i = 0; i < text.size(); ++i) {
      int c = text[i];
      while (!current->edge_[c]) {
        current = current->fail_;
      }
      current = current->edge_[c];
      // 次の行が増えた
      for (Trie* t = current; current != current->fail_; current = current->fail_) {
        for (int i = 0; i < t->values_.size(); ++i) {
          cout << t->values_[i] << endl;
        }
      }
    }
  }

のように書き換えても動作は変わらない。これは何度か試してみたのだけれど、遅くなるだけだったのでやめた。オートマトン構築のコストより検索のコストがボトルネックになる事が多いので、そこは気にしないで良いのかも。

#include <iostream>
#include <vector>
#include <queue>
#define MAX 256

using namespace std;

class Aho {
 public:
  Aho(const vector<string>& v) {
    root_ = new Trie();
    for (int i = 0; i < v.size(); ++i) {
      root_->insert(v[i]);
    }

    for (int i = 0; i < MAX; ++i) {
      if (!root_->edge_[i]) {
        root_->edge_[i] = root_;
      }
    }
    root_->fail_ = root_;

    queue<Trie*> que;
    for (int i = 0; i < MAX; ++i) {
      Trie* q = root_->edge_[i];
      if (q != root_) {
        q->fail_ = root_;
        que.push(q);
      }
    }
    while (!que.empty()) {
      Trie* r = que.front(); que.pop();
      for (int i = 0; i < MAX; ++i) {
        Trie* u = r->edge_[i];
        if (u) {
          que.push(u);
          Trie* v = r->fail_;
          while (!v->edge_[i]) {
            v = v->fail_;
          }
          u->fail_ = v->edge_[i];
          for (int i = 0; i < u->fail_->values_.size(); ++i) {
            u->values_.push_back(u->fail_->values_[i]);
          }
        }
      }
    }
  }

  ~Aho() { delete root_; }

  void Match(const string& text) {
    Trie* current = root_;
    for (int i = 0; i < text.size(); ++i) {
      int c = text[i];
      while (!current->edge_[c]) {
        current = current->fail_;
      }
      current = current->edge_[c];
      for (int i = 0; i < current->values_.size(); ++i) {
        cout << current->values_[i] << endl;
      }
    }
  }

 private:

  class Trie {
   public:
    Trie() {
      for (int i = 0; i < MAX; ++i) {
        edge_[i] = NULL;
      }
      fail_ = NULL;
    }

    ~Trie() {
      for (int i = 0; i < MAX; ++i) {
        if (edge_[i] && edge_[i] != this) {
          delete edge_[i];
        }
      }
    }

    void insert(const string& s) {
      Trie* t = this;
      for (int i = 0; i < s.size(); ++i) {
        int c = s[i];
        if (!t->edge_[c]) t->edge_[c] = new Trie();
        t = t->edge_[c];
      }
      t->values_.push_back(s);
    }

   private:
    friend class Aho;
    Trie* edge_[MAX];
    Trie* fail_;
    vector<string> values_;
  };

  Trie* root_;
};

int main() {
  vector<string> dictionary;
  dictionary.push_back("he");
  dictionary.push_back("she");
  dictionary.push_back("his");
  dictionary.push_back("hers");

  Aho aho(dictionary);
  string text = "ushers";
  aho.Match(text);
}