Intro to PyTorch DataLoader

PyTorch
deep learning
tutorial
Author

Kai Tan

Published

April 22, 2024

What Is a DataLoader?

In PyTorch, the DataLoader is a powerful tool that lets you iterate over datasets efficiently. It handles batching, shuffling, and even multi-process data loading — all crucial for model training.


Step 1: Create a Custom Dataset

Start by subclassing torch.utils.data.Dataset. Here’s an example where we simulate binary labels.

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(100)
        self.labels = self.data % 2  # binary labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

Step 2: Use DataLoader to Load in Batches

from torch.utils.data import DataLoader

dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

You can now easily loop through the dataset in batches:

for batch_idx, (x, y) in enumerate(dataloader):
    if batch_idx < 2: 
        print(f"Batch {batch_idx}")
        print("x:", x)
        print("y:", y)
        print("---")
Batch 0
x: tensor([58, 82, 22, 98, 81, 68, 62, 24, 69, 20])
y: tensor([0, 0, 0, 0, 1, 0, 0, 0, 1, 0])
---
Batch 1
x: tensor([77, 32, 52, 97, 41, 93, 27, 65, 33, 39])
y: tensor([1, 0, 0, 1, 1, 1, 1, 1, 1, 1])
---

Step 3: Use a Real-World Dataset (California Housing)

You can also use built-in datasets from torchvision.

from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset

# Load data
california = fetch_california_housing()
X = california.data[:1000]
y = california.target[:1000]

# Convert to tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)

# Wrap in a dataset
real_dataset = TensorDataset(X_tensor, y_tensor)
real_loader = DataLoader(real_dataset, batch_size=32, shuffle=True)

Preview the first batch:

features, targets = next(iter(real_loader))
print(features.shape)  # [32, 8]
print(targets.shape)   # [32, 1]
torch.Size([32, 8])
torch.Size([32, 1])

Step 4: Visualization

import matplotlib.pyplot as plt

feature_index = 0  # e.g., Median Income
plt.hist(features[:, feature_index].numpy(), bins=20, edgecolor='k')
plt.title("Distribution of Feature: Median Income")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()


Common DataLoader Parameters

Parameter Meaning
batch_size Number of samples per batch
shuffle Whether to shuffle data at every epoch
num_workers Number of subprocesses to use for data loading
drop_last Drop last batch if it’s smaller than batch_size
pin_memory Speed up data transfer to GPU if using CUDA

Reproducibility Tip

To make shuffling deterministic:

g = torch.Generator()
g.manual_seed(42)

dataloader = DataLoader(dataset, batch_size=10, shuffle=True, generator=g)

Summary

PyTorch’s DataLoader is a flexible and powerful abstraction for handling dataset loading. Whether you’re using built-in datasets or your own, DataLoader gives you:

  • Mini-batch loading
  • Shuffling for randomness
  • Parallelism with num_workers