PyTorch中的squeeze和unsqueeze:张量的"瘦身"与"增肥"魔法
想象你有一个张量(Tensor),它就像一包压缩饼干(squeeze)或一盒膨化食品(unsqueeze)。PyTorch的squeeze和unsqueeze操作,就是让张量瘦身或增肥的魔法!
1. 什么是squeeze?(张量的"瘦身术")
squeeze的作用是"挤掉"张量里所有大小为1的维度,就像把一包多层压缩饼干压扁成一包普通饼干。
数学解释:
- 如果张量某个维度的长度=1(比如
shape=[1, 3, 1, 4]),squeeze会直接删除这个维度,变成[3, 4]。 - 如果不指定维度,它会挤掉所有大小为1的维度。
- 如果指定维度(比如
squeeze(0)),则只挤掉第0维(如果它的长度是1)。
类比:
- 原始张量:
[1, 3, 1, 4](像一包4层压缩饼干,其中第0层和第2层只有1片饼干) -
squeeze()后:[3, 4](直接压扁成普通饼干,去掉所有空层) -
squeeze(0)后:[3, 1, 4](只去掉第0层,第2层还是1片饼干)
代码示例:
import torch x = torch.randn(1, 3, 1, 4) # shape=[1, 3, 1, 4] y = x.squeeze() # shape=[3, 4](挤掉所有大小为1的维度) z = x.squeeze(0) # shape=[3, 1, 4](只挤掉第0维)
2. 什么是unsqueeze?(张量的"增肥术")
unsqueeze的作用是"增加一个大小为1的维度",就像把一包普通饼干包装进一个新盒子,变成多层压缩饼干。
数学解释:
unsqueeze(dim)会在指定维度dim的位置插入一个长度=1的维度。- 比如
shape=[3, 4]→unsqueeze(0)→shape=[1, 3, 4](在第0维外面套一个盒子)。
类比:
- 原始张量:
[3, 4](普通饼干) -
unsqueeze(0)后:[1, 3, 4](套一个新盒子,变成1层压缩饼干) -
unsqueeze(1)后:[3, 1, 4](在第1维外面套一个盒子,变成3×1×4的膨化饼干)
代码示例:
import torch x = torch.randn(3, 4) # shape=[3, 4] y = x.unsqueeze(0) # shape=[1, 3, 4](在第0维外面套一个盒子) z = x.unsqueeze(1) # shape=[3, 1, 4](在第1维外面套一个盒子)
3. 什么时候用squeeze和unsqueeze?
✅ squeeze的典型场景:
- 神经网络输入/输出调整:比如模型输出
[1, 10](batch=1),但你想去掉batch维度,变成[10]。 - 矩阵运算前对齐维度:比如
[3, 1]和[3, 4]相乘会报错,但squeeze后[3]和[3, 4]可以广播(Broadcasting)。
✅ unsqueeze的典型场景:
- 增加batch维度:比如
[3, 4]→[1, 3, 4],让数据变成1个样本的batch,方便输入模型。 - 广播(Broadcasting)计算:比如
[3, 4]和[4]不能直接相加,但unsqueeze(0)后[1, 4]可以广播成[3, 4]。
4. 常见错误 & 注意事项
❌ 错误1:试图squeeze一个长度>1的维度
x = torch.randn(2, 3) # shape=[2, 3] y = x.squeeze(0) # 报错!因为第0维长度=2(不是1)
❌ 错误2:unsqueeze的维度超出范围
x = torch.randn(3, 4) # shape=[3, 4] y = x.unsqueeze(3) # 报错!因为当前只有2维,最大只能`unsqueeze(2)`
5. 总结
- x.squeeze(): 挤掉张量x所有大小为1的维度
- x.squeeze(0): 只挤掉张量x指定维度大小为1的维度
- x.unsqueeze(0): 在张量x指定维度插入一个长度为1的维度
- 适用场景:调整维度以匹配神经网络输入、广播计算等。
掌握这两个操作后,你的PyTorch张量维度管理能力会直接起飞!
Python核心知识唠明白 文章被收录于专栏
想学Python怕被线程池|元组解包劝退?本专栏用打工人打工魂|拆快递|交换奶茶的生活化比喻,把核心知识点讲成唠家常!从线程池原理到元组解包技巧,每篇带代码实战+避坑指南,小白边看边练,无痛掌握。新手入门、老萌新优化代码都适用;学完直接上手批量下载、处理Excel、优化爬虫,Python原来这么简单好玩!

三奇智元机器人科技有限公司公司福利 70人发布