首页 > 试题广场 >

小O的树上加边

[编程题]小O的树上加边
  • 热度指数:131 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
小O有一棵 n 个点组成的树,树上的点从 1n 编号,树上的边是无向的。任意一棵树都是二分图,小O想知道她最多可以给树加多少条边,使得新的图仍然是二分图。

二分图的定义:如果一个图的所有点可以被分成两个集合 AB,使得所有的边都是一端在 A 中,一端在 B 中,那么这个图就是二分图。

输入描述:
第一行输入一个整数 n\ (1 \leq n \leq 10^5 ),表示树上的点数。
此后 n-1 行,第 i 行输入两个整数 u_iv_i\ (1 \leq u_i, v_i \leq n ),表示树上的一条边连接了点 u_i 和点 v_i。保证这些边可以形成一棵树。


输出描述:
在一行上输出一个正整数,表示最多可以给树加多少条边,使得新的图仍然是二分图。

示例1

输入

4
1 2
2 3
3 4

输出

1

说明

可以添加一条边 (1, 4),新的图仍然是二分图。
import sys

# 确保将两个集合区分开来,return (len(A)*len(B) - len(a))

def get_input():
    for line in sys.stdin:
        for word in line.split():
            yield word

tokens = get_input()

graph1 = []
graph2 = []
N = 0
try:
    line1 = next(tokens)
    N = int(line1)
    while 1:
        u = next(tokens)
        v = next(tokens)
        if u<v:
            graph1.append(u)
            graph2.append(v)
        else:
            graph1.append(v)
            graph2.append(u)

except StopIteration:
    pass

# # in case the input is not in order, but need to change u and v both from int to str previously
# combined = zip(graph1, graph2)
# combined = sorted(combined)
# graph1, graph2 = zip(*combined)

group1 = []
group2 = []

for idx in range(len(graph1)):
    if idx == 0:
        group1.append(graph1[idx])
        group2.append(graph2[idx])
    else:
        if graph1[idx] in group1 and graph2[idx] not in group2:
            group2.append(graph2[idx])
        if graph1[idx] in group2 and graph2[idx] not in group1:
            group1.append(graph2[idx])
        if graph2[idx] in group1 and graph1[idx] not in group2:
            group2.append(graph1[idx])
        if graph2[idx] in group2 and graph1[idx] not in group1:
            group1.append(graph1[idx])
        if graph2[idx] not in group2 and graph2[idx] not in group1:
            if graph1[idx] not in group2 and graph1[idx] not in group1:
                group1.append(graph1[idx])
                group2.append(graph1[idx])
G1 = len(group1)
G2 = len(group2)
res = N - G1 - G2

if res==0:
    print(len(group1)*len(group2)-len(graph1))
else:
    if res<=abs(G1-G2):
        print(max(G1,G2)*(min(G1,G2)+res)-len(graph1))
    else:
        print((N-int(N/2))*int(N/2)-len(graph1))
# for debug use
if res<0:
    print('error')
    print(group1)
    print(group2)

   


发表于 2026-04-25 05:44:28 回复(0)