题解 | Rinne Loves Edges

Rinne Loves Edges

https://www.nowcoder.com/practice/080d6bd6edb043bd855f633f70edde1d

首先注意到这是一棵树,可以看到数据范围里面M = N - 1 ,然后题目要求度为1的节点都不能到达S,度为1的节点就是叶子节点。

所以题目演变为切断叶子节点与父节点之间的路径或者切断叶子节点通向S的路径,我们思考状态和代价,对于一个节点来说,我们可以通过直接切断他与父节点之间的路径来切断他的子树上所有叶子节点通向S的路径,当然我们可以切断他的子树上的一些边来切断,发现高层的点依靠底层的点的值来进行判断,存在依靠关系,所以时候自下而上的dfs遍历树,对每一个节点,达到最优状态,而对于叶子节点,他的最优只能是切断它与父节点的边。

#include <bits/stdc++.h>
#define int long long
using namespace std;
#define endl '\n'
#define pb push_back
#define ull unsigned long long
#define all(a) a.begin(), a.end()
#define vi vector<int>
#define vii vector<vector<int>>
#define fi first
#define se second
#define vs vector<string>
#define eb emplace_back
#define in insert
#define pf push_front
#define sep "================"
#define ios                      \
    ios::sync_with_stdio(false); \
    cin.tie(0);                  \
    cout.tie(0);
const int inf = 2e18 + 9;
const int mod1 = 1e9 + 7;
const int mod2 = 998244353;
typedef pair<int, int> pll;
typedef long double db;
inline void pt(int x)
{
    if (x < 0)
        putchar('-'), x = -x;
    if (x > 9)
        pt(x / 10);
    putchar(x % 10 + '0');
}
inline void print(int x) { pt(x), puts(""); }
inline void print(pll x) { pt(x.fi), putchar(' '), pt(x.se), putchar('\n'); }
inline void print(vi &vec)
{
    for (const auto t : vec)
        pt(t), putchar(' ');
    puts("");
}
inline void print(vector<pll> &vec)
{
    puts(sep);
    for (const auto v : vec)
    {
        print(v);
    }
    puts(sep);
}
inline int gcd(int a, int b) { return b ? gcd(b, a % b) : a; };
inline int qsm(int a, int b, int mod);
inline int lem(int a, int b) { return a * b / (gcd(a, b)); }
vii mul(vii &h1, vii h2)
{
    int a1 = h1.size();
    int a2 = h2[0].size();
    int a3 = h1[0].size();
    vii ans(a1, vi(a2, 0));
    for (int i = 0; i < a1; i++)
    {
        for (int j = 0; j < a2; j++)
        {
            for (int k = 0; k < a3; k++)
            {
                ans[i][j] = (ans[i][j] + (__int128)h1[i][k] * h2[k][j]);
            }
        }
    }
    return ans;
}
inline int read()
{
    int x = 0;
    short f = 1;
    char c = getchar();
    while ((c < '0' || c > '9') && c != '-')
        c = getchar();
    if (c == '-')
        f = -1, c = getchar();
    while (c >= '0' && c <= '9')
        x = x * 10 + c - '0', c = getchar();
    x *= f;
    return x;
}
int per(int n, int m)
{
    if (m == 0)
    {
        return 1;
    }
    int result = 1;
    for (int i = 0; i < m; ++i)
    {
        result *= (n - i);
    }
    return result;
}
void C()
{
    const int N = 200005 + 10;
    int inv[N];
    int jie[N];
    jie[0] = 1;
    for (int i = 1; i < N; i++)
    {
        jie[i] = jie[i - 1] * i % mod1;
    }
    inv[N - 1] = qsm(jie[N - 1], mod1, mod1 - 2);
    for (int i = N - 2; i >= 0; i--)
    {
        inv[i] = inv[i + 1] * (i + 1) % mod1;
    }
}
// inline int comb(int n, int r){if (r < 0 || r > n)return 0;return jie[n] * inv[r] % mod1 * inv[n - r] % mod1;}
int dir[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
int dirx[8] = {-1, -1, -1, 0, 0, 1, 1, 1};
int diry[8] = {-1, 0, 1, -1, 1, -1, 0, 1};
void init()
{
}
void work()
{
    int n, m, s;
    cin >> n >> m >> s;
    vector<vector<pll>> g(n + 1);
    for (int i = 1; i <= m; i++)
    {
        int u, v, w;
        cin >> u >> v >> w;
        g[u].pb({v, w});
        g[v].pb({u, w});
    }
    int ans = 0;
    vi dp(n + 1 , inf) ;
    auto dfs = [&](auto &&dfs , int u , int pa)->void{
        if(g[u][0].fi == pa && g[u].size() == 1)
        {
            dp[u] = g[u][0].se ; 
            return ; 
        }
        int res = 0 ; 
        for(auto &[v , w] : g[u])
        {
            if(v == pa)
            {
                dp[u] = min(dp[u] , w) ;
            }
            else
            {
                dfs(dfs , v , u);
                res += dp[v] ; 
            }
        }
        dp[u] = min(dp[u] , res) ;
    };
    dfs(dfs , s , 0);
    cout << dp[s] << endl ; 
}
signed main()
{
    init();
    ios;
    int t = 1;
    while (t--)
    {
        work();
    }
    return 0;
}
inline int qsm(int a, int b, int mod)
{
    int res = 1;
    while (b)
    {
        if (b & 1)
        {
            res = res * a % mod;
        }
        b >>= 1;
        a = a * a % mod;
    }
    return res % mod;
}

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务