Pytorch 基础教程

PyTorch 基础教程

original icon
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://www.knowledgedict.com/tutorial/pytorch-tensor.html

PyTorch 张量(Tensor)详解


PyTorch 的张量(Tensor)作为其核心数据结构。张量类似于 NumPy 数组,但具有 GPU 支持,可用于深度学习模型的构建和训练。本文将详细介绍 PyTorch 的张量,包括如何创建张量、张量的元素类型、张量的操作、张量形状操作以及与 NumPy 的互操作性。

创建张量

在 PyTorch 中,可以通过多种方式创建张量,以下是一些常见的方法:

从列表或 NumPy 数组创建

import torch
import numpy as np

# 从 Python 列表创建张量
list_data = [1, 2, 3, 4, 5]
tensor_from_list = torch.tensor(list_data)

# 从 NumPy 数组创建张量
numpy_data = np.array([6, 7, 8, 9, 10])
tensor_from_numpy = torch.from_numpy(numpy_data)

当然也可以从都是从数据类型的元组中构建:

import torch

# 从 Python 数字元组创建张量
tuple_data = (11, 12, 13, 14, 15)
tensor_from_tuple = torch.tensor(tuple_data)

使用特定值创建

pytorch 提供了几个特定值创建 tensor 的方法,分别是全零、全一和随机数的张量,示例如下:

import torch

# 创建全零张量
zeros_tensor = torch.zeros(2, 3)

# 创建全一张量
ones_tensor = torch.ones(3, 4)

# 创建随机张量(均匀分布)
rand_tensor = torch.rand(4, 5)

使用范围创建

创建等差张量或创建等间隔张量,示例如下:

import torch

# 创建等差张量
arange_tensor = torch.arange(0, 10, step=2)  # 0, 2, 4, 6, 8

# 创建等间隔张量
linspace_tensor = torch.linspace(0, 1, steps=5)  # 0.00, 0.25, 0.50, 0.75, 1.00

张量的元素类型

PyTorch 张量可以具有不同的数据类型,常见的包括 torch.float32(默认)、torch.int64torch.bool 等。可以通过 dtype 属性来查看和修改张量的数据类型:

import torch

# 查看张量的数据类型
tensor = torch.tensor([1, 2, 3])
print(tensor.dtype)  # 输出: torch.int64

# 修改张量的数据类型
float_tensor = tensor.to(torch.float32)

张量操作

加法操作

在 PyTorch 中,你可以使用多种方法对张量进行相加。以下列出了常见的4种的方法。

使用运算符 + 进行张量相加:

使用 + 运算符是最常见的方法,它允许你对两个张量进行逐元素相加。前提是张量的形状必须兼容,即有相同的形状。

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 相加
result = tensor1 + tensor2
print(result)  # 输出: tensor([5, 7, 9])

使用 torch.add 函数进行张量相加:

torch.add 函数可以用于执行逐元素的张量相加,它还可以指定输出张量来保存结果。

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 使用 torch.add 进行相加
result = torch.add(tensor1, tensor2)
print(result)  # 输出: tensor([5, 7, 9])

使用 torch.add 函数的 out 参数指定输出张量:

torch.add 函数的 out 参数允许你指定一个输出张量,这可以用于保存结果而不创建新的张量。

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 创建一个空的输出张量
result = torch.empty(3)

# 使用 torch.add 将结果保存到输出张量
torch.add(tensor1, tensor2, out=result)
print(result)  # 输出: tensor([5., 7., 9.])

使用原地操作符 += 进行张量相加:

原地操作符 += 允许你在不创建新张量的情况下就地修改一个张量。

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 使用原地操作符 +=
tensor1 += tensor2
print(tensor1)  # 输出: tensor([5, 7, 9])

这些方法都可以用于对 PyTorch 张量进行相加操作。你可以选择适合你需求的方法,具体取决于是否需要保存结果、是否需要原地修改张量等因素。

乘法操作

使用运算符 * 进行张量逐元素乘法:

使用 * 运算符是最常见的方法,它允许你对两个张量进行逐元素乘法。前提是张量的形状必须兼容,即有相同的形状。

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 逐元素相乘
result = tensor1 * tensor2
print(result)  # 输出: tensor([ 4, 10, 18])

使用 torch.mul 函数进行逐元素乘法:

torch.mul 函数可以用于执行逐元素的张量乘法,它还可以指定输出张量来保存结果。torch.mul 函数它还有一个别称 torch.multiply

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 使用 torch.mul 进行逐元素相乘
result = torch.mul(tensor1, tensor2)
print(result)  # 输出: tensor([ 4, 10, 18])

使用原地操作符 *= 进行逐元素乘法:

原地操作符 *= 允许你在不创建新张量的情况下就地修改一个张量。

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

# 使用原地操作符 *=
tensor1 *= tensor2
print(tensor1)  # 输出: tensor([ 4, 10, 18])

使用 torch.matmul 进行矩阵乘法:

torch.matmul 函数用于执行矩阵乘法,它适用于二维张量(矩阵)。矩阵乘法的前提是第一个矩阵的列数必须等于第二个矩阵的行数。

import torch

# 创建两个矩阵
matrix1 = torch.tensor([[1, 2], [3, 4]])
matrix2 = torch.tensor([[5, 6], [7, 8]])

# 执行矩阵乘法
result = torch.matmul(matrix1, matrix2)
print(result)
# 输出:
# tensor([[19, 22],
#         [43, 50]])

使用 torch.dot 函数进行点乘:

点乘,也称为内积或数量积,是两个向量的逐元素相乘然后求和的操作。在线性代数中,它可以表示为:a · b = a1 * b1 + a2 * b2 + ... + an * bn

计算 input 和 output 的点乘,此函数要求 input 和 output 都必须是一维的张量(其 shape 属性中只有一个值),并且要求两者元素个数相同。

import torch

# 创建两个向量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# 计算点乘
dot_product = torch.dot(a, b)
print(dot_product)  # 输出: tensor(32)

使用 torch.mm 函数进行矩阵相乘:

在线性代数中,矩阵乘法是一种常见的操作,用于将一个矩阵与另一个矩阵相乘。在矩阵乘法中,第一个矩阵的列数必须等于第二个矩阵的行数。

在 PyTorch 中,你可以使用 torch.mm 函数来执行两个矩阵的乘法(适用于二维张量)。

import torch

# 创建两个矩阵
matrix1 = torch.tensor([[1, 2], [3, 4]])
matrix2 = torch.tensor([[5, 6], [7, 8]])

# 执行矩阵乘法
result = torch.mm(matrix1, matrix2)
print(result)
# 输出:
# tensor([[19, 22],
#         [43, 50]])

使用 torch.mv 函数进行矩阵和向量相乘:

torch.mv 是 PyTorch 中用于执行矩阵和向量相乘(Matrix-Vector Multiplication)的函数。这个函数非常有用,因为在深度学习和线性代数中,矩阵和向量相乘是一个常见的操作,通常用于线性变换和特征提取。

torch.mv(mat, vec, out=None)
  • mat:要进行矩阵和向量相乘的二维张量(矩阵)。
  • vec:要与矩阵相乘的一维张量(向量)。
  • out(可选):指定一个输出张量,用于存储结果。如果不提供此参数,函数将创建一个新的张量来保存结果。
import torch

# 创建一个矩阵
matrix = torch.tensor([[1, 2, 3],
                       [4, 5, 6]])

# 创建一个向量
vector = torch.tensor([7, 8, 9])

# 使用 torch.mv 进行矩阵和向量相乘
result = torch.mv(matrix, vector)
print(result)

切片操作

# 创建一个二维张量
matrix = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 切片操作
row1 = matrix[0, :]  # 获取第一行
col2 = matrix[:, 1]  # 获取第二列

# 切片并修改张量
matrix[1, 2] = 10

复制和转置

# 复制张量
original_tensor = torch.tensor([1, 2, 3])
copied_tensor = original_tensor.clone()  # 或者使用 copied_tensor = original_tensor.copy()

# 转置张量
original_matrix = torch.tensor([[1, 2],
                                [3, 4]])
transposed_matrix = original_matrix.T

张量形状操作

PyTorch 提供了许多用于修改张量形状的操作,如 reshapeviewpermute 等。

# 改变张量形状
x = torch.arange(12)
reshaped_x = x.reshape(3, 4)  # 或者使用 x.view(3, 4)

# 转置维度
tensor = torch.tensor([[1, 2], [3, 4]])
permuted_tensor = tensor.permute(1, 0)  # 将维度交换

与 NumPy 的互操作

PyTorch 与 NumPy 之间有很好的互操作性,可以轻松地将张量转换为 NumPy 数组,反之亦然。

# 将 PyTorch 张量转换为 NumPy 数组
pytorch_tensor = torch.tensor([1, 2, 3])
numpy_array = pytorch_tensor.numpy()

# 将 NumPy 数组转换为 PyTorch 张量
numpy_array = np.array([4, 5, 6])
pytorch_tensor = torch.from_numpy(numpy_array)

 

Elasticsearch索引的配置项主要分为静态配置属性和动态配置属性,静态配置属性是索引创建后不能修改,而动态配置属性则可以随时修改。r ...
基于 elasticsearch 构建的业务中最常用的聚合查询就是 terms aggregation,它基于 term 粒度的词或数字值进 ...
在 es 使用中,开发者想配置自身业务中沉淀的同义词(synonyms)表,并基于该同义词库配置包含其的分析器(analyzer),主要分为 ...
elasticsearch 删除索引操作能够用单个命令来进行完成,有不同的操作形式,具体如下: ...
es bool 查询是把任意多个简单查询组合在一起,使用 must、should、must_not、filter 选项来表示简单查询之间的逻 ...