[BZOJ2142]礼物

这题很容易看出来求n!/(A1! * A2! * ... * An! * (n-ΣAi)!) % M,但是模数太大,用一种比较特殊的方法,这个也是对付阶乘取模的一个通用方法。对于比较特殊的情况,可以参考[BZOJ2313]分组

大体思路

  • 把模数M拆成质因子相乘的形式p1^c1 * p2^c2 * ... * pk^ck,然后对每一项pi^ci求答案,最后中国剩余定理合并。
  • 把阶乘拆成n! = A * pi^B的形式,其中A为除去所有pi因子之后mod pi^ci的结果。这样做除法的时候,就可以指数相减,A再用扩展欧几里得算法搞定。

拆阶乘

n! = (n%p^c)! * ((p^c)!^(n/p^c)) * (n/p)!

前面两项可以预处理fac[i]数组和快速幂解决,后面那一项递归。需要注意的是,算fac[i]的时候,要把所有的p的倍数去掉。也就是说:

fac[i] = fac[i-1] * (i % p ? i : 1)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <cstdio>
#include <cstring>
#include <algorithm>

typedef std::pair<long long, long long> Pair;
struct Triple
{
  long long x, y, z;
  Triple() { }
  Triple(const long long a, const long long b, const long long c): x(a), y(b), z(c) { }
};
Triple ExtendedGCD(const long long a, const long long b)
{
  if (!b) return Triple(1, 0, a);
  const Triple last(ExtendedGCD(b, a%b));
  return Triple(last.y, last.x - a / b * last.y, last.z);
}
long long Power(long long base, long long k, const long long mod)
{
  long long ans = 1;
  for (; k; k >>= 1)
  {
    if (k & 1) (ans *= base) %= mod;
    (base *= base) %= mod;
  }
  return ans;
}

int N, M, P, sum, facs, a[5];
long long p[100], pc[100], fac[100][100001], m[100];
Pair FnModPc(long long n, const int i)
{
  if (n == 0) return Pair(1, 0);
  const Pair last(FnModPc(n/p[i], i));
  return Pair(Power(fac[i][pc[i]-1], n/pc[i], pc[i]) * fac[i][n%pc[i]] % pc[i] * last.first % pc[i], n/p[i] + last.second);
}
inline void AdivB(Pair& A, const Pair& B, const int i)
{
  if (A.second < B.second) { A.first = A.second = 0; return; }
  const Triple gcd(ExtendedGCD(B.first, pc[i]));
  A.first = gcd.x * A.first % pc[i];
  A.second -= B.second;
}
void Factor(long long& x, const long long d)
{
  p[facs] = d, pc[facs] = 1, fac[facs][0] = 1;
  for (; x % d == 0; x /= d) pc[facs] *= d;
  for (int i = 1; i <= pc[facs]; ++i)
    fac[facs][i] = fac[facs][i-1] * (i % p[facs] ? i : 1) % pc[facs];
  ++facs;
}

int main()
{
  scanf("%d%d%d", &P, &N, &M);
  for (int i = 0; i < M; sum += a[i++])
    scanf("%d", a+i);
  if (sum > N)
  {
    printf("Impossible");
    return 0;
  }
  long long x = P;
  for (long long i = 2; i*i <= x; ++i)
    if (x % i == 0)
      Factor(x, i);
  if (x != 1)
    Factor(x, x);
  for (int i = 0; i < facs; ++i)
  {
    Pair A(FnModPc(N, i));
    AdivB(A, FnModPc(N-sum, i), i);
    for (int j = 0; j < M; ++j)
      AdivB(A, FnModPc(a[j], i), i);
    m[i] = A.first * Power(p[i], A.second, pc[i]) % pc[i];
  }
  long long ans = 0;
  for (int i = 0; i < facs; ++i)
  {
    const long long Mi = P / pc[i];
    const Triple gcd(ExtendedGCD(Mi, pc[i]));
    ans += (Mi * gcd.x * m[i]) % P;
  }
  printf("%lld", (ans+P)%P);
}

Comments