首页 > 试题广场 >

结构化剪枝后的分类预测

[编程题]结构化剪枝后的分类预测
  • 热度指数:50 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解

在终端设备上部署模型时,常需要先压缩网络规模。现给定一批样本矩阵 X(n 行 d 列)、一层线性分类器的权重矩阵 W(d 行 c 列),以及剪枝比例 ratio。请对 W 进行“按行剪枝”(即移除整行,对应丢弃一个输入特征),然后用剪枝后的模型对每个样本做预测,输出每行样本的预测类别索引(从 0 开始)。

任务要求

  • 剪枝指标:对 W 的每一行计算 L1 范数(该行各元素绝对值之和)。L1 越小,越不重要。
  • 剪枝行数:k = floor(ratio × d)。若 ratio > 0 且 floor(ratio × d) = 0,则令 k = 1(至少剪 1 行)。
  • 剪枝规则:移除 L1 范数最小的 k 行,得到新权重 W'(形状为 (d−k) × c)。
  • 特征对齐:将 X 中与被移除行同索引的列一并删除,得到 X'(形状为 n × (d−k))。
  • 线性输出:h = X' × W',得到大小为 n × c 的分数矩阵。
  • 稳定 Softmax:对 h 的每一行 i,先减去该行最大值,再做 softmax,得到概率分布 y_i。softmax 仅用于说明稳定做法;最终类别索引与直接对 h 行取最大位置相同。
  • 预测结果:对每行取 argmax(若有并列则取最左的列索引),输出为一行,用空格分隔各样本类别索引。

输入描述:
  • 第一行:三个整数 n d c。
  • 接着 n 行:每行 d 个浮点数,构成矩阵 X。
  • 接着 d 行:每行 c 个浮点数,构成矩阵 W。
  • 最后一行:一个浮点数 ratio(0 <= ratio <= 1.0)。


输出描述:
  • 一行,输出 n 个整数,空格分隔,为每个样本的预测类别索引。
示例1

输入

3 3 2
1 0 0
0 1 0
0 0 1
2 1
0 -1
-2 3
0.33

输出

0 0 1

说明

d=3,ratio=0.33 → floor(0.33×3)=0,但 ratio>0,因此 k=1;
W 的行 L1:row1=3,row2=1,row3=5 → 移除 row2;
删除 X 的第 2 列,得到 X';W 删除对应行得到 W';
计算 h=X'W' 后逐行取最大位置,得到预测 0 0 1。

备注:

注意

  • ratio=0 表示不剪枝;
  • 当 ratio>0 但 floor(ratio × d)=0 时,按要求仍需剪 1 行;
  • 若多类别分数相等,取索引更小的类别。

这道题你会答吗?花几分钟告诉大家答案吧!