倍增/ST表

树上倍增

利用二进制的思想,二进制的数可以表示出所有十进制的数,令f[i][j]表示i的第图片说明 个祖先,i = f[i][j]表示i跳到第j个祖先上去,假如我们要求i的第k个祖先,我们不需要一步一步的往上走,利用二进制的思想,我们可以一次条多步的往上走,例如要求i的第5个祖先,5的二进制形式是101,即图片说明 那么我们先跳到i的第一个祖先t,然后t跳到t的第4个祖先,也就求到了i的第五个祖先。
树上倍增的应用

1) 求某个节点的第k个祖先
2) 求a,b的最近公共祖先LCA
......
如何求f[i][j]呢,对树根节点跑一边dfs即可

void dfs(int so,int fa)
{
    f[so][0] = fa;
    //h[so] = h[fa] + 1;
    for(int i = 1; i <= 19; i++) {
        f[so][i] = f[f[so][i-1]][i-1];//倍增,so的第2^i个祖先也就是so的第2^(i-1)个祖先的第2^(i-1)的祖先,可以画图理解。
    }
    for(int i = head[so]; ~i; i = edge[i].next){
        int to = edge[i].v;
        if(to == fa) continue;
        len[to] = len[so]+1;
        dfs(to,so);
    } 
}

1) 求某个节点的第k个祖先

int Kth(int u,int k)
{
    int b = 0;
    while(k){
        if(k & 1) u = f[u][i];
        k >>= 1;
        b++;
    }
    return u;
}

2) 求x,y的最近公共祖先LCA
图片说明
求法:首先我们要保证x,y在同一层,假设x在更深的一层,那么我们要让他上升到和y一样的高度,利用二进制的思想跳上去,我们优先往远的地方跳,如果大于了h[y],那么不跳,如果小于等于h[y],就要跳,为什么呢,因为当前祖先的深度都比y深了,而我们是从高处往低处枚举的,之后的祖先也一定比y更深,所有我们先跳上去,去判断当前祖先的祖先与y的深度。到达同一层之后,如果x==y那么直接返回x,否则的后继续寻找最近的公共祖先。看下面代码:

int lca(int x,int y)
{
        //h[x]表示节点x的深度。
    if(h[y] > h[x]) swap(x,y);// 保证x更深
    for(int i = 19; i >= 0; i--){
        if(h[f[x][i]] >= h[y]){
            x = f[x][i];
        }
    } 
    if(x == y) return x;
    for(int i = 19; i >= 0; i--){
        if(f[x][i] != f[y][i]){
            x = f[x][i];
            y = f[y][i];
        }
    }
    return f[x][0];//x的第一个祖先最近。
}

LCA模板题

https://www.luogu.com.cn/problem/P3379

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>
#include <queue>
#include <string.h>
#include <cmath>
#include <bitset>
#define ll long long
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
using namespace std;

//void print(char a[])
//{
//    printf("%d %c\n",strlen(a),a[2]);
//} 
struct edeg{
    int v,next;
}edge[2000005];

int head[1000005],f[1000005][20],h[1000005],cnt;

void add(int u,int v)
{
    edge[cnt].next = head[u];
    edge[cnt].v = v;
    head[u] = cnt++;
}

void dfs(int s,int fa)
{
    f[s][0] = fa;
    for(int i = 1; i <= 19; i++){
        f[s][i] = f[f[s][i - 1]][i - 1];
    }
    h[s] = h[fa] + 1;
    for(int i = head[s]; ~i; i = edge[i].next){
        int to = edge[i].v;
        if(to == fa) continue;
        dfs(to,s);
    } 
}

int lca(int x,int y)
{
    if(h[y] > h[x]) swap(x,y);//保证x更深。
    for(int i = 19; i >= 0; i--){
        if(h[f[x][i]] >= h[y]) x = f[x][i];
    }
    if(x == y) return x;
    for(int i = 19; i >= 0; i--){
        if(f[x][i] != f[y][i]){
            x = f[x][i];
            y = f[y][i];
        }
    }
    return f[y][0];
}

