题解 | #二分 K-means子网分割#

二分 K-means子网分割

http://www.nowcoder.com/questionTerminal/308edb1f38464124ba30197b48ebc3cf

import math
# If you need to import additional packages&nbs***bsp;classes, please import here.
N = int(input().strip())  # 期望分割的子网数量,用N表示
M = int(input().strip())  # 全网站点总数,用M表示
net_list = []  # 网络站点坐标列表,用net_list表示,列表中的每个元素是一个二元元组,第一位代表x坐标,第二位代表y坐标
for i in range(M):
    x_axis, y_axis = input().strip().split()
    net_list.append((int(x_axis), int(y_axis)))

def distance(x1, x2):
    d = (x1[0] - x2[0]) ** 2 + (x1[1] - x2[1]) ** 2
    return d

def kmeans(nodes_id):
    nodes = []
    for ii in nodes_id:
        nodes.append(net_list[ii])
    nodes.sort()
    centers = [nodes[0], nodes[-1]]

    group = [[], []]

    step = 0
    error = 1
    while step < 1000 and error > 1e-6:
        group = cal_group(nodes_id, centers)
        centers_new = cal_centers(group)
        error = cal_error(centers, centers_new)
        centers = centers_new
        step += 1
    return group, centers

def cal_error(centers, centers_new):
    error = math.sqrt(distance(centers[0], centers_new[0])) + math.sqrt(distance(centers[1], centers_new[1]))
    return error

def cal_centers(group):
    centers = [[0, 0], [0, 0]]
    if group[0]:
        for ii in group[0]:
            centers[0][0] += net_list[ii][0]
            centers[0][1] += net_list[ii][1]
        centers[0][0] /= len(group[0])
        centers[0][1] /= len(group[0])

    if group[1]:
        for ii in group[1]:
            centers[1][0] += net_list[ii][0]
            centers[1][1] += net_list[ii][1]
        centers[1][0] /= len(group[1])
        centers[1][1] /= len(group[1])
    return centers

def cal_group(nodes_id, centers):
    group = [[], []]
    for ii in nodes_id:
        d0 = distance(net_list[ii], centers[0])
        d1 = distance(net_list[ii], centers[1])
        if d0 > d1:
            group[1].append(ii)
        else:
            group[0].append(ii)
    return group

def bin_kmeans(groups, centers):
    n = len(groups)

    id_min = 0
    SSE_min = float('inf')
    for ii in range(n):
        groups_cur = groups.copy()
        centers_cur = centers.copy()
        group = groups_cur[ii]
        groups_cur.pop(ii)
        centers_cur.pop(ii)
        group_new, centers_new = kmeans(group)
        if group_new[0]:
            groups_cur.append(group_new[0])
            centers_cur.append(centers_new[0])
        if group_new[1]:
            groups_cur.append(group_new[1])
            centers_cur.append(centers_new[1])

        SSE_cur = cal_SSE(groups_cur, centers_cur)

        if SSE_cur < SSE_min:
            SSE_min = SSE_cur
            id_min = ii

    groups_cur = groups.copy()
    centers_cur = centers.copy()
    group = groups_cur[id_min]
    groups_cur.pop(id_min)
    centers_cur.pop(id_min)
    group_new, centers_new = kmeans(group)
    if group_new[0]:
        groups_cur.append(group_new[0])
        centers_cur.append(centers_new[0])
    if group_new[1]:
        groups_cur.append(group_new[1])
        centers_cur.append(centers_new[1])

    return groups_cur, centers_cur

def cal_SSE(groups, centers):
    SSE = 0
    n = len(groups)
    for ii in range(n):
        SSE += dist_in_group(groups[ii], centers[ii])

    return SSE

def dist_in_group(group, center):
    d = 0
    for ii in group:
        d += distance(net_list[ii], center)
    return d

def func():
    # please define the python3 input here. For example: a,b = map(int, input().strip().split())
    nodes_id = list(range(M))

    if N == 1:
        print(M)

    else:
        groups, centers = kmeans(nodes_id)
        n = len(groups)
        nums = []
        for group in groups:
            nums.append(len(group))
        nums.sort(reverse=True)
        re = []
        for num in nums:
            re.append(str(num))
        print(' '.join(re))
        while n < N:
            groups, centers = bin_kmeans(groups, centers)
            n = len(groups)
            nums = []
            for group in groups:
                nums.append(len(group))
            nums.sort(reverse=True)
            re = []
            for num in nums:
                re.append(str(num))
            print(' '.join(re))

if __name__ == "__main__":
    func()

全部评论

相关推荐

StephenZ_:我9月份找的第一段实习也是遇到这种骗子公司了,问他后端有多少人和我说7个正职,进去一看只有一个后端剩下的都是产品前端算法(没错甚至还有算法)。还是某制造业中大厂,我离职的时候还阴阳怪气我
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务