开发者

PyTorch实现线性回归详细过程

开发者 https://www.devze.com 2022-12-12 12:52 出处:网络 作者: 心️升明月
目录一、实现步骤1、准备数据2、设计模型3、构造损失函数和优化器4、训练过程5、结果展示二、参考文献一、实现步骤
目录
  • 一、实现步骤
    • 1、准备数据
    • 2、设计模型
    • 3、构造损失函数和优化器
    • 4、训练过程
    • 5、结果展示
  • 二、参考文献

    一、实现步骤

    1、准备数据

    x_data = torch.tensor([[1.0],[2.0],[3.0]])
    y_data = torch.tensor([[2.0],[4.0],[6.0]])

    2、设计模型

    class LinearModel(torch.nn.Module):
      def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
       
      def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
       
    model = LinearModel() 

    3、构造损失函数和优化器

    criterion = torch.nn.MSELoss(reduction='sum')
    optimhttp://www.cppcns.comizer = torch.optim.SGD(model.parameters(),lr=0.01)

    4、训练过程

    epoch_list = []
    loss_list = []
    w_list = []
    b_list = []
    for epoch in range(1000):
      y_pred = model(x_data)      # 计算预测值
      loss = criterion(y_pred, y_data) # 计算损失
      print(epoch,loss)
     
      epoch_list.append(epoch)
      loss_list.append(loss.data.item())
      w_list.append(model.linear.weight.item())
      b_list.append(model.linear.bias.item())
     
      optimizer.zero_grad()  # 梯度归零
      loss.backward()     # 反向传播
      optimizer.step()    # 更新

    5、结果展示

    展示最终的权重和偏置:

    # 输出权重和偏置
    print('w = ',model.linear.weight.item())
    print('b = ',model.linear.bias.item())

    结果为:

    w =  1.9998501539230347

    b =  0.0003405189490877092

    模型测试:

    # 测试模型
    x_test = torch.tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred = ',y_test.data)
    
    y_pred = tensor([[7.9997]])

    分别绘制损失值随迭代次数变化的二维曲线图和其随权重与偏置变化的三维散点图:

    # 二维曲线图
    plt.phttp://www.cppcns.comlot(epoch_list,loss_list,'b')
    plt编程客栈.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()
    
    # 三维散点图
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(w_list,b_list,loss_list,c='r')
    #设置坐MfJdEM标轴
    ax.set_xlabel('weight')
    ax.set_ylabel('bias')
    ax.set_zlabel('loss')
    plt.show()

    结果如下图所示:

    PyTorch实现线性回归详细过程

    PyTorch实现线性回归详细过程

     到此这篇关于PyTorch实现线性回归详细过程的文章就介绍到这了,更多相关PyTorch线性回归内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

    二、参考文献

    • [1] https://MfJdEMwww.bilibili.com/video/BV1Y7411d7Ys?p=5

    0

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    关注公众号