Pytorch中torch.cat()函数



torch.cat() 是一个在 PyTorch 中用来拼接张量的函数。它的基本语法如下:

torch.cat(tensors, dim=0)
  • tensors 是一个张量的列表,你想要将它们拼接在一起。
  • dim 是你想要进行拼接的维度。默认是 0,也就是在第一个维度上进行拼接。

举个例子,假设你有两个形状为 (3,) 的一维张量 a 和 b

a = torch.tensor([1, 2, 3])  
b = torch.tensor([4, 5, 6])

你可以使用 torch.cat() 将它们拼接在一起,形成一个形状为 (6,) 的张量:

c = torch.cat((a, b))

c现在是一个形状为(6,)的张量,其值为[1, 2, 3, 4, 5, 6]`。

如果你在更高的维度上进行拼接,你需要确保除拼接维度外,其他所有维度的大小都是相同的。例如,如果你有两个形状为 (3, 2) 的二维张量 a 和 b

a = torch.tensor([[1, 2], [3, 4], [5, 6]])  
b = torch.tensor([[7, 8], [9, 10], [11, 12]])

你可以在第一个维度(dim=0)上拼接它们,形成一个形状为 (6, 2) 的张量:

c = torch.cat((a, b), dim=0)

c现在是一个形状为(6, 2)` 的张量,其值为:

[[ 1,  2],  
 [ 3,  4],  
 [ 5,  6],  
 [ 7,  8],  
 [ 9, 10],  
 [11, 12]]

你也可以在第二个维度(dim=1)上拼接它们,形成一个形状为 (3, 4) 的张量:

c = torch.cat((a, b), dim=1)

在这种情况下,c 是一个形状为 (3, 4) 的张量,其值为:

[[ 1,  2,  7,  8],  
 [ 3,  4,  9, 10],  
 [ 5,  6, 11, 12]]