Skip to content

Instantly share code, notes, and snippets.

@morawi
Created February 18, 2021 13:49
Show Gist options
  • Save morawi/42d9a99e135e832f9fd54fb678151efb to your computer and use it in GitHub Desktop.
Save morawi/42d9a99e135e832f9fd54fb678151efb to your computer and use it in GitHub Desktop.
Using/Tesgin PyTorch SubSetRandSampler
import torch
from torch.utils.data import SubsetRandomSampler as SubSetRandSampler
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
my_sampler_size = 5000 # I'm going to randomly sample 5000 items from the dataset
# Why using this instead of shuffle? Because I want to get less samples than the whole dataset
# would using an if statement inside the validation function be good ... something like if i>my_sampler_size: break
val_transforms = transforms.Compose([
# transforms.Resize(image_size, interpolation=Image.BICUBIC),
# transforms.CenterCrop(crop_size),
transforms.ToTensor(),
# normalize,
])
val_dataset = datasets.CIFAR10('../data', train= False, download=True,
transform = val_transforms,
)
validate_significance_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=100,
num_workers=0,
sampler = SubSetRandSampler(torch.randint(0, len(val_dataset), (my_sampler_size,)) ), # SubSetRandSampler(range(1000)),
shuffle= False,
pin_memory=True)
# Let's check it out
sum=0
for i, (images, target) in enumerate(validate_significance_loader):
sum += len(target)
if i<2:
plt.imshow(images[i,:].permute(1, 2, 0) ); plt.show() # to see the image
print(target[i])
break
print(sum)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment