在这个页面中,我们将提供一些 PyTorch 的样例代码,帮助您更好地理解和应用 PyTorch 进行机器学习。
样例列表
以下是我们在 PyTorch 教程中的一些样例代码:
线性回归
线性回归是最简单的机器学习算法之一,用于预测一个连续值。
import torch
import torch.nn as nn
# 定义模型
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = LinearRegression()
# 输入数据
x = torch.tensor([[1], [2], [3], [4], [5]])
y = torch.tensor([[2], [4], [6], [8], [10]])
# 训练模型
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
print('模型训练完成')
逻辑回归
逻辑回归用于分类问题,例如二分类。
import torch
import torch.nn as nn
# 定义模型
class LogisticRegression(nn.Module):
def __init__(self):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.linear(x)
return self.sigmoid(x)
# 创建模型实例
model = LogisticRegression()
# 输入数据
x = torch.tensor([[1], [2], [3], [4], [5]])
y = torch.tensor([[0], [0], [1], [1], [1]])
# 训练模型
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
print('模型训练完成')
神经网络
神经网络可以用于更复杂的任务,例如图像识别。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义模型
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建模型实例
model = NeuralNetwork()
# 输入数据
x = torch.randn(1, 784)
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, torch.tensor([0]))
loss.backward()
optimizer.step()
print('模型训练完成')
更多资源
如果您想了解更多关于 PyTorch 的内容,请访问我们的 PyTorch 教程页面。