其实树剖的部分没难度,麻烦在最后一段的区间合并
...
啊啊啊等会再说
我们先看看怎么定义一个区间:
struct Node
{
int l, r, c[2], cnt;//l, r仅在线段树中使用
//c:左右两端颜色, cnt:颜色总数
bool t; //t:是否推平,仅在线段树中使用
inline Node operator+(const Node &that) const
{//合并两个区间,左加右,this在链的前部分
Node ret;
ret.l = l, ret.r = that.r;//仅在线段树中有意义
ret.cnt = cnt + that.cnt - (c[1] == that.c[0]);
ret.c[0] = c[0], ret.c[1] = that.c[1];
ret.t = 0;
return ret;
}
inline int mid(void) { return (l + r) >> 1; }
inline void assign(int color) { t = cnt = 1, c[0] = c[1] = color; }
} nd[4 * N];
我想合并应该没问题:
xxxooo|oooxxx = 2 + 2 - 1
xxxooo|xxxooo = 2 + 2
主要看左区间最右边和右区间最左边是否一样
我想推平应该没问题,不涉及复杂合并(合并只在线段树中完成)
void set(int l, int r, int c)
{
while (top[l] != top[r])
{
if (depth[top[l]] < depth[top[r]])
l ^= r ^= l ^= r;
SegTree::set(dfn[top[l]], dfn[l], c);
l = fa[top[l]];
}
if (depth[l] > depth[r])
l ^= r ^= l ^= r;
SegTree::set(dfn[l], dfn[r], c);
}
然后是玄学的get:
int get(int l, int r)
{
static SegTree::Node pnode[2], node;
pnode[0].cnt = pnode[1].cnt = 0;
pnode[0].c[0] = pnode[0].c[1] = pnode[1].c[0] = pnode[1].c[1] = -1;
//初始化区间。由于离开线段树,区间的l,r,t就没意义了,于是忽略。
bool left = true;//我们要保证这个函数的l, r对应的是正确的临时区间
//当然也可以不用标志,直接交换node,不过常数较大
while (top[l] != top[r])
{
if (depth[top[l]] < depth[top[r]])
l ^= r ^= l ^= r, left = !left;//交换l,r同时也要交换标志。
pnode[left] = SegTree::get(dfn[top[l]], dfn[l]) + pnode[left];
//链靠上(浅),在线段树靠左
l = fa[top[l]];
}
if (depth[l] > depth[r])
{
//std::swap(pnode[0], pnode[1]);//不要学我,这就是没想清楚就动键盘的下场
l ^= r ^= l ^= r;
left = !left;
}
//这时候我们知道了l对应的区间靠上(浅),r对应的区间靠下(深)
std::swap(pnode[left].c[0], pnode[left].c[1]);//注意合并过程,默认右边颜色为-1
//由于要理成一条完整的链,我们让浅的临时节点翻转,使其左端颜色为-1
node = SegTree::get(dfn[l], dfn[r]);
node = pnode[left] + node;//[-1 浅][浅 深]
return (node + pnode[!left]).cnt;//[-1 深][深 -1]
}
完整代码:
#include <cstdio>
#include <cctype>
#include <cassert>
#include <algorithm>
struct Read
{
//...
} read;
const int N = 100000 + 10;
int n, origin[N], dfn[N], adfn[N], size[N], top[N], love[N], depth[N], fa[N];
namespace SegTree
{
struct Node
{
int l, r, c[2], cnt;
bool t;
inline Node operator+(const Node &that) const
{
Node ret;
ret.l = l, ret.r = that.r;
ret.cnt = cnt + that.cnt - (c[1] == that.c[0]);
ret.c[0] = c[0], ret.c[1] = that.c[1];
ret.t = 0;
return ret;
}
inline int mid(void) { return (l + r) >> 1; }
inline void assign(int color) { t = cnt = 1, c[0] = c[1] = color; }
} nd[4 * N];
#define LV (p << 1)
#define RV (LV + 1)
#define P (nd[p])
#define L (nd[LV])
#define R (nd[RV])
inline void build(int p, int l, int r)
{
if (l == r)
{
P.l = l, P.r = r;
P.c[0] = P.c[1] = origin[adfn[l]];
P.cnt = 1;
return;
}
int mid = (l + r) >> 1;
build(LV, l, mid);
build(RV, mid + 1, r);
P = L + R;
}
inline void pushdown(int p)
{
if (P.t)
{
P.t = false;
L.assign(P.c[0]);
R.assign(P.c[1]);
}
}
int l, r, c;
inline Node get(int p)
{
if (l <= P.l && P.r <= r)
return P;
pushdown(p);
int mid = P.mid();
if (l <= mid && mid < r)
return get(LV) + get(RV);
if (l <= mid)
return get(LV);
return get(RV);
}
inline Node get(int l_, int r_)
{
l = l_, r = r_;
return get(1);
}
inline void set(int p)
{
if (l <= P.l && P.r <= r)
{
P.assign(c);
return;
}
pushdown(p);
int mid = P.mid();
if (l <= mid)
set(LV);
if (mid < r)
set(RV);
P = L + R;
}
inline void set(int l_, int r_, int c_)
{
l = l_, r = r_, c = c_;
set(1);
}
}; // namespace SegTree
struct Link
{
int p;
Link *next;
} * head[N], poool[N * 2], *pool = poool;
inline void connect(int u, int v)
{
pool->p = v;
pool->next = head[u];
head[u] = pool++;
}
void dfs1(int p)
{
size[p] = 1;
for (Link *now = head[p]; now; now = now->next)
if (size[now->p] == 0)
{
fa[now->p] = p;
depth[now->p] = depth[p] + 1;
dfs1(now->p), size[p] += size[now->p];
if (size[now->p] > size[love[p]])
love[p] = now->p;
}
}
void dfs2(int p)
{
dfn[p] = ++dfn[0];
adfn[dfn[p]] = p;
if (top[p] == 0)
top[p] = p;
if (love[p] == 0)
return;
top[love[p]] = top[p];
dfs2(love[p]);
for (Link *now = head[p]; now; now = now->next)
if (now->p != fa[p] && now->p != love[p])
dfs2(now->p);
}
void set(int l, int r, int c)
{
while (top[l] != top[r])
{
if (depth[top[l]] < depth[top[r]])
l ^= r ^= l ^= r;
SegTree::set(dfn[top[l]], dfn[l], c);
l = fa[top[l]];
}
if (depth[l] > depth[r])
l ^= r ^= l ^= r;
SegTree::set(dfn[l], dfn[r], c);
}
int get(int l, int r)
{
static SegTree::Node pnode[2], node;
pnode[0].cnt = pnode[1].cnt = 0;
pnode[0].c[0] = pnode[0].c[1] = pnode[1].c[0] = pnode[1].c[1] = -1;
bool left = true;
while (top[l] != top[r])
{
if (depth[top[l]] < depth[top[r]])
l ^= r ^= l ^= r, left = !left;
pnode[left] = SegTree::get(dfn[top[l]], dfn[l]) + pnode[left];
l = fa[top[l]];
}
if (depth[l] > depth[r])
{
//std::swap(pnode[0], pnode[1]);
l ^= r ^= l ^= r;
left = !left;
}
std::swap(pnode[left].c[0], pnode[left].c[1]);
node = SegTree::get(dfn[l], dfn[r]);
node = pnode[left] + node;
return (node + pnode[!left]).cnt;
}
int main(void)
{
n = read;
int m = read;
for (int i = 1; i <= n; ++i)
origin[i] = read;
for (int i = 1; i != n; ++i)
{
int x, y;
scanf("%d %d", &x, &y);
connect(x, y);
connect(y, x);
}
dfs1(1);
dfs2(1);
SegTree::build(1, 1, n);
while (m--)
{
static char buf[4];
scanf("%s", buf);
if (*buf == 'C')
{
int l = read, r = read, x = read;
set(l, r, x);
}
else
{
int l = read, r = read;
printf("%d\n", get(l, r));
}
}
}