def setup(self, stage=None): self.usps_test = USPS(self.data_dir, train=False, transform=transforms.ToTensor(), download=True) usps_full = USPS(self.data_dir, train=True, transform=transforms.ToTensor(), download=True) self.usps_train, self.usps_val = random_split(usps_full, [6000, 1291])
def load_usps(img_size=28, augment=False, **kwargs): transformations = [transforms.Resize(img_size)] transformations.append(transforms.ToTensor()) if augment: transformations.append( transforms.Lambda(lambda x: random_affine_augmentation(x))) transformations.append(transforms.Lambda(lambda x: gaussian_blur(x))) img_transform = transforms.Compose(transformations) test_transform = transforms.Compose( [transforms.Resize(img_size), transforms.ToTensor()]) train_set = USPS('../data', transform=img_transform, download=True) test_set = USPS('../data', transform=test_transform, download=True) return get_loader(train_set, **kwargs), get_loader(test_set, **kwargs)
def get_data(domain_name: str, split="train"): train = split == "train" if domain_name == "mnist": return do.from_pytorch(MNIST(data_path, download=True, train=train)) if domain_name == "usps": return do.from_pytorch(USPS(data_path, download=True, train=train))
def select_test_dataset(dataset_name, testing=False): """ Selects a dataset from the options below Parameters ---------- dataset_name: dataset name given as string: 'fmnist' testing: testing flag. If testing is True, then the function returns 1000 samples only Returns ------- vec_data: the dataset as a numpy array. The dimensions are N X D where N is the number of samples in the data and D is the dimensions of the feature vector. labels: the labels of the samples. The dimensions are N X 1 where N is the number of samples in the data and 1 is label of the sample. """ if dataset_name == 'fmnist': f_mnist = FashionMNIST(root="./datasets", train=False, download=True) data = f_mnist.data.numpy() vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = f_mnist.targets.numpy() elif dataset_name == 'usps': f_mnist = USPS(root="./datasets", train=False, download=True) data = f_mnist.data vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = np.float32(f_mnist.targets) elif dataset_name == 'char': digits = datasets.load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) vec_data = np.float32(data) labels = digits.target elif dataset_name == 'charx': file_name = file_path + "/datasets/char_x.npy" data = np.load(file_name, allow_pickle=True) vec_data, labels = data[2], data[3] else: print('The dataset you asked for is not available. Gave you MNIST instead.') d_mnist = MNIST(root="./datasets", train=False, download=True) data = d_mnist.data.numpy() vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = d_mnist.targets.numpy() if testing: return vec_data[:1000], labels[:1000] else: return vec_data, labels
def main(): args = parse_args() if args.debug or not args.non_deterministic: np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # torch.set_deterministic(True) # grid_sampler_2d_backward_cuda does not have a deterministic implementation if args.debug: torch.autograd.set_detect_anomaly(True) dataloader_args = EasyDict( batch_size=args.batch_size, shuffle=False, num_workers=0 if args.debug else args.data_workers) if args.dataset == 'mnist': args.num_classes = 10 args.im_channels = 1 args.image_size = (40, 40) from torchvision.datasets import MNIST t = transforms.Compose([ transforms.RandomCrop(size=(40, 40), pad_if_needed=True), transforms.ToTensor(), # norm_1c ]) train_dataloader = DataLoader( MNIST(data_path / 'mnist', train=True, transform=t, download=True), **dataloader_args) val_dataloader = DataLoader( MNIST(data_path / 'mnist', train=False, transform=t, download=True), **dataloader_args) elif args.dataset == 'usps': args.num_classes = 10 args.im_channels = 1 args.image_size = (40, 40) from torchvision.datasets import USPS t = transforms.Compose([ transforms.RandomCrop(size=(40, 40), pad_if_needed=True), transforms.ToTensor(), # norm_1c ]) train_dataloader = DataLoader( USPS(data_path / 'usps', train=True, transform=t, download=True), **dataloader_args) val_dataloader = DataLoader( USPS(data_path / 'usps', train=False, transform=t, download=True), **dataloader_args) elif args.dataset == 'constellation': data_gen = create_constellation( batch_size=args.batch_size, shuffle_corners=True, gaussian_noise=.0, drop_prob=0.5, which_patterns=[[0], [1], [0]], rotation_percent=180 / 360., max_scale=3., min_scale=3., use_scale_schedule=False, schedule_steps=0, ) train_dataloader = DataLoader(data_gen, **dataloader_args) val_dataloader = DataLoader(data_gen, **dataloader_args) elif args.dataset == 'cifar10': args.num_classes = 10 args.im_channels = 3 args.image_size = (32, 32) from torchvision.datasets import CIFAR10 t = transforms.Compose([transforms.ToTensor()]) train_dataloader = DataLoader( CIFAR10(data_path / 'cifar10', train=True, transform=t, download=True), **dataloader_args) val_dataloader = DataLoader( CIFAR10(data_path / 'cifar10', train=False, transform=t, download=True), **dataloader_args) elif args.dataset == 'svhn': args.num_classes = 10 args.im_channels = 3 args.image_size = (32, 32) from torchvision.datasets import SVHN t = transforms.Compose([transforms.ToTensor()]) train_dataloader = DataLoader( SVHN(data_path / 'svhn', split='train', transform=t, download=True), **dataloader_args) val_dataloader = DataLoader( SVHN(data_path / 'svhn', split='test', transform=t, download=True), **dataloader_args) else: raise NotImplementedError() logger = WandbLogger(project=args.log.project, name=args.log.run_name, entity=args.log.team, config=args, offline=not args.log.upload) if args.model == 'ccae': from scae.modules.attention import SetTransformer from scae.modules.capsule import CapsuleLayer from scae.models.ccae import CCAE encoder = SetTransformer(2) decoder = CapsuleLayer(input_dims=32, n_caps=3, n_caps_dims=2, n_votes=4, n_caps_params=32, n_hiddens=128, learn_vote_scale=True, deformations=True, noise_type='uniform', noise_scale=4., similarity_transform=False) model = CCAE(encoder, decoder, args) # logger.watch(encoder._encoder, log='all', log_freq=args.log_frequency) # logger.watch(decoder, log='all', log_freq=args.log_frequency) elif args.model == 'pcae': from scae.modules.part_capsule_ae import CapsuleImageEncoder, TemplateImageDecoder from scae.models.pcae import PCAE encoder = CapsuleImageEncoder(args) decoder = TemplateImageDecoder(args) model = PCAE(encoder, decoder, args) logger.watch(encoder._encoder, log='all', log_freq=args.log.frequency) logger.watch(decoder, log='all', log_freq=args.log.frequency) elif args.model == 'ocae': from scae.modules.object_capsule_ae import SetTransformer, ImageCapsule from scae.models.ocae import OCAE encoder = SetTransformer() decoder = ImageCapsule() model = OCAE(encoder, decoder, args) # TODO: after ccae else: raise NotImplementedError() # Execute Experiment lr_logger = cb.LearningRateMonitor(logging_interval='step') best_checkpointer = cb.ModelCheckpoint(save_top_k=1, monitor='val_rec_ll', filepath=logger.experiment.dir) last_checkpointer = cb.ModelCheckpoint(save_last=True, filepath=logger.experiment.dir) trainer = pl.Trainer( max_epochs=args.num_epochs, logger=logger, callbacks=[lr_logger, best_checkpointer, last_checkpointer]) trainer.fit(model, train_dataloader, val_dataloader)
def setup(self, stage=None): self.usps_test = self.colorize_dataset( USPS(self.data_dir, train=False, download=True)) usps_full = self.colorize_dataset( USPS(self.data_dir, train=True, download=True)) self.usps_train, self.usps_val = random_split(usps_full, [6000, 1291])
def select_dataset(dataset_name, input_dim=2, n_samples=10000): """ :params n_samples: number of points returned. If 0, all datapoints will be returned. For artificial data, it will throw an error. """ if dataset_name == 'fmnist': f_mnist = FashionMNIST(root="./datasets", download=True) data = f_mnist.data.numpy() vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = f_mnist.targets.numpy() elif dataset_name == 'emnist': f_mnist = EMNIST(root="./datasets", download=True, split='byclass') data = f_mnist.data.numpy() vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = f_mnist.targets.numpy() elif dataset_name == 'kmnist': f_mnist = KMNIST(root="./datasets", download=True) data = f_mnist.data.numpy() vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = f_mnist.targets.numpy() elif dataset_name == 'usps': f_mnist = USPS(root="./datasets", download=True) data = f_mnist.data vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = np.float32(f_mnist.targets) elif dataset_name == 'news': newsgroups_train = fetch_20newsgroups(data_home='./datasets', subset='train', remove=('headers', 'footers', 'quotes')) vectorizer = TfidfVectorizer() vec_data = vectorizer.fit_transform(newsgroups_train.data).toarray() vec_data = np.float32(vec_data) labels = newsgroups_train.target labels = np.float32(labels) elif dataset_name == 'cover_type': file_name = file_path + "/datasets/covtype.data" train_data = np.array(pd.read_csv(file_name, sep=',')) vec_data = np.float32(train_data[:, :-1]) labels = np.float32(train_data[:, -1]) elif dataset_name == 'char': digits = datasets.load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) vec_data = np.float32(data) labels = digits.target elif dataset_name == 'charx': file_name = file_path + "/datasets/char_x.npy" data = np.load(file_name, allow_pickle=True) vec_data, labels = data[0], data[1] elif dataset_name == 'kdd_cup': cover_train = fetch_kddcup99(data_home='./datasets', download_if_missing=True) vec_data = cover_train.data string_labels = cover_train.target vec_data, labels = feature_tranformers.vectorizer_kdd(data=vec_data, labels=string_labels) elif dataset_name == 'aggregation': file_name = file_path + "/2d_data/Aggregation.csv" a = np.array(pd.read_csv(file_name, sep=';')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'compound': file_name = file_path + "/2d_data/Compound.txt" a = np.array(pd.read_csv(file_name, sep='\t')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'd31': file_name = file_path + "/2d_data/D31.txt" a = np.array(pd.read_csv(file_name, sep='\t')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'flame': file_name = file_path + "/2d_data/flame.txt" a = np.array(pd.read_csv(file_name, sep='\t')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'path_based': file_name = file_path + "/2d_data/pathbased.txt" a = np.array(pd.read_csv(file_name, sep='\t')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'r15': file_name = file_path + "/2d_data/R15.txt" a = np.array(pd.read_csv(file_name, sep='\t')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'spiral': file_name = file_path + "/2d_data/spiral.txt" a = np.array(pd.read_csv(file_name, sep='\t')) vec_data = a[:, 0:2] labels = a[:, 2] elif dataset_name == 'birch1': file_name = file_path + "/2d_data/birch1.txt" a = np.array(pd.read_csv(file_name, delimiter=r"\s+")) vec_data = a[:, :] labels = np.ones((vec_data.shape[0])) elif dataset_name == 'birch2': file_name = file_path + "/2d_data/birch2.txt" a = np.array(pd.read_csv(file_name, delimiter=r"\s+")) vec_data = a[:, :] labels = np.ones((vec_data.shape[0])) elif dataset_name == 'birch3': file_name = file_path + "/2d_data/birch3.txt" a = np.array(pd.read_csv(file_name, delimiter=r"\s+")) vec_data = a[:, :] labels = np.ones((vec_data.shape[0])) elif dataset_name == 'worms': file_name = file_path + "/2d_data/worms/worms_2d.txt" a = np.array(pd.read_csv(file_name, sep=' ')) vec_data = a[:, :] labels = np.ones((vec_data.shape[0])) elif dataset_name == 't48k': file_name = file_path + "/2d_data/t4.8k.txt" a = np.array(pd.read_csv(file_name, sep=' ')) vec_data = a[1:, :] labels = np.ones((vec_data.shape[0])) elif dataset_name == 'moons': data, labels = make_moons(n_samples=5000) vec_data = np.float32(data) labels = np.float32(labels) elif dataset_name == 'circles': data, labels = make_circles(n_samples=5000) vec_data = np.float32(data) labels = np.float32(labels) elif dataset_name == 'blobs': data, labels = make_blobs(n_samples=n_samples, centers=3) vec_data = np.float32(data) labels = np.float32(labels) elif dataset_name == 'gmm': mean_1 = np.zeros(input_dim) mean_2 = 100 * np.ones(input_dim) cov = np.eye(input_dim) data_1 = np.random.multivariate_normal(mean_1, cov, int(n_samples / 2)) labels_1 = np.ones(int(n_samples / 2)) labels_2 = 2 * np.ones(int(n_samples / 2)) data_2 = np.random.multivariate_normal(mean_2, cov, int(n_samples / 2)) vec_data = np.concatenate([data_1, data_2], axis=0) labels = np.concatenate([labels_1, labels_2], axis=0) elif dataset_name == 'uniform': vec_data = np.random.uniform(0, 1, size=(n_samples, input_dim)) * 10 labels = np.ones(n_samples) elif dataset_name == 'mnist_pc': d_mnist = MNIST(root="./datasets", download=True) mnist = d_mnist.data.numpy() data = np.float32(np.reshape(mnist, (mnist.shape[0], -1))) pca_data = PCA(n_components=input_dim).fit_transform(data) n_indices = np.random.randint(pca_data.shape[0], size=n_samples) vec_data = pca_data[n_indices] labels = d_mnist.targets.numpy()[n_indices] elif dataset_name == 'usps_pc': d_mnist = USPS(root="./datasets", download=True) mnist = d_mnist.data data = np.float32(np.reshape(mnist, (mnist.shape[0], -1))) pca_data = PCA(n_components=input_dim).fit_transform(data) n_indices = np.random.randint(pca_data.shape[0], size=n_samples) vec_data = pca_data[n_indices] labels = np.float32(d_mnist.targets) elif dataset_name == 'char_pc': digits = datasets.load_digits() n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1)) data = np.float32(data) targets = digits.target pca_data = PCA(n_components=input_dim).fit_transform(data) n_indices = np.random.randint(pca_data.shape[0], size=n_samples) vec_data = pca_data[n_indices] labels = targets else: d_mnist = MNIST(root="./datasets", download=True) data = d_mnist.data.numpy() vec_data = np.reshape(data, (data.shape[0], -1)) vec_data = np.float32(vec_data) labels = d_mnist.targets.numpy() if 0 < n_samples < vec_data.shape[0]: rand_indices = np.random.choice(vec_data.shape[0], size=(n_samples,), replace=False) return vec_data[rand_indices], labels[rand_indices] else: return vec_data, labels
if use_y_to_verify_performance: plot_3d(transformed, y_for_verification, show=False) plt.show() if __name__ == "__main__": from torchvision.datasets import MNIST, USPS, FashionMNIST, CIFAR10 from torchtext.datasets import AG_NEWS n = None # semisupervised_proportion = .2 e = DEN(n_components=2, internal_dim=128) USPS_data_train = USPS("./", train=True, download=True) USPS_data_test = USPS("./", train=False, download=True) USPS_data = ConcatDataset([USPS_data_test, USPS_data_train]) X, y = zip(*USPS_data) y_numpy = np.array(y[:n]) X_numpy = np.array( [np.asarray(X[i]) for i in range(n if n is not None else len(X))]) X = torch.Tensor(X_numpy).unsqueeze(1) # which = np.random.choice(len(y_numpy), int((1-semisupervised_proportion)*len(y_numpy)), replace = False) # y_for_verification = copy.deepcopy(y_numpy) # y_numpy[which] = -1 # news_train, news_test = AG_NEWS('./', ngrams = 1) # X, y = zip(*([item[1], item[0]] for item in news_test))
def get_dataset(args, config): if config.data.dataset == 'CIFAR10': if (config.data.random_flip): dataset = CIFAR10(os.path.join('datasets', 'cifar10'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ])) test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'), train=False, download=True, transform=transforms.Compose([ transforms.Resize( config.data.image_size), transforms.ToTensor() ])) else: dataset = CIFAR10(os.path.join('datasets', 'cifar10'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = CIFAR10(os.path.join('datasets', 'cifar10_test'), train=False, download=True, transform=transforms.Compose([ transforms.Resize( config.data.image_size), transforms.ToTensor() ])) elif config.data.dataset == 'CELEBA': if config.data.random_flip: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) elif (config.data.dataset == "CELEBA-32px"): if config.data.random_flip: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(32), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) elif (config.data.dataset == "CELEBA-8px"): if config.data.random_flip: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(8), transforms.Resize(config.data.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]), download=True) else: dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='train', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(8), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) test_dataset = CelebA(root=os.path.join('datasets', 'celeba'), split='test', transform=transforms.Compose([ transforms.CenterCrop(140), transforms.Resize(8), transforms.Resize(config.data.image_size), transforms.ToTensor(), ]), download=True) elif config.data.dataset == 'LSUN': train_folder = '{}_train'.format(config.data.category) val_folder = '{}_val'.format(config.data.category) if config.data.random_flip: dataset = LSUN(root=os.path.join('datasets', 'lsun'), classes=[train_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), ])) else: dataset = LSUN(root=os.path.join('datasets', 'lsun'), classes=[train_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.ToTensor(), ])) test_dataset = LSUN(root=os.path.join('datasets', 'lsun'), classes=[val_folder], transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.CenterCrop(config.data.image_size), transforms.ToTensor(), ])) elif config.data.dataset == "FFHQ": if config.data.random_flip: dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'), transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]), resolution=config.data.image_size) else: dataset = FFHQ(path=os.path.join('datasets', 'FFHQ'), transform=transforms.ToTensor(), resolution=config.data.image_size) num_items = len(dataset) indices = list(range(num_items)) random_state = np.random.get_state() np.random.seed(2019) np.random.shuffle(indices) np.random.set_state(random_state) train_indices, test_indices = indices[:int(num_items * 0.9 )], indices[int(num_items * 0.9):] test_dataset = Subset(dataset, test_indices) dataset = Subset(dataset, train_indices) elif config.data.dataset == "MNIST": if config.data.random_flip: dataset = MNIST(root=os.path.join('datasets', 'MNIST'), train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) else: dataset = MNIST(root=os.path.join('datasets', 'MNIST'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = MNIST(root=os.path.join('datasets', 'MNIST'), train=False, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) elif config.data.dataset == "USPS": if config.data.random_flip: dataset = USPS(root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) else: dataset = USPS(root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = USPS(root=os.path.join('datasets', 'USPS'), train=False, download=True, transform=transforms.Compose([ transforms.Resize(config.data.image_size), transforms.ToTensor() ])) elif config.data.dataset == "USPS-Pad": if config.data.random_flip: dataset = USPS( root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(20), # resize and pad like MNIST transforms.Pad(4), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) else: dataset = USPS( root=os.path.join('datasets', 'USPS'), train=True, download=True, transform=transforms.Compose([ transforms.Resize(20), # resize and pad like MNIST transforms.Pad(4), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) test_dataset = USPS( root=os.path.join('datasets', 'USPS'), train=False, download=True, transform=transforms.Compose([ transforms.Resize(20), # resize and pad like MNIST transforms.Pad(4), transforms.Resize(config.data.image_size), transforms.ToTensor() ])) elif (config.data.dataset.upper() == "GAUSSIAN"): if (config.data.num_workers != 0): raise ValueError( "If using a Gaussian dataset, num_workers must be zero. \ Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error." ) if (config.data.isotropic): dim = config.data.dim rank = config.data.rank cov = np.diag(np.pad(np.ones((rank, )), [(0, dim - rank)])) mean = np.zeros((dim, )) else: cov = np.array(config.data.cov) mean = np.array(config.data.mean) shape = config.data.dataset.shape if hasattr(config.data.dataset, "shape") else None dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape) test_dataset = Gaussian(device=args.device, cov=cov, mean=mean, shape=shape) elif (config.data.dataset.upper() == "GAUSSIAN-HD"): if (config.data.num_workers != 0): raise ValueError( "If using a Gaussian dataset, num_workers must be zero. \ Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error." ) cov = np.load(config.data.cov_path) mean = np.load(config.data.mean_path) dataset = Gaussian(device=args.device, cov=cov, mean=mean) test_dataset = Gaussian(device=args.device, cov=cov, mean=mean) elif (config.data.dataset.upper() == "GAUSSIAN-HD-UNIT"): # This dataset is to be used when GAUSSIAN with the isotropic option is infeasible due to high dimensionality # of the desired samples. If the dimension is too high, passing a huge covariance matrix is slow. if (config.data.num_workers != 0): raise ValueError( "If using a Gaussian dataset, num_workers must be zero. \ Gaussian data is sampled at runtime and doing so with multiple workers may cause a CUDA error." ) shape = config.data.shape if hasattr(config.data, "shape") else None dataset = Gaussian(device=args.device, mean=None, cov=None, shape=shape, iid_unit=True) test_dataset = Gaussian(device=args.device, mean=None, cov=None, shape=shape, iid_unit=True) return dataset, test_dataset