Day63:深度学习项目实战
在前面的章节中,我们学习了什么是深度学习,如何构建不同的深度学习的神经网络,如何优化网络,优化模型,如何对网络中的参数实现可视化,深度学习的基础知识我们都已经学习,现在需要的就是加强我们对知识的掌握。
这一节,我们开启一个任务图像分类任务的实战练习,同样是我们熟悉的CIFAR-10数据集,但不同的是我们将完整走一遍使用深度学习进行项目实战的流程。
1. 数据获取与分析
CIFAR-10数据集包含10个不同类别的图像,每个类别有6000张32x32像素的彩色图像。我们的目标是训练一个模型,能够准确地对这些图像进行分类。我们首先需要下载数据集,它是torchvision的内置数据集,可以使用代码直接下载:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 加载CIFAR-10数据集
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 类别名称
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
然后,我们其实查看一下数据集,看看数据集到底长什么样子,有没有什么需要处理的,或者了解数据集什么样之后便与我们后续调优:
# 显示一些训练图像
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 随机获取一些训练图像
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 显示图像及其标签
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
从上图可以看出,整体上图像质量不高,属于低分辨率的小图像,对硬件的要求也不会太高。
2. 数据预处理
一般进行图像任务的深度学习之前,我们会考虑将图像进行标准化,同时为了提升模型的鲁棒性与泛化性,防止过拟合,我们也会考虑给数据进行一些翻转、裁剪的增强:
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
在上述代码中,我们定义了用于训练集和测试集的数据预处理操作。训练集使用了数据增强技术,包括随机裁剪和水平翻转,以增加数据的多样性。测试集是我们用于验证模型性能的,因此要保证完整性和标准,不需要进行数据增强,但是要和训练集保持一致的标准化。
3. 模型、损失函数、优化器
在这一步中,我们不再自己定义模型。我们使用前人实践经验出来的好模型,比如预训练的ResNet-18模型,这些模型与预训练参数都在torchvision
库中,可以直接加载。
import torch.nn as nn
import torchvision.mode
剩余60%内容,订阅专栏后可继续查看/也可单篇购买
1. AI爱好者,爱搞事的 2. 想要掌握第二门语言的Javaer或者golanger 3. 决定考计算机领域研究生,给实验室搬砖的uu,强烈建议你花时间学完这个,后续搬砖比较猛 4. 任何对编程感兴趣的,且愿意掌握一门技能的人