def get_omniglot_loaders(arguments): if arguments.preload_all_data: raise NotImplementedError train_loader = torch.utils.data.DataLoader(datasets.Omniglot( DATASET_PATH, background=True, download=True, transform=transforms.Compose([ transforms.RandomAffine(10), transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=arguments.batch_size, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS) test_loader = torch.utils.data.DataLoader( datasets.Omniglot( DATASET_PATH, background=False, download=True, transform=transforms.Compose([ # transforms.RandomCrop(70), transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=arguments.batch_size, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS) return train_loader, test_loader
def get_inverted_omniglot_loaders(arguments, mean=(0.5,), std=(0.5,)): print("Using mean", mean) # (1-0.92206,), (0.08426,) if arguments['preload_all_data']: raise NotImplementedError train_loader = torch.utils.data.DataLoader( datasets.Omniglot(DATASET_PATH, background=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: 1 - x), transforms.Normalize((1 - mean), (1 - std)) ])), batch_size=arguments['batch_size'], shuffle=True, pin_memory=True, num_workers=NUM_WORKERS ) test_loader = torch.utils.data.DataLoader( datasets.Omniglot(DATASET_PATH, background=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: 1 - x), transforms.Normalize((1 - mean), (1 - std)) ])), batch_size=arguments['batch_size'], shuffle=False, pin_memory=True, num_workers=NUM_WORKERS ) return train_loader, test_loader
def main(): controller = Controller() kwargs = {"num_workers": 1, "pin_memory": True} if FLAGS.CUDA else {} train_loader = DataLoader( datasets.Omniglot( "data", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Normalize(mean, std) should add for better performance, + other transforms ])), batch_sizer=FLAGS.BATCH_SIZE, shuffle=True, **kwargs) test_loader = DataLoader( datasets.Omniglot( "data", train=False, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Normalize(mean, std) should add for better performance, + other transforms ])), batch_sizer=FLAGS.BATCH_SIZE, shuffle=True, **kwargs) if FLAGS.TRAIN: train(controller, train_loader) else: test(controller, test_loader)
def _get_dataset(self): print("Loading {} dataset from {}.".format(self.dataset_name, self.datadir)) augment_transforms = [] image_transforms = tfs.Compose([tfs.ToTensor()]) if self.dataset_name == 'omniglot': self.input_shape, self.num_classes = (1, 105, 105), 1623 self.train_dataset = datasets.Omniglot(self.datadir, background=True, target_transform=None, download=True, transform=image_transforms) self.test_dataset = datasets.Omniglot(self.datadir, background=False, target_transform=None, download=True, transform=image_transforms) elif self.dataset_name == 'cifar100': self.input_shape, self.num_classes = (3, 32, 32), 100 if self.augment: print("Using augmentation on train dataset.") augment_transforms = [tfs.RandomCrop(32, padding=4), tfs.RandomHorizontalFlip()] image_transforms = [tfs.ToTensor(), tfs.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])] train_transforms = tfs.Compose(augment_transforms + image_transforms) test_transforms = tfs.Compose(image_transforms) self.train_dataset = datasets.CIFAR100(self.datadir, train=True, download=True, transform=train_transforms) self.test_dataset = datasets.CIFAR100(self.datadir, train=False, download=True, transform=test_transforms) elif self.dataset_name == 'cifar10': self.input_shape, self.num_classes = (3, 32, 32), 10 self.train_dataset = datasets.CIFAR10(self.datadir, train=True, download=True, transform=image_transforms) self.test_dataset = datasets.CIFAR10(self.datadir, train=False, download=True, transform=image_transforms) else: raise Exception("{} dataset not found!".format(self.dataset_name))
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1): self.data_dir = data_dir dataset = datasets.Omniglot(self.data_dir, download=True, background=True) eval_dataset = datasets.Omniglot(self.data_dir, download=True, background=False) target_transform = OmniglotTargetTransform(self.data_dir, background=True) eval_target_transform = OmniglotTargetTransform(self.data_dir, background=False) self.dataset = datasets.Omniglot(self.data_dir, background=True, download=True, transform=omni_transforms, target_transform=target_transform) self.eval_dataset = datasets.Omniglot(self.data_dir, background=False, download=True, transform=omni_transforms, target_transform=eval_target_transform) self.targets = np.array([self.dataset[i][1] for i in range(len(self.dataset))]) eval_targets = np.array([self.eval_dataset[i][1] for i in range(len(self.eval_dataset))]) super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, self.targets, drop_train_last=False, drop_valid_last=False, evaluation=(self.eval_dataset, eval_targets))
def omniglot(): return itertools.chain(*[ collect_download_configs( lambda: datasets.Omniglot( ROOT, background=background, download=True), name=f"Omniglot, {'background' if background else 'evaluation'}", ) for background in (True, False) ])
def download(self): origin_dir = 'data/omniglot-py' processed_dir = self.root_dir dset.Omniglot(root='data', background=False, download=True) dset.Omniglot(root='data', background=True, download=True) try: os.mkdir(processed_dir) except OSError: pass for p in ['images_background', 'images_evaluation']: for f in os.listdir(os.path.join(origin_dir, p)): shutil.move(os.path.join(origin_dir, p, f), processed_dir) shutil.rmtree(origin_dir)
def __init__(self, height=32, length=32): self.channels = 1 self.height = height self.length = length self.data = datasets.Omniglot(root='./data', download=True) self.make_tasks() self.split_validation_and_training_task() self.resize = transforms.Resize((self.height, self.length)) self.to_tensor = transforms.ToTensor()
def __init__(self, height=32, length=32): self.channels = 1 self.height = height self.length = length self.data = datasets.Omniglot(root='./data', download=True) self.task_maker() self.split_dataset() self.resize = transforms.Resize((self.height, self.length)) self.tensor = transforms.ToTensor()
def __init__(self, height=32, length=32): self.channels = 1 self.height = height self.length = length self.data = datasets.Omniglot( root= 'C:/Users/kashi/Documents/CMPE_258/proj/FIGR-master/FIGR-master/data/', download=True) self.make_tasks() self.split_validation_and_training_task() self.resize = transforms.Resize((self.height, self.length)) self.to_tensor = transforms.ToTensor()
def load_data_and_initialize_loaders(data_name, train_batch, test_batch): data_name = data_name.lower() kwargs = {'num_workers': 1, 'pin_memory': True} if data_name == 'mnist': train_data = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor()) test_data = datasets.MNIST('./data', train=False, transform=transforms.ToTensor()) elif data_name == 'fashion' or data_name == 'fashionmnist': train_data = datasets.FashionMNIST('./data', train=True, download=True, transform=transforms.ToTensor()) test_data = datasets.FashionMNIST('./data', train=False, transform=transforms.ToTensor()) elif data_name == 'omniglot': train_data = datasets.Omniglot(root='./data', background=True, download=True, transform=transforms.ToTensor()) test_data = datasets.Omniglot(root='./data', background=False, download=True, transform=transforms.ToTensor()) # else: raise Exception("Data name not recognized") train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch, shuffle=True, **kwargs) return train_loader, test_loader
def load_omniglot(args): torch.cuda.manual_seed(1) kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True} path = 'data_o/' if args.scratch: path = '/scratch/eecs-share/ratzlafn/' + path train_loader = torch.utils.data.DataLoader(datasets.Omniglot( path, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=32, shuffle=True, **kwargs) return train_loader
def few_shot_omniglot(root, train=True, batch_size=128, timesteps=15, n_jobs=0, resize=28, **kwargs): train_set = dset.Omniglot(root=root, download=True, background=train, transform=transforms.Compose([ transforms.Resize([resize, resize], interpolation=Image.NEAREST), transforms.ToTensor(), transforms.Lambda(lambda x: 1 - x) ])) collate = FewShotCollate(timesteps=timesteps) sampler = FewShotSampler( num_cls=len(train_set._character_images), batch_size=batch_size, separate=True, timesteps=timesteps, **kwargs, img_per_class=len(train_set._character_images[0]), ) return torch.utils.data.DataLoader( dataset=train_set, batch_sampler=sampler, collate_fn=collate, num_workers=n_jobs )
def __load_data(subset) -> pd.DataFrame: # 必要があればデータをダウンロード datasets.Omniglot(root=config.DATA_PATH, download=True, background=(subset == 'background')) # プログレスバー表示のため,全量をカウント print(f'loading omniglot dataset ({subset})') total_images = 0 for root, folders, files in os.walk( f'{config.DATA_PATH}/omniglot-py/images_{subset}/'): total_images += len(files) # ファイルシステムを参照し,画像データに属性を付与してDataFrameをつくる progress = tqdm(total=total_images) images = list() for root, folders, files in os.walk( f'{config.DATA_PATH}/omniglot-py/images_{subset}/'): alphabet = root.split('/')[-2] class_name = alphabet + '.' + root.split('/')[-1] for f in files: images.append({ 'subset': subset, 'alphabet': alphabet, 'class_name': class_name, 'filepath': os.path.join(root, f) }) progress.update(1) progress.close() # DataFrameに変換 df = pd.DataFrame(images) df = df.assign(id=df.index.values) # indexに応じた値をIDカラムとして追加 unique_characters = sorted(df['class_name'].unique()) num_classes = len(df['class_name'].unique()) class_name_to_id = { unique_characters[i]: i for i in range(num_classes) } df = df.assign( class_id=df['class_name'].apply(lambda c: class_name_to_id[ c])) # クラスごとにユニークなIDを振り,class_nameカラムとして追加 return df
def load_omniglot(root_dir=None, batch_size=20, shuffle=True, transform=None, download=True): dataset_type = "binary" if root_dir is None: root_dir = pathlib.Path(sys.argv[0]).parents[0] / 'datasets' # root_dir = str(root_dir) if transform is None: transform = transforms.ToTensor() train_dataset = datasets.Omniglot(root_dir, transform=transform, download=download) test_dataset = datasets.Omniglot(root_dir, transform=transform, download=download, background=False) train_data = np.zeros((len(train_dataset), 105, 105)) test_data = np.zeros((len(test_dataset), 105, 105)) for i, (image, _) in enumerate(train_dataset): train_data[i] = image.numpy() / 255 for i, (image, _) in enumerate(test_dataset): test_data[i] = image.numpy() / 255 if shuffle: np.random.shuffle(train_data) np.random.shuffle(test_data) train_data = torch.from_numpy(train_data) test_data = torch.from_numpy(test_data) # no labels train_labels = torch.zeros(train_data.shape) test_labels = torch.zeros(test_data.shape) train_dataset = data_utils.TensorDataset(train_data.float(), train_labels) test_dataset = data_utils.TensorDataset(test_data.float(), test_labels) size_train = len(train_dataset) indices = list(range(size_train)) val_split = size_train - 1345 # given by god to make the batch size reasonable train_idx, valid_idx = indices[:val_split], indices[val_split:] train_sampler = SubsetSampler(train_idx) valid_sampler = SubsetSampler(valid_idx) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False, ) valid_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, shuffle=False) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle) return train_loader, test_loader, valid_loader, dataset_type
def main(): # Training settings parser = argparse.ArgumentParser(description='Amortized approximation on MNIST') parser.add_argument('--batch-size', type=int, default=256, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--approx-epochs', type=int, default=200, metavar='N', help='number of epochs to approx (default: 10)') parser.add_argument('--lr', type=float, default=1e-2, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--S', type=int, default=100, metavar='N', help='number of posterior samples from the Bayesian model') parser.add_argument('--model-path', type=str, default='../saved_models/mnist_sgld/', metavar='N', help='number of posterior samples from the Bayesian model') parser.add_argument('--from-approx-model', type=int, default=1, metavar='N', help='if our model is loaded or trained') parser.add_argument('--test-ood-from-disk', type=int, default=1, help='generate test samples or load from disk') parser.add_argument('--ood-name', type=str, default='omniglot', help='name of the used ood dataset') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} tr_data = MNIST('../data', train=True, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]), download=True) te_data = MNIST('../data', train=False, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]), download=True) train_loader = torch.utils.data.DataLoader( tr_data, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader( te_data, batch_size=args.batch_size, shuffle=False, **kwargs) if args.ood_name == 'omniglot': ood_data = datasets.Omniglot('../../data', download=True, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])) elif args.ood_name == 'SEMEION': ood_data = datasets.SEMEION('../../data', download=True, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])) ood_loader = torch.utils.data.DataLoader( ood_data, batch_size=args.batch_size, shuffle=False, **kwargs) model = mnist_mlp(dropout=False).to(device) model.load_state_dict(torch.load(args.model_path + 'sgld-mnist.pt')) test(args, model, device, test_loader) if args.from_approx_model == 0: output_samples = torch.load(args.model_path + 'mnist-sgld-train-samples.pt') # --------------- training approx --------- fmodel = mnist_mlp_h().to(device) gmodel = mnist_mlp_g().to(device) if args.from_approx_model == 0: g_optimizer = optim.SGD(gmodel.parameters(), lr=args.lr) f_optimizer = optim.SGD(fmodel.parameters(), lr=args.lr) best_acc = 0 for epoch in range(1, args.approx_epochs + 1): train_approx(args, fmodel, gmodel, device, train_loader, f_optimizer, g_optimizer, output_samples, epoch) acc = test(args, fmodel, device, test_loader) # if (args.save_approx_model == 1): if acc > best_acc: torch.save(fmodel.state_dict(), args.model_path + 'sgld-mnist-mmd-mean.pt') torch.save(gmodel.state_dict(), args.model_path + 'sgld-mnist-mmd-conc.pt') best_acc = acc else: fmodel.load_state_dict(torch.load(args.model_path + 'sgld-mnist-mmd-mean.pt')) gmodel.load_state_dict(torch.load(args.model_path + 'sgld-mnist-mmd-conc.pt')) print('generating teacher particles for testing&ood data ...') # generate particles for test and ood dataset model.train() if args.test_ood_from_disk == 1: teacher_test_samples = torch.load(args.model_path + 'mnist-sgld-test-samples.pt') else: with torch.no_grad(): # obtain ensemble outputs all_samples = [] for i in range(500): samples_a_round = [] for data, target in test_loader: data = data.to(device) data = data.view(data.shape[0], -1) output = F.softmax(model(data)) samples_a_round.append(output) samples_a_round = torch.cat(samples_a_round).cpu() all_samples.append(samples_a_round) all_samples = torch.stack(all_samples).permute(1,0,2) torch.save(all_samples, args.model_path + 'mnist-sgld-test-samples.pt') teacher_test_samples = all_samples if args.test_ood_from_disk == 1: teacher_ood_samples = torch.load(args.model_path + 'mnist-sgld-' + args.ood_name + '-samples.pt') else: with torch.no_grad(): # obtain ensemble outputs all_samples = [] for i in range(500): samples_a_round = [] for data, target in ood_loader: data = data.to(device) data = data.view(data.shape[0], -1) output = F.softmax(model(data)) samples_a_round.append(output) samples_a_round = torch.cat(samples_a_round).cpu() all_samples.append(samples_a_round) all_samples = torch.stack(all_samples).permute(1,0,2) torch.save(all_samples, args.model_path + 'mnist-sgld-' + args.ood_name + '-samples.pt') teacher_ood_samples = all_samples eval_approx(args, fmodel, gmodel, device, test_loader, ood_loader, teacher_test_samples, teacher_ood_samples)
def run(seed): assert torch.cuda.is_available() device = torch.device('cuda') torch.set_default_tensor_type('torch.cuda.FloatTensor') np.random.seed(seed) torch.manual_seed(seed) # Create training data. data_transform = tvtransforms.Compose( [tvtransforms.ToTensor(), tvtransforms.Lambda(torch.bernoulli)]) if args.dataset_name == 'mnist': dataset = datasets.MNIST(root=os.path.join(utils.get_data_root(), 'mnist'), train=True, download=True, transform=data_transform) test_dataset = datasets.MNIST(root=os.path.join( utils.get_data_root(), 'mnist'), train=False, download=True, transform=data_transform) elif args.dataset_name == 'fashion-mnist': dataset = datasets.FashionMNIST(root=os.path.join( utils.get_data_root(), 'fashion-mnist'), train=True, download=True, transform=data_transform) test_dataset = datasets.FashionMNIST(root=os.path.join( utils.get_data_root(), 'fashion-mnist'), train=False, download=True, transform=data_transform) elif args.dataset_name == 'omniglot': dataset = datasets.Omniglot(root=os.path.join(utils.get_data_root(), 'omniglot'), train=False, download=True, transform=data_transform) test_dataset = datasets.Omniglot(root=os.path.join( utils.get_data_root(), 'omniglot'), train=False, download=True, transform=data_transform) elif args.dataset_name == 'emnist': rotate = partial(tvF.rotate, angle=-90) hflip = tvF.hflip data_transform = tvtransforms.Compose([ tvtransforms.Lambda(rotate), tvtransforms.Lambda(hflip), tvtransforms.ToTensor(), tvtransforms.Lambda(torch.bernoulli) ]) dataset = datasets.EMNIST(root=os.path.join(utils.get_data_root(), 'emnist'), split='letters', train=True, transform=data_transform, download=True) test_dataset = datasets.EMNIST(root=os.path.join( utils.get_data_root(), 'emnist'), split='letters', train=False, transform=data_transform, download=True) else: raise ValueError if args.dataset_name == 'omniglot': split = -1345 elif args.dataset_name == 'emnist': split = -20000 else: split = -10000 indices = np.arange(len(dataset)) np.random.shuffle(indices) train_indices, val_indices = indices[:split], indices[split:] train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) train_loader = data.DataLoader( dataset=dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=4 if args.dataset_name == 'emnist' else 0) train_generator = data_.batch_generator(train_loader) val_loader = data.DataLoader(dataset=dataset, batch_size=1024, sampler=val_sampler, shuffle=False, drop_last=False) val_batch = next(iter(val_loader))[0] test_loader = data.DataLoader( test_dataset, batch_size=16, shuffle=False, drop_last=False, ) def create_linear_transform(): if args.linear_type == 'lu': return transforms.CompositeTransform([ transforms.RandomPermutation(args.latent_features), transforms.LULinear(args.latent_features, identity_init=True) ]) elif args.linear_type == 'svd': return transforms.SVDLinear(args.latent_features, num_householder=4, identity_init=True) elif args.linear_type == 'perm': return transforms.RandomPermutation(args.latent_features) else: raise ValueError def create_base_transform(i, context_features=None): if args.prior_type == 'affine-coupling': return transforms.AffineCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm)) elif args.prior_type == 'rq-coupling': return transforms.PiecewiseRationalQuadraticCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm), num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, apply_unconditional_transform=args. apply_unconditional_transform, ) elif args.prior_type == 'rl-coupling': return transforms.PiecewiseRationalLinearCouplingTransform( mask=utils.create_alternating_binary_mask( features=args.latent_features, even=(i % 2 == 0)), transform_net_create_fn=lambda in_features, out_features: nn_. ResidualNet(in_features=in_features, out_features=out_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm), num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, apply_unconditional_transform=args. apply_unconditional_transform, ) elif args.prior_type == 'affine-autoregressive': return transforms.MaskedAffineAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) elif args.prior_type == 'rq-autoregressive': return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) elif args.prior_type == 'rl-autoregressive': return transforms.MaskedPiecewiseRationalLinearAutoregressiveTransform( features=args.latent_features, hidden_features=args.hidden_features, context_features=context_features, num_bins=args.num_bins, tails='linear', tail_bound=args.tail_bound, num_blocks=args.num_transform_blocks, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=args.dropout_probability, use_batch_norm=args.use_batch_norm) else: raise ValueError # --------------- # prior # --------------- def create_prior(): if args.prior_type == 'standard-normal': prior = distributions_.StandardNormal((args.latent_features, )) else: distribution = distributions_.StandardNormal( (args.latent_features, )) transform = transforms.CompositeTransform([ transforms.CompositeTransform( [create_linear_transform(), create_base_transform(i)]) for i in range(args.num_flow_steps) ]) transform = transforms.CompositeTransform( [transform, create_linear_transform()]) prior = flows.Flow(transform, distribution) return prior # --------------- # inputs encoder # --------------- def create_inputs_encoder(): if args.approximate_posterior_type == 'diagonal-normal': inputs_encoder = None else: inputs_encoder = nn_.ConvEncoder( context_features=args.context_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) return inputs_encoder # --------------- # approximate posterior # --------------- def create_approximate_posterior(): if args.approximate_posterior_type == 'diagonal-normal': context_encoder = nn_.ConvEncoder( context_features=args.context_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) approximate_posterior = distributions_.ConditionalDiagonalNormal( shape=[args.latent_features], context_encoder=context_encoder) else: context_encoder = nn.Linear(args.context_features, 2 * args.latent_features) distribution = distributions_.ConditionalDiagonalNormal( shape=[args.latent_features], context_encoder=context_encoder) transform = transforms.CompositeTransform([ transforms.CompositeTransform([ create_linear_transform(), create_base_transform( i, context_features=args.context_features) ]) for i in range(args.num_flow_steps) ]) transform = transforms.CompositeTransform( [transform, create_linear_transform()]) approximate_posterior = flows.Flow( transforms.InverseTransform(transform), distribution) return approximate_posterior # --------------- # likelihood # --------------- def create_likelihood(): latent_decoder = nn_.ConvDecoder( latent_features=args.latent_features, channels_multiplier=16, dropout_probability=args.dropout_probability_encoder_decoder) likelihood = distributions_.ConditionalIndependentBernoulli( shape=[1, 28, 28], context_encoder=latent_decoder) return likelihood prior = create_prior() approximate_posterior = create_approximate_posterior() likelihood = create_likelihood() inputs_encoder = create_inputs_encoder() model = vae.VariationalAutoencoder( prior=prior, approximate_posterior=approximate_posterior, likelihood=likelihood, inputs_encoder=inputs_encoder) n_params = utils.get_num_parameters(model) print('There are {} trainable parameters in this model.'.format(n_params)) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=args.num_training_steps, eta_min=0) def get_kl_multiplier(step): if args.kl_multiplier_schedule == 'constant': return args.kl_multiplier_initial elif args.kl_multiplier_schedule == 'linear': multiplier = min( step / (args.num_training_steps * args.kl_warmup_fraction), 1.) return args.kl_multiplier_initial * (1. + multiplier) # create summary writer and write to log directory timestamp = cutils.get_timestamp() if cutils.on_cluster(): timestamp += '||{}'.format(os.environ['SLURM_JOB_ID']) log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp) while True: try: writer = SummaryWriter(log_dir=log_dir, max_queue=20) break except FileExistsError: sleep(5) filename = os.path.join(log_dir, 'config.json') with open(filename, 'w') as file: json.dump(vars(args), file) best_val_elbo = -np.inf tbar = tqdm(range(args.num_training_steps)) for step in tbar: model.train() optimizer.zero_grad() batch = next(train_generator)[0].to(device) elbo = model.stochastic_elbo(batch, kl_multiplier=get_kl_multiplier(step)) loss = -torch.mean(elbo) loss.backward() optimizer.step() scheduler.step(step) if (step + 1) % args.monitor_interval == 0: model.eval() with torch.no_grad(): elbo = model.stochastic_elbo(val_batch.to(device)) mean_val_elbo = elbo.mean() if mean_val_elbo > best_val_elbo: best_val_elbo = mean_val_elbo path = os.path.join( cutils.get_checkpoint_root(), '{}-best-val-{}.t'.format(args.dataset_name, timestamp)) torch.save(model.state_dict(), path) writer.add_scalar(tag='val-elbo', scalar_value=mean_val_elbo, global_step=step) writer.add_scalar(tag='best-val-elbo', scalar_value=best_val_elbo, global_step=step) with torch.no_grad(): samples = model.sample(64) fig, ax = plt.subplots(figsize=(10, 10)) cutils.gridimshow(make_grid(samples.view(64, 1, 28, 28), nrow=8), ax) writer.add_figure(tag='vae-samples', figure=fig, global_step=step) plt.close() # load best val model path = os.path.join( cutils.get_checkpoint_root(), '{}-best-val-{}.t'.format(args.dataset_name, timestamp)) model.load_state_dict(torch.load(path)) model.eval() np.random.seed(5) torch.manual_seed(5) # compute elbo on test set with torch.no_grad(): elbo = torch.Tensor([]) log_prob_lower_bound = torch.Tensor([]) for batch in tqdm(test_loader): elbo_ = model.stochastic_elbo(batch[0].to(device)) elbo = torch.cat([elbo, elbo_]) log_prob_lower_bound_ = model.log_prob_lower_bound( batch[0].to(device), num_samples=1000) log_prob_lower_bound = torch.cat( [log_prob_lower_bound, log_prob_lower_bound_]) path = os.path.join( log_dir, '{}-prior-{}-posterior-{}-elbo.npy'.format( args.dataset_name, args.prior_type, args.approximate_posterior_type)) np.save(path, utils.tensor2numpy(elbo)) path = os.path.join( log_dir, '{}-prior-{}-posterior-{}-log-prob-lower-bound.npy'.format( args.dataset_name, args.prior_type, args.approximate_posterior_type)) np.save(path, utils.tensor2numpy(log_prob_lower_bound)) # save elbo and log prob lower bound mean_elbo = elbo.mean() std_elbo = elbo.std() mean_log_prob_lower_bound = log_prob_lower_bound.mean() std_log_prob_lower_bound = log_prob_lower_bound.std() s = 'ELBO: {:.2f} +- {:.2f}, LOG PROB LOWER BOUND: {:.2f} +- {:.2f}'.format( mean_elbo.item(), 2 * std_elbo.item() / np.sqrt(len(test_dataset)), mean_log_prob_lower_bound.item(), 2 * std_log_prob_lower_bound.item() / np.sqrt(len(test_dataset))) filename = os.path.join(log_dir, 'test-results.txt') with open(filename, 'w') as file: file.write(s)
def _get_dataset(self, train: bool, transform: Any) -> torch.utils.data.Dataset: return datasets.Omniglot(self.data_folder, background=train, download=False, transform=transform)
import torch import torch.nn.functional as F import utils from torchvision import datasets, transforms import VAE_model as vae import easy_conv_vae as conv_vae import maml_class import numpy as np import pickle from torch import optim import matplotlib.pyplot as plt import random from tqdm import tqdm import seaborn as sns data = datasets.Omniglot(root='./data', download=True) source_task_number=500 task_set_number = 200 data_set = utils.get_dataset(data,50, 20, 10) random.shuffle(data_set) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") epochs=200 betas = [1.] for j,beta in enumerate(betas): model = conv_vae.VAE(z_dim=128) model.to(device) maml = maml_class.MAML(model=model, data=data_set, inner_lr=1e-2,
def train_gan(opt): os.makedirs(os.path.join(opt.savingroot, opt.dataset, 'images'), exist_ok=True) os.makedirs(os.path.join(opt.savingroot, opt.dataset, 'chkpts'), exist_ok=True) #Build networ if opt.model_type == 'sa': AC = False if opt.loss_type == 'Projection': AC = False elif opt.loss_type == 'Twin_AC': AC = True elif opt.loss_type == 'AC': AC = True netd_g = nn.DataParallel( SA_Discriminator(n_class=opt.num_classes, nc=opt.nc, AC=AC, Resolution=opt.image_size, ch=64).cuda()) netg = nn.DataParallel( SA_Generator(n_class=opt.num_classes, code_dim=opt.nz, nc=opt.nc, SN=opt.SN, Resolution=opt.image_size, ch=32).cuda()) elif opt.model_type == 'big': AC = False if opt.loss_type == 'Projection': AC = False elif opt.loss_type == 'Twin_AC': AC = True elif opt.loss_type == 'AC': AC = True netd_g = nn.DataParallel( Discriminator(n_classes=opt.num_classes, resolution=opt.image_size, AC=AC).cuda()) netg = nn.DataParallel( Generator(n_classes=opt.num_classes, resolution=opt.image_size, SN=opt.SN).cuda()) if opt.data_r == 'MNIST': dataset = dset.MNIST(root=opt.dataroot, download=True, transform=tsfm) elif opt.data_r == 'CIFAR10': dataset = dset.CIFAR10(root=opt.dataroot, download=True, transform=tsfm) elif opt.data_r == 'CIFAR100': dataset = dset.CIFAR100(root=opt.dataroot, download=True, transform=tsfm) elif opt.data_r == 'CUB': dataset = dset.ImageFolder( root='/home/yanwuxu/CUB_200_2011_processed/ImageNet/ImageNet/', transform=tsfm) elif opt.data_r == 'VGGFACE': dataset = ILSVRC_HDF5(root='../data/VGGFACE64.hdf5', transform=tsfm) elif opt.data_r == 'IMAGENET100': dataset = Load_numpy_data(root='../data/ImageNet100.pt', transform=tsfm) elif opt.data_r == 'MNIST_overlap': dataset = Load_gray_data(root='../data/overlap_MNIST.pt', transform=tsfm) elif opt.data_r == 'OMNIGLOT': dataset = dset.Omniglot( '../result', transform=tsfm, download=True) #Load_gray_data(root='omniglot.pt', transform=tsfm) print('training_start') print(opt.loss_type) step = 0 train_g(netd_g, netg, dataset, step, opt)
import utils import easy_conv_vae as conv_vae import pickle import matplotlib.pyplot as plt import numpy as np import seaborn as sns device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") number_of_tasks = 500 train_number = 20 * number_of_tasks transform = transforms.Compose( [transforms.Resize((28, 28)), transforms.ToTensor()]) data = datasets.Omniglot(root='./data', transform=transform) train_set = list(data)[:train_number] test_set = list(data)[train_number:train_number + 1000] train_loader = torch.utils.data.DataLoader(train_set, batch_size=20, shuffle=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size=20, shuffle=True) model = conv_vae.VAE(z_dim=128).to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) epochs = 20 beta = 1. # Reconstruction + KL divergence losses summed over all elements and batch
nn.init.normal_(m.weight,0,2e-2) nn.init.normal_(m.bias, 0.5, 1e-2) #testing net=Siamese_Net() # print(net) dummyx1=torch.randn(64,1,105,105) dummyx2=torch.randn(64,1,105,105) o=net(dummyx1,dummyx2) print(o.shape) """# **Omniglot Dataset**""" # download the dataset omni_train=datasets.Omniglot(root='/content', background= True, download = True ) omni_test=datasets.Omniglot(root='/content', background= False, download = True ) # Omniglot dataset- didnt rename cause would have had to rename it everywhere class Face_Dataset(Dataset): def __init__(self,root_dir, job='train',ways=10,transform=None): super(Face_Dataset,self).__init__() self.root_dir=root_dir self.job=job self.all_classes=os.listdir(root_dir) self.num_classes=len(self.all_classes)
def main(): # Training settings parser = argparse.ArgumentParser( description='run approximation to LeNet on Mnist') parser.add_argument('--batch-size', type=int, default=256, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--dropout-rate', type=float, default=0.5, metavar='p_drop', help='dropout rate') parser.add_argument( '--S', type=int, default=500, metavar='N', help='number of posterior samples from the Bayesian model') parser.add_argument( '--model-path', type=str, default='../saved_models/mnist_sgld/', metavar='N', help='number of posterior samples from the Bayesian model') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=False, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) model = mnist_mlp(dropout=False).to(device) optimizer = SGLD(model.parameters(), lr=args.lr) import copy import pickle as pkl for epoch in range(1, args.epochs + 1): train_bayesian(args, model, device, train_loader, optimizer, epoch) print("epoch: {}".format(epoch)) test(args, model, device, test_loader) # save models torch.save(model.state_dict(), args.model_path + 'sgld-mnist.pt') # save samples param_samples = [] while (1): for idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() data = data.view(data.shape[0], -1) output = model(data) loss = F.nll_loss(F.log_softmax(output, dim=1), target) loss.backward() optimizer.step() param_samples.append(copy.deepcopy(model.state_dict())) if param_samples.__len__() >= args.S: print('1', len(param_samples)) break if param_samples.__len__() >= args.S: print('2', len(param_samples)) break with open(args.model_path + "sgld_samples.pkl", "wb") as f: print('3', len(param_samples)) pkl.dump(param_samples, f) test(args, model, device, test_loader) train_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])), batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=False, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) # generate teacher train samples with torch.no_grad(): # obtain ensemble outputs all_samples = [] for i in range(500): samples_a_round = [] model.load_state_dict(param_samples[i]) for data, target in train_loader: data = data.to(device) data = data.view(data.shape[0], -1) output = F.softmax(model(data)) samples_a_round.append(output) samples_a_round = torch.cat(samples_a_round).cpu() all_samples.append(samples_a_round) all_samples = torch.stack(all_samples).permute(1, 0, 2) torch.save(all_samples, args.model_path + 'mnist-sgld-train-samples.pt') # generate teacher test samples with torch.no_grad(): # obtain ensemble outputs all_samples = [] for i in range(500): samples_a_round = [] model.load_state_dict(param_samples[i]) for data, target in test_loader: data = data.to(device) data = data.view(data.shape[0], -1) output = F.softmax(model(data)) samples_a_round.append(output) samples_a_round = torch.cat(samples_a_round).cpu() all_samples.append(samples_a_round) all_samples = torch.stack(all_samples).permute(1, 0, 2) torch.save(all_samples, args.model_path + 'mnist-sgld-test-samples.pt') # generate teacher omniglot samples ood_data = datasets.Omniglot( '../../data', download=True, transform=transforms.Compose([ # transforms.ToPILImage(), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )), ])) ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=False, **kwargs) with torch.no_grad(): # obtain ensemble outputs all_samples = [] for i in range(500): samples_a_round = [] model.load_state_dict(param_samples[i]) for data, target in ood_loader: data = data.to(device) data = data.view(data.shape[0], -1) output = F.softmax(model(data)) samples_a_round.append(output) samples_a_round = torch.cat(samples_a_round).cpu() all_samples.append(samples_a_round) all_samples = torch.stack(all_samples).permute(1, 0, 2) torch.save(all_samples, args.model_path + 'mnist-sgld-omniglot-samples.pt') # generate teacher SEMEION samples ood_data = datasets.SEMEION( '../../data', download=True, transform=transforms.Compose([ # transforms.ToPILImage(), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )), ])) ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=False, **kwargs) with torch.no_grad(): # obtain ensemble outputs all_samples = [] for i in range(500): samples_a_round = [] model.load_state_dict(param_samples[i]) for data, target in ood_loader: data = data.to(device) data = data.view(data.shape[0], -1) output = F.softmax(model(data)) samples_a_round.append(output) samples_a_round = torch.cat(samples_a_round).cpu() all_samples.append(samples_a_round) all_samples = torch.stack(all_samples).permute(1, 0, 2) torch.save(all_samples, args.model_path + 'mnist-sgld-SEMEION-samples.pt')
def get_datasets( dataset_name, frac_val=FRAC_VAL, batch_size=8, img_shape=None, nn_architecture=None, train_params=None, synthetic_params=None, class_2d=None, kwargs=None, ): if kwargs is None: kwargs = KWARGS img_shape_no_channel = None if img_shape is not None: img_shape_no_channel = img_shape[1:] # TODO(nina): Consistency in datasets: add channels for all logging.info("Loading data from dataset: %s" % dataset_name) if dataset_name == "mnist": train_dataset, val_dataset = get_dataset_mnist() elif dataset_name == "omniglot": if img_shape_no_channel is not None: transform = transforms.Compose([ transforms.Resize(img_shape_no_channel), transforms.ToTensor() ]) else: transform = transforms.ToTensor() dataset = datasets.Omniglot("../data", download=True, transform=transform) train_dataset, val_dataset = split_dataset(dataset, frac_val=frac_val) elif dataset_name in [ "cryo_sim", "randomrot1D_nodisorder", "randomrot1D_multiPDB", "randomrot_nodisorder", ]: dataset = get_dataset_cryo(dataset_name, img_shape_no_channel, kwargs) train_dataset, val_dataset = split_dataset(dataset) elif dataset_name == "cryo_sphere": dataset = get_dataset_cryo_sphere(img_shape_no_channel, kwargs) train_dataset, val_dataset = split_dataset(dataset) elif dataset_name == "cryo_exp": dataset = get_dataset_cryo_exp(img_shape_no_channel, kwargs) train_dataset, val_dataset = split_dataset(dataset) elif dataset_name == "cryo_exp_class_2d": dataset = get_dataset_cryo_exp_class_2d(img_shape_no_channel, class_2d) # , kwargs) train_dataset, val_dataset = split_dataset(dataset) elif dataset_name == "cryo_exp_3d": dataset = get_dataset_cryo_exp_3d(img_shape_no_channel, kwargs) train_dataset, val_dataset = split_dataset(dataset) elif dataset_name == "connectomes": train_dataset, val_dataset = get_dataset_connectomes( img_shape_no_channel=img_shape_no_channel) elif dataset_name == "connectomes_simu": train_dataset, val_dataset = get_dataset_connectomes_simu( img_shape_no_channel=img_shape_no_channel) elif dataset_name == "connectomes_schizophrenia": train_dataset, val_dataset, _ = get_dataset_connectomes_schizophrenia() elif dataset_name in ["mri", "segmentation", "fmri"]: train_loader, val_loader = get_loaders_brain(dataset_name, frac_val, batch_size, img_shape_no_channel, kwargs) return train_loader, val_loader elif dataset_name == "synthetic": dataset = make_synthetic_dataset_and_decoder( synthetic_params=synthetic_params, nn_architecture=nn_architecture, train_params=train_params, ) train_dataset, val_dataset = split_dataset(dataset) else: raise ValueError("Unknown dataset name: %s" % dataset_name) return train_dataset, val_dataset
batch_size = 32 max_length = 15 rescaling = lambda x: (x - .5) * 2. rescaling_inv = lambda x: .5 * x + .5 flip = lambda x: -x kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True} resizing = lambda x: x.resize((28, 28)) omni_transforms = transforms.Compose( [resizing, transforms.ToTensor(), rescaling, flip]) #TODO: check this, but i think i don't want rescaling kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True} train_loader = torch.utils.data.DataLoader(datasets.Omniglot( '../vhe/data', download=True, background=True, transform=omni_transforms), batch_size=batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.Omniglot( '../vhe/data', download=True, background=False, transform=omni_transforms), batch_size=batch_size, shuffle=True, **kwargs)
def load_dataset( root, name, dataset_size=None, transform=None, split="train", fraction=1.0, batch_size=None, ): name = name.lower().replace("-", "_") assert split in ("train", "test") if name == "mnist": data = datasets.MNIST( os.path.join(root, "MNIST"), train=split == "train", download=True, transform=transform, ) elif name == "fashion-mnist": data = datasets.FashionMNIST( os.path.join(root, "FASHION-MNIST"), train=split == "train", download=True, transform=transform, ) elif name == "cifar10": data = datasets.CIFAR10( os.path.join(root, "CIFAR10"), train=split == "train", download=True, transform=transform, ) elif name == "svhn": data = datasets.SVHN( os.path.join(root, "SVHN"), split=split, download=True, transform=transform ) elif name == "stl10": if split == "train": split += "+unlabeled" data = datasets.STL10( os.path.join(root, "STL10"), split=split, download=True, transform=transform ) elif name == "lsun-bed": data = datasets.LSUN( os.path.join(root, "LSUN"), classes=["bedroom_train"], transform=transform ) elif name == "omniglot": data = datasets.Omniglot( os.path.join(root, "OMNIGLOT"), background=split == "train", download=True, transform=transform, ) elif name.startswith("test"): _, c, d = name.split("_") data = datasets.FakeData( size=dataset_size, image_size=(int(c), int(d), int(d)), num_classes=2, transform=transform, ) data.labels = np.random.randint(0, 2, dataset_size) elif name == "scaly": data = SingleFolderDataset(os.path.join(root, "SCALY"), transform=transform) elif name == "celeba": center_crop = transforms.CenterCrop(178) # deal with non squared celebA transform = transforms.Compose([center_crop, transform]) data = SingleFolderDataset( os.path.join(root, "celebA", "img_align_celeba"), transform=transform ) else: raise NotImplementedError(name) assert 0 < fraction <= 1 assert dataset_size is None or dataset_size <= len(data) if dataset_size is None: dataset_size = len(data) size = min(int(fraction * len(data)), dataset_size) size = max( size, batch_size if batch_size is not None else 0, torch.cuda.device_count(), 1 ) if size < len(data): points = np.random.choice(range(len(data)), replace=False, size=size) data = torch.utils.data.Subset(data, points) return IgnoreLabelDataset(data)
def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.view(x.size(0), -1) return self.logits(x) if __name__ == "__main__": trans = transforms.Compose( [transforms.Resize((28, 28)), transforms.ToTensor()]) tasks = Omniglot_Task_Distribution( datasets.Omniglot('./Omniglot/', transform=trans), 20) N, K = 5, 5 task = tasks.sample_task(N, K, 15) meta_model = Classifier(N) maml = MAML(meta_model.cuda(), tasks, inner_lr=0.01, meta_lr=0.001, K=10, inner_steps=1, tasks_per_meta_batch=32, criterion=nn.CrossEntropyLoss()) maml.main_loop(num_iterations=100)
elif 'cifar' in args.dataset : train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True, download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs) if args.nr_logistic_mix: loss_op = lambda real, fake : discretized_mix_logistic_loss(real, fake) sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix) else: raise NotImplementedError("No 3D Softmax") elif 'omni' in args.dataset : train_loader = torch.utils.data.DataLoader(datasets.Omniglot(args.data_dir, download=True, background=True, transform=omni_transforms), batch_size=1, shuffle=True, **kwargs) #d = datasets.Omniglot(args.data_dir, download=True, # background=True, transform=omni_transforms) test_loader = torch.utils.data.DataLoader(datasets.Omniglot(args.data_dir, download=True, background=False, transform=omni_transforms), batch_size=1, shuffle=True, **kwargs) if args.nr_logistic_mix: loss_op = lambda real, fake : discretized_mix_logistic_loss_1d(real, fake) sample_op = lambda x : sample_from_discretized_mix_logistic_1d(x, args.nr_logistic_mix) else: loss_op = lambda real, fake : softmax_loss_1d(real, fake) sample_op = lambda x : sample_from_softmax_1d(x)
import torchvision.datasets as tvd if __name__ == "__main__": data_dir = "./data" for train in [True, False]: tvd.MNIST(data_dir, download=True, train=train) tvd.CIFAR10(data_dir, download=True, train=train) tvd.Omniglot(data_dir, download=True, background=train)
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, background=True): self.data_dir = data_dir self.dataset = datasets.Omniglot(self.data_dir, background=background, download=True, transform=omni_transforms) super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)