首页 > 教程 > 如何使用 Pytorch 中的 DataSet 和 DataLoader

如何使用 Pytorch 中的 DataSet 和 DataLoader

时间:2024-10-30 | 来源: | 阅读:107

话题: a T S C 深度学习

如何使用数据集DataSet? 在介绍DataLoader之前,需要先了解数据集DataSet的使用。Pytorch中集成了很多已经处理好的数据集,在pytorch的torchvision、torchtext等模块有一些典型的数据集,可以通过配置来下载使用。 以CIFAR10 数据集为例,文档已经描

Pytorch中集成了大量已处理好的数据集,在torchvision、torchtext等模块中都有一些典型的数据集,用户可以通过配置来下载并使用这些数据集。例如,CIFAR10 数据集已经被描述得非常清晰了。其中要注意的是 transform 这个参数,可以用来将图像转换为所需要的格式,比如将PIL格式的图像转化为tensor格式的图像。

在介绍 DataLoader 之前,需要先了解如何使用 DataSet。Pytorch 中的DataSet是一个存储所有数据(例如图像、音频)的容器。DataLoader 就是另一个具有更好收纳功能的容器,其中分隔开来很多小隔间,可以自己设定一个小隔间有多少个数据集的数据来组成,每次将数据放进收纳小隔间的时候要不要把源数据集打乱再进行收纳等等。

给定了一个数据集,用户可以决定如何从数据集里面拿取数据来进行训练,比如一次拿取多少数据作为一个对象来对数据集进行分割,对数据集进行分割之前要不要打乱数据集等等。DataLoader的结果就是一个对数据集进行分割的大字典列表,列表中的每个对象都是由设置的多少个数据集的对象组合而成的。

如何使用 DataLoader?

__getitem__方法

首先需要先理解 __getitem__ 方法。__getitem__被称为魔法方法,在Python中定义一个类的时候,如果想要通过键来得到类的输出值,就需要 __getitem__ 方法。__getitem__ 方法的作用就是在调用类的时候自动的运行 __getitem__ 方法的内容,得结果并返回。

    
        class Fib():
            def __init__(self,start=0,step=1):
                self.step=step
            def __getitem__(self, key):
                a = key+self.step
                return a
        s=Fib()
        s[1]
    

例如,在Pytorch中的CIFAR10数据集中,可以看到源码中的 __getitem__ 方法是这样的。

    
        def __getitem__(self, index: int) -> Tuple[Any, Any]:
            img, target = self.data[index], self.targets[index]
            img = Image.fromarray(img)
            if self.transform is not None:
                img = self.transform(img)
            if self.target_transform is not None:
                target = self.target_transform(target)
            return img, target
    

DataLoader 语法

用户可以在Pytorch的Documents文档中查看DataLoader的使用方法。其中介绍几个比较常用的参数,例如 dataset,batch_size,shuffle,num_workers 和 drop_last。其中,batch_size表示在数据集容器中一次拿取多少数据,shuffle表示是否在每次操作的时候打乱数据集,一般选择为True。num_workers表示多线程进行拿取数据操作,0表示只在主线程中操作。drop_last表示如果拿取数据有余数,是否保留最后剩下的部分。

  • dataset:就是用户的数据集,构建好数据集对象后传入即可。

  • shuffle:是否在每次操作的时候打乱数据集,一般选择为True。

  • num_workers: 多线程进行拿取数据操作,0表示只在主线程中操作。

  • drop_last:如果拿取数据有余数,是否保留最后剩下的部分。

例如,在后面的代码中,如果设置 drop_last=False,那么一共有156次数据拿取,并且最后一次剩余的部分不会被丢弃。如果设置 drop_last=True,那么最后剩余的部分被丢弃,并且拿取次数也少了一次。

使用 DataLoader

初步使用的代码如下:

    
        import torchvision
        from torch.utils.data import DataLoader
        from torch.utils.tensorboard import SummaryWriter
        test_data=torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor())
        test_dataloader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=True)
        writer=SummaryWriter("logs")
        step=0
        for data in test_dataloader:
            images,targets=data
            writer.add_images("test_03",images,step)
            step=step+1
        writer.close()
    

然后配合使用tensorboard就可以直观体会到它的使用方法了。


湘ICP备2022002427号-10湘公网安备:43070202000427号
© 2013~2019 haote.com 好特网