PyTorch 拼接与拆分-Tensor基本操作

PyTorch 拼接与拆分-Tensor基本操作

码农世界 2024-06-18 后端 82 次浏览 0个评论

拼接: cat, stack …

  • 使用 cat 在指定维度 dim 上拼接: torch.cat(element_list, dim)

    >>> a = torch.rand(2,3) 
    >>> b = torch.rand(1,3) 
    >>> c = torch.cat([a,b], dim=0) 
    >>> c.shape
    torch.Size([3, 3])
    
  • 使用 stack 在新增维度 dim 上拼接: torch.cat(element_list, dim),

    • 注:element_list 中 element 的 shape 必须完全一致
      >>> a = torch.rand(2,3) 
      >>> b = torch.rand(2,3) 
      >>> c = torch.stack([a,b], dim=0)
      >>> c.shape
      torch.Size([2, 2, 3])
      

      拆分:split,chunk …

      • 使用 split 根据长度拆分:a.split(l, dim)
        • 注:长度不一样时:a.split(l_list, dim)
          >>> a.split(1, dim=0)  # 或 a.split([1,1], dim=0)
          (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
          
        • 使用 chunk根据数量拆分:a.chunk(n, dim)
          >>> a.chunk(2, dim=0) 
          (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
          

          • B站视频参考资料

转载请注明来自码农世界,本文标题:《PyTorch 拼接与拆分-Tensor基本操作》

百度分享代码,如果开启HTTPS请参考李洋个人博客
每一天,每一秒,你所做的决定都会改变你的人生!

发表评论

快捷回复:

评论列表 (暂无评论,82人围观)参与讨论

还没有评论,来说两句吧...

Top