def load_data(composed_transforms: transforms.Compose, split: str) -> DataLoader: """ Loads the data given the config.dataset :param composed_transforms: the augmentations to apply to the dataset :param split: either train or test, dataset returned also depends on config.fully_train - 'train' will return 42500 images for evolution, but 50000 for fully training. 'test' will return 7500 images for evolution, but 10000 for fully training. Note: the validation/train split does not try balance the classes of data, it just takes the first n for the train set and the remaining data goes to the validation set """ if split.lower() not in ['train', 'test']: raise ValueError('Parameter split can be one of train, test or validation, but received: ' + str(split)) # when to load train set train: bool = True if split == 'train' or (split == 'test' and not config.fully_train) else False dataset_args = { 'root': DataManager.get_datasets_folder(), 'train': train, 'download': config.download_dataset, 'transform': composed_transforms } if config.dataset == 'mnist': dataset = MNIST(**dataset_args) elif config.dataset == 'cifar10': dataset = CIFAR10(**dataset_args) elif config.dataset == 'cifar100': dataset = CIFAR100(**dataset_args) elif config.dataset == 'fashionMnist': dataset = FashionMNIST(**dataset_args) elif config.dataset == 'custom': dataset = get_generic_dataset(composed_transforms, train) else: raise ValueError('config.dataset can be one of mnist, cifar10 or custom, but received: ' + str(config.dataset)) if train and not config.fully_train: # Splitting the train set into a train and valid set train_size = int(len(dataset) * (1 - config.validation_split)) if split == 'train': dataset = Subset(dataset, range(train_size)) elif split == 'test': dataset = Subset(dataset, range(train_size, len(dataset))) else: raise Exception("") # print(split, 'set size in', 'FT' if config.fully_train else 'evo', len(dataset)) # TODO: test num workers and pin memory return DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0, pin_memory=False)
def load_data(composed_transforms: transforms.Compose, split: str) -> DataLoader: """Loads the data given the config.dataset""" if split not in ['train', 'test', 'validation']: raise ValueError( 'Parameter split can be one of train, test or validation, but received: ' + str(split)) train: bool = True if split == 'train' or split == 'validation' else False dataset_args = { 'root': DataManager.get_datasets_folder(), 'train': train, 'download': True, 'transform': composed_transforms } if config.dataset == 'mnist': dataset = MNIST(**dataset_args) elif config.dataset == 'cifar10': dataset = CIFAR10(**dataset_args) elif config.dataset == 'custom': dataset = get_generic_dataset(composed_transforms, train) else: raise ValueError( 'config.dataset can be one of mnist, cifar10 or custom, but received: ' + str(config.dataset)) if train: # Splitting the train set into a train and valid set train_size = int(len(dataset) * (1 - config.validation_split)) validation_size = len(dataset) - train_size train, valid = random_split(dataset, [train_size, validation_size]) if split == 'train': dataset = train else: dataset = valid # TODO: test num workers and pin memory return DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0, pin_memory=False)
def load_data(batch_size=Config.batch_size, dataset=""): """loads a dataset using the torch dataloader and and the settings in Config""" data_loader_args = { 'num_workers': Config.num_workers, 'pin_memory': False if Config.device != 'cpu' else False } data_path = DataManager.get_datasets_folder() colour_image_transform = transforms.Compose([ #image transform goes here transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) black_and_white_image_transform = transforms.Compose([ #image transform goes here transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]) download = False if dataset == "": dataset = Config.dataset.lower() if dataset == 'mnist': train_loader = DataLoader(datasets.MNIST( data_path, train=True, download=download, transform=black_and_white_image_transform), batch_size=batch_size, shuffle=True, **data_loader_args) test_loader = DataLoader(datasets.MNIST( data_path, train=False, download=download, transform=black_and_white_image_transform), batch_size=batch_size, shuffle=True, **data_loader_args) elif dataset == 'fashion_mnist': train_loader = DataLoader(datasets.FashionMNIST( data_path, train=True, download=download, transform=black_and_white_image_transform), batch_size=batch_size, shuffle=True, **data_loader_args) test_loader = DataLoader(datasets.FashionMNIST( data_path, train=False, download=download, transform=black_and_white_image_transform), batch_size=batch_size, shuffle=True, **data_loader_args) elif dataset == 'cifar10': train_loader = DataLoader(datasets.CIFAR10( data_path, train=True, download=download, transform=colour_image_transform), batch_size=batch_size, shuffle=True, **data_loader_args) test_loader = DataLoader(datasets.CIFAR10( data_path, train=False, download=download, transform=colour_image_transform), batch_size=batch_size, shuffle=True, **data_loader_args) else: raise Exception( 'Invalid dataset name, options are fashion_mnist, mnist or cifar10 you provided:' + dataset) return train_loader, test_loader