8.2. 数据集制作与加载

借助 Torch 提供的 DataSet 和 DataLoader 类实现自定义数据集的制作与加载.

8.2.1. DataSet

数据集有以下几个类

  • Dataset(): 无输入参数

  • TensorDataset(*tensors): 将多个``Tensor``做成一个数据集, 输入为多个``Tensors``

  • ConcatDataset(datasets): 将多个数据集拼成一个数据集,输入为多个数据集

  • Subset(dataset, indices): 取数据集的子集, 输入为数据集与索引

使用 Dataset

需要重写该类的 __len__, __getitem__, __init__ 方法. 下面举例说明, 假设网络含两个输入 \(x_1, x_2\), 一个输出 \(y\), 构造 MyDataset 类, 实现代码为

代码 8.5 demo_DatasetDataLoader.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch as th
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable


# ===data

x1 = th.randn((10, 3, 128, 128))
x2 = th.randn((10, 2, 128, 128))
y = th.randn((10, 128, 128))

epochs = 4

# ===DataSet

print("---Dataset")

class MyDataset(Dataset):

    def __init__(self, x1, x2, y):
        self.x1 = x1
        self.x2 = x2
        self.y = y
        self.len = y.shape[0]


    def __getitem__(self, index):
        return self.x1[index], self.x2[index], self.y[index]

    def __len__(self):
        return self.len

mydataset = MyDataset(x1, x2, y)

dataloader = DataLoader(dataset=mydataset,
    batch_size=3, shuffle=True, num_workers=2)

for epoch in range(epochs):
    print(epoch)
    for i, data in enumerate(dataloader):
        x1v, x2v, yv = data
        print(x1v.size(), x2v.size(), yv.size())


使用 TensorDataset

TensorDataset 类的使用非常简单, 省去了重写类方法的麻烦, 下面举例说明, 假设网络含两个输入 \(x_1, x_2\), 一个输出 \(y\), 构造 MyDataset 类, 实现代码为

代码 8.6 demo_TensorDatasetDataLoader.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch as th
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable


# ===data

x1 = th.randn((10, 3, 128, 128))
x2 = th.randn((10, 2, 128, 128))
y = th.randn((10, 128, 128))

epochs = 4

# ===TensorDataSet

print("---TensorDataset")
mydataset = TensorDataset(x1, x2, y)

dataloader = DataLoader(dataset=mydataset,
                        batch_size=3, shuffle=True, num_workers=2)

# epoch = 0
for epoch in range(epochs):
    print(epoch)
    for i, data in enumerate(dataloader):
        x1v, x2v, yv = data
        print(x1v.size(), x2v.size(), yv.size())

8.2.2. DataLoader