int main()
{
    int n,q,m,s;
    cnt = 0;
    memset(head,-1,sizeof(head));
    scanf("%d%d%d",&n,&m,&s);
    for(int i = 1; i < n; i++){
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(s,0);
    for(int i = 1; i <= m; i++){
        int x,y;
        scanf("%d%d",&x,&y);
        printf("%d\n",lca(x,y));
    }
    return 0;
}

https://ac.nowcoder.com/acm/contest/7009/E
定义数组len[i]为节点i距离根节点的距离。
树中任意两点间的距离:len[a] + len[b] - 2*len[lca(a,b)].
该题保证len[a] < t 并且 len[lca(b,c)] + len[b] < t即牛牛不被挨揍。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>
#include <queue>
#include <string.h>
#include <cmath>
#include <bitset>
#define ll long long
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
using namespace std;

//void print(char a[])
//{
//    printf("%d %c\n",strlen(a),a[2]);
//} 
struct edeg{
    int v,next;
}edge[2000005];

int head[500005],f[500005][20],len[500005],h[500005],cnt,s,ans = 0;
void add(int u,int v)
{
    edge[cnt].next = head[u];
    edge[cnt].v = v;
    head[u] = cnt++;
}

void dfs(int so,int fa)
{
    f[so][0] = fa;
    h[so] = h[fa] + 1;
    for(int i = 1; i <= 19; i++) {
        f[so][i] = f[f[so][i-1]][i-1];
    }
    for(int i = head[so]; ~i; i = edge[i].next){
        int to = edge[i].v;
        if(to == fa) continue;
        len[to] = len[so]+1;
        dfs(to,so);
    } 
}

int lca(int x,int y)
{
    if(h[y] > h[x]) swap(x,y);// 保证x更深
    for(int i = 19; i >= 0; i--){
        if(h[f[x][i]] >= h[y]){
            x = f[x][i];
        }
    } 
    if(x == y) return x;
    for(int i = 19; i >= 0; i--){
        if(f[x][i] != f[y][i]){
            x = f[x][i];
            y = f[y][i];
        }
    }
    return f[x][0];
}

//ll qpow(ll a,ll b)
//{
//    ll res = 1;
//    while(b){
//        if(b & 1) res = res * a % mod;
//        b >>= 1;
//        a = a * a % mod;
//    } 
//    return res;
//}
ll exgcd(int a,int b,ll &x,ll &y){
    if(b == 0){
        x = 1;
        y = 0;
        return a;
    }
    int t = exgcd(b,a%b,y,x);
    y = y - a/b*x;
    return t;
}

ll inv(int a,int mod)
{
    ll x,y;
    int t = exgcd(a,mod,x,y);
    return (x % mod + mod) % mod;
}

int main()
{
    int n,q,m;
    cnt = 0;
    memset(head,-1,sizeof(head));
    scanf("%d",&n);
    for(int i = 1; i < n; i++){
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y); 
        add(y,x);
    }
    scanf("%d",&q);
    int a,b,c,t;
    dfs(1,0);
    for(int i = 1; i <= q; i++){
        scanf("%d%d%d%d",&a,&b,&c,&t);
        int zx = lca(b,c);//t1是b,c的最近公共祖先
        int  l;
        l = len[b] + len[c] - 2 * len[zx] + len[b];
        if(len[a] < t && l < t) ans++;
    }
    ll q1 = inv(1ll*q,1ll*mod);
    ll ans1 = (1ll*ans % mod * q1 % mod);
    printf("%lld", ans1 % mod); 
    return 0;
}

城市网络:倍增

https://ac.nowcoder.com/acm/problem/13331
思路:分析题目所给的信息,题目说每经过一个城市,如果这个城市的珠宝比当前手中的珠宝价值更高,则购入,问购买次数,很明显买完的珠宝从最后一次购买到第一次购买的价值是递减的,也就是说我从一个城市往上走,要进行购买珠宝的城市的价值是递增的。
所有我们只找比当前所在城市珠宝价值大的祖先节点,即h[i][j]表示i城市的第2^j个祖先且价值比i城市大。
再看城市u->v,为什么我们查找的时候只要找到两个城市在同一层即可呢,因为从u->v一定会经过他们的最近公共祖先,而他们的最近公共祖先又一定比v->最近公共祖先里面的所有祖先价值都高,我们不会去购买,所以跳到同一层即可。
https://ac.nowcoder.com/discuss/395376?type=101&order=0&pos=1&page=2&channel=1009&source_id=discuss_tag

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>
#include <queue>
#include <string.h>
#include <cmath>
#include <bitset>
#define ll long long
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
using namespace std;

