def make_celeba_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False): crop_size = 108 offset_height = (218 - crop_size) // 2 offset_width = (178 - crop_size) // 2 crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size] transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(crop), transforms.ToPILImage(), transforms.Resize(size=(resize, resize)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset = torchlib.DiskImageDataset(img_paths, map_fn=transform) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory) img_shape = (resize, resize, 3) return data_loader, img_shape
def make_custom_datset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False): transform = transforms.Compose([ # ====================================== # = custom = # ====================================== ..., # custom preprocessings # ====================================== # = custom = # ====================================== transforms.Resize(size=(resize, resize)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset = torchlib.DiskImageDataset(img_paths, map_fn=transform) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory) img_shape = (resize, resize, 3) return data_loader, img_shape