PyTorch 的 Tensor 和 Function:一场数学快递员的奇幻冒险
想象一下,PyTorch 是一个超级智能的"数学快递公司",专门负责计算各种数学运算(比如加法、乘法、求平均等)。在这个公司里,有两个关键角色:
- Tensor(张量):就像快递包裹,里面装着数据(数字),并且可以贴上"需要追踪物流"(
requires_grad=True
)或"不需要追踪物流"(requires_grad=False
)的标签。 - Function(函数):就像快递员,负责执行运算(比如加法、乘法),并且记录下"这个包裹是怎么来的"(
grad_fn
)。
1. Tensor:会"追踪物流信息"的快递包裹
(1)创建 Tensor:默认不追踪物流
import torch # 创建一个普通的 Tensor(默认不追踪物流) x = torch.ones(2, 2) # 2x2 的全1矩阵 print(x.requires_grad) # False,表示"不需要追踪物流" print(x.grad_fn) # None,表示"没有快递员记录它的来源"
解释:
x
是一个普通的 Tensor,就像一个普通快递包裹,公司不会记录它的物流信息(requires_grad=False
),也没有快递员(grad_fn=None
)知道它是怎么来的。
(2)设置 requires_grad=True
:开始追踪物流
# 创建一个需要追踪物流的 Tensor x = torch.ones(2, 2, requires_grad=True) # 现在要追踪它的物流! print(x.requires_grad) # True,表示"需要追踪物流" print(x.grad_fn) # None,因为它是直接创建的,没有快递员记录
解释:
x
现在是一个"VIP包裹",公司会记录它的物流信息(requires_grad=True
),但因为它是直接创建的(没有经过任何运算),所以没有快递员(grad_fn=None
)知道它是怎么来的。
(3)进行运算:快递员登场!
# 对 x 进行加法运算 y = x + 2 # y = x + 2 print(y) # tensor([[3., 3.], [3., 3.]]) print(y.grad_fn) # <AddBackward>,表示"这个包裹是由加法快递员送来的"
解释:
y
是通过x + 2
运算得到的,所以公司派了一个"加法快递员"(grad_fn=<AddBackward>
)来记录这个运算过程。
(4)叶子节点 vs 非叶子节点
print(x.is_leaf, y.is_leaf) # True False
解释:
x
是直接创建的,没有经过任何运算,所以它是"叶子节点"(is_leaf=True
)。y
是通过运算得到的,所以它是"非叶子节点"(is_leaf=False
),并且有快递员(grad_fn
)记录它的来源。
2. Function:快递员的工作记录
每个 Tensor 都有一个 grad_fn
属性,记录它是通过什么运算得到的:
- 如果是直接创建的(比如
torch.ones()
),grad_fn=None
(没有快递员)。 - 如果是通过运算得到的(比如
x + 2
),grad_fn
就是执行这个运算的"快递员"(比如<AddBackward>
)。
3. 反向传播:计算梯度
(1)计算 out
并反向传播
# 继续运算 z = y * y * 3 # z = y² * 3 out = z.mean() # out = z 的平均值 print(z, out) # tensor([[27., 27.], [27., 27.]]) tensor(27.) print(out.grad_fn) # <MeanBackward1>,表示"这个包裹是由求平均快递员送来的"
解释:
z
是通过y * y * 3
运算得到的,所以它的快递员是<MulBackward>
。out
是z
的平均值,所以它的快递员是<MeanBackward1>
。
(2)调用 backward()
计算梯度
out.backward() # 反向传播,计算梯度 print(x.grad) # tensor([[4.5, 4.5], [4.5, 4.5]])
解释:
out.backward()
相当于让公司"反向追踪物流",计算x
的梯度(x.grad
)。- 最终
x.grad
是[[4.5, 4.5], [4.5, 4.5]]
,表示x
对out
的贡献有多大。
4. 如何停止追踪物流?
(1)detach()
:把包裹从物流系统里移除
# 把 x 从物流系统里移除 x_detached = x.detach() # 现在 x_detached 不会追踪物流 print(x_detached.requires_grad) # False
解释:
detach()
就像把包裹从快递公司的物流系统里移除,之后它不会再记录任何运算。
(2)with torch.no_grad()
:临时关闭物流追踪
# 临时关闭物流追踪 with torch.no_grad(): a = x + 2 # 这个运算不会被记录 print(a.requires_grad) # False
解释:
with torch.no_grad()
就像进入一个"无物流追踪区",在这个区域里的所有运算都不会被记录。
总结
这样,PyTorch 的 Tensor 和 Function 就像一个高效的"数学快递公司",既能追踪计算过程,又能灵活控制是否需要计算梯度!
Python核心知识唠明白 文章被收录于专栏
想学Python怕被线程池|元组解包劝退?本专栏用打工人打工魂|拆快递|交换奶茶的生活化比喻,把核心知识点讲成唠家常!从线程池原理到元组解包技巧,每篇带代码实战+避坑指南,小白边看边练,无痛掌握。新手入门、老萌新优化代码都适用;学完直接上手批量下载、处理Excel、优化爬虫,Python原来这么简单好玩!