[BZOJ2876][Noi2012]骑行川藏

这题如果会拉格朗日乘数法的话,那就是裸题了。首先大概讲一下这题的意思。给定s[], k[], v'[], E,求一组v[],满足E = sum{ k[i]*s[i]*(v[i]-v'[i])^2 },并且使得t = sum{ s[i]/v[i] }最小。

很明显这个就是一个多元函数求最值的问题,拉格朗日乘数法就是解决这种问题的方法。

拉格朗日乘数法的核心是:对于多元函数f和约束方程g = 0,存在一个负数λ,使得∇f = λ ∇g,其中∇ff的梯度向量,也就是对每个元取偏导构成的向量。所谓偏导,就是把某一个元看作是变量,其他的当作常量,求导数。然后,通过方程组∇f = λ ∇g以及g = 0就可以解出取最值的时候,所有变量的值了。

对于这题,可以得到:

-s1 / v1^2 = λ * 2 * k1 * s1 * (v1 - v'1)
-s2 / v2^2 = λ * 2 * k2 * s2 * (v2 - v'2)
...

sum{ ki * si * (vi - v'i) ^ 2 } = E

然后解这个方程组就可以解出viλ。但是解出来之后能干嘛呢?

我们观察到,如果λ变大(注意,λ是负数,也就是绝对值变小),那么vi就会变大,接下来就会导致所需要的能量变大。也就是说,我们可以二分λ,然后解出vi,计算sum{ ki * si * (vi - v'i) ^ 2 }直到等于E

vi的话,可以用二分也可以用牛顿法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <cmath>
#include <cstdio>
#include <algorithm>

inline double sqr(const double x) { return x*x; }
inline double cube(const double x) { return x*x*x; }
struct quad
{ //Ax^3 + Bx^2 + Cx + D
  double A, B, C, D;
  quad(const double a, const double b, const double c, const double d): A(a), B(b), C(c), D(d) { }
  double operator()(const double x) const { return A*cube(x)+B*sqr(x)+C*x+D; }
};
int n;
double E, lef = -1e3, rig, s[10000], k[10000], x[10000], v[10000], maxv[10000];
inline double solve(const quad& f, double x)
{
  const quad d(0, 3*f.A, 2*f.B, f.C);
  double x0;
  do x0 = x, x -= f(x)/d(x);
  while (std::fabs(x-x0) > 1e-12);
  return x;
}
inline double check(const double lambda)
{
  double sum = 0;
  for (int i = 0; i < n; ++i)
  {
    x[i] = solve(quad(2*lambda*k[i], -2*lambda*k[i]*v[i], 0, 1), (std::max(v[i], .0)+maxv[i])/2);
    sum += k[i]*s[i]*sqr(x[i]-v[i]);
  }
  return sum;
}
int main()
{
  scanf("%d%lf", &n, &E);
  for (int i = 0; i < n; ++i)
  {
    scanf("%lf%lf%lf", s+i, k+i, v+i);
    maxv[i] = s[i] > 1e-12 ? std::sqrt(E/(k[i]*s[i]))+v[i] : 0;
    rig = std::min(rig, -1/(2*k[i]*(maxv[i]-v[i])*sqr(maxv[i])));
  }
  while (rig-lef > 1e-12)
  {
    const double mid = (lef+rig)/2;
    if (check(mid) >= E) rig = mid;
    else lef = mid;
  }
  double ans = 0;
  for (int i = 0; i < n; ++i) ans += s[i]/x[i];
  printf("%.10f", ans);
}

Comments