转载

pytorch中Tensor各种操作

import torch

pytorch中Tensor初始化

#返回全零Tensor
a = torch.zeros(2,3); print(a)

#返回shape和参数一样的全1Tensor, zeros_like类似
b = torch.ones_like(a); print(b)

#torch.arange(start=0, end, step=1) 
c = torch.arange(0, 10, 1); print(c)

#初始化指定值的Tensor
e = torch.full((2,3), 2); print(e)


#************随机初始化***********
#初始化为[0,1)内的均匀分布随机数
a = torch.rand((2,3)); print(a)

#初始化为[0,1)内的均匀分布随机数,不过shape与参数相同
b = torch.rand_like(a); print(b)

#返回标准正态分布(0,1)的随机数
torch.randn(2,3)

#torch.normal(mean, std, out=None)
c = torch.normal(torch.randn(2,3), torch.randn(2,3)); print(c)
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([[2., 2., 2.],
        [2., 2., 2.]])
tensor([[0.5680, 0.8983, 0.3664],
        [0.1633, 0.9800, 0.8474]])
tensor([[0.5968, 0.0573, 0.2997],
        [0.7349, 0.0017, 0.4628]])
tensor([[ 0.5925,  1.8533, -1.8251],
        [-1.9088, -0.1117,  0.6028]])

pytorch中contiguous

调用view之前最好先contiguous

x.contiguous().view()

因为view需要tensor的内存是整块的

x = torch.ones(10, 10)
print( x.is_contiguous())  # True
print (x.transpose(0, 1).is_contiguous())  # False
print (x.transpose(0, 1).contiguous().is_contiguous())
True
False
True

pytorch中transpose

transpose就是转置,只能操作两位坐标轴转换

x = torch.randn(2, 3, 4)
x1 = x.transpose(0, 1)
x2 = x.transpose(1, 0)
# x1, x2相同
x, x1, x2
(tensor([[[ 0.4893, -0.0905, -1.5506,  1.1095],
          [-1.0713, -0.4694, -0.6739, -0.9530],
          [ 1.6341,  0.3761, -1.6181,  0.7335]],
 
         [[-0.4032, -0.3206,  0.1094,  2.7867],
          [-0.1956,  0.3200, -1.3562,  0.7300],
          [-0.3488,  0.5682,  0.8729, -0.4489]]]),
 tensor([[[ 0.4893, -0.0905, -1.5506,  1.1095],
          [-0.4032, -0.3206,  0.1094,  2.7867]],
 
         [[-1.0713, -0.4694, -0.6739, -0.9530],
          [-0.1956,  0.3200, -1.3562,  0.7300]],
 
         [[ 1.6341,  0.3761, -1.6181,  0.7335],
          [-0.3488,  0.5682,  0.8729, -0.4489]]]),
 tensor([[[ 0.4893, -0.0905, -1.5506,  1.1095],
          [-0.4032, -0.3206,  0.1094,  2.7867]],
 
         [[-1.0713, -0.4694, -0.6739, -0.9530],
          [-0.1956,  0.3200, -1.3562,  0.7300]],
 
         [[ 1.6341,  0.3761, -1.6181,  0.7335],
          [-0.3488,  0.5682,  0.8729, -0.4489]]]))

pytorch中cat

cat的第一个参数必须是tuple of Tensors, 第二个参数指定连接的方式,通俗来说是横着还是竖着,默认是横着

a = torch.randn(1,3)
b = torch.randn(1,3)
c = torch.cat((a,b))
print(a.shape); print(b.shape);print(c.shape)
print('\n')
a = torch.randn(2,3)
b = torch.randn(2,1)
c = torch.cat((a,b), 1)
print(a.shape); print(b.shape); print(c.shape)
torch.Size([1, 3])
torch.Size([1, 3])
torch.Size([2, 3])


torch.Size([2, 3])
torch.Size([2, 1])
torch.Size([2, 4])

pytorch中stack

与cat操作不同的是stack操作过程中会增加新的维度,具体看例子解释

All tensors need to be of the same size.

a = torch.randn(2, 3)
b = torch.randn(2,3)
c = torch.stack([a,b], 2)
print(a.shape); print(b.shape); print(c.shape)
print(a); print(b); print(c)
torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 3, 2])
tensor([[ 0.1350, -0.6451,  0.4002],
        [ 1.2582,  0.5484,  0.9806]])
tensor([[-2.3363, -0.8511, -1.5627],
        [-0.3785,  0.2618, -1.0212]])
tensor([[[ 0.1350, -2.3363],
         [-0.6451, -0.8511],
         [ 0.4002, -1.5627]],

        [[ 1.2582, -0.3785],
         [ 0.5484,  0.2618],
         [ 0.9806, -1.0212]]])
正文到此结束
本文目录