开发者

PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

开发者 https://www.devze.com 2023-02-03 09:18 出处:网络 作者: gy笨瓜
目录torch.nn.CrossEntropyLoss交叉熵损失torch.nnCrossEntropyLosss使用的公式sampletorch.nn.CrossEntropyLoss使用流程测试torch.nn.CrossEntropyLoss的reduction参数为默认值"mean"交叉熵损失nn.CrossE
目录
  • torch.nn.CrossEntropyLoss交叉熵损失
    • torch.nnCrossEntropyLosss使用的公式
    • sample
    • torch.nn.CrossEntropyLoss使用流程
    • 测试torch.nn.CrossEntropyLoss的reduction参数为默认值"mean"
  • 交叉熵损失nn.CrossEntropyLoss()的真正计算过程
    • 正确公式
    • 综上,可得两个结论
  • 总结

    torch.nn.CrossEntropyLoss交叉熵损失

    本文只考虑基本情况,未考虑加权。

    torch.nnCrossEntropyLosss使用的公式

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    目标类别采用one-hot编码

    其中,class表示当前样本类别在one-hot编码中对应的索引(从0开始),

    x[j]表示预测函数的第j个输出

    公式(1)表示先对预测函数使用softmax计算每个类别的概率,再使用log(以e为底)计算后的相反数表示当前类别的损失,只表示其中一个样本的损失计算方式,非全部样本。

    每个样本使用one-hot编码表示所属类别时,只有一项为1,因此与基本的交叉熵损失函数相比,省略了其它值为0的项,只剩(1)所表示的项。

    sample

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    torch.nn.CrossEntropyLoss使用流程

    torch.nn.CrossEntropyLoss为一个类,并非单独一个函数,使用到的相关简单参数会在使用中说明,并非对所有参数进行说明。

    首先创建类对象

    In [1]: import torch
    In [2]: import torch.nn as nn
    In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")

    参数reduction默认为"mean",表示对所有样本的loss取均值,最终返回只有一个值

    参数reduction取"none",表示保留每一个样本的loss

    计算损失

    In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
    In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
    In [6]: loss_value = loss_function(pred, class_index)
    In [7]: loss_value
    Out[7]: tensor([1.5210, 0.6247]) # 与上述【sample】计算一致

    实际计算损失值调用函数时,传入pred预测值与class_index类别索引

    在传入每个类别时,class_index应为一维,长度为样本个数,每个元素表示对应样本的类别索引,非one-hot编码方式传入

    测试torch.nn.CrossEntropyLoss的reduction参数为默认值"mean"

    In [1]: import torch
    In [2]: import torch.nn as nn
    In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
    In [4]: pred = torc编程客栈h.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
    In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
    In [6]: loss_value = loss_function(pred, class_index)
    In [7]: loss_value
    Out[7]: 1.073 # 与上述【sample】计算一致

    交叉熵损失nn.CrossEntropyLoss()的真正计算过程

    对于多分类损失函数Cross Entropy Loss,就不过多的解释,网上的博客不计其数。在这里,讲讲对于CE Loss的一些真正的理解。

    首先大部分博客给出的公式如下:

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    其中p为真实标签值,q为预测值。

    在低维复现此公式,结果如下。在此强调一点,pytorch中CE Loss并不会将输入的target映射为one-hot编码格式,而是直接取下标进行计算。

    import torch
    import torch.nn as nn
    import math
    import numpy as np
    
    #官方的实现
    entroy=nn.CrossEntropyLoss()
    input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
    target = torch.tensor([0,1,2])
    output = entroy(input, target)
    print(output)
    #输出 tensor(1.1142)
    
    #自己实现
    input=np.array(input)
    target = np.array(target)
    def cross_entorpy(input, target):
        output = 0
        length = len(t开发者_Go学习arget)
        for i in range(length):
            hou = 0
            for j in input[i]:
                hou += np.log(input[i][target[i]])
            output += -hou
        return np.around(output / length, 4)
    print(cross_entorpy(input, target))
    #输出 3.8162
    

    我们按照官方给的CE Lo编程客栈ss和根据公式得到的答案并不相同,说明公式是有问题的。

    正确公式

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    实现代码如下

    import torphpch
    import torch.nn as nn
    import math
    import numpy as np
    
    entroy=nn.CrossEntropyLoss()
    input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
    target = torch.tensor([0,1,2])
    output = entroy(input, target)
    print(output)
    #输出 tensor(1.1142)
    #%%
    input=np.array(input)
    target = np.array(target)
    def cross_entorpy(input, target):
        output = 0
        length = len(target)
        for i in range(length):
            hou = 0
            for j in input[i]:
                hou += np.exp(j)
            output += -input[i][target[i]] + np.log(hou)
        return np.around(output / length, 4)
    print(cross_entorpy(input, target))
    #输出 1.1142
    

    对比自己实现的公式和android官方给出的结果,可以验证公式的正确性。

    观察公式可以发现其实nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。

    nn.logSoftmax(),公式如下

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    nn.NLLLoss(),公式如下

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    将nn.logSoftmax()作为变量带入nn.NLLLoss()可得

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    因为

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    可看做一个常量,故上式可化简为:

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    对比nn.Cross Entropy Loss公式,结果显而易见。

    验证代码如下。

    import torch
    import torch.nn as nn
    import math
    import numpy as np
    
    entroy=nn.CrossEntropyLoss()
    input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
    target = torch.tensor([0,1,2])
    output = entroy(input, target)
    print(output)
    # 输出为tensor(1.1142)
    m = nn.LogSoftmax()
    loss = nn.NLLLoss()
    input=m(input)
    output = loss(input, target)
    print(output)
    # 输出为tensor(1.1OMcSfiF142)
    

    综上,可得两个结论

    1.nn.Cross Entropy Loss的公式。

    PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失

    2.nn.Cross Entropy Loss为nn.logSoftmax()和nn.NLLLoss()的整合版本。

    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

    0

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    关注公众号