POJ 3415 Common Substrings

http://poj.org/problem?id=3415

https://vjudge.net/problem/POJ-3415

問題概要

  • 文字列$A,B$と整数$K$が与えられる
  • $A,B$の長さ$k(\geq K)$の部分文字列のペア$A[i\dots(i+k)],B[j\dots(j+k)]$のうち、先頭$K$文字が一致するものの個数を数え上げよ

制約

  • $1\leq |A|,|B|\leq 10^5$
  • $1\leq K\leq \min{|A|,|B|}$

解法メモ

  • LCP配列を用いて数え上げる。

  • $A,B$を連結した文字列のSA,LCPを求めてうまくやろうとすると、SA上である値が$A$のものなのか$B$のものなのかを区別する必要があり、かなり面倒。

  • そこで、$A,B$の区別なく数え上げた後、余計な部分を引くことを考える。

  • $F(S)$を、文字列$S$の異なる位置にある部分文字列のペアのうち、先頭$K$文字が一致するものの個数 とする

  • これは [[Codeforces 123D String]] と似たような問題になり、LCP配列の区間の内、最小値が$K$以下になるものについて、$\min - K + 1$の合計を求めることで求めることができる

    • これはstackを使って、ヒストグラム内最大長方形と似たようなコードで線形時間で解くことができる。
  • これを使えば、最終的に求めたいものは$F(A+B)-F(A)-F(B)$となる。

実装例

最初setをつかって$O((|A|+|B|)\log (|A|+|B|))$で書いたがTLが厳しくて通らず、stackを使った線形時間の解法で通した。

#include <algorithm>
#include <cassert>
#include <iostream>
#include <set>
#include <stack>
#include <string>
#include <vector>
typedef long long ll;
#define rep(i, n) for (int i = 0, i##_len = (n); i < i##_len; ++i)
#define rep2(i, m, n) for (int i = (m), i##_len = (n); i < i##_len; ++i)
using namespace std;

vector<int> sa_is(vector<int> v, int upper) {
    // 1 <= v[i] <= upper
    if (v.size() == 0) return vector<int>();
    else if (v.size() == 1) return vector<int>(1, 0);
    else if (v.size() == 2) {
        vector<int> res(2, 0);
        if (v[0] < v[1]) res[1] = 1;
        else res[0] = 1;
        return res;
    }
    v.push_back(0);  // sentinel
    const int n = v.size();

    vector<int> bl(upper + 1), br(upper + 1);  // bucket range
    for (int i = 0; i < n; ++i) br[v[i]]++;
    for (int i = 1; i <= upper; ++i) br[i] += br[i - 1];
    for (int i = 1; i <= upper; ++i) bl[i] = br[i - 1];

    vector<int> is_l(n);
    vector<int> lms, sa(n, -1), lms_ord(n);  // lms_ord[i] := 0 -> not lms, 1~ -> 1-indexed
    for (int i = n - 2; i >= 0; --i) is_l[i] = (v[i] == v[i + 1]) ? is_l[i + 1] : (v[i] > v[i + 1]);
    for (int i = 1; i < n; ++i) {
        if (!is_l[i] && is_l[i - 1]) {
            sa[--br[v[i]]] = i;
            lms_ord[i] = ~int(lms.size());
            lms.push_back(i);
        }
    }

    for (int i = 0; i < upper; ++i) br[i] = bl[i + 1];
    br[upper] = n;

    for (int i = 0; i < n; ++i)
        if (sa[i] > 0 && is_l[sa[i] - 1]) sa[bl[v[sa[i] - 1]]++] = sa[i] - 1;
    for (int i = 1; i <= upper; ++i) bl[i] = br[i - 1];
    for (int i = 1; i < n; i++)
        if (sa[i] > -1 && !is_l[sa[i]]) sa[i] = -1;
    for (int i = n - 1; i >= 1; i--)
        if (sa[i] > 0 && !is_l[sa[i] - 1]) sa[--br[v[sa[i] - 1]]] = sa[i] - 1;
    for (int i = 0; i < upper; ++i) br[i] = bl[i + 1];
    br[upper] = n;

    vector<int> lms_substr_sorted(lms.size());
    int cnt = 0;
    for (int i = 0; i < n; ++i)
        if (sa[i] > -1 && lms_ord[sa[i]]) lms_substr_sorted[cnt++] = sa[i];

    // same lms_substr -> same rank
    vector<int> ord(lms.size());
    ord[0] = 1;
    for (int i = 0; i < int(lms.size()) - 1; ++i) {
        int l1 = lms_substr_sorted[i], l2 = lms_substr_sorted[i + 1];
        if (l1 > l2) swap(l1, l2);
        if (l2 == n - 1) ord[i + 1] = ord[i] + 1;
        else {
            int p1 = l1, p2 = l2;
            bool f = true;
            while (p1 <= lms[~lms_ord[l1] + 1] && p2 < n)
                if (v[p1] == v[p2]) ++p1, ++p2;
                else {
                    f = false;
                    break;
                }
            ord[i + 1] = f ? ord[i] : ord[i] + 1;
        }
    }
    vector<int> va(lms.size());  // make array of appearance order

    for (int i = 0; i < int(lms.size()); ++i) va[~lms_ord[lms_substr_sorted[i]]] = ord[i];

    vector<int> lms_sorted = sa_is(va, ord.back());

    // place lms at correct position
    fill(sa.begin(), sa.end(), -1);
    for (int i = lms.size() - 1; i >= 0; i--) sa[--br[v[lms[lms_sorted[i]]]]] = lms[lms_sorted[i]];
    for (int i = 0; i < upper; ++i) br[i] = bl[i + 1];
    br[upper] = n;
    for (int i = 0; i < n; ++i)
        if (sa[i] > 0 && is_l[sa[i] - 1]) sa[bl[v[sa[i] - 1]]++] = sa[i] - 1;
    for (int i = 1; i < n; i++)
        if (sa[i] > -1 && !is_l[sa[i]]) sa[i] = -1;
    for (int i = n - 1; i >= 1; i--)
        if (sa[i] > 0 && !is_l[sa[i] - 1]) sa[--br[v[sa[i] - 1]]] = sa[i] - 1;
    sa.erase(sa.begin());
    return sa;
}
vector<int> sa_is(string s) {
    vector<int> v(s.size());
    for (int i = 0; i < int(s.size()); i++) v[i] = s[i];
    return sa_is(v, 255);
}
vector<int> lcp(const string& s, const vector<int>& sa) {
    assert(sa.size() == s.size());
    vector<int> rank(s.size()), _lcp(s.size() - 1);
    const int n = s.size();
    for (int i = 0; i < n; i++) rank[sa[i]] = i;
    int h = 0;
    for (int i = 0; i < n; i++) {
        if (rank[i] == 0) continue;
        int j = sa[rank[i] - 1];
        if (h > 0) h--;
        for (; j + h < n && i + h < n; h++)
            if (s[j + h] != s[i + h]) break;
        _lcp[rank[i] - 1] = h;
    }
    return _lcp;
}