//void print(char a[])
//{
//    printf("%d %c\n",strlen(a),a[2]);
//} 

int a[500005],en[500005],f[500005][20],h[500005],cnt;
vector<int>edge[500005];
void dfs(int s,int fa)
{
    int x = fa;
    //找到第一个比a[s]大的祖先节点 
    for(int i = 19; i >= 0; i--)
        if(f[x][i] && a[f[x][i]] <= a[s]) x = f[x][i];
    if(a[x] > a[s]) f[s][0] = x;
    else f[s][0] = f[x][0];
    for(int i = 1; i <= 19; i++) f[s][i] = f[f[s][i-1]][i-1];
    h[s] = h[fa] + 1;
    int len = edge[s].size();
    for(int i = 0;i < len; i++){
        int to = edge[s][i];
        if(to == fa) continue;
        dfs(to,s);
    }
}

int main()
{
    int n,q;
    cnt = 0;
    scanf("%d%d",&n,&q);
    for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
    for(int i = 1; i < n; i++){
        int x,y;
        scanf("%d%d",&x,&y);
        edge[x].push_back(y);
        edge[y].push_back(x);
    }
    for(int i = n+1; i <= n+q; i++){
        int x,y,val;
        scanf("%d%d%d",&x,&y,&val);
        edge[i].push_back(x);
        edge[x].push_back(i);
        a[i] = val;
        en[i - n] = y;
    }
    dfs(1,0);
    int ans = 0;
    for(int i = 1; i <= q; i++){
        ans = 0;
        int x = i + n;
        int v = en[i];
        for(int j = 19; j >= 0; j--){
            if(h[f[x][j]] >= h[v]){
                ans += (1 << j);
                x = f[x][j];
            }
        }
        printf("%d\n",ans);
    }
    return 0;
}

Borrow Classroom:LCA

https://ac.nowcoder.com/acm/problem/13813
lca模板题:判断lca(a,c)是否为一,为一说明在len[a] < len[c]+len[b,c],不为一则可以相等。
图片说明

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <stack>
#include <queue>
#include <string.h>
#include <cmath>
#include <bitset>
//#include <set>
#define ll long long
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
using namespace std;

vector<int>a[100001];
int f[100001][20],len[100001],h[100001];
void dfs(int s,int fa)
{
    f[s][0] = fa;
    h[s] = h[fa] + 1;
    for(int i = 1; i <= 17; i++) f[s][i] = f[f[s][i-1]][i-1];
    int len1 = a[s].size();
    for(int i = 0; i < len1; i++){
        int v = a[s][i];
        if(v == fa) continue;
        len[v] = len[s] + 1;
        dfs(v,s);
    }
}

int lca(int x,int y)
{
    if(len[y] > len[x]) swap(x,y);
    for(int i = 17; i >= 0; i--){
        if(len[f[x][i]] >= len[y]) x = f[x][i];
    }
    if(x == y) return x;
    for(int i = 17; i >= 0; i--){
        if(f[x][i] != f[y][i]){
            x = f[x][i];
            y = f[y][i];
        }
    }
    return f[x][0];
}

int main()
{
    //ios::sync_with_stdio(false);
    int t,n,q,x,y;
    cin >> t;
    while(t--){
    cin >> n >> q;
    for(int i = 1; i <= n; i++){
        a[i].clear();
    }
    for(int i = 1; i < n; i++){
        cin >> x >> y;
        a[x].push_back(y);
        a[y].push_back(x);
    }
    len[0] = 0;
    dfs(1,0);
    int A,B,C,t1,t2;
    while(q--){
        t1 = t2 = 0;
        cin >> A >> B >> C;
        int zx = lca(B,C);
        t1 += len[B] + len[C] - 2*len[zx] + len[C];
        t2 = len[A];
        if(t2 < t1) cout << "YES\n";
        else if(t2 == t1 && lca(A,C) != 1) cout << "YES\n";
        else cout << "NO\n";
    }
}
    return 0;
}
全部评论

相关推荐

05-03 12:45
西南大学 Java
nsnzkv:你这项目写的内容太多了,说实话都是在给自己挖坑,就算简历过了,后面面试也难受
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务