开发者

Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

开发者 https://www.devze.com 2023-02-15 09:24 出处:网络 作者: cv_lhp
目录一. torch.squeeze()函数解析1. 官网链接2. torch.squeeze()函数解析3. 代码举例3.1 输入size=(2, 1, 2, 1, 2)的张量3.2 把x中维度大小为1的所有维都已删除3.3 把x中第一维删除,但是第一维大小为2,不为1,因此
目录
  • 一. torch.squeeze()函数解析
    • 1. 官网链接
    • 2. torch.squeeze()函数解析
    • 3. 代码举例
      • 3.1 输入size=(2, 1, 2, 1, 2)的张量
      • 3.2 把x中维度大小为1的所有维都已删除
      • 3.3 把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉
      • 3.4 把x中第二维删除,因为第二维大小是1,因此可以删掉
      • 3.5 把编程客栈x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉
  • 二.torch.unsqueeze()函数解析
    • 1. 官网链接
      • 2. torch.unsqueeze()函数解析
        • 3. 代码举例
        • 总结

          一. torch.squeeze()函数解析

          1. 官网链接

          torch.squeeze(),如下图所示:

          Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

          2. torch.squeeze()函数解析

          torch.squeeze(input, dim=None, out=None) 

          squeeze()函数的功能是维度压缩。返回一个tensor(张量),其中 input 中维度大小为1的所有维都已删除。

          举个例子:如果 input 的形状为 (A×1×B×C×1×D),那么返回的tensor的形状则为 (A×B×C×D)

          当给定 dim 时,那么只在给定的维度(dimension)上进行压缩操作,注意给定的维度大小必须是1,否则不能进行压缩。

          举个例python子:如果 input 的形状为 (A×1×B),squeeze(input, dim=0)后,返回的tensor不变,因为第0维的大小为A,不是1;squeeze(input, 1)后,返回的tensor将被压缩为 (A×B)。

          3. 代码举例

          3.1 输入size=(2, 1, 2, 1, 2)的张量

          x = torch.randn(size=(2, 1, 2, 1, 2))
          x.shape
          

          输出结果如下:

          torch.Size([2, 1, 2, 1, 2])

          3.2 把x中维度大小为1的所有维都已删除

          y = torch.squeeze(x)#表示把x中维度大小为1的所有维都已删除
          y.shape
          

          输出结果如下:

          torch.Size([2, 2, 2])

          3.3 把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉

          y =GSAOWkugk torch.squeeze(x,0)#表示把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉
          y.shape
          

          输出结果如下:

          torch.Size([2, 1, 2, 1, 2])

          3.4 把x中第二维删除,因为第二维大小是1,因此可以删掉

          y = torch.squeeze(x,1)#表示把x中第二维删除,因为第二维大小是1,因此可以删掉
          y.shape
          

          输出结果如下:

          torch.Size([2, 2, 1, 2])

          3.5 把x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉

          y = torch.squeeze(x,dim=-1)#表示把x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉
          y.shape
          

          输出结果如下:

          torch.Size([2, 1, 2, 1, 2])

          二.torch.unsqueeze()函数解析

          1. 官网链接

          torch.unsqueeze(),如下图所示:

          Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

          2. torch.uns开发者_C培训queeze()函数解析

          torch.unsqueeze(input, dim) → Tensor
          

          unsqueeze()函数起升维的作用,参数dim表示在哪个地方加一个维度,注意dim范围在:[-input.dim() - 1, input.dim() + 1]之间,比如输入input是一维,则dim=0时数据为行方向扩,dim=1时为列方编程客栈向扩,再大错误。

          3. 代码举例

          3.1 输入一维张量,在第0维(行)扩展,第0维大小为1

          x = torch.tensor([1, 2, 3, 4])
          y = torch.unsqueeze(x, 0)#在第0维扩展,第0维大小为1
          y,y.shape
          

          输出结果如下:

          (tensor([[1, 2, 3, 4]]), torch.Size([1, 4]))

          3.2 在第1维(列)扩展,第1维大小为1

          y = torch.unsqueeze(x, 1)#在第1维扩展,第1维大小为1
          y,y.shape
          

          输出结果如下:

          (tensor([[1],

                   [2],

                   [3],

                   [4]]),

           torch.Size([4, 1]))

          3.3 在第最后一维(也就是倒数第一维进行)扩展,最后一维大小为1

          y = torch.unsqueeze(x, -1)#在第最后一维扩展,最后一维大小为1
          y,y.shape
          

          输出结果如下:

          (tensor([[1],

                   [2],

                   [3],

                   [4]]),

           torch.Size([4, 1]))

          总结

          到此这篇关于Pytorch中torch.unsqueeze()与torch.squeeze()函数的文章就介绍到这了,更多相关Pytorch torch.unsqueeze()与torch.squeeze()函数内容请搜索我们以前的文GSAOWkugk章或继续浏览下面的相关文章希望大家以后多多支持我们!

          0

          精彩评论

          暂无评论...
          验证码 换一张
          取 消