pytorch基础(八):Dataloader的简单使用

前言

  本系列主要是对pytorch基础知识学习的一个记录,尽量保持博客的更新进度和自己的学习进度。本人也处于学习阶段,博客中涉及到的知识可能存在某些问题,希望大家批评指正。另外,本博客中的有些内容基于吴恩达老师深度学习课程,我会尽量说明一下,但不敢保证全面。

一、构造数据类Dataset

  要想使用Dataloader,我们需要构造一个适用于待解决问题的一个数据类,该数据类必须继承Dataset,下面是一个简单的例子:

from torch.utils.data import DataLoader, Dataset
class MnistData(Dataset):
    def __init__(self, data_path, label_path):
        super(MnistData, self).__init__()
        self.data, self.label = load_mnist(data_path, label_path)
        self.data = self.data[0:1000, :]
        self.label = self.label[0:1000]
        self.len = self.label.shape[0]

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.len

  这里我为手写数字图片构造了一个数据类,因为数据集比较简单(数据+标签),因此数据类中的成员变量也不是很多。我个人觉得还是要具体问题具体分析,当需要处理文本数据时,构造的类就会复杂许多。==loadmnist== 是读取文件的一个函数。
  当你构建数据类时,你必须继承 ==Dataset== 类,并复写
getitem 函数和 len _ 函数

二、使用Dataloader

  pytorch给出的Dataloader解释如下:

在这里插入图片描述

  Dataloder中参数还是非常多的,我暂时用到过的并不多,主要是以下三个参数:

1.dataset:Dataset类型,传入提前构造好的数据类。
2.batch_size:int类型,批处理的大小,不用自己划分数据集。
3.shuffle:bool类型,当设置为True时,每个epoch会随机打乱数据集。

使用Dataloader:

mnist_data = MnistData(train_data_path, train_label_path)
train_loader = DataLoader(dataset=mnist_data, batch_size=32, shuffle=True)

遍历Dataloader:

for epoch in range(epoch_num):
    epoch_cost = 0
    for i, data in enumerate(train_loader):
        img_data, labels = data

总结

  本人对Dataloader的了解其实并不很透彻,只会一些基本使用,在今后的情形中若碰见比较复杂的情形,我会完善这一篇博客。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