import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(100)
self.labels = self.data % 2 # binary labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
Intro to PyTorch DataLoader
PyTorch
deep learning
tutorial
What Is a DataLoader?
In PyTorch, the DataLoader
is a powerful tool that lets you iterate over datasets efficiently. It handles batching, shuffling, and even multi-process data loading — all crucial for model training.
Step 1: Create a Custom Dataset
Start by subclassing torch.utils.data.Dataset
. Here’s an example where we simulate binary labels.
Step 2: Use DataLoader to Load in Batches
from torch.utils.data import DataLoader
= MyDataset()
dataset = DataLoader(dataset, batch_size=10, shuffle=True) dataloader
You can now easily loop through the dataset in batches:
for batch_idx, (x, y) in enumerate(dataloader):
if batch_idx < 2:
print(f"Batch {batch_idx}")
print("x:", x)
print("y:", y)
print("---")
Batch 0
x: tensor([58, 82, 22, 98, 81, 68, 62, 24, 69, 20])
y: tensor([0, 0, 0, 0, 1, 0, 0, 0, 1, 0])
---
Batch 1
x: tensor([77, 32, 52, 97, 41, 93, 27, 65, 33, 39])
y: tensor([1, 0, 0, 1, 1, 1, 1, 1, 1, 1])
---
Step 3: Use a Real-World Dataset (California Housing)
You can also use built-in datasets from torchvision
.
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset
# Load data
= fetch_california_housing()
california = california.data[:1000]
X = california.target[:1000]
y
# Convert to tensors
= torch.tensor(X, dtype=torch.float32)
X_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)
y_tensor
# Wrap in a dataset
= TensorDataset(X_tensor, y_tensor)
real_dataset = DataLoader(real_dataset, batch_size=32, shuffle=True) real_loader
Preview the first batch:
= next(iter(real_loader))
features, targets print(features.shape) # [32, 8]
print(targets.shape) # [32, 1]
torch.Size([32, 8])
torch.Size([32, 1])
Step 4: Visualization
import matplotlib.pyplot as plt
= 0 # e.g., Median Income
feature_index =20, edgecolor='k')
plt.hist(features[:, feature_index].numpy(), bins"Distribution of Feature: Median Income")
plt.title("Value")
plt.xlabel("Frequency")
plt.ylabel(True)
plt.grid( plt.show()
Common DataLoader Parameters
Parameter | Meaning |
---|---|
batch_size |
Number of samples per batch |
shuffle |
Whether to shuffle data at every epoch |
num_workers |
Number of subprocesses to use for data loading |
drop_last |
Drop last batch if it’s smaller than batch_size |
pin_memory |
Speed up data transfer to GPU if using CUDA |
Reproducibility Tip
To make shuffling deterministic:
= torch.Generator()
g 42)
g.manual_seed(
= DataLoader(dataset, batch_size=10, shuffle=True, generator=g) dataloader
Summary
PyTorch’s DataLoader
is a flexible and powerful abstraction for handling dataset loading. Whether you’re using built-in datasets or your own, DataLoader
gives you:
- Mini-batch loading
- Shuffling for randomness
- Parallelism with
num_workers