拼接: cat, stack …
-
使用 cat 在指定维度 dim 上拼接: torch.cat(element_list, dim)
>>> a = torch.rand(2,3) >>> b = torch.rand(1,3) >>> c = torch.cat([a,b], dim=0) >>> c.shape torch.Size([3, 3])
-
使用 stack 在新增维度 dim 上拼接: torch.cat(element_list, dim),
- 注:element_list 中 element 的 shape 必须完全一致
>>> a = torch.rand(2,3) >>> b = torch.rand(2,3) >>> c = torch.stack([a,b], dim=0) >>> c.shape torch.Size([2, 2, 3])
拆分:split,chunk …
- 使用 split 根据长度拆分:a.split(l, dim)
- 注:长度不一样时:a.split(l_list, dim)
>>> a.split(1, dim=0) # 或 a.split([1,1], dim=0) (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
- 使用 chunk根据数量拆分:a.chunk(n, dim)
>>> a.chunk(2, dim=0) (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
- B站视频参考资料
- 注:长度不一样时:a.split(l_list, dim)
- 使用 split 根据长度拆分:a.split(l, dim)
- 注:element_list 中 element 的 shape 必须完全一致
还没有评论,来说两句吧...