// ヒストグラムの中で、h[i]が最小となるような区間を返す。ただし、h[i]が一致する場合はiが小さいほうが小さいと見なす
template <typename T>
vector<pair<int, int> > largest_rectangle_in_histogram(const vector<T>& h) {
    const int n = h.size();
    vector<pair<int, int> > res(n);
    vector<int> st;
    for (int i = 0; i < n; ++i) {
        while (!st.empty() && h[st.back()] > h[i]) st.pop_back();
        res[i].first = st.empty() ? 0 : st.back() + 1;
        st.push_back(i);
    }
    st.clear();
    for (int i = n - 1; i >= 0; --i) {
        while (!st.empty() && h[st.back()] >= h[i]) st.pop_back();
        res[i].second = st.empty() ? n - 1 : st.back() - 1;
        st.push_back(i);
    }
    return res;
}
ll range_min_sum_leq_k_(const vector<int>& v, int k) {
    const int n = v.size();
    vector<pair<int, int> > lr = largest_rectangle_in_histogram<int>(v);
    ll res = 0;
    rep(i, n) {
        if (v[i] >= k) {
            ll ln = i - lr[i].first + 1, rn = lr[i].second - i + 1;
            res += ln * rn * (v[i] - k + 1);
        }
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    while (true) {
        int k;
        cin >> k;
        if (k == 0) break;
        string a, b;
        cin >> a >> b;
        int n = a.size(), m = b.size();

        string ab = a;
        ab.push_back('{');
        rep(i, m) ab.push_back(b[i]);

        vector<int> sa_ab = sa_is(ab);
        vector<int> la_ab = lcp(ab, sa_ab);

        vector<int> sa_a = sa_is(a);
        vector<int> la_a = lcp(a, sa_a);
        vector<int> sa_b = sa_is(b);
        vector<int> la_b = lcp(b, sa_b);

        ll ans = 0;
        ans += range_min_sum_leq_k_(la_ab, k);
        ans -= range_min_sum_leq_k_(la_a, k);
        ans -= range_min_sum_leq_k_(la_b, k);
        cout << ans << "\n";
    }
}