POJ 3708 Recurrent Function

http://poj.org/problem?id=3708

https://vjudge.net/problem/POJ-3708

問題概要

  • ある関数$f(x)$があり、以下のように定義される $$ \begin{align} f(j) &= a_j & \text{for } 1 \leq j < d, \ f(d \times n + j) &= d \times f(n) + b_j & \text{for } 0 \leq j < d \text{ and } n \geq 1, \end{align} $$

  • ここで、集合 ${a_i}$ は ${1, 2, \ldots, d-1}$ から選ばれ、集合 ${b_i}$ は ${0, 1, \ldots, d-1}$ から選ばれる

  • 以下を定義する $$ f_x(m) = f(f(f(\cdots f(m)))) \quad x\text{ times}

$$

  • 正の整数 $m$ と $k$ が与えられる。$f_x(m) = k$ となる最小の非負整数 $x$ が存在するか、存在する場合はその値を求めよ
  • なお、答えが $2^{63}$ 未満であることが保証される

入力

マルチテストケースで、各ケースは以下のようになっている。整数-1のみの行が与えられたとき入力終了

d
a[1] a[2] ... a[d-1]
b[0] b[1] b[2] ... b[d-1]
m
k

制約

  • $2 \leq d \leq 100$
  • $0 \leq m \leq 10^{100}$
  • $0 \leq k \leq 10^{100}$

解法メモ

かなり難しかった。

f(x)の定義について

まず$f(x)$の定義が分かりづらい。再帰関数で書くと少しわかりやすくなる。

