POJ 3729 Facer’s string
http://poj.org/problem?id=3729
https://vjudge.net/problem/POJ-3729
問題概要
-
長さ$N,M$の非負整数列$S_1,S_2$と整数$K$が与えられる
-
$S_1$の末尾に$S_1,S_2$のどの要素とも一致しない番兵$-1$を追加した数列を$S_1’$とする
-
$S_1’$の長さ$K+1$の連続部分列$S_3$のうち以下の2条件を満たすものはいくつあるか
- $S_3$は$S_2$に含まれない
- 先頭$K$要素$S_3[1\dots K]$は$S_2$に含まれる
制約
- $S_1, S_2$は$10^4$以下の非負整数からなる
- $0\leq N, M\leq 5\times 10^4$
解法メモ
思いついた解法はSA+LCPを使う解法、ロリハを使う解法の2つ
SA+LCPを使う解法
-
$S_1’+S_2$のSA、LCPを求めてこれについて考える
-
答えは、「(長さ$K$の部分列$S_3$で先頭$K$が$S_2$に含まれるものの個数) - (長さ$K+1$の部分列$S_4$で先頭$K$が$S_2$に含まれるものの個数)」とみなすことができる
-
同じ処理を2回、文字列長$K, K+1$に対して行ってからその差を求めるだけでいい
-
LCPのminが$K$以上になる区間は線形時間で列挙でき、その区間は先頭$K$文字が一致する接尾辞の集合と考えることができる。
-
この区間ごとに、$S_1’$に属する接尾辞と$S_2$に属する接尾辞の数を数え上げ、後者が1つでも存在すれば前者を足し合わせていけばいい。
ローリングハッシュを使う解法
-
$S_2$について、長さ$K$の部分文字列と$K+1$の部分文字列のハッシュをset1, set2に保管しておく
-
$S_1’$について各始点からの長さ$K$の部分文字列がset1に存在し、$K+1$の部分文字列がset2に存在しなければ答えに1を加算していけばよい
-
実装時にはTLがきついのでsetではなくvectorをソートしてlower_boundで存在判定を行った
実装例
SA解
#include <algorithm>
#include <cassert>
#include <iostream>
#include <vector>
typedef long long ll;
#define rep(i, n) for (int i = 0, i##_len = (n); i < i##_len; ++i)
using namespace std;
vector<int> sa_is(vector<int> v, int upper) {
// 1 <= v[i] <= upper
if (v.size() == 0) return vector<int>();
else if (v.size() == 1) return vector<int>(1, 0);
else if (v.size() == 2) {
vector<int> res(2, 0);
if (v[0] < v[1]) res[1] = 1;
else res[0] = 1;
return res;
}
v.push_back(0); // sentinel
const int n = v.size();
vector<int> bl(upper + 1), br(upper + 1); // bucket range
for (int i = 0; i < n; ++i) br[v[i]]++;
for (int i = 1; i <= upper; ++i) br[i] += br[i - 1];
for (int i = 1; i <= upper; ++i) bl[i] = br[i - 1];
vector<int> is_l(n);
vector<int> lms, sa(n, -1), lms_ord(n); // lms_ord[i] := 0 -> not lms, 1~ -> 1-indexed
for (int i = n - 2; i >= 0; --i) is_l[i] = (v[i] == v[i + 1]) ? is_l[i + 1] : (v[i] > v[i + 1]);
for (int i = 1; i < n; ++i) {
if (!is_l[i] && is_l[i - 1]) {
sa[--br[v[i]]] = i;
lms_ord[i] = ~int(lms.size());
lms.push_back(i);
}
}
for (int i = 0; i < upper; ++i) br[i] = bl[i + 1];
br[upper] = n;
for (int i = 0; i < n; ++i)
if (sa[i] > 0 && is_l[sa[i] - 1]) sa[bl[v[sa[i] - 1]]++] = sa[i] - 1;
for (int i = 1; i <= upper; ++i) bl[i] = br[i - 1];
for (int i = 1; i < n; i++)
if (sa[i] > -1 && !is_l[sa[i]]) sa[i] = -1;
for (int i = n - 1; i >= 1; i--)
if (sa[i] > 0 && !is_l[sa[i] - 1]) sa[--br[v[sa[i] - 1]]] = sa[i] - 1;
for (int i = 0; i < upper; ++i) br[i] = bl[i + 1];
br[upper] = n;
vector<int> lms_substr_sorted(lms.size());
int cnt = 0;
for (int i = 0; i < n; ++i)
if (sa[i] > -1 && lms_ord[sa[i]]) lms_substr_sorted[cnt++] = sa[i];
// same lms_substr -> same rank
vector<int> ord(lms.size());
ord[0] = 1;
for (int i = 0; i < int(lms.size()) - 1; ++i) {
int l1 = lms_substr_sorted[i], l2 = lms_substr_sorted[i + 1];
if (l1 > l2) swap(l1, l2);
if (l2 == n - 1) ord[i + 1] = ord[i] + 1;
else {
int p1 = l1, p2 = l2;
bool f = true;
while (p1 <= lms[~lms_ord[l1] + 1] && p2 < n)
if (v[p1] == v[p2]) ++p1, ++p2;
else {
f = false;
break;
}
ord[i + 1] = f ? ord[i] : ord[i] + 1;
}
}
vector<int> va(lms.size()); // make array of appearance order
for (int i = 0; i < int(lms.size()); ++i) va[~lms_ord[lms_substr_sorted[i]]] = ord[i];
vector<int> lms_sorted = sa_is(va, ord.back());
// place lms at correct position
fill(sa.begin(), sa.end(), -1);
for (int i = lms.size() - 1; i >= 0; i--) sa[--br[v[lms[lms_sorted[i]]]]] = lms[lms_sorted[i]];
for (int i = 0; i < upper; ++i) br[i] = bl[i + 1];
br[upper] = n;
for (int i = 0; i < n; ++i)
if (sa[i] > 0 && is_l[sa[i] - 1]) sa[bl[v[sa[i] - 1]]++] = sa[i] - 1;
for (int i = 1; i < n; i++)
if (sa[i] > -1 && !is_l[sa[i]]) sa[i] = -1;
for (int i = n - 1; i >= 1; i--)
if (sa[i] > 0 && !is_l[sa[i] - 1]) sa[--br[v[sa[i] - 1]]] = sa[i] - 1;
sa.erase(sa.begin());
return sa;
}
vector<int> sa_is(vector<int> v) {
vector<int> cv = v;
sort(cv.begin(), cv.end());
cv.erase(unique(cv.begin(), cv.end()), cv.end());
for (int i = 0; i < (int)v.size(); ++i) v[i] = lower_bound(cv.begin(), cv.end(), v[i]) - cv.begin() + 1;
return sa_is(v, cv.size());
}
vector<int> lcp(const vector<int>& s, const vector<int>& sa) {
assert(sa.size() == s.size());
const int n = s.size();
vector<int> rank(n), _lcp(n - 1);
for (int i = 0; i < n; i++) rank[sa[i]] = i;
int h = 0;
for (int i = 0; i < n; i++) {
if (rank[i] == 0) continue;
int j = sa[rank[i] - 1];
if (h > 0) h--;
for (; j + h < n && i + h < n; h++)
if (s[j + h] != s[i + h]) break;
_lcp[rank[i] - 1] = h;
}
return _lcp;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m, k;
while (cin >> n >> m >> k) {
vector<int> s1(n), s2(m);
rep(i, n) cin >> s1[i];
rep(i, m) cin >> s2[i];
vector<int> s1s2 = s1;
s1s2.push_back(1000000000);
rep(i, m) s1s2.push_back(s2[i]);
vector<int> sa = sa_is(s1s2);
vector<int> lcp_ = lcp(s1s2, sa);
ll cnt_k = 0, cnt_k_1 = 0;
rep(i, n + m) {
if (lcp_[i] < k) continue;
int j = i;
ll cnt_s1 = sa[i] < n, cnt_s2 = 1 - cnt_s1;
while (j < n + m && lcp_[j] >= k) {
if (sa[j + 1] < n) {
cnt_s1++;
} else {
cnt_s2++;
}
j++;
}
if (cnt_s2 > 0) cnt_k += cnt_s1;
i = j - 1;
}
rep(i, n + m) {
if (lcp_[i] < k + 1) continue;
int j = i;
ll cnt_s1 = sa[i] < n, cnt_s2 = 1 - cnt_s1;
while (j < n + m && lcp_[j] >= k + 1) {
if (sa[j + 1] < n) {
cnt_s1++;
} else {
cnt_s2++;
}
j++;
}
if (cnt_s2 > 0) cnt_k_1 += cnt_s1;
i = j - 1;
}
cout << cnt_k - cnt_k_1 << '\n';
}
}
ロリハ解
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <string>
#include <vector>
typedef long long ll;
#define rep(i, n) for (int i = 0, i##_len = (n); i < i##_len; ++i)
using namespace std;
ll MOD[] = {
999999503, 999999527, 999999541, 999999587, 999999599, 999999607, 999999613, 999999667, 999999677, 999999733, 999999739, 999999751, 999999757, 999999761, 999999797, 999999883, 999999893, 999999929, 999999937, 1000000007, 1000000009, 1000000021, 1000000033, 1000000087, 1000000093, 1000000097, 1000000103, 1000000123, 1000000181, 1000000207, 1000000223, 1000000241, 1000000271, 1000000289, 1000000297, 1000000321, 1000000349, 1000000363, 1000000403, 1000000409, 1000000411, 1000000427, 1000000433, 1000000439, 1000000447, 1000000453, 1000000459, 1000000483,
};
class RollingHash {
private:
ll MOD1, MOD2, BASE1, BASE2;
vector<ll> power1, power2;
public:
struct Hash {
ll h1, h2;
int length;
Hash(ll h1, ll h2, int length) : h1(h1), h2(h2), length(length) {}
bool operator<(const Hash& rhs) const {
if (length != rhs.length) return length < rhs.length;
if (h1 != rhs.h1) return h1 < rhs.h1;
return h2 < rhs.h2;
}
bool operator==(const Hash& rhs) const { return h1 == rhs.h1 && h2 == rhs.h2 && length == rhs.length; }
};
typedef pair<vector<ll>, vector<ll> > Hashes;
RollingHash() {
srand(time(0));
MOD1 = MOD[rand() % 50];
do {
MOD2 = MOD[rand() % 50];
} while (MOD1 == MOD2);
BASE1 = rand() % 1000 + 1000;
do {
BASE2 = rand() % 1000 + 1000;
} while (BASE1 == BASE2);
power1.push_back(1);
power2.push_back(1);
}
Hashes operator()(const string& str) {
int size = str.size();
vector<ll> hash1(size + 1), hash2(size + 1);
for (int i = 1; i <= size; i++) {
hash1[i] = (hash1[i - 1] + str[i - 1]) * BASE1 % MOD1;
hash2[i] = (hash2[i - 1] + str[i - 1]) * BASE2 % MOD2;
}
while (power1.size() <= size) {
power1.push_back(power1.back() * BASE1 % MOD1);
power2.push_back(power2.back() * BASE2 % MOD2);
}
return make_pair(hash1, hash2);
}
template <typename T>
Hashes operator()(const vector<T>& v) {
int size = v.size();
vector<ll> hash1(size + 1), hash2(size + 1);
for (int i = 1; i <= size; i++) {
hash1[i] = (hash1[i - 1] + v[i - 1] % MOD1) * BASE1 % MOD1;
hash2[i] = (hash2[i - 1] + v[i - 1] % MOD2) * BASE2 % MOD2;
}
while (power1.size() <= size) {
power1.push_back(power1.back() * BASE1 % MOD1);
power2.push_back(power2.back() * BASE2 % MOD2);
}
return make_pair(hash1, hash2);
}
Hash get(const Hashes& hashes, int l, int r) {
assert(0 <= l && l <= r && r <= hashes.first.size() - 1);
return Hash(((hashes.first[r] - hashes.first[l] * power1[r - l]) % MOD1 + MOD1) % MOD1, ((hashes.second[r] - hashes.second[l] * power2[r - l]) % MOD2 + MOD2) % MOD2, r - l);
}
Hash concat(const Hash& h1, const Hash& h2) {
ll h1_1 = h1.h1, h1_2 = h1.h2, h2_1 = h2.h1, h2_2 = h2.h2;
ll h1_length = h1.length, h2_length = h2.length;
return Hash((h1_1 * power1[h2_length] + h2_1) % MOD1, (h1_2 * power2[h2_length] + h2_2) % MOD2, h1_length + h2_length);
}
} RH;
typedef RollingHash::Hashes Hashes;
typedef RollingHash::Hash Hash;
// Usage: string s = "abc"; Hashes rh = RH(s); Hash hash = RH.get(rh, 0, 3);
int main() {
int n, m, k;
while (scanf("%d%d%d", &n, &m, &k) != EOF) {
vector<int> s1(n), s2(m);
rep(i, n) scanf("%d", &s1[i]);
rep(i, m) scanf("%d", &s2[i]);
s1.push_back(10000000);
Hashes rh1 = RH(s1);
Hashes rh2 = RH(s2);
vector<Hash> st1, st2;
rep(i, m - k + 1) { st1.push_back(RH.get(rh2, i, i + k)); }
rep(i, m - (k + 1) + 1) { st2.push_back(RH.get(rh2, i, i + (k + 1))); }
sort(st1.begin(), st1.end());
sort(st2.begin(), st2.end());
ll ans = 0;
rep(i, (n + 1) - (k + 1) + 1) {
Hash h1 = RH.get(rh1, i, i + k);
Hash h2 = RH.get(rh1, i, i + (k + 1));
vector<Hash>::iterator it1 = lower_bound(st1.begin(), st1.end(), h1);
vector<Hash>::iterator it2 = lower_bound(st2.begin(), st2.end(), h2);
bool f1 = it1 != st1.end() && *it1 == h1;
bool f2 = it2 != st2.end() && *it2 == h2;
if (f1 && !f2) ans++;
}
printf("%lld\n", ans);
}
}