torch.cat()
是一个在 PyTorch 中用来拼接张量的函数。它的基本语法如下:
torch.cat(tensors, dim=0)
tensors
是一个张量的列表,你想要将它们拼接在一起。dim
是你想要进行拼接的维度。默认是 0,也就是在第一个维度上进行拼接。举个例子,假设你有两个形状为 (3,)
的一维张量 a
和 b
:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
你可以使用 torch.cat()
将它们拼接在一起,形成一个形状为 (6,)
的张量:
c = torch.cat((a, b))
c现在是一个形状为
(6,)的张量,其值为
[1, 2, 3, 4, 5, 6]`。
如果你在更高的维度上进行拼接,你需要确保除拼接维度外,其他所有维度的大小都是相同的。例如,如果你有两个形状为 (3, 2)
的二维张量 a
和 b
:
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([[7, 8], [9, 10], [11, 12]])
你可以在第一个维度(dim=0
)上拼接它们,形成一个形状为 (6, 2)
的张量:
c = torch.cat((a, b), dim=0)
c现在是一个形状为
(6, 2)` 的张量,其值为:
[[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10],
[11, 12]]
你也可以在第二个维度(dim=1
)上拼接它们,形成一个形状为 (3, 4)
的张量:
c = torch.cat((a, b), dim=1)
在这种情况下,c
是一个形状为 (3, 4)
的张量,其值为:
[[ 1, 2, 7, 8],
[ 3, 4, 9, 10],
[ 5, 6, 11, 12]]