结构化剪枝后的分类预测

结构化剪枝后的分类预测

https://www.nowcoder.com/practice/e00ba135dfa24b6ea8c2aa3bb3cdd67f

结构化剪枝后的分类预测

题目分析

给定样本矩阵 )、线性分类器权重矩阵 )和剪枝比例 ,模拟"行剪枝"过程:按 L1 范数从小到大移除 中最不重要的 行及 中对应的列,然后用剪枝后的矩阵做线性变换 + Softmax 预测每个样本的类别。

思路

矩阵剪枝 + 线性分类模拟

按题意逐步实现即可,关键是把每一步的细节处理正确:

  1. 计算剪枝行数 。特判:当 时,强制 (至少剪一行)。
  1. 计算每行 L1 范数:对 的第 行,。按 L1 从小到大排序,移除最小的 行。
  1. 特征对齐 移除的行索引对应 要移除的列索引。保留剩余特征,得到 )和 )。
  1. 线性变换,得到 的得分矩阵。
  1. Stable Softmax 预测:对每个样本的得分向量,先减去最大值再求 ,取 argmax 作为预测类别(相同取最左)。注意:由于 是单调递增函数,减去最大值后 argmax 不变,所以实际上直接对原始得分取 argmax 即可,Softmax 不影响结果。

以样例验证: ,但 ,所以 。三行 L1 范数为 ,移除第 1 行(索引从 0 开始)。剪枝后 保留第 0、2 列, 保留第 0、2 行。计算 后取 argmax,得到 ,与预期一致。

复杂度

  • 时间复杂度:,主要瓶颈是矩阵乘法
  • 空间复杂度:,存储输入矩阵

代码

import sys
import math

def main():
    data = sys.stdin.read().split()
    idx = 0
    n = int(data[idx]); idx += 1
    d = int(data[idx]); idx += 1
    c = int(data[idx]); idx += 1

    X = []
    for i in range(n):
        row = []
        for j in range(d):
            row.append(float(data[idx])); idx += 1
        X.append(row)

    W = []
    for i in range(d):
        row = []
        for j in range(c):
            row.append(float(data[idx])); idx += 1
        W.append(row)

    ratio = float(data[idx])

    # 计算剪枝行数
    k = int(math.floor(ratio * d))
    if ratio > 0 and k == 0:
        k = 1

    # 按 L1 范数排序,找出要移除的行
    l1 = [(sum(abs(W[i][j]) for j in range(c)), i) for i in range(d)]
    l1.sort()
    removed = set(l1[i][1] for i in range(k))
    kept = [i for i in range(d) if i not in removed]

    # 剪枝:保留对应行/列
    W_p = [W[i] for i in kept]
    X_p = [[X[i][j] for j in kept] for i in range(n)]

    # 矩阵乘法 + argmax
    d_new = len(kept)
    res = []
    for i in range(n):
        scores = []
        for j in range(c):
            s = 0.0
            for t in range(d_new):
                s += X_p[i][t] * W_p[t][j]
            scores.append(s)
        # argmax(取最左)
        best = 0
        for j in range(1, c):
            if scores[j] > scores[best]:
                best = j
        res.append(str(best))

    print(' '.join(res))

if __name__ == '__main__':
    main()
全部评论

相关推荐

有很多问题,求大佬们解答,谢谢大佬们:不知道现在该怎么投实习,该怎么准备内心很纠结学校课程和实习到底怎么选择, 自己也不想课程学业这边出问题, 是不是只能投暑期实习,具体时间该怎么安排前端面试也需要准备算法么, 自己的算法能力很薄弱, 面试题需要准备到什么程度?没有ai项目经验的话,我该如何去补充,如何去找好的ai项目
smile丶snow:1.简历尽量一页,比如教育经历那里,全日制,计算机学院这些可以去掉没啥用好浪费空间。 熟悉三件套就没必要写了吧。js基本上是这样写 * JavaScript核心:深入理解 JS 运行机制(事件循环 Event Loop、微任务/宏任务),熟练掌握 Promise/Async 异步编程 模型。 熟悉可以改成熟练掌握。组件库写一个ant感觉就行,多写了浪费空间。 旅游项目是不是jonas的natours啊,我之前简历也有这个。我之前是这样写的 全栈思维: 熟悉 Node.js/Express 后端架构,掌握 MongoDB 数据库设计与聚合查询 工程化我觉得还是少些吧,不写就问的少,如果你真的了解的话可以写。 1.实习的话推荐大厂官网和aoob上面投,我自己有写一个校招网站的小网站可以直达~github主页上面有,顺便求个关注( 2.大三下一般课程比较少了吧,如果学校比较严的话可以多沉淀一会,如果不太严可以请dai课然后去实习,尽量找个近一些的就行。暑期实习不是暑假才实习哦,基本是上3月底4月初发offer就可以过去了,然后大概暑假的时候走转正流程答辩。 3.大厂算法题+js手写体。hot100+常见的比如数组转树,Promise.all,deepClone,之类 js手写都不难其实。算法看自己能力吧,我其实算法能力也不行。 4.自己平时没有用AI Coding吗?自己想一下怎么让AI帮你更好的写代码~比如Skill的诞生,OpenSpec的诞生,不都是我们想让AI更好帮我们写代码吗。
我的实习日记
点赞 评论 收藏
分享
求你们别卷了的大学生...:你不骂他,我就要骂你了
今天你投了哪些公司?
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务