def f(x):
    if x < d:
        return a[x]
    return f(x // d) * d + b[x % d]

つまり、$f(x)$は以下のような処理を行っていると考えられる。

def f(x):
    xl = xをd進表記したリスト
    for i in range(len(xl)):
		if i == 0: # 最上位桁だけは順列aで置換
	  	  xl[i] = a[xl[i]]
		else: # それ以外の桁は順列bで置換
	    	xl[i] = b[xl[i]]
	return to_int(xl)

問題文を言い換えると以下のようになる。なお、整数$n$を$d$進表記した時の$i$桁目を$n[i]$とする。(最上位桁が$n[1]$)

$1,\cdots,d-1$の順列$a$, $0,\cdots,d-1$の順列$b$が与えられる。 関数$f(x)$は以下のような関数である。

$x$を$d$進表記した時、$f(x)$は$x$と同じ桁数の整数を返し、その時の$i$桁目$f(x)[i]$は以下のようになる。

$$ \begin{align}f(x)[i] =\left{\begin{array}{ll} a_{x[i]}&\text{if }i=1\ b_{x[i]}&\text{otherwise} \end{array} \right. \end{align} $$ $m$に何回$f$を適用すると$k$になるか求めよ

ここから解法

  • $a, b$は順列なので、各桁の変換は有向閉路上の移動とみなすことができる。

  • 各$k[i], m[i]$が同じ閉路上にあれば、それらが一致する$f$の適用回数は閉路のサイズを$c$として$ct+k\quad(t\geq1)$のような形で表すことができる。

  • このような式が桁数分並び、それらすべてに当てはまる適用回数を求めればいい。これはGarnerのアルゴリズム(ACLのCRT)で求めることができる。

実装例

最初の$d$進に変換するところで多倍長整数がないと面倒なのでJavaで通した

import java.util.Scanner;
import java.util.ArrayList;
import java.math.BigInteger;
import java.util.Collections;

public class Main {

    private static long[] inv_gcd(long a, long b) { // 拡張ユークリッド互除法
        a %= b;
        if (a < b)
            a += b;
        if (a == 0) {
            return new long[] { b, 0 };
        }

        long s = b, t = a;
        long m0 = 0, m1 = 1;

        while (t > 0) {
            long u = s / t;
            s -= t * u;
            m0 -= m1 * u;
            long tmp = s;
            s = t;
            t = tmp;
            tmp = m0;
            m0 = m1;
            m1 = tmp;
        }
        if (m0 < 0)
            m0 += b / s;
        return new long[] { s, m0 };
    }

    private static long[] crt(long[] r, long[] m) { // ほぼACLと同じ
        int n = r.length;
        long r0 = 0, m0 = 1;
        for (int i = 0; i < n; i++) {
            long r1 = r[i] % m[i];
            if (r1 < 0)
                r1 += m[i];
            long m1 = m[i];
            if (m0 < m1) {
                long tmp = r0;
                r0 = r1;
                r1 = tmp;
                tmp = m0;
                m0 = m1;
                m1 = tmp;
            }
            if (m0 % m1 == 0) {
                if (r0 % m1 != r1)
                    return new long[] { 0, 0 };
                continue;
            }
            long[] inv = inv_gcd(m0, m1);
            long g = inv[0], im = inv[1];

            long u1 = (m1 / g);
            if ((r1 - r0) % g != 0)
                return new long[] { 0, 0 };

            long x = (r1 - r0) / g % u1 * im % u1;

            r0 += x * m0;
            m0 *= u1;
            if (r0 < 0)
                r0 += m0;
        }
        return new long[] { r0, m0 };
    }

    private static long solve(int d, ArrayList<Integer> a, ArrayList<Integer> b, BigInteger m, BigInteger k) {
        // m,kをd進数に変換
        ArrayList<Integer> ml = new ArrayList<Integer>();
        ArrayList<Integer> kl = new ArrayList<Integer>();
        while (m.compareTo(BigInteger.ZERO) > 0) {
            ml.add(m.mod(BigInteger.valueOf(d)).intValue());
            m = m.divide(BigInteger.valueOf(d));
        }
        Collections.reverse(ml);
        while (k.compareTo(BigInteger.ZERO) > 0) {
            kl.add(k.mod(BigInteger.valueOf(d)).intValue());
            k = k.divide(BigInteger.valueOf(d));
        }
        Collections.reverse(kl);

        // 桁数が異なる場合は不可能
        if (ml.size() != kl.size()) {
            return -1;
        }

        // 1桁目のループ周期を求める。 x0*t+y0回fを適用すると1桁目がkと一致する
        ArrayList<Boolean> used0 = new ArrayList<Boolean>(d);
        for (int i = 0; i < d; i++)
            used0.add(false);
        int x0 = 0, y0 = -1;
        int ml0 = ml.get(0);
        while (true) {
            if (ml0 == kl.get(0) && y0 == -1)
                y0 = x0;
            ml0 = a.get(ml0);
            if (used0.get(ml0)) {
                break;
            }
            x0++;
            used0.set(ml0, true);
        }

        // 2桁目以降について、順列を閉路に分解、各閉路の情報を求める
        ArrayList<Integer> cycleId = new ArrayList<Integer>(d); // 各数がどの閉路に属するか
        ArrayList<Integer> cycleSize = new ArrayList<Integer>(); // 各閉路のサイズ
        ArrayList<Integer> index = new ArrayList<Integer>(d); // 各数が閉路内で何番目か
        for (int i = 0; i < d; i++) {
            cycleId.add(-1);
            index.add(-1);
        }
        for (int i = 0; i < d; i++) {
            if (cycleId.get(i) != -1)
                continue;
            int p = i;
            int cnt = 0;
            while (cycleId.get(p) == -1) {
                cycleId.set(p, cycleSize.size());
                index.set(p, cnt);
                p = b.get(p);
                cnt++;
            }
            cycleSize.add(cnt);
        }

        // 2桁目以降の各桁の数字だけをkと一致させるために必要なfの適用回数・周期を求める
        boolean ok = used0.get(kl.get(0));
        ArrayList<Long> ra = new ArrayList<Long>(); // 何回fを適用すれば一致するか
        ArrayList<Long> ma = new ArrayList<Long>(); // その周期
        ra.add((long) y0);
        ma.add((long) x0);

        for (int i = 1; i < ml.size(); i++) {
            if (cycleId.get(ml.get(i)) != cycleId.get(kl.get(i))) { // 何回fを適用してもその桁が一致しない場合
                ok = false;
                break;
            }
            int num = index.get(kl.get(i)) - index.get(ml.get(i));
            if (num < 0)
                num += cycleSize.get(cycleId.get(ml.get(i)));
            ra.add((long) num);
            ma.add((long) cycleSize.get(cycleId.get(ml.get(i))));
        }

        if (!ok) {
            return -1;
        }

        // CRTで解を求める
        long[] _ra = new long[ra.size()]; // arrayに変換
        long[] _ma = new long[ma.size()];
        for (int i = 0; i < ra.size(); i++) {
            _ra[i] = ra.get(i);
            _ma[i] = ma.get(i);
        }
        long[] ans = crt(_ra, _ma);

        if (ans[0] == 0 && ans[1] == 0) {
            return -1;
        } else {
            return ans[0];
        }
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        while (true) {
            // 入力ここから
            int d = scanner.nextInt();
            if (d == -1)
                break;
            ArrayList<Integer> a = new ArrayList<Integer>(d);
            ArrayList<Integer> b = new ArrayList<Integer>(d);
            a.add(0);
            for (int i = 1; i < d; i++) {
                a.add(scanner.nextInt());
            }
            for (int i = 0; i < d; i++) {
                b.add(scanner.nextInt());
            }
            BigInteger m = scanner.nextBigInteger();
            BigInteger k = scanner.nextBigInteger();
            // 入力ここまで

            long ans = solve(d, a, b, m, k);
            if (ans == -1) {
                System.out.println("NO");
            } else {
                System.out.println(ans);
            }
        }

        scanner.close();
    }
}