8.2. 数据集制作与加载¶
借助 Torch 提供的 DataSet 和 DataLoader 类实现自定义数据集的制作与加载.
8.2.1. DataSet¶
数据集有以下几个类
Dataset()
: 无输入参数TensorDataset(*tensors)
: 将多个``Tensor``做成一个数据集, 输入为多个``Tensors``ConcatDataset(datasets)
: 将多个数据集拼成一个数据集,输入为多个数据集Subset(dataset, indices)
: 取数据集的子集, 输入为数据集与索引
使用 Dataset¶
需要重写该类的 __len__
, __getitem__
, __init__
方法. 下面举例说明, 假设网络含两个输入 \(x_1, x_2\), 一个输出 \(y\), 构造 MyDataset
类, 实现代码为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch as th
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
# ===data
x1 = th.randn((10, 3, 128, 128))
x2 = th.randn((10, 2, 128, 128))
y = th.randn((10, 128, 128))
epochs = 4
# ===DataSet
print("---Dataset")
class MyDataset(Dataset):
def __init__(self, x1, x2, y):
self.x1 = x1
self.x2 = x2
self.y = y
self.len = y.shape[0]
def __getitem__(self, index):
return self.x1[index], self.x2[index], self.y[index]
def __len__(self):
return self.len
mydataset = MyDataset(x1, x2, y)
dataloader = DataLoader(dataset=mydataset,
batch_size=3, shuffle=True, num_workers=2)
for epoch in range(epochs):
print(epoch)
for i, data in enumerate(dataloader):
x1v, x2v, yv = data
print(x1v.size(), x2v.size(), yv.size())
|
使用 TensorDataset¶
TensorDataset 类的使用非常简单, 省去了重写类方法的麻烦, 下面举例说明, 假设网络含两个输入 \(x_1, x_2\), 一个输出 \(y\), 构造 MyDataset
类, 实现代码为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import torch as th
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
# ===data
x1 = th.randn((10, 3, 128, 128))
x2 = th.randn((10, 2, 128, 128))
y = th.randn((10, 128, 128))
epochs = 4
# ===TensorDataSet
print("---TensorDataset")
mydataset = TensorDataset(x1, x2, y)
dataloader = DataLoader(dataset=mydataset,
batch_size=3, shuffle=True, num_workers=2)
# epoch = 0
for epoch in range(epochs):
print(epoch)
for i, data in enumerate(dataloader):
x1v, x2v, yv = data
print(x1v.size(), x2v.size(), yv.size())
|