def get_train_loader(dataset): """ Get train dataloader of source domain or target domain :return: dataloader """ if dataset == 'MNIST': transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)), transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) ]) data = datasets.MNIST(root= params.mnist_path, train= True, transform= transform, download= True) dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True) elif dataset == 'MNIST_M': transform = transforms.Compose([ transforms.RandomCrop((28)), transforms.ToTensor(), transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) ]) data = MNISTM(transform=transform) dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) elif dataset == 'SVHN': transform = transforms.Compose([ transforms.RandomCrop((28)), transforms.ToTensor(), transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std) ]) data1 = datasets.SVHN(root=params.svhn_path, split='train', transform=transform, download=True) data2 = datasets.SVHN(root= params.svhn_path, split= 'extra', transform = transform, download= True) data = torch.utils.data.ConcatDataset((data1, data2)) dataloader = DataLoader(dataset=data, batch_size=params.batch_size, shuffle=True) elif dataset == 'SynDig': transform = transforms.Compose([ transforms.RandomCrop((28)), transforms.ToTensor(), transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) ]) data = SynDig.SynDig(root= params.syndig_path, split= 'train', transform= transform, download= False) dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) else: raise Exception('There is no dataset named {}'.format(str(dataset))) return dataloader
def get_test_loader(dataset): """ Get test dataloader of source domain or target domain :return: dataloader """ if dataset == 'MNIST': transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) ]) data = datasets.MNIST(root= params.mnist_path, train= False, transform= transform, download= True) dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True) elif dataset == 'MNIST_M': transform = transforms.Compose([ # transforms.RandomCrop((28)), transforms.CenterCrop((28)), transforms.ToTensor(), transforms.Normalize(mean= params.dataset_mean, std= params.dataset_std) ]) data = datasets.ImageFolder(root=params.mnistm_path + '/test', transform= transform) dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) elif dataset == 'SVHN': transform = transforms.Compose([ transforms.CenterCrop((28)), transforms.ToTensor(), transforms.Normalize(mean= params.dataset_mean, std = params.dataset_std) ]) data = datasets.SVHN(root= params.svhn_path, split= 'test', transform = transform, download= True) dataloader = DataLoader(dataset = data, batch_size= params.batch_size, shuffle= True) elif dataset == 'SynDig': transform = transforms.Compose([ transforms.CenterCrop((28)), transforms.ToTensor(), transforms.Normalize(mean=params.dataset_mean, std=params.dataset_std) ]) data = SynDig.SynDig(root= params.syndig_path, split= 'test', transform= transform, download= False) dataloader = DataLoader(dataset= data, batch_size= params.batch_size, shuffle= True) else: raise Exception('There is no dataset named {}'.format(str(dataset))) return dataloader