Tensor
Tensor (张量) 是类似数组和矩阵的多维数据结构。Tensor 和 NumPy 的 ndarray 非常相似,但是张量可以在 GPU 上使用来加速计算。
在开始之前,有必要了解 Tensor 作为一个类的基本属性。这将有助于我们更好地理解 Tensor 的使用。
定义
面向对象结构
在 torch\tensor.py
中,我们可以看到 Tensor 的定义:
class Tensor(torch._C._TensorBase):
...
可以看到,Tensor 是 torch._C._TensorBase
的子类。
在 torch\_C\__init__.pyi
中,我们可以看到 torch._C._TensorBase
的类型声明:
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorBase(metaclass=_TensorMeta):
shape: Size
dtype: _dtype
device: device
data: _TensorBase
grad: Optional[_TensorBase]
grad_fn: Optional[_FunctionBase]
...
它告诉我们 torch._C._TensorBase
是一个定义在 torch/csrc/autograd/python_variable.cpp
中的 C++ 类,并且在下方给出了其属性和方法的类型声明。
了解到这种层次目前对我们来说已经足够,不必深究 C++ 的内部机理。
属性和方法
我们简要介绍这几个属性和方法:
shape
: Tensor 的形状,即 Tensor 的维度。# Defined in torch/csrc/Size.cpp class Size(Tuple[_int, ...]): ...
可以看出,
shape
是一个整数型元组。我们可以通过shape
属性来确定和获取 Tensor 的形状。>>> x = torch.rand(3, 4) >>> x.shape torch.Size([3, 4])
dtype
: Tensor 的数据类型。如torch.float32
、torch.int64
等。Data type dtype 32-bit floating point torch.float32
ortorch.float
64-bit floating point torch.float64
ortorch.double
16-bit floating point torch.float16
ortorch.half
16-bit floating point torch.bfloat16
32-bit complex torch.complex32
ortorch.chalf
64-bit complex torch.complex64
ortorch.cfloat
128-bit complex torch.complex128
ortorch.cdouble
8-bit integer (unsigned) torch.uint8
8-bit integer (signed) torch.int8
16-bit integer (signed) torch.int16
ortorch.short
32-bit integer (signed) torch.int32
ortorch.int
64-bit integer (signed) torch.int64
ortorch.long
Boolean torch.bool
quantized 8-bit integer (unsigned) torch.quint8
quantized 8-bit integer (signed) torch.qint8
quantized 32-bit integer (signed) torch.qint32
quantized 4-bit integer (unsigned) torch.quint4x2
device
: Tensor 所在的设备,即 CPU 或 GPU。data
: Tensor 的数据,即 Tensor 的值。grad
: Tensor 的梯度。grad_fn
: Tensor 的梯度函数。
使用
初始化
Tensor 可以通过以下方式初始化:
通过数据
可以直接将多维列表 (__builtins__.list
) 传入 torch.tensor()
。数据类型和形状是自动推断的。
此时两者并不共享内存。
data = [[1, 2],[3, 4]]
print(f"data = \n{data}")
t_data = torch.tensor(data)
print(f"t_data = \n{t_data}\n")
print(f"data 和 t_data 是否共享内存:{t_data.data_ptr() == id(data[0][0])}")
data =
[[1, 2], [3, 4]]
t_data =
tensor([[1, 2],
[3, 4]])
data 和 t_data 是否共享内存:False
通过 NumPy 数组
注意,这个方法从 NumPy 数组中读取数据,而不是复制它,它们仍然共享同样的内存。
因此,如果你改变了原始 NumPy 数组,这些改变也将反映在 Tensor 中;反之同理。
np_array = np.array(data)
print(f"np_array = \n{np_array}\n")
t_np = torch.from_numpy(np_array)
print(f"t_np = \n{t_np}\n")
print(f"np_array 和 t_np 是否共享内存:{t_np.data_ptr() == np_array.ctypes.data}")
np_array =
[[1 2]
[3 4]]
t_np =
tensor([[1, 2],
[3, 4]], dtype=torch.int32)
np_array 和 t_np 是否共享内存:True
通过其他 Tensor
这些方法将重用输入 Tensor 的属性,例如 shape
和 dtype
,除非我们提供新的值。
t_ones = torch.ones_like(t_data) # 保留 t_data 的所有属性,只是覆盖数据
print(f"Ones Tensor: \n{t_ones}\n")
t_rand = torch.rand_like(t_data, dtype=torch.float) # 重写覆盖 dtype 属性,保留其他属性
print(f"Random Tensor, dtype changed to float: \n{t_rand}\n")
Ones Tensor:
tensor([[1, 1],
[1, 1]])
Random Tensor, dtype changed to float:
tensor([[0.4963, 0.7682],
[0.0885, 0.1320]])
Creation Ops
通过Creation Ops 是一些用于创建 Tensor 的函数。
一般地,传入参数 shape
是必需的,其他参数是可选的。
shape = (2,3)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape, dtype=torch.int)
zeros_tensor = torch.zeros(shape, dtype=torch.int)
print(f"Random Tensor: \n{rand_tensor}\n")
print(f"Ones Tensor: \n{ones_tensor}\n")
print(f"Zeros Tensor: \n{zeros_tensor}")
Random Tensor:
tensor([[0.3074, 0.6341, 0.4901],
[0.8964, 0.4556, 0.6323]])
Ones Tensor:
tensor([[1, 1, 1],
[1, 1, 1]], dtype=torch.int32)
Zeros Tensor:
tensor([[0, 0, 0],
[0, 0, 0]], dtype=torch.int32)
属性
方法
原教程 写的不够好,弄了个全 1 Tensor,难以直观表现相关操作的区别。这里做了修改。
同时把更多常用的方法穿插融入到了下面的例子中。
Tensor Views
与 NumPy 类似,我们可以使用 view()
来重塑 Tensor 的形状。
注意,view()
返回的新 Tensor 和原始 Tensor 共享相同的数据,所以改变其中一个,另一个也会改变。
t = torch.tensor(range(12))
print(f"t = \n{t}\n")
view1 = t.view(2, 6)
print(f"view1 = \n{view1}\n")
print(f"t 和 view1 共享内存:{ t.storage().data_ptr() == view1.storage().data_ptr() }")
t =
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
view1 =
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
t 和 view1 共享内存:True
注意,view()
需要确保传入的 Tensor 是邻接的(contiguous,即内存中的数据是连续的),否则会报错。
在下面的例子中,我们使用了 .t()
来转置 Tensor(同 base.transpose(0, 1)
),这个方法返回的也是 Tensor 的一个 View,而不是一个新的 Tensor。
显然,这种方法作出的 View 不是邻接的。我们可以使用 contiguous()
方法将其转换为邻接的。
base = torch.tensor([[0, 1],[2, 3]])
print(f"base = \n{base}")
print(f"base is contiguous: {base.is_contiguous()}\n")
transposed = base.t()
print(f"transposed = \n{transposed}")
print(f"transposed is contiguous: {transposed.is_contiguous()}\n")
try:
view2 = transposed.view(4)
except RuntimeError as e:
print(f"运行时错误: {e}\n")
contiguouse = transposed.contiguous()
view2 = contiguouse.view(4)
print(f"view2 = \n{view2}")
base =
tensor([[0, 1],
[2, 3]])
base is contiguous: True
transposed =
tensor([[0, 2],
[1, 3]])
transposed is contiguous: False
运行时错误: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
view2 =
tensor([0, 2, 1, 3])
因此,PyTorch 提供了 reshape()
方法,它相当于 view(contiguous())
。
当原 Tensor 邻接时,它返回的也是原 Tensor 的一个 View,否则返回的是一个新的 Tensor。
view3 = t.reshape(3,4)
print(f"# 令 view3 = t.reshape(3,4)")
print(f"view3 = \n{view3}")
print(f"t 和 view3 共享内存:{ t.storage().data_ptr() == view3.storage().data_ptr() }\n")
t[0] = 0
view4 = transposed.reshape(4)
print(f"# 令 view4 = transposed.reshape(4)")
print(f"view4 = \n{view4}")
print(f"transposed 和 view4 共享内存:{ transposed.storage().data_ptr() == view4.storage().data_ptr() }")
# 令 view3 = t.reshape(3,4)
view3 =
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
t 和 view3 共享内存:True
# 令 view4 = transposed.reshape(4)
view4 =
tensor([0, 2, 1, 3])
transposed 和 view4 共享内存:False
索引及切片
与 NumPy 类似,我们可以使用 []
来索引和切片 Tensor。
t = t.reshape(3, 4)
print(f"First row : {t[0]}")
print(f"First column: {t[:, 0]}")
print(f"Last column: {t[..., -1]}")
t[:,1] = 0
print(t)
First row : tensor([0, 1, 2, 3])
First column: tensor([0, 4, 8])
Last column: tensor([ 3, 7, 11])
tensor([[ 0, 0, 2, 3],
[ 4, 0, 6, 7],
[ 8, 0, 10, 11]])
数学运算
t.add_(1)
print(t)
tensor([[ 1, 1, 3, 4],
[ 5, 1, 7, 8],
[ 9, 1, 11, 12]])
t1 = torch.cat([t, t, t], dim=1)
print(t1)
tensor([[ 1, 1, 3, 4, 1, 1, 3, 4, 1, 1, 3, 4],
[ 5, 1, 7, 8, 5, 1, 7, 8, 5, 1, 7, 8],
[ 9, 1, 11, 12, 9, 1, 11, 12, 9, 1, 11, 12]])
print(f"tensor.mul(tensor)\n{t.mul(t)}\n")
print(f"tensor * tensor\n{t * t}\n")
print(f"tensor.mul(2)\n{t.mul(2)}\n")
tensor.mul(tensor)
tensor([[ 1, 1, 9, 16],
[ 25, 1, 49, 64],
[ 81, 1, 121, 144]])
tensor * tensor
tensor([[ 1, 1, 9, 16],
[ 25, 1, 49, 64],
[ 81, 1, 121, 144]])
tensor.mul(2)
tensor([[ 2, 2, 6, 8],
[10, 2, 14, 16],
[18, 2, 22, 24]])