PyTorch入门

基本操作

Tensors可以迁移到cuda上进行操作

1
2
3
4
5
# let us run this cell only if CUDA is available
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
x + y

autograd

pytorch所有张量的中心是autograd库。而autograd的中心组件是autograd.Variable,它包装了一个Tensor,并基本支持了所有Tensor的操作。可以使用.data来访问内置的tensor。可以使用.grad来访问计算出来的梯度。而用.backward()来计算梯度。

数据处理

torch.utils.data.Dataset是一个代表数据集的抽象类。自定义的数据集需要继承Dataset这个类,并且对其中的方法进行重写。

  • __len__,返回数据集中数据的数量
  • __getitem__支持类似dataset[i]的这种索引。

torch.utils.data.Dataloader包含其他更高级的数据用法,比如batchingshuffling、以及并行地读取数据。

torch.nn

Parameters

torch.nn.parameter

这是Variable的一个子类,同时当它作为module的参数时,自动加入到parameters()的

torch.Tensor操作

view 操作必须在同一块内存里完成,因此可能需要configuous操作来将数据放到一个内存块中。

Switch to CUDA

我的做法是在code中定义一个CUDA全局的量,表示第几块GPU,但是有一个问题,就是如果有一个tensor是在GPU1上,然后直接进行print,会导致GPU0上有内存被当前进程占用。

如果需要吧tensor打印出来,可以使用tensor.cpu().numpy(),先将tensor放到cpu上。

可以设置一个TensorWrapper来包装tensor或者variable

1
2
3
4
5
def TensorWrapper(tensor, cuda=CUDA):
if cuda > -1:
return tensor.cuda(cuda)
else:
return tensor

同样的,当模型是在GPU上的时候,保存模型也会占用GPU0的内存,所以可以先把模型拷贝到CPU上,然后再保存,最后再将其移回GPU。

使用Dataloader和Dataset

pytorch提供了内建的数据提供模块,可以用来构建数据训练的pipeline,而且可以提供了多进程的可能,让整个处理速度的瓶颈尽可能地转移到GPU上。

Dataset

这个可以帮助存储整个数据集,当数据集很大的时候,还可以做到on-the-fly的读写方式。

1
2
3
4
5
6
7
8
9
10
11
from torch.utils.data import Dataset, DataLoader
class TextClassificationDataset(Dataset):
def __init__(self, X, y, is_cuda):
self.X = X
self.y = y

def __len__(self):
return len(self.y)

def __getitem__(self, idx):
return {'X':self.X[idx], 'y':self.y[idx]}

DataLoader

然后就可以在用DataLoader来包装我们定义的TextClassificationDataset来自动地包装我们的数据,返回的数据类型是torch.Tensor

使用torch.nn.DataParallel

这个模块可以把输入切成若干份,放到各个GPU上去跑,然后再将最后的结果merge起来。这个方法只需要在实例化好的模型外再套一层torch.nn.DataParallel即可。如果要调用原先模型中自定义的一些函数,除了forward之外都需要改成调用model.module了。

bidirectional lstm

1
2
lstm = nn.LSTM(size1, size2, 1, bidirectional=True)
output, (h,c) = lstm(input)

分为两个lstm,然后将原有的input转换方向,变成reversed_input,把inputreversed_input分别输入到LSTM中,然后得到两个输出,一个是outputfw和outputbw。将outputbw反转,拼接上outputfw,就是output的结果,而将不反转的outputbw和ouptutfw的最后一步的状态拼接起来就是h,c同理。

分享到