import torch from torch.utils.data import SubsetRandomSampler, DataLoader dataset = torch.utils.data.TensorDataset(torch.randn(100)) subset_indices = list(range(20)) sampler = SubsetRandomSampler(subset_indices) dataloader = DataLoader(dataset, sampler=sampler, batch_size=1) for i, batch in enumerate(dataloader): if i == 10: break print(batch)
tensor([ 1.6447]) tensor([-0.4465]) tensor([-1.5022]) tensor([-0.2073]) tensor([-0.0603]) tensor([-0.5113]) tensor([-0.1399]) tensor([ 0.7900]) tensor([ 0.1792]) tensor([ 1.2661])
import torch from torch.utils.data import SubsetRandomSampler, DataLoader from torchvision import datasets, transforms # Load MNIST dataset trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()) testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor()) # Create a subset of the training set train_indices = list(range(0, 5000)) train_sampler = SubsetRandomSampler(train_indices) # Create dataloaders for training and testing trainloader = DataLoader(trainset, sampler=train_sampler, batch_size=32) testloader = DataLoader(testset, batch_size=32) # Define a simple neural network class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 784) x = F.relu(self.fc1(x)) x = self.fc2(x) return x net = Net() # Train the neural network using the subset of the dataset criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for epoch in range(10): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 99: # print every 100 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0 print('Finished Training')Overall, SubsetRandomSampler is a useful class in the PyTorch torch.utils.data package for creating subsets of datasets and randomly sampling elements from those subsets.