[BZOJ2588]Spoj 10628. Count on a Tree

LCA

这题经过seter重做数据后,和spoj原题很大的不同是:强制在线。也就是说求LCA不能再用各种离线算法(比如Tarjan)。在线求LCA的算法我能想到的有两个:一个是DFS序+RMQ,一个是倍增算法。因为几乎没打过倍增算法,所以就用倍增算法做了。

倍增求LCA的主要思路是:求出i点往上2^j个父亲parent[i][j],还有i点的深度。转移很愉悦parent[i][j] = parent[parent[i][j-1]][j-1]。然后查询的时候就先让两个点u, v爬到同一个高度,然后再从同一个高度爬到LCA。

可持久化线段树

这题竟然是可持久化线段树,我一开始的时候,还以为是树链剖分之类,后来Orz cxjyxx_me和 Orz Yangyue之后发现这个可以用可持久化线段树做。(询问区间第k大肯定是建立权值线段树,对于原始的区间第k大参见:从区间第k大讲起

具体的做法就是先求出一个特殊的DFS序:当一个点入栈的时候在可持久化线段树上新建版本+u,在出栈的时候建立版本-u。这个时候可以发现,对于每个点u,我们找到它那个+u的版本,如果直接从这个版本走下去,会发现恰好是u这个点到根(一般就当作1了)这条链上的所有数。因此,对于查询路径(u, v),我们可以用版本+u + 版本+v - 版本+lca - 版本+parent[lca]来表示(u, v)路径上的所有数。然后就很愉悦的求第k大了。

代码

写起来还是很愉悦的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <queue>
#include <vector>
#include <cstdio>
#include <algorithm>
using std::vector;

const int MAXN = 100002, MAXDEEP = 20;
struct Edge
{
  int v;
  Edge *next;
} g[MAXN*2], *header[MAXN];
void AddEdge(const int x, const int y)
{
  static int LinkSize;
  Edge* const node = g+(LinkSize++);
  node->v = y;
  node->next = header[x];
  header[x] = node;
}
int n, m, a[MAXN], p[MAXN][MAXDEEP], deep[MAXN];
vector<int> v;
void GetParent()
{
  std::queue<int> Q;
  for (Q.push(1); !Q.empty(); Q.pop())
  {
    const int u = Q.front();
    deep[u] = deep[p[u][0]]+1;
    for (int i = 1; i < MAXDEEP; ++i)
      p[u][i] = p[p[u][i-1]][i-1];
    for (Edge *e = header[u]; e; e = e->next)
      if (p[u][0] != e->v)
        p[e->v][0] = u, Q.push(e->v);
  }
}
int GetLCA(int u, int v)
{
  if (deep[u] < deep[v]) std::swap(u, v);
  for (int i; deep[u] != deep[v]; u = p[u][i-1])
    for (i = 1; deep[p[u][i]] >= deep[v]; ++i);
  for (int i; u != v; u = p[u][i-1], v = p[v][i-1])
    for (i = 1; p[u][i] != p[v][i]; ++i);
  return v;
}
struct Node
{
  int sum;
  Node *lef, *rig;
} _memory[MAXN*MAXDEEP*2], *root[MAXN*2+1];
int _memory_size, _st, _ed, _x, dfn[MAXN*2];
Node* _new_archive(Node* const pre, const int lef, const int rig)
{
  Node* const node = _memory+(_memory_size++);
  node->sum = pre->sum + _x;
  node->lef = pre->lef;
  node->rig = pre->rig;
  if (lef == rig) return node;
  const int mid = (lef+rig)/2;
  if (_st <= mid) node->lef = _new_archive(pre->lef, lef, mid);
  if (mid+1 <= _ed) node->rig = _new_archive(pre->rig, mid+1, rig);
  node->sum = node->lef->sum + node->rig->sum;
  return node;
}
int _query(Node* const pre, Node* const lca, Node* const u, Node* const v, const int lef, const int rig)
{
  if (lef == rig) return lef;
  const int mid = (lef+rig)/2;
  const int cnt = u->lef->sum + v->lef->sum - lca->lef->sum - pre->lef->sum;
  if (_x <= cnt) return _query(pre->lef, lca->lef, u->lef, v->lef, lef, mid);
  _x -= cnt;
  return _query(pre->rig, lca->rig, u->rig, v->rig, mid+1, rig);
}
Node* NewArchive(Node* const pre, const int pos, const int delta)
{
  _st = _ed = pos, _x = delta;
  return _new_archive(pre, 1, n);
}
int Query(const int pre, const int lca, const int u, const int v, const int k)
{
  _x = k;
  return _query(root[dfn[pre]], root[dfn[lca]], root[dfn[u]], root[dfn[v]], 1, n);
}
Node* Build(const int lef, const int rig)
{
  Node* const node = _memory+(_memory_size++);
  node->sum = 0;
  if (lef == rig) return node;
  const int mid = (lef+rig)/2;
  node->lef = Build(lef, mid);
  node->rig = Build(mid+1, rig);
  return node;
}
void DFSBuild(const int u)
{
  static int version;
  dfn[u] = ++version;
  root[version] = NewArchive(root[version-1], a[u], +1);
  for (Edge *e = header[u]; e; e = e->next)
    if (e->v != p[u][0])
      DFSBuild(e->v);
  ++version;
  root[version] = NewArchive(root[version-1], a[u], -1);
}
int main()
{
  scanf("%d%d", &n, &m);
  for (int i = 1; i <= n; ++i)
  {
    scanf("%d", a+i);
    v.push_back(a[i]);
  }
  std::sort(v.begin(), v.end());
  v.resize(std::unique(v.begin(), v.end())-v.begin());
  for (int i = 1; i <= n; ++i)
    a[i] = std::lower_bound(v.begin(), v.end(), a[i])-v.begin()+1;
  for (int i = 1, x, y; i < n; ++i)
  {
    scanf("%d%d", &x, &y);
    AddEdge(x, y);
    AddEdge(y, x);
  }
  GetParent();
  root[0] = Build(1, n);
  DFSBuild(1);
  for (int i = 0, last = 0, x, y, k; i < m; ++i)
  {
    scanf("%d%d%d", &x, &y, &k);
    x ^= last;
    const int lca = GetLCA(x, y);
    if (i) printf("\n");
    printf("%d", last = v[Query(p[lca][0], lca, x, y, k)-1]);
  }
}

Comments