题解 | #标签在前K个近邻中的出现次数#

标签在前K个近邻中的出现次数

http://www.nowcoder.com/questionTerminal/0665b50ee0bb46f488a5ac329b033333

常规的KNN分类, 只是最后有相同频次的标签需要额外处理
class KNN:
    def __init__(self, K, M, N, S, sample, label):
        self.K = K
        self.M = M
        self.N = N
        self.S = S
        self.label = label
        self.sample = sample
        return

    def fit(self, X):
        distances = []
        for ii in range(self.M):
            distances.append((self.euclidean_distance(X, self.sample[ii]), ii))
        distances.sort()
        distances = distances[0:self.K]

        count_label = [0 for _ in range(self.S)]
        for ii in range(self.K):
            count_label[self.label[distances[ii][1]]] += 1

        max = 0
        arg_max = 0
        for ii in range(self.S):
            if count_label[ii] > max:
                max = count_label[ii]
                arg_max = ii

        arg_maxs = [arg_max]
        for ii in range(self.S):
            if count_label[ii] == max:
                arg_maxs.append(ii)

        if len(arg_maxs) == 1:
            return arg_maxs[0], count_label[arg_maxs[0]]
        else:
            dist_in_argmax = []
            for ii in range(self.K):
                if self.label[distances[ii][1]] in arg_maxs:
                    dist_in_argmax.append(distances[ii])
            # dist_in_argmax.sort()
            # return self.label[dist_in_argmax[0][1]], count_label[self.label[dist_in_argmax[0][1]]]
            min_dist = dist_in_argmax[0][0]
            arg_min_dist = dist_in_argmax[0][1]
            for dist in dist_in_argmax:
                if dist[0] < min_dist:
                    min_dist = dist[0]
                    arg_min_dist = dist[1]
            return self.label[arg_min_dist], count_label[self.label[arg_min_dist]]


    def euclidean_distance(self, x1, x2):
        distant = 0.0
        for ii in range(self.N):
            distant += (x1[ii] - x2[ii]) ** 2
        distant = distant ** 0.5
        return distant

if __name__ == "__main__":
    k, m, n, s = map(int, input().split())
    sample_need_classify = list(map(float, input().split()))
    sample_read = []
    label_read = []
    for _ in range(m):
        read_line = list(map(float, input().split()))
        sample_read.append(read_line[0:n])
        label_read.append(int(read_line[n]))

    knn = KNN(k, m, n, s, sample_read, label_read)
    argmax, countlabel = knn.fit(sample_need_classify)
    print(argmax, countlabel)


全部评论

相关推荐

点赞 评论 收藏
分享
11-06 16:50
门头沟学院 Java
用微笑面对困难:word打字比赛二等奖的我,也要来凑合凑合
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务