<题解>[SDOI2011]染色

洛谷的题目链接

其实树剖的部分没难度,麻烦在最后一段的区间合并
...
啊啊啊等会再说

我们先看看怎么定义一个区间:

struct Node
{
    int l, r, c[2], cnt;//l, r仅在线段树中使用
                        //c:左右两端颜色, cnt:颜色总数
    bool t;             //t:是否推平,仅在线段树中使用
    inline Node operator+(const Node &that) const
    {//合并两个区间,左加右,this在链的前部分
        Node ret;
        ret.l = l, ret.r = that.r;//仅在线段树中有意义
        ret.cnt = cnt + that.cnt - (c[1] == that.c[0]);
        ret.c[0] = c[0], ret.c[1] = that.c[1];
        ret.t = 0;
        return ret;
    }
    inline int mid(void) { return (l + r) >> 1; }
    inline void assign(int color) { t = cnt = 1, c[0] = c[1] = color; }
} nd[4 * N];

我想合并应该没问题:

xxxooo|oooxxx = 2 + 2 - 1
xxxooo|xxxooo = 2 + 2
主要看左区间最右边和右区间最左边是否一样

我想推平应该没问题,不涉及复杂合并(合并只在线段树中完成)

void set(int l, int r, int c)
{
    while (top[l] != top[r])
    {
        if (depth[top[l]] < depth[top[r]])
            l ^= r ^= l ^= r;
        SegTree::set(dfn[top[l]], dfn[l], c);
        l = fa[top[l]];
    }
    if (depth[l] > depth[r])
        l ^= r ^= l ^= r;
    SegTree::set(dfn[l], dfn[r], c);
}

然后是玄学的get:

int get(int l, int r)
{
    static SegTree::Node pnode[2], node;
    pnode[0].cnt = pnode[1].cnt = 0;
    pnode[0].c[0] = pnode[0].c[1] = pnode[1].c[0] = pnode[1].c[1] = -1;
    //初始化区间。由于离开线段树,区间的l,r,t就没意义了,于是忽略。
    bool left = true;//我们要保证这个函数的l, r对应的是正确的临时区间
    //当然也可以不用标志,直接交换node,不过常数较大
    while (top[l] != top[r])
    {
        if (depth[top[l]] < depth[top[r]])
            l ^= r ^= l ^= r, left = !left;//交换l,r同时也要交换标志。
        pnode[left] = SegTree::get(dfn[top[l]], dfn[l]) + pnode[left];
        //链靠上(浅),在线段树靠左
        l = fa[top[l]];
    }
    if (depth[l] > depth[r])
    {
        //std::swap(pnode[0], pnode[1]);//不要学我,这就是没想清楚就动键盘的下场
        l ^= r ^= l ^= r;
        left = !left;
    }
    //这时候我们知道了l对应的区间靠上(浅),r对应的区间靠下(深)
    std::swap(pnode[left].c[0], pnode[left].c[1]);//注意合并过程,默认右边颜色为-1
    //由于要理成一条完整的链,我们让浅的临时节点翻转,使其左端颜色为-1

    node = SegTree::get(dfn[l], dfn[r]);
    node = pnode[left] + node;//[-1  浅][浅  深]
    return (node + pnode[!left]).cnt;//[-1  深][深  -1]
}

完整代码:

#include <cstdio>
#include <cctype>
#include <cassert>
#include <algorithm>

struct Read
{
    //...
} read;

const int N = 100000 + 10;

int n, origin[N], dfn[N], adfn[N], size[N], top[N], love[N], depth[N], fa[N];

namespace SegTree
{
struct Node
{
    int l, r, c[2], cnt;
    bool t;
    inline Node operator+(const Node &that) const
    {
        Node ret;
        ret.l = l, ret.r = that.r;
        ret.cnt = cnt + that.cnt - (c[1] == that.c[0]);
        ret.c[0] = c[0], ret.c[1] = that.c[1];
        ret.t = 0;
        return ret;
    }
    inline int mid(void) { return (l + r) >> 1; }
    inline void assign(int color) { t = cnt = 1, c[0] = c[1] = color; }
} nd[4 * N];
#define LV (p << 1)
#define RV (LV + 1)
#define P (nd[p])
#define L (nd[LV])
#define R (nd[RV])
inline void build(int p, int l, int r)
{
    if (l == r)
    {
        P.l = l, P.r = r;
        P.c[0] = P.c[1] = origin[adfn[l]];
        P.cnt = 1;
        return;
    }
    int mid = (l + r) >> 1;
    build(LV, l, mid);
    build(RV, mid + 1, r);
    P = L + R;
}
inline void pushdown(int p)
{
    if (P.t)
    {
        P.t = false;
        L.assign(P.c[0]);
        R.assign(P.c[1]);
    }
}
int l, r, c;
inline Node get(int p)
{
    if (l <= P.l && P.r <= r)
        return P;
    pushdown(p);
    int mid = P.mid();
    if (l <= mid && mid < r)
        return get(LV) + get(RV);
    if (l <= mid)
        return get(LV);
    return get(RV);
}
inline Node get(int l_, int r_)
{
    l = l_, r = r_;
    return get(1);
}
inline void set(int p)
{
    if (l <= P.l && P.r <= r)
    {
        P.assign(c);
        return;
    }
    pushdown(p);
    int mid = P.mid();
    if (l <= mid)
        set(LV);
    if (mid < r)
        set(RV);
    P = L + R;
}

inline void set(int l_, int r_, int c_)
{
    l = l_, r = r_, c = c_;
    set(1);
}
}; // namespace SegTree

struct Link
{
    int p;
    Link *next;
} * head[N], poool[N * 2], *pool = poool;

inline void connect(int u, int v)
{
    pool->p = v;
    pool->next = head[u];
    head[u] = pool++;
}

void dfs1(int p)
{
    size[p] = 1;
    for (Link *now = head[p]; now; now = now->next)
        if (size[now->p] == 0)
        {
            fa[now->p] = p;
            depth[now->p] = depth[p] + 1;
            dfs1(now->p), size[p] += size[now->p];
            if (size[now->p] > size[love[p]])
                love[p] = now->p;
        }
}

void dfs2(int p)
{
    dfn[p] = ++dfn[0];
    adfn[dfn[p]] = p;
    if (top[p] == 0)
        top[p] = p;
    if (love[p] == 0)
        return;
    top[love[p]] = top[p];
    dfs2(love[p]);
    for (Link *now = head[p]; now; now = now->next)
        if (now->p != fa[p] && now->p != love[p])
            dfs2(now->p);
}

void set(int l, int r, int c)
{
    while (top[l] != top[r])
    {
        if (depth[top[l]] < depth[top[r]])
            l ^= r ^= l ^= r;
        SegTree::set(dfn[top[l]], dfn[l], c);
        l = fa[top[l]];
    }
    if (depth[l] > depth[r])
        l ^= r ^= l ^= r;
    SegTree::set(dfn[l], dfn[r], c);
}

int get(int l, int r)
{
    static SegTree::Node pnode[2], node;
    pnode[0].cnt = pnode[1].cnt = 0;
    pnode[0].c[0] = pnode[0].c[1] = pnode[1].c[0] = pnode[1].c[1] = -1;
    bool left = true;
    while (top[l] != top[r])
    {
        if (depth[top[l]] < depth[top[r]])
            l ^= r ^= l ^= r, left = !left;
        pnode[left] = SegTree::get(dfn[top[l]], dfn[l]) + pnode[left];
        l = fa[top[l]];
    }
    if (depth[l] > depth[r])
    {
        //std::swap(pnode[0], pnode[1]);
        l ^= r ^= l ^= r;
        left = !left;
    }
    std::swap(pnode[left].c[0], pnode[left].c[1]);
    node = SegTree::get(dfn[l], dfn[r]);
    node = pnode[left] + node;
    return (node + pnode[!left]).cnt;
}

int main(void)
{
    n = read;
    int m = read;
    for (int i = 1; i <= n; ++i)
        origin[i] = read;
    for (int i = 1; i != n; ++i)
    {
        int x, y;
        scanf("%d %d", &x, &y);
        connect(x, y);
        connect(y, x);
    }
    dfs1(1);
    dfs2(1);
    SegTree::build(1, 1, n);
    while (m--)
    {
        static char buf[4];
        scanf("%s", buf);
        if (*buf == 'C')
        {
            int l = read, r = read, x = read;
            set(l, r, x);
        }
        else
        {
            int l = read, r = read;
            printf("%d\n", get(l, r));
        }
    }
}
点赞

发表评论

电子邮件地址不会被公开。必填项已用 * 标注