PyTorch 的 Tensor 和 Function:一场数学快递员的奇幻冒险

想象一下,PyTorch 是一个超级智能的"数学快递公司",专门负责计算各种数学运算(比如加法、乘法、求平均等)。在这个公司里,有两个关键角色:

  1. Tensor(张量)​:就像快递包裹,里面装着数据(数字),并且可以贴上"需要追踪物流"(requires_grad=True)或"不需要追踪物流"(requires_grad=False)的标签。
  2. Function(函数)​:就像快递员,负责执行运算(比如加法、乘法),并且记录下"这个包裹是怎么来的"(grad_fn)。

1. Tensor:会"追踪物流信息"的快递包裹

​(1)创建 Tensor:默认不追踪物流

import torch

# 创建一个普通的 Tensor(默认不追踪物流)
x = torch.ones(2, 2)  # 2x2 的全1矩阵
print(x.requires_grad)  # False,表示"不需要追踪物流"
print(x.grad_fn)  # None,表示"没有快递员记录它的来源"

解释

  • x 是一个普通的 Tensor,就像一个普通快递包裹,公司不会记录它的物流信息(requires_grad=False),也没有快递员(grad_fn=None)知道它是怎么来的。

(2)设置 requires_grad=True:开始追踪物流

# 创建一个需要追踪物流的 Tensor
x = torch.ones(2, 2, requires_grad=True)  # 现在要追踪它的物流!
print(x.requires_grad)  # True,表示"需要追踪物流"
print(x.grad_fn)  # None,因为它是直接创建的,没有快递员记录

解释

  • x 现在是一个"VIP包裹",公司会记录它的物流信息(requires_grad=True),但因为它是直接创建的(没有经过任何运算),所以没有快递员(grad_fn=None)知道它是怎么来的。

(3)进行运算:快递员登场!​

# 对 x 进行加法运算
y = x + 2  # y = x + 2
print(y)  # tensor([[3., 3.], [3., 3.]])
print(y.grad_fn)  # <AddBackward>,表示"这个包裹是由加法快递员送来的"

解释

  • y 是通过 x + 2 运算得到的,所以公司派了一个"加法快递员"(grad_fn=<AddBackward>)来记录这个运算过程。

(4)叶子节点 vs 非叶子节点

print(x.is_leaf, y.is_leaf)  # True False

解释

  • x 是直接创建的,没有经过任何运算,所以它是"叶子节点"(is_leaf=True)。
  • y 是通过运算得到的,所以它是"非叶子节点"(is_leaf=False),并且有快递员(grad_fn)记录它的来源。

2. Function:快递员的工作记录

每个 Tensor 都有一个 grad_fn 属性,记录它是通过什么运算得到的:

  • 如果是直接创建的(比如 torch.ones()),grad_fn=None(没有快递员)。
  • 如果是通过运算得到的(比如 x + 2),grad_fn 就是执行这个运算的"快递员"(比如 <AddBackward>)。

3. 反向传播:计算梯度

​(1)计算 out 并反向传播

# 继续运算
z = y * y * 3  # z = y² * 3
out = z.mean()  # out = z 的平均值

print(z, out)  # tensor([[27., 27.], [27., 27.]]) tensor(27.)
print(out.grad_fn)  # <MeanBackward1>,表示"这个包裹是由求平均快递员送来的"

解释

  • z 是通过 y * y * 3 运算得到的,所以它的快递员是 <MulBackward>
  • out 是 z 的平均值,所以它的快递员是 <MeanBackward1>

(2)调用 backward() 计算梯度

out.backward()  # 反向传播,计算梯度
print(x.grad)  # tensor([[4.5, 4.5], [4.5, 4.5]])

解释

  • out.backward() 相当于让公司"反向追踪物流",计算 x 的梯度(x.grad)。
  • 最终 x.grad 是 [[4.5, 4.5], [4.5, 4.5]],表示 x 对 out 的贡献有多大。

4. 如何停止追踪物流?

​(1)detach():把包裹从物流系统里移除

# 把 x 从物流系统里移除
x_detached = x.detach()  # 现在 x_detached 不会追踪物流
print(x_detached.requires_grad)  # False

解释

  • detach() 就像把包裹从快递公司的物流系统里移除,之后它不会再记录任何运算。

(2)with torch.no_grad():临时关闭物流追踪

# 临时关闭物流追踪
with torch.no_grad():
    a = x + 2  # 这个运算不会被记录
print(a.requires_grad)  # False

解释

  • with torch.no_grad() 就像进入一个"无物流追踪区",在这个区域里的所有运算都不会被记录。

总结

这样,PyTorch 的 Tensor 和 Function 就像一个高效的"数学快递公司",既能追踪计算过程,又能灵活控制是否需要计算梯度!

Python核心知识唠明白 文章被收录于专栏

想学Python怕被线程池|元组解包劝退?本专栏用打工人打工魂|拆快递|交换奶茶的生活化比喻,把核心知识点讲成唠家常!从线程池原理到元组解包技巧,每篇带代码实战+避坑指南,小白边看边练,无痛掌握。新手入门、老萌新优化代码都适用;学完直接上手批量下载、处理Excel、优化爬虫,Python原来这么简单好玩!

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务