影の魔女 (AOJ 2315)

http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2315

問題

解法

解説参照。

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <iostream>
#include <math.h>
#include <assert.h>
#include <vector>
#include <map>

using namespace std;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
static const long double EPS = 1e-9;
static const long double PI = acos(-1.0);

#define REP(i, n) for (int i = 0; i < (int)(n); i++)
#define FOR(i, s, n) for (int i = (s); i < (int)(n); i++)
#define FOREQ(i, s, n) for (int i = (s); i <= (int)(n); i++)
#define FORIT(it, c) for (__typeof((c).begin())it = (c).begin(); it != (c).end(); it++)
#define MEMSET(v, h) memset((v), h, sizeof(v))

typedef vector<long double> Array;
typedef vector<Array> Matrix;


void PrintMatrix(const Matrix &matrix) {
  for (int y = 0; y < (int)matrix.size(); y++) {
    for (int x = 0; x < (int)matrix[y].size(); x++) {
      printf("%.3Lf ", matrix[y][x]);
    }
    puts("");
  }
}

Array GaussElimination(Matrix matrix, Array vect) {
  const int n = vect.size();
  Array ret(n, 0.0);
  REP(x, n) {
    int pivot = x;
    FOR(i, x + 1, n) {
      if (fabs(matrix[i][x]) - fabs(matrix[pivot][x]) > EPS) { pivot = i; }
    }
    swap(matrix[x], matrix[pivot]);
    swap(vect[x], vect[pivot]);
    if (fabs(matrix[x][x]) < EPS) { continue; }
    FOR(y, x + 1, n) {
      long double ratio = -matrix[y][x] / matrix[x][x];
      matrix[y][x] = 0.0;
      FOR(i, x + 1, n) {
        matrix[y][i] += matrix[x][i] * ratio;
      }
      vect[y] += vect[x] * ratio;
    }
  }
  for (int x = n - 1; x >= 0; x--) {
    long double sum = vect[x];
    for (int i = n - 1; i > x; i--) {
      sum -= ret[i] * matrix[x][i];
    }
    ret[x] = sum / matrix[x][x];
  }
  return ret;
}

Matrix mul(const Matrix &lhs, const Matrix &rhs) {
  const int in = lhs[0].size();
  const int h = lhs.size();
  const int w = rhs[0].size();
  assert((int)rhs.size() == in);
  Matrix ret(h, Array(w, 0));
  REP(y, h) REP(i, in) REP(x, w) {
    ret[y][x] += lhs[y][i] * rhs[i][x];
  }
  return ret;
}

Matrix pow(Matrix base, ll power) {
  const int h = base.size();
  const int w = base[0].size();
  assert(h == w);
  Matrix ret(h, Array(w, 0));
  REP(i, h) { ret[i][i] = 1; }
  while (power > 0) {
    if (power & 1) {
      ret = mul(ret, base);
    }
    base = mul(base, base);
    power >>= 1;
  }
  return ret;
}

long double probability[20][210];
ll s;
int n, k;
int nk;

long double naive(const Array &lastDist) {
  long double ret = 0.0;
  map<ll, long double> dists;
  dists[s] = 1.0;
  while (!dists.empty()) {
    ll pos = dists.rbegin()->first;
    long double p = dists.rbegin()->second;
    assert(p >= -EPS);
    dists.erase(pos);
    if (pos < nk) {
      ret += lastDist[pos] * p;
      continue;
    }
    ret += p;
    FOREQ(i, 1, nk) {
      assert(pos >= i);
      dists[pos - i] += p * probability[k][i];
    }
  }
  return ret;
}

int main() {
  while (scanf("%lld %d %d", &s, &n, &k) > 0) {

    // initialize
    MEMSET(probability, 0);
    nk = n * k;

    // corner case
    s = abs(s);
    if (n == 1) {
      ll ans = -1;
      if (s % k == 0) { ans = s / k; }
      printf("%lld\n", ans);
      continue;
    }

    // calc probability
    probability[0][0] = 1.0;
    REP(i, k) {
      REP(j, nk) {
        REP(d, n) {
          probability[i + 1][j + d + 1] += 1.0 / n * probability[i][j];
        }
      }
    }

    // calc last loop
    Matrix matrix(nk, Array(nk, 0.0));
    Array vect(nk, 1.0);
    matrix[0][0] = 1.0;
    vect[0] = 0.0;
    FOR(i, 1, nk) {
      matrix[i][i] = 1.0;
      vect[i] = 1.0;
      FOREQ(j, 1, nk) {
        int d = abs(i - j);
        assert(d < nk);
        matrix[i][d] -= probability[k][j];
      }
    }
    Array lastDist = GaussElimination(matrix, vect);

    // calc answer
    long double ans = 0.0;
    if (s < nk) {
      ans = lastDist[s];
    } else {
      Matrix base(nk + 1, Array(nk + 1, 0.0));
      REP(i, nk) {
        base[0][i] = probability[k][i + 1];
        base[nk][i] = probability[k][i + 1];
      }
      base[0][nk] = 0.0;
      base[nk][nk] = 1.0;
      REP(i, nk - 1) {
        base[i + 1][i] = 1.0;
      }
      base = pow(base, s);
      REP(i, nk) {
        long double p = base[i][0];
        REP(j, nk - i - 1) {
          p -= probability[k][j + 1] * base[i + j + 1][0];
          ans -= probability[k][j + 1] * base[i + j + 1][0];
        }
        ans += p * lastDist[i];
      }
      ans += base[nk][0];
    }

    printf("%.8Lf\n", ans);
  }
}