当前位置:优草派 > 问答 > Python问答

pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换

标签: Python  Python开发  PyTorch  作者: xujincan

回答:

PyTorch是一种基于Python的科学计算库,它是一个用于构建深度神经网络的开源机器学习库。PyTorch具有高度可扩展性,可用于处理各种类型的数据,包括张量(tensor)、图片、CPU、GPU、数组等。在本文中,我们将深入探讨如何使用PyTorch实现这些数据类型的转换。一、张量tensor的转换

张量是PyTorch中最基本的数据类型。它类似于Numpy中的数组,但可以在GPU上运行,这使得PyTorch比Numpy更快。下面是如何在PyTorch中创建张量的代码:

```python

import torch

# 创建一个大小为5x3的未初始化张量

x = torch.empty(5, 3)

print(x)

# 创建一个大小为5x3的随机张量

x = torch.rand(5, 3)

print(x)

# 创建一个大小为5x3的全0张量,数据类型为long

x = torch.zeros(5, 3, dtype=torch.long)

print(x)

```

可以使用`size()`方法来查看张量的大小:

```python

print(x.size())

```

张量可以在CPU和GPU之间转换。下面是将张量从CPU转移到GPU的代码:

```python

# 在GPU上创建一个大小为5x3的随机张量

x = torch.rand(5, 3).cuda()

# 将张量从GPU转移到CPU

x = x.cpu()

# 将张量从CPU转移到GPU

x = x.cuda()

```

二、图片的转换

在PyTorch中,可以使用`torchvision`模块处理图像。下面是如何将图像转换为张量的代码:

```python

import torch

import torchvision

import torchvision.transforms as transforms

# 定义变换

transform = transforms.Compose(

[transforms.ToTensor()])

# 加载数据集

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,

shuffle=True, num_workers=2)

# 显示图像

import matplotlib.pyplot as plt

import numpy as np

# 定义类别

classes = ('plane', 'car', 'bird', 'cat',

'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 获取随机数据

dataiter = iter(trainloader)

images, labels = dataiter.next()

# 显示图像

def imshow(img):

img = img / 2 + 0.5 # 非标准化

npimg = img.numpy()

plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.show()

imshow(torchvision.utils.make_grid(images))

print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

```

三、CPU和GPU的转换

在PyTorch中,可以在CPU和GPU之间转换数据。下面是如何将数据从CPU转移到GPU的代码:

```python

import torch

# 在GPU上创建一个大小为5x3的随机张量

x = torch.rand(5, 3).cuda()

# 将张量从GPU转移到CPU

x = x.cpu()

# 将张量从CPU转移到GPU

x = x.cuda()

```

四、数组的转换

在PyTorch中,可以使用`numpy()`方法将张量转换为Numpy数组。下面是如何将张量转换为Numpy数组的代码:

```python

import torch

# 创建一个大小为5x3的随机张量

x = torch.rand(5, 3)

# 将张量转换为Numpy数组

y = x.numpy()

# 将Numpy数组转换为张量

z = torch.from_numpy(y)

```

五、

TOP 10
  • 周排行
  • 月排行