library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub kk2a/library

:heavy_check_mark: string/aho_corasick.hpp

Depends on

Verified with

Code

#ifndef KK2_STRING_AHO_CORASICK_HPP
#define KK2_STRING_AHO_CORASICK_HPP 1

#include <algorithm>
#include <queue>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

#include "../data_structure/trie.hpp"

namespace kk2 {

template <int char_size, int margin> struct AhoCorasick : Trie<char_size + 1, margin> {
    using Trie<char_size + 1, margin>::Trie;
    using Trie<char_size + 1, margin>::count;

    constexpr static int FAIL = char_size;
    std::vector<int> correct, perm;

    void build() {
        correct.resize(this->size());
        int now = 0;
        perm.resize(this->size());
        perm[now++] = this->root;
        for (int i = 0; i < (int)this->size(); ++i) {
            correct[i] = (int)this->nodes[i].accept.size();
        }
        std::queue<int> que;
        for (int i = 0; i <= char_size; ++i) {
            if (this->nodes[this->root].nxt[i] == -1) {
                this->nodes[this->root].nxt[i] = this->root;
            } else {
                this->nodes[this->nodes[this->root].nxt[i]].nxt[FAIL] = this->root;
                que.emplace(this->nodes[this->root].nxt[i]);
            }
        }
        while (!que.empty()) {
            perm[now++] = que.front();
            auto &now = this->nodes[que.front()];
            int fail = now.nxt[FAIL];
            correct[que.front()] += correct[fail];
            que.pop();
            for (int i = 0; i < char_size; ++i) {
                if (now.nxt[i] == -1) {
                    now.nxt[i] = this->nodes[fail].nxt[i];
                } else {
                    this->nodes[now.nxt[i]].nxt[FAIL] = this->nodes[fail].nxt[i];
                    que.emplace(now.nxt[i]);
                }
            }
        }
    }

    long long all_match(const std::string &str, int now_ = 0) {
        std::unordered_map<int, int> visit_cnt;
        for (char c : str) {
            now_ = this->nodes[now_].nxt[c - margin];
            visit_cnt[now_]++;
        }
        long long res{};
        for (auto &&[now, cnt] : visit_cnt) { res += (long long)correct[now] * cnt; }
        return res;
    }

    std::vector<long long> each_match(const std::string &str, int now_ = 0) {
        std::vector<int> visit_cnt(this->size());
        for (char c : str) {
            now_ = this->nodes[now_].nxt[c - margin];
            visit_cnt[now_]++;
        }
        std::vector<long long> res(this->count());
        for (int i = this->size() - 1; i > 0; --i) {
            int now = perm[i];
            visit_cnt[this->nodes[now].nxt[FAIL]] += visit_cnt[now];
            for (int idx : this->nodes[now].accept) { res[idx] += visit_cnt[now]; }
        }
        return res;
    }

    int move(int now, char c) { return this->nodes[now].nxt[c - margin]; }

    int count(int node) const { return correct[node]; }
};

} // namespace kk2

#endif // KK2_STRING_AHO_CORASICK_HPP
#line 1 "string/aho_corasick.hpp"



#include <algorithm>
#include <queue>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

#line 1 "data_structure/trie.hpp"



#include <cassert>
#include <cstring>
#include <functional>
#line 9 "data_structure/trie.hpp"

namespace kk2 {

template <int char_size> struct TrieNode {
    int nxt[char_size];
    // 通過した回数,終点は含まない
    int passing;
    // このノードが終点のidたち
    std::vector<int> accept;

    TrieNode() : passing(0) { std::fill(nxt, nxt + char_size, -1); }
};

template <int char_size, int margin> struct Trie {
    using Node = TrieNode<char_size>;

    std::vector<Node> nodes;
    constexpr static int root = 0;

    Trie() { nodes.emplace_back(); }

    Trie(int Num) {
        nodes.reserve(Num);
        nodes.emplace_back();
    }

    int push_node() {
        nodes.emplace_back();
        return (int)nodes.size() - 1;
    }

    void update_direct(int node, int id) { nodes[node].accept.push_back(id); }
    void update_child(int node) { ++nodes[node].passing; }

    void add(const std::string &str) {
        assert(!str.empty());
        const int id = nodes[root].passing;
        int now = root;
        for (int i = 0; i < (int)str.size(); ++i) {
            const int d = str[i] - margin;
            if (nodes[now].nxt[d] == -1) nodes[now].nxt[d] = push_node();
            update_child(now);
            now = nodes[now].nxt[d];
        }
        update_direct(now, id);
    }

    template <void (*f)(int)> void query(const std::string &str) {
        query(str, [](int idx) { f(idx); });
    }

    template <class F> void query(const std::string &str, const F &f) {
        int now = root;
        for (char c : str) {
            for (int &idx : nodes[now].accept) f(idx);
            const int d = c - margin;
            now = nodes[now].nxt[d];
            if (now == -1) return;
        }
        for (int idx : nodes[now].accept) f(idx);
    }

    int count() const { return nodes[0].passing; }
    int size() const { return (int)nodes.size(); }

    // return the number of strings which have the prefix corresponding to the node_id
    int size(int node_idx) const {
        return (int)nodes[node_idx].accept.size() + nodes[node_idx].passing;
    }
};

} // namespace kk2


#line 12 "string/aho_corasick.hpp"

namespace kk2 {

template <int char_size, int margin> struct AhoCorasick : Trie<char_size + 1, margin> {
    using Trie<char_size + 1, margin>::Trie;
    using Trie<char_size + 1, margin>::count;

    constexpr static int FAIL = char_size;
    std::vector<int> correct, perm;

    void build() {
        correct.resize(this->size());
        int now = 0;
        perm.resize(this->size());
        perm[now++] = this->root;
        for (int i = 0; i < (int)this->size(); ++i) {
            correct[i] = (int)this->nodes[i].accept.size();
        }
        std::queue<int> que;
        for (int i = 0; i <= char_size; ++i) {
            if (this->nodes[this->root].nxt[i] == -1) {
                this->nodes[this->root].nxt[i] = this->root;
            } else {
                this->nodes[this->nodes[this->root].nxt[i]].nxt[FAIL] = this->root;
                que.emplace(this->nodes[this->root].nxt[i]);
            }
        }
        while (!que.empty()) {
            perm[now++] = que.front();
            auto &now = this->nodes[que.front()];
            int fail = now.nxt[FAIL];
            correct[que.front()] += correct[fail];
            que.pop();
            for (int i = 0; i < char_size; ++i) {
                if (now.nxt[i] == -1) {
                    now.nxt[i] = this->nodes[fail].nxt[i];
                } else {
                    this->nodes[now.nxt[i]].nxt[FAIL] = this->nodes[fail].nxt[i];
                    que.emplace(now.nxt[i]);
                }
            }
        }
    }

    long long all_match(const std::string &str, int now_ = 0) {
        std::unordered_map<int, int> visit_cnt;
        for (char c : str) {
            now_ = this->nodes[now_].nxt[c - margin];
            visit_cnt[now_]++;
        }
        long long res{};
        for (auto &&[now, cnt] : visit_cnt) { res += (long long)correct[now] * cnt; }
        return res;
    }

    std::vector<long long> each_match(const std::string &str, int now_ = 0) {
        std::vector<int> visit_cnt(this->size());
        for (char c : str) {
            now_ = this->nodes[now_].nxt[c - margin];
            visit_cnt[now_]++;
        }
        std::vector<long long> res(this->count());
        for (int i = this->size() - 1; i > 0; --i) {
            int now = perm[i];
            visit_cnt[this->nodes[now].nxt[FAIL]] += visit_cnt[now];
            for (int idx : this->nodes[now].accept) { res[idx] += visit_cnt[now]; }
        }
        return res;
    }

    int move(int now, char c) { return this->nodes[now].nxt[c - margin]; }

    int count(int node) const { return correct[node]; }
};

} // namespace kk2
Back to top page