开发者

如何从PyTorch中获取过程特征图实例详解

开发者 https://www.devze.com 2023-01-11 09:24 出处:网络 作者: ViperL1
目录一、获取Tensor①类型转换②张量拆解③图像展示总结一、获取Tensor 神经网络在运算过程中实际上是以Tensor为格式进行计算的,我们只需稍稍改动一下forward函数即可从运算过程中抓到Tensor
目录
  • 一、获取Tensor
    • ①类型转换
    • ②张量拆解
    • ③图像展示
  • 总结

    一、获取Tensor

    神经网络在运算过程中实际上是以Tensor为格式进行计算的,我们只需稍稍改动一下forward函数即可从运算过程中抓到Tensor

    代码如下:

    base_feature = self.extractor.forward(x)    #正常的前向传递
    feature=base_feature.detach()               #抓取tensor
    feature_imshow(feature)                     #展示函数(关键代码)

    通过将过程张量赋值给一个临时变量,即可将其从前向传递中分离出来且不影响原来的前向传递函数,这种方法远比复杂的hook函数更实用。

    将Tensor数据取开发者_Python到后到可视化还需要进行以下几步:

    ①类型转换

    如果js网络是在cuda中进行运算,则需要将提取到的tensor转换为cpu类型才能进行接下来的运算

    inp = inp.cpu()        #类型转换

    ②张量拆解

    网络中的张量一般是高维度的,需要对其进行降维,一般降至两php维即可进行显示。这里以Faster R-CNN中的resnet50特征提取网络为例:输出其特征图尺寸为:[1,1024,68,38],可以很明显的看出,第一维实际上是BATch_size,在图像显示中不需要,可以直接去除;第二维1024则是网络提取到的特征图张数,故可以对第二python维进行遍历;而第3,4维是特征图的尺寸,直接显示即可。

    inp=inp.squeeze(0)    #除去第一维
     
    for i in range(len(inp)):
        plt.imshow(transforms.ToPILImage()(inp[i]))    #遍历第二维并将其转换为图像

    ③图像展示

    选取你需要的特征图像,进行保存或使用plt进展示

    完整的展示函数如下:

    def feature_imshow(inp, title=None):
        inp = inp.cpu()
        inp=inp.squeeze(0)
        print(inp.shape)
        plt.figure(figsize=(12, 7))
        for i in ra编程客栈nge(len(inp)):
            plt.subplot(4, 5, i+1)    #第一二个参数为图像个数,第三参数为图像位置
            plt.imshow(transforms.ToPILImage()(inp[i]))
            i+=1
        plt.show()
        plt.pause(0.001)

    如何从PyTorch中获取过程特征图实例详解

    总结

    到此这篇关于如何从PyTowww.devze.comrch中获取过程特征图的文章就介绍到这了,更多相关PyTorch获取过程特征图内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

    0

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    关注公众号