开发者

Python Pytorch深度学习之图像分类器

开发者 https://www.devze.com 2022-12-03 13:48 出处:网络 作者: 柚子味的羊
目录一、简介二、数据集三、训练一个图像分类器1、导入package吧2、归一化处理+贴标签吧3、先来康康训练集中的照片吧4、定义一个神经网络吧5、定义一个损失函数和优化器吧6、训练网络吧7、在测试集上测试一下网络吧8
目录
  • 一、简介
  • 二、数据集
  • 三、训练一个图像分类器
    • 1、导入package吧
    • 2、归一化处理+贴标签吧
    • 3、先来康康训练集中的照片吧
    • 4、定义一个神经网络吧
    • 5、定义一个损失函数和优化器吧
    • 6、训练网络吧
    • 7、在测试集上测试一下网络吧
    • 8、分别查看一下训练效果吧
  • 总结

    一、简介

    通常,当处理图像、文本、语音或视频数据时,可以使用标准python将数据加载到numpy数组格式,然后将这个数组转换成torch.*Tensor

    • 对于图像,可以用Pillow,OpenCV
    • 对于语音,可以用scipy,librosa
    • 对于文本,可以直接用Python或Cython基础数据加载模块,或者用NLTK和SpaCy

    特别是对于视觉,Pytorch已经创建了一个叫torchvision的package,该报包含了支持加载类似Imagenet,CIFAR10,MNIST等公共数据集的数据加载模快torchvision.datasets和支持加载图像数据数据转换模块torch.utils.data.DataLoader。这提供了极大地便利,并避免了编写“样板代码”

    二、数据集

    对于本小节,使用CIFAR10数据集,它包含了是个类别:airplane,automobile,bird,cat,deer,dog,frog,horse,ship,truck。CIFAR10中的图像尺寸是33232,也就是RGB的3层颜色通道,每层通道内的尺寸为32*32

    三、训练一个图像分类器

    训练图像分类器的步骤

    • 使用torchvision加载并且归一化CIFAR10的训练和测试数据集
    • 定义一个卷积神经网络
    • 定义一个损失函数
    • 在训练样本数据上训练网络
    • 在测试样本数据上测试网络

    1、导入package吧

    # 使用torchvision,加载并归一化CIFAR10
    import torch
    import torchvision
    import torchvision.transforms as transforms
    

    2、归一化处理+贴标签吧

    # torchvision数据集的输出是范围在[0,1]之间的PILImage,将他们转换成归一化范围为[-1,1]之间的张量Tensor
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
        )
    # 训练集
    trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=False,transform=transform)
    trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
    # 测试集
    testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=transform)
    testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
    classes=("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")
    

    3、先来康康训练集中的照片吧

    # 展示其中的训练照片
    import matplotlib.pyplot as plt
    import numpy as np
    # 定义图片显示的function
    def imshow(img):
        img=img/2+0.5
        npimg=img.numpy()
        plt.imshow(np.transpose(npimg,(1,2,0)))
        plt.show()
    # 得到随机训练图像
    dataiter=iter(trainloader)
    images,labels=dathttp://www.cppcns.comaiter.next()
    # 展示kDqpUntqA图片
    imshow(torchvision.utils.make_grid(images))
    #打印标签labels
    print(' '.join("%5s"%classes[labels[j]] for j in range(4)))
    

    运行结果

    Python Pytorch深度学习之图像分类器

    Python Pytorch深度学习之图像分类器

    注:初学的猿仔们如果Spyder不显示图片,自己设置一下就OK,在Tools——>Preferences中设置如下:

    Python Pytorch深度学习之图像分类器

    4、定义一个神经网络吧

    此处,复制前一节的神经网络(在这里),并修改为3通道的图片(之前定义的是1通道)

    #%%
    # 定义卷积神经网络
    import torch.nn as nn
    import torch.nn.functional as F
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            # 1个输入,6个输出,5*5的卷积
            # 内核
            self.conv1=nn.Conv2d(3,6,5)#定义三个通道
            self.pool=nn.MaxPool2d(2,2)
            self.conv2=nn.Conv2d(6,16,5)
            # 映射函数:线编程客栈性——y=Wx+b
            self.fc1=nn.Linear(16*5*5,120)#输入特征值:16*5*5,输出特征值:120
            self.fc2=nn.Linear(120,84)
            self.fc3=nn.Linear(84,10)
        def forward(self,x):
            x=self.pool(F.relu(self.conv1(x)))
            x=self.pool(F.relu(self.conv2(x)))
            x=x.view(-1,16*5*5)
            x=F.relu(self.fc1(x))
            x=F.relu(self.fc2(x))
            x=self.fc3(x)
            return x
    net=Net()
    

    Tips:在Spyder中可用使用“#%%”得到cell块,之后对每个cell进行运行,快捷键(Ctrl+Enter)——>我太爱用快捷键了,无论是什么能用键盘坚决不用鼠标(是真的懒吧!!!)

    5、定义一个损失函数和优化器吧

    使用分类交叉熵Cross-Entropy做损失函数,动量SGD做优化器

    #%%
    # 定义一个损失函数和优化器
    import torch.optim as optim
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.SGD(net.parameters(), lr=0.001,momentum=0.9)
    

    6、训练网络吧

    此处只需要在数据迭代器上循环输入网络和优化器

    #%%训练网络
    for epoch in range(2):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            #得到输入
            inputs,labels=data
            # 将参数的梯度值置零
            optimizer.zero_grad()
            #反向传播+优化
            outputs=net(inputs)
            loss=criterion(outputs,labels)
            loss.backward()
            optimizer.step()
            #打印数据
            running_loss+=loss.item()
            if i% 2000==1999:
                print('[%d,编程客栈%5d] loss: %.3f'%(epoch+1,i+1,running_loss/2000))#每2000个输出一次
    print('Finished Training')
    

    运行结果

    Python Pytorch深度学习之图像分类器

    7、在测试集上测试一下网络吧

    已经通过训练数据集对网络进行了两次训练,但是我们需要检查是否已经学到了东西。将使用神经网络的输出作为预测的类标来检查网络的预测性能,用样本的真实类标校对,如过预测正确,将样本添加到正确预测的列表中

    #%%
    #在测试集上显示
    outputs=net(images)
    # 输出是预测与十个类的相似程度,与某一个类的近似程度越高,网络就越认为图像是属于这一类别
    # 打印其中最相似类别类标
    _, predictd=torch.max(outputs,1)
    print('Predicted:',' '.join('%5s'% classes[predictd[j]]
                                for j in range(4)))
    

    运行结果

    Python Pytorch深度学习之图像分类器

    把网络放在整个数据集上看看具体表现

    #%% 结果看起来还好55%,看看网络在整个数据集的表现
    correct=0
    total=0
    with torch.no_grad():
        for data in testloader:
            images,labels=data
            outputs=net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted==labels).sum().item()
    print('Accuracy of the network on the 10000 test images:%d %%' % (
        100*correct/total))
    

    运行结果

    Python Pytorch深度学习之图像分类器

    8、分别查看一下训练效果吧

    #%%分类查看
    class_correct=list(0. for i in range(10))
    class_total=list(0. for i in range(10))
    with torch.no_grad():
        for data in testloader:
            images,labels=data
            outputs=net(images)
            _, predictd=torch.max(outputs,1)
            c=(predictd==labels).squeeze()
            for i in range(4):
                label=labels[i]
                class_correct[label]+=c[i].item()
                class_total[label]+=1
                
    for i in range(10):
        print('Accuracy of %5s:%2d %%'% (classes[i],100*class_correct[i]/class_total[i]))
    

    运行结果

    Python Pytorch深度学习之图像分类器

    总结

    本篇文章就到这里了,希望能够给你带来帮助,也希望您能够多多关注我http://www.cppcns.com们的更多内容!

    0

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    关注公众号