Rabbit Lunch (AOJ 2374)

http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2374

問題

解法

忘れた。たしか多い方からgreedyに取っていきつつ、取っていった分をずらしてごにょごにょすれば行けた気がする。
下のコードはたしかm-judgeではTLEしたけど、AOJでは通った。

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <iostream>
#include <math.h>
#include <assert.h>
#include <vector>

using namespace std;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
static const double EPS = 1e-9;
static const double PI = acos(-1.0);

#define REP(i, n) for (int i = 0; i < (int)(n); i++)
#define FOR(i, s, n) for (int i = (s); i < (int)(n); i++)
#define FOREQ(i, s, n) for (int i = (s); i <= (int)(n); i++)
#define FORIT(it, c) for (__typeof((c).begin())it = (c).begin(); it != (c).end(); it++)
#define MEMSET(v, h) memset((v), h, sizeof(v))

int carrot[3000000];
int kiwi[3000000];
int carrotCnt[3000000];
int kiwiCnt[3000000];
pair<int, int> stack[3000000];
int head = -1;
int dif = 0;

inline void stackPush(int kiwi, int cnt) {
  if (head == -1 || stack[head].second != cnt) {
    head++;
    stack[head] = make_pair(kiwi, cnt);
  } else {
    stack[head].first += kiwi;
  }
}

int main() {
  //MEMSET(carrot, 0);
  //MEMSET(kiwi, 0);
  //MEMSET(carrotCnt, 0);
  //MEMSET(kiwiCnt, 0);
  //MEMSET(stack, 0);
  ll M, N, m0, n0, md, nd;
  scanf("%lld %lld %lld %lld %lld %lld", &M, &N, &m0, &md, &n0, &nd);
  ll prev = m0;
  REP(i, M) {
    carrotCnt[prev]++;
    prev = (prev * 58 + md) % (N + 1);
  }
  prev = n0;
  REP(i, N) {
    kiwiCnt[prev]++;
    prev = (prev * 58 + nd) % (M + 1);
  }
  int m = 0;
  REP(i, N + 1) {
    REP (j, carrotCnt[i]) { carrot[m++] = i; }
  }
  //assert(m == M);
  m = 0;
  REP(i, M + 1) {
    REP (j, kiwiCnt[i]) { kiwi[m++] = i; }
  }
  //assert(m == N);
  ll ans = 0;
  int sum = 0;
  int index = M;
  REP(iter, M) {
    int rest = 0;
    int use = 0;
    while (sum < carrot[iter] && index > 0) {
      //if (kiwiCnt[index] == 0) { index--; continue; }
      if (sum + kiwiCnt[index] > carrot[iter]) {
        rest = sum + kiwiCnt[index] - carrot[iter];
        use = kiwiCnt[index] - rest;
        ans += use;
        kiwiCnt[index - 1] += use;
        kiwiCnt[index] -= use;
        break;
      } else {
        sum += kiwiCnt[index];
        stackPush(kiwiCnt[index], index + dif);
      }
      kiwiCnt[index] = 0;
      index--;
    }
    //assert(head == -1 || stack[head].second - dif > index);
    //assert(sum <= carrot[iter]);
    ans += sum;
    dif++;
    if (head >= 0 && (stack[head].second - dif == 0 || stack[head].second - dif == index)) {
      sum -= stack[head].first;
      kiwiCnt[index] += stack[head].first;
      head--;
    }
  }
  printf("%lld\n", ans);
}