开发者

Pytorch模型的保存/复用/迁移实现代码

开发者 https://www.devze.com 2023-05-06 09:31 出处:网络 作者: 信海
目录模型的保存与复用模型定义和参数打印模型保存模型推理模型再训练模型迁移参考文献本文整理了Pytorch框架下模型的保存、复用、推理、再训练和迁移等实现。
目录
  • 模型的保存与复用
    • 模型定义和参数打印
    • 模型保存
    • 模型推理
    • 模型再训练
    • 模型迁移
  • 参考文献

    本文整理了Pytorch框架下模型的保存、复用、推理、再训练和迁移等实现。

    模型的保存与复用

    模型定义和参数打印

    # 定义模型结构
    class LenNet(nn.Module):
        def __init__(self):
            super(LenNet, self).__init__()
            self.conv = nn.Sequential(  # [BATch, 1, 28, 28]
                nn.Conv2d(1, 8, 5, 2),  # [batch, 1, 28, 28]
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2, 2),  # [batch, 8, 14, 14]
                nn开发者_JAVA学习.Conv2d(8, 16, 5),  # [batch, 16, 10, 10]
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2, 2),  # [batch, 16, 5, 5]
            )
            self.fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(16*5*5, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 64),
                nn.ReLU(inplace=True),
                nn.Linear(64, 10)
            )
        def forward(self, X):
            return self.fc(self.conv(X))
    # 查看模型参数
    # 网络模型中的参数model.state_dict()是以字典形式保存(实质上是collections模块中的OrderedDict)
    model = LenNet()
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())
    # 参数名中的fc和conv前缀是根据定义nn.Sequential()时的名字所确定。
    # 参数名中的数字表示每个Sequential()中网络层所在的位置。
    print(model.state_dict().keys())  # 打印键
    print(model.state_dict().values())  # 打印值
    # 优化器optimizer的参数打印类似
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    print("Optimizer's state_dict:")
    for var_name in optimizer.state_dict():
       print(var_name, "\t", optimizer.state_dict()[var_name])

    模型保存

    import os
    # 指定保存的模型名称时Pytorch官方建议的后缀为.pt或者.pth
    model_save_dir = './model_logs/'
    model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
    torch.save(model.state_dict(), model_save_path)
    # 在训练过程中保存某个条件下的最优模型,可以如下操作
    best_model_state = deepcopy(model.state_dict()) 
    torch.save(best_model_state, model_save_path)
    # 下面这种方法是错误的,因为best_model_state只是model.state_dict()的引用,会随着训练的改变而改变
    best_model_state = model.state_dict() 
    torch.save(best_model_state, model_save_path)

    模型推理

    def inference(data_iter, device, model_save_dir):
    	model = LeNet()  # 初始化现有模型的权重参数
        model.to(device)
        model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
        # 如果本地存在模型,则加载本地模型参数覆http://www.devze.com盖原有模型
        if os.pathandroid.exists(model_save_path): 
            loaded_paras = torch.load(model_save_path)
            model.load_state_dict(loaded_paras)
            model.eval()
        with torch.no_grad():  # 开始推理
            acc_sum, n = 0., 0
            for x, y in data_iter:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                acc_sum += (logits.argmax(1) == y).float().sum().item()
                n += len(y)
            print("Accuracy in test data is : ", acc_sum / n)

    模型再训练

    class MyModel:
        def __init__(self,编程客栈
                     batch_size=64,
                     epochs=5,
                     learning_rate=0.001,
                     model_save_dir='./MODEL'):
            self.batch_size = batch_size
            self.epochs = epochs
            self.learning_rate = learning_rate
            self.model_save_dir = model_save_dir
            self.model = LeNet()
        def train(self):
            train_iter, test_iter = load_dataset(self.batch_size)
            # 在训练过程中只保存网络权重,在再训练时只载入网络权重参数初始化网络训练。这里是核心部分,开始。
            if not os.path.exists(self.model_save_dir):
                os.makedirs(self.model_save_dir)
            model_save_path = os.path.join(self.model_http://www.devze.comsave_dir, 'model.pt')
            if os.path.exists(model_save_path):
                loaded_paras = torch.load(model_save_path)
                self.model.load_state_dict(loaded_paras)
                print("#### 成功载入已有模型,进行再训练...")
            # 结束  
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model.to(device)
            for epoch in range(self.epochs):
                for i, (x, y) in enumerate(train_iter):
                    x, y = x.to(device), y.to(device)
                    loss, logits = self.model(x)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()  
                    if i % 100 == 0:
                        acc = (logits.argmax(1) == y).float().mean()
                        print("Epochs[http://www.devze.com{}/{}]---batch[{}/{}]---acc {:.4}---loss {:.4}".format(
                            epoch, self.epochs, len(train_iter), i, acc, loss.item()))
                print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                                                self.evaLuate(test_iter, self.model, device)))
                torch.save(self.model.state_dict(), model_save_path)
        @staticmethod
        def evaluate(data_iter, model, device):
            with torch.no_grad():
                acc_sum, n = 0.0, 0
                for x, y in data_iter:
                    x, y = x.to(device), y.to(device)
                    logits = model(x)
                    acc_sum += (logits.argmax(1) == y).float().sum().item()
                    n += len(y)
                return acc_sum / n
    # 在保存参数的时候,将优化器参数、损失值等可一同保存,然后在恢复模型时连同其它参数一起恢复
    model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, model_save_path)
    # 加载方式如下
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    模型迁移

    # 定义新模型NewLeNet 和LeNet区别在于新增了一个全连接层
    class NewLenNet(nn.Module):
        def __init__(self):
            super(NewLenNet, self).__init__()
            self.conv = nn.Sequential(  # [batch, 1, 28, 28]
                nn.Conv2d(1, 8, 5, 2),  # [batch, 1, 28, 28]
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2, 2),  # [batch, 8, 14, 14]
                nn.Conv2d(8, 16, 5),  # [batch, 16, 10, 10]
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2, 2),  # [batch, 16, 5, 5]
            )
            self.fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(16*5*5, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 64), # 这层以前和LeNet结构一致 可以用LeNet的参数来进行替换
                nn.ReLU(inplace=True),
                nn.Linear(64, 32),
                nn.ReLU(inplace=True),
                nn.Linear(32, 10)
            )
        def forward(self, X):
            return self.fc(self.conv(X))
    # 定义替换函数 匹配两个网络 size相同处地方进行参数替换
    def para_state_dict(model, model_save_dir):
        state_dict = deepcopy(model.state_dict())
        model_save_path = os.path.join(model_save_dir, 'model.pt')
        if os.path.exists(model_save_path):
            loaded_paras = torch.load(model_save_path)
            for key in state_dict:  # 在新的网络模型中遍历对应参数
                if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
                    print("成功初始化参数:", key)
                    state_dict[key] = loaded_paras[key]
        return state_dict
    # 更新一下模型迁移后的训练代码
    def train(self):
            train_iter, test_iter = load_dataset(self.batch_size)
            if not os.path.exists(self.model_save_dir):
                os.makedirs(self.model_save_dir)
            model_save_path = os.path.join(self.model_save_dir, 'model_new.pt')
            old_model = os.path.join(self.model_save_dir, 'LeNet.pt')
            if os.path.exists(old_model):
                state_dict = para_state_dict(self.model, self.model_save_dir)  # 调用迁移代码 将LeNet的前几层参数迁移到NewLeNet
                self.model.load_state_dict(state_dict)
                print("#### 成功载入已有模型,进行再训练...")
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model.to(device)
            for epoch in range(self.epochs):
                for i, (x, y) in enumerate(train_iter):
                    x, y = x.to(device), y.to(device)
                    loss, logits = self.model(x)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()  
                    if i % 100 == 0:
                        acc = (logits.argmax(1) == y).float().mean()
                        print("Epochs[{}/{}]---batch[{}/{}]---acc {:.4}---loss {:.4}".format(
                            epoch, self.epochs, len(train_iter), i, acc, loss.item()))
                print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                                                self.evaluate(test_iter, self.model, device)))
                torch.save(self.model.state_dict(), model_save_path)
    # 这里更新未进行训练的推理
    def inference(data_iter, device, model_save_dir='./MODEL'):
        model = NewLeNet()  # 初始化现有模型的权重参数
        print("初始化参数 conv.0.bias 为:", model.state_dict()['conv.0.bias'])
        model.to(device)
        state_dict = para_state_dict(model, model_save_dir) # 迁移模型参数
        model.load_state_dict(state_dict)
        model.eval()
        print("载入本地模型重新初始化 conv.0.bias 为:", model.state_dict()['conv.0.bias'])
        with torch.no_grad():
            acc_sum, n = 0.0, 0
            for x, y in data_iter:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                acc_sum += (logits.argmax(1) == y).float().sum().item()
                n += len(y)
            print("Accuracy in test data is :", acc_sum / n)

    参考文献

    [1] https://github.com/moon-hotel/DeepLearningWithMe

    到此这篇关于Pytorch模型的保存/复用/迁移的文章就介绍到这了,更多相关Pytorch模型保存迁移内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

    0

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    关注公众号