How to use pytorch DataLoader
a tutorial on pytorch DataLoader, Dataset, SequentialSampler, and RandomSampler
- What are pytorch DataLoader and Dataset
- The data loading order and Sampler
- Get the random indices from RandomSampler
If you use pytorch as your deep learning framework, it's likely that you'll need to use DataLoader in your model training loop.
In this tutorial, you'll learn about
- How to construct a custom Dataset class
- How to use DataLoader to split a dataset into batches
- How to randomize a dataset in DataLoader
- How to return the randomized index in DataLoader
A custom Dataset class must have three functions:
-
__init__
: instantiates the Dataset object -
__len__
: returns the number of samples in the dataset -
__getitem__
: loads and returns a sample from the dataset at the given index idx
Here's an example custom dataset that takes in a pandas DataFrame with columns "text" and "label".
from torch.utils.data import Dataset
class NewsDataset(Dataset):
def __init__(self, df):
self.df = df
self.texts = df['text']
self.labels = df['label']
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
text = self.texts.iloc[idx]
label = self.labels.iloc[idx]
return text,label
from datasets import load_dataset
agnews = load_dataset('ag_news')
agnews.set_format(type="pandas")
news_df = agnews['train'][:25]
news_df.head()
Initiating a NewsDataset
news_dataset = NewsDataset(news_df)
Check if news_dataset
has a length
len(news_dataset)
This means that the news_dataset
has 25 news data.
Now check if the news_dataset
is indexable.
news_dataset[7]
We can see that the index 7 of news_dataset
has a piece of news with label 2.
from torch.utils.data import DataLoader
news_dataloader = DataLoader(news_dataset,batch_size=4)
To load the data in batches, we can do
for batch_index, (batch,label) in enumerate(news_dataloader):
print(f'batch index: {batch_index},\n label: {label},\n batch: {batch}')
The data loading order and Sampler
A DataLoader leverages a Sampler to generate indices of dataset for each batch.
We can use either the parameter shuffle
or sampler
to adjust the order of indices.
SequentialSampler
By default, the parameter shuffle
is False
. And the DataLoader will use a SequentialSampler that returns indices sequentially.
news_dataloader = DataLoader(news_dataset,batch_size=4)
type(news_dataloader.sampler)
We can also explicitly create a SequentialSampler and pass it into a DataLoader using the parameter sampler
.
# initiate a SequentialSampler
from torch.utils.data import SequentialSampler
sequential_news_sampler = SequentialSampler(news_dataset)
# let's print out the indices in the SequentialSampler
for i in sequential_news_sampler:
print(i)
news_dataloader = DataLoader(news_dataset,batch_size=4,sampler = sequential_news_sampler )
for batch_index, (batch,label) in enumerate(news_dataloader):
print(f'batch index: {batch_index},\n label: {label},\n batch: {batch}')
news_dataloader = DataLoader(news_dataset,batch_size=4,shuffle=True)
for batch_index, (batch,label) in enumerate(news_dataloader):
print(f'batch index: {batch_index},\n label: {label},\n batch: {batch}')
Note that the dataloader will automatically use RandomSampler as its sampler
when shuffle
is True
type(news_dataloader.sampler)
We can also
2. Pass a RandomSampler to a DataLoader
RandomSampler returns random indices.
# initiate a RandomSampler
from torch.utils.data import RandomSampler
random_news_sampler = RandomSampler(news_dataset)
# print out the indices in the RandomSampler
for i in random_news_sampler:
print(i)
We can then pass random_news_sampler
to the dataloader using the parameter sampler
.
random_news_dataloader = DataLoader(news_dataset,batch_size=4,sampler=random_news_sampler)
for batch_index, (batch,label) in enumerate(random_news_dataloader):
print(f'batch index: {batch_index},\n label: {label},\n batch: {batch}')
Sometimes, we might want to get the random indices from the DataLoader.
One way to do this is to let the __getitem__
method in Dataset return a triple (index, data, label)
Define a new NewsDataset
class:
class NewsDataset(Dataset):
def __init__(self, df):
self.df = df
self.texts = df['text']
self.labels = df['label']
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
text = self.texts.iloc[idx]
label = self.labels.iloc[idx]
return idx,text,label # <--- make the dataset return the index as well
Let's see if the NewsDataset works as desired.
# Initiate a NewsDataset
news_dataset = NewsDataset(news_df)
# Access the news of index 8
news_dataset[8]
You can see in the output that now the news_dataset
also returns the index.
We can then create a DataLoader with this news_dataset
.
# initiate a news RandomSampler
random_news_sampler = RandomSampler(news_dataset)
# initiate a news DataLoader
random_news_dataloader = DataLoader(news_dataset,batch_size = 4,sampler=random_news_sampler)
When looping through the batches, use the according triple (index, batch, label)
to store what we get form the dataloader.
for batch_index, (index,batch,label) in enumerate(random_news_dataloader):
print(f'batch index: {batch_index},\n index: {index},\n label: {label},\n batch: {batch}')
You can see that now the batches also contain the information of indices.