Pytorch主要使用
发表于|更新于
|浏览量:
PyTorch核心使用指南
PyTorch是一个开源的Python机器学习库,广泛应用于深度学习研究和生产环境。本文将详细介绍PyTorch的核心使用方法。
1. 张量基础
张量(Tensor)是PyTorch中最基本的数据结构,类似于NumPy的ndarray,但支持GPU加速。
1 2 3 4 5 6 7 8 9
| import torch
x = torch.tensor([1, 2, 3]) y = torch.rand(2, 3)
z = x + y m = torch.matmul(y.T, y)
|
2. 自动微分
PyTorch的自动微分系统(autograd)是其核心特性之一。
1 2 3 4
| x = torch.tensor(2.0, requires_grad=True) y = x ** 2 + 3 * x + 1 y.backward() print(x.grad)
|
3. 神经网络构建
使用torch.nn
模块构建神经网络:
1 2 3 4 5 6 7 8 9 10 11 12 13
| import torch.nn as nn import torch.nn.functional as F
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 256) self.fc2 = nn.Linear(256, 10) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)
|
4. 数据加载与预处理
PyTorch提供了Dataset
和DataLoader
来处理数据:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx]
dataset = CustomDataset(X_train, y_train) dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
5. 模型训练流程
完整的训练循环示例:
1 2 3 4 5 6 7 8 9 10 11 12
| model = Net() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss()
for epoch in range(10): for data, target in dataloader: optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() print(f'Epoch {epoch}, Loss: {loss.item()}')
|
6. GPU加速
PyTorch支持CUDA加速:
1 2 3
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) data = data.to(device)
|
7. 模型保存与加载
保存和恢复模型状态:
1 2 3 4 5 6 7
| torch.save(model.state_dict(), 'model.pth')
model = Net() model.load_state_dict(torch.load('model.pth')) model.eval()
|
8. 常见问题
- 梯度爆炸:使用梯度裁剪
torch.nn.utils.clip_grad_norm_
- 过拟合:添加Dropout层
nn.Dropout(p=0.5)
- 内存不足:减小batch size或使用梯度累积
总结
PyTorch提供了灵活高效的深度学习开发环境。通过掌握这些核心功能,您可以快速构建和训练各种神经网络模型。