AOJ 1312 Where’s Wally
https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1312
https://vjudge.net/problem/Aizu-1312
問題概要
- base64形式で$H\times W$の2値画像、$P\times P$の2値パターンが与えられる
- 画像内の$P\times P$領域の中で、パターンの0・90・180・270度回転・反転いずれかに一致するものはいくつあるか
制約
- $1\leq W\leq 1000$
- $1\leq H\leq 1000$
- $1\leq P\leq 1000$
解法メモ
-
2次元でRolling Hashをやればよい
-
前処理として、あらかじめ画像とパターン(反転・回転)のハッシュを計算しておく
-
画像の中の各$P\times P$領域のハッシュがパターンの反転・回転8通りのいずれかと一致する場所をカウントしていけばよい。
-
(書いていて思ったが8種とハッシュをかけている…?)
実装例
BASE64の中身を初めて調べた Rolling Hashに加えてBASE64デコード処理や回転・反転も書かないといけなくてまあまあ実装が面倒
#include <algorithm>
#include <array>
#include <cassert>
#include <iostream>
#include <random>
#include <set>
#include <vector>
#define rep(i, n) for (int i = 0, i##_len = (n); i < i##_len; ++i)
using namespace std;
vector<bool> base64_decoding(const string& s) {
assert(s.size() % 4 == 0);
vector<bool> res(s.size() * 6);
for (size_t i = 0; i < s.size(); i += 4) {
unsigned x = 0;
for (size_t j = 0; j < 4; ++j) {
x <<= 6;
if (s[i + j] == '=') x |= 0;
else if ('A' <= s[i + j] && s[i + j] <= 'Z') x |= s[i + j] - 'A' + 0;
else if ('a' <= s[i + j] && s[i + j] <= 'z') x |= s[i + j] - 'a' + 26;
else if ('0' <= s[i + j] && s[i + j] <= '9') x |= s[i + j] - '0' + 52;
else if (s[i + j] == '+') x |= 62;
else if (s[i + j] == '/') x |= 63;
else assert(false);
}
for (size_t j = 0; j < 24; ++j) res[i * 6 + j] = x & (1u << (23 - j));
}
return res;
}
template <size_t MOD_NUM>
struct RollingHash2d {
using Hash = array<int, MOD_NUM>;
using Hashes = vector<vector<Hash>>;
static_assert(MOD_NUM < 40, "MOD_NUM must be less than 40");
array<long long, 50> MODS = {
999999503, 999999527, 999999541, 999999587, 999999599, 999999607, 999999613, 999999667, 999999677, 999999733, 999999739, 999999751, 999999757, 999999761, 999999797, 999999883, 999999893, 999999929, 999999937, 1000000007, 1000000009, 1000000021, 1000000033, 1000000087, 1000000093, 1000000097, 1000000103, 1000000123, 1000000181, 1000000207, 1000000223, 1000000241, 1000000271, 1000000289, 1000000297, 1000000321, 1000000349, 1000000363, 1000000403, 1000000409, 1000000411, 1000000427, 1000000433, 1000000439, 1000000447, 1000000453, 1000000459, 1000000483, 1000000513, 1000000531,
};
private:
array<long long, MOD_NUM> MOD;
array<long long, MOD_NUM * 2> BASE;
vector<vector<long long>> power_table, power_table_inv;
long long mod_inv(long long a, long long M) {
long long b = M, u = 1, v = 0;
while (b) {
long long t = a / b;
a -= t * b;
swap(a, b);
u -= t * v;
swap(u, v);
}
u %= M;
if (u < 0) u += M;
return u;
}
long long power(size_t x, size_t baseIdx) {
while (power_table[baseIdx].size() <= x) {
power_table[baseIdx].push_back((power_table[baseIdx].back() * BASE[baseIdx] % MOD[baseIdx / 2]));
power_table_inv[baseIdx].push_back(mod_inv(power_table[baseIdx].back(), MOD[baseIdx / 2]));
}
return power_table[baseIdx][x];
}
long long power_inv(size_t x, size_t baseIdx) {
while (power_table_inv[baseIdx].size() <= x) {
power_table[baseIdx].push_back((power_table[baseIdx].back() * BASE[baseIdx] % MOD[baseIdx / 2]));
power_table_inv[baseIdx].push_back(mod_inv(power_table[baseIdx].back(), MOD[baseIdx / 2]));
}
return power_table_inv[baseIdx][x];
}
template <typename T>
vector<vector<int>> _compute_hash(const vector<vector<T>>& a, size_t mIdx) {
const size_t n = a.size();
const size_t m = a[0].size();
vector<vector<int>> hash_(n + 1, vector<int>(m + 1, 0));
size_t b1Idx = mIdx * 2, b2Idx = mIdx * 2 + 1;
long long mod = MOD[mIdx];
for (size_t i = 1; i <= n; ++i)
for (size_t j = 1; j <= m; ++j) hash_[i][j] = a.at(i - 1).at(j - 1) % mod;
for (size_t i = 1; i <= n; ++i)
for (size_t j = 1; j <= m; ++j) {
hash_[i][j] = (hash_[i][j] * power(j - 1, b1Idx)) % mod;
hash_[i][j] += hash_[i][j - 1];
if (hash_[i][j] >= mod) hash_[i][j] -= mod;
}
for (size_t i = 1; i <= n; ++i)
for (size_t j = 1; j <= m; ++j) {
hash_[i][j] = (hash_[i][j] * power(i - 1, b2Idx)) % mod;
hash_[i][j] += hash_[i - 1][j];
if (hash_[i][j] >= mod) hash_[i][j] -= mod;
}
return hash_;
}
int _get_hash(const Hashes& hashes, size_t i, size_t j, size_t h, size_t w, size_t mIdx) {
size_t b1Idx = mIdx * 2, b2Idx = mIdx * 2 + 1;
int res = hashes[i + h - 1][j + w - 1][mIdx];
res -= hashes[i - 1][j + w - 1][mIdx];
if (res < 0) res += MOD[mIdx];
res -= hashes[i + h - 1][j - 1][mIdx];
if (res < 0) res += MOD[mIdx];
res += hashes[i - 1][j - 1][mIdx];
if (res >= MOD[mIdx]) res -= MOD[mIdx];
res = (res * power_inv(j - 1, b1Idx)) % MOD[mIdx];
res = (res * power_inv(i - 1, b2Idx)) % MOD[mIdx];
return res;
}
public:
RollingHash2d() {
random_device rnd;
set<int> mods_idx;
while (mods_idx.size() < size_t(MOD_NUM)) mods_idx.insert(rnd() % MODS.size());
vector<int> tmp(mods_idx.begin(), mods_idx.end());
for (size_t i = 0; i < MOD_NUM; ++i) MOD[i] = MODS[tmp[i]];
for (size_t i = 0; i < MOD_NUM * 2; ++i) BASE[i] = rnd() % (MOD[i / 2] - 2) + 2;
power_table.resize(MOD_NUM * 2, vector<long long>(1, 1));
power_table_inv.resize(MOD_NUM * 2, vector<long long>(1, 1));
}
template <typename T>
Hashes compute_hash(const vector<vector<T>>& a) {
const size_t n = a.size();
assert(n >= 0);
const size_t m = a[0].size();
assert(m >= 0);
Hashes res(n + 1, vector<Hash>(m + 1));
for (size_t i = 0; i < MOD_NUM; ++i) {
vector<vector<int>> hashes = _compute_hash(a, i);
for (size_t y = 1; y <= n; ++y)
for (size_t x = 1; x <= m; ++x) {
res[y][x][i] = hashes[y][x];
}
}
return res;
}
Hashes compute_hash(const vector<string>& a) {
const int n = int(a.size()) - 1;
assert(n >= 1);
const int m = int(a[1].size()) - 1;
assert(m >= 1);
vector<vector<int>> a_(n + 1, vector<int>(m + 1));
for (size_t i = 1; i <= n; ++i)
for (size_t j = 1; j <= m; ++j) a_[i][j] = a[i][j];
return compute_hash(a_);
}
Hash get_hash(const Hashes& hashes, size_t i, size_t j, size_t h, size_t w) {
assert(i >= 0 && j >= 0 && i + h <= hashes.size() && j + w <= hashes[0].size());
++i, ++j;
Hash res;
for (size_t k = 0; k < MOD_NUM; ++k) res[k] = _get_hash(hashes, i, j, h, w, k);
return res;
}
};
using RollingHash = RollingHash2d<2>;
RollingHash RH;
template <typename T>
void rotate_invert(vector<vector<T>>& a, int rotate, bool invert) {
const int n = a.size();
const int m = a[0].size();
assert(n == m);
assert(0 <= rotate && rotate < 4);
vector<vector<T>> b = a;
if (rotate == 1) {
rep(i, n) rep(j, m) a[i][j] = b[j][n - i - 1];
} else if (rotate == 2) {
rep(i, n) rep(j, m) a[i][j] = b[n - i - 1][m - j - 1];
} else if (rotate == 3) {
rep(i, n) rep(j, m) a[i][j] = b[m - j - 1][i];
}
if (invert) {
rep(i, n) reverse(a[i].begin(), a[i].end());
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
while (true) {
int w, h, p;
cin >> w >> h >> p;
if (w == 0 && h == 0 && p == 0) break;
vector<vector<bool>> image(h), pattern(p);
rep(i, h) {
string s;
cin >> s;
while (s.size() % 4 != 0) s.push_back('='); // この問題では必要ないが、base64の仕様に合わせている
image[i] = base64_decoding(s);
image[i].resize(w);
}
rep(i, p) {
string s;
cin >> s;
while (s.size() % 4 != 0) s.push_back('=');
pattern[i] = base64_decoding(s);
pattern[i].resize(p);
}
auto image_hashes = RH.compute_hash(image);
vector<RollingHash::Hash> pattern_hash(8); // 回転4通り * 反転2通りで8通りのパターンのハッシュを計算
rep(i, 4) {
rep(j, 2) {
auto pattern_ = pattern;
rotate_invert(pattern_, i, j);
auto pattern_hashes_ = RH.compute_hash(pattern_);
pattern_hash[i * 2 + j] = RH.get_hash(pattern_hashes_, 0, 0, p, p);
}
}
int ans = 0;
rep(i, h - p + 1) {
rep(j, w - p + 1) {
auto image_hash = RH.get_hash(image_hashes, i, j, p, p);
bool found = false;
rep(k, 8) {
if (image_hash == pattern_hash[k]) {
found = true;
break;
}
}
if (found) ++ans;
}
}
cout << ans << "\n";
}
}