def get_miniimagenet_dataloaders_torchmeta(args): args.trainin_with_epochs = False args.data_path = Path( '~/data/').expanduser() # for some datasets this is enough args.criterion = nn.CrossEntropyLoss() # args.image_size = 84 # do we need this? from torchmeta.datasets.helpers import miniimagenet normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) data_augmentation_transforms = transforms.Compose([ transforms.RandomResizedCrop(84), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), normalize ]) dataset_train = miniimagenet(args.data_path, transform=data_augmentation_transforms, ways=args.n_classes, shots=args.k_shots, test_shots=args.k_eval, meta_split='train', download=True) dataset_val = miniimagenet(args.data_path, ways=args.n_classes, shots=args.k_shots, test_shots=args.k_eval, meta_split='val', download=True) dataset_test = miniimagenet(args.data_path, ways=args.n_classes, shots=args.k_shots, test_shots=args.k_eval, meta_split='test', download=True) meta_train_dataloader = BatchMetaDataLoader( dataset_train, batch_size=args.meta_batch_size_train, num_workers=args.num_workers) meta_val_dataloader = BatchMetaDataLoader( dataset_val, batch_size=args.meta_batch_size_eval, num_workers=args.num_workers) meta_test_dataloader = BatchMetaDataLoader( dataset_test, batch_size=args.meta_batch_size_eval, num_workers=args.num_workers) return meta_train_dataloader, meta_val_dataloader, meta_test_dataloader
def create_miniimagenet_data_loader( root, meta_split, k_way, n_shot, n_query, batch_size, num_workers, download=False, seed=None, ): """Create a torchmeta BatchMetaDataLoader for MiniImagenet Args: root: Path to mini imagenet root folder (containing an 'miniimagenet'` subfolder with the preprocess json-Files or downloaded tar.gz-file). meta_split: see torchmeta.datasets.MiniImagenet k_way: Number of classes per task n_shot: Number of samples per class n_query: Number of test images per class batch_size: Meta batch size num_workers: Number of workers for data preprocessing download: Download (and dataset specific preprocessing that needs to be done on the downloaded files). seed: Seed to be used in the meta-dataset Returns: A torchmeta :class:`BatchMetaDataLoader` object. """ dataset = miniimagenet( root, n_shot, k_way, meta_split=meta_split, test_shots=n_query, download=download, seed=seed, ) dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) return dataloader
def dataset_f(args, meta_split: Literal['train', 'val', 'test'] = None): if meta_split is None: meta_split = 'train' meta_train = meta_split == 'train' meta_val = meta_split == 'val' meta_test = meta_split == 'test' dataset = args.dataset if dataset == 'miniimagenet' and meta_val and args.num_classes > 16: args.num_classes = 16 print( 'set num classes of mini_imagenet val to 16 because is the maximum' ) dataset_kwargs = dict( folder=DATAFOLDER, shots=args.support_samples, ways=args.num_classes, shuffle=True, test_shots=args.query_samples, seed=args.seed, target_transform=Categorical(num_classes=args.num_classes), download=True, meta_train=meta_train, meta_val=meta_val, meta_test=meta_test) if dataset == 'omniglot': return omniglot( **dataset_kwargs, class_augmentations=[Rotation([90, 180, 270])], ) elif dataset == 'miniimagenet': tg.set_dim('NUM_FEATURES', 1600) return miniimagenet(**dataset_kwargs) elif dataset.upper() == 'CUB': if args.support_samples == 0: from cub_dataset import CubDatasetEmbeddingsZeroShot print('Instantiating CubDatasetEmbeddingsZeroShot') return CubDatasetEmbeddingsZeroShot(DATAFOLDER, meta_split, args.query_samples, args.num_classes) else: return cub(**dataset_kwargs)
def main(): parser = argparse.ArgumentParser(description='Data HyperCleaner') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--dataset', type=str, default='omniglot', metavar='N', help='omniglot or miniimagenet') parser.add_argument('--hg-mode', type=str, default='CG', metavar='N', help='hypergradient approximation: CG or fixed_point') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') args = parser.parse_args() log_interval = 100 eval_interval = 500 inner_log_interval = None inner_log_interval_test = None ways = 5 batch_size = 16 n_tasks_test = 1000 # usually 1000 tasks are used for testing if args.dataset == 'omniglot': reg_param = 2 # reg_param = 2 T, K = 16, 5 # T, K = 16, 5 elif args.dataset == 'miniimagenet': reg_param = 0.5 # reg_param = 0.5 T, K = 10, 5 # T, K = 10, 5 else: raise NotImplementedError(args.dataset, " not implemented!") T_test = T inner_lr = .1 loc = locals() del loc['parser'] del loc['args'] print(args, '\n', loc, '\n') cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} # the following are for reproducibility on GPU, see https://pytorch.org/docs/master/notes/randomness.html # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False torch.random.manual_seed(args.seed) np.random.seed(args.seed) if args.dataset == 'omniglot': dataset = omniglot("data", ways=ways, shots=1, test_shots=15, meta_train=True, download=True) test_dataset = omniglot("data", ways=ways, shots=1, test_shots=15, meta_test=True, download=True) meta_model = get_cnn_omniglot(64, ways).to(device) elif args.dataset == 'miniimagenet': dataset = miniimagenet("data", ways=ways, shots=1, test_shots=15, meta_train=True, download=True) test_dataset = miniimagenet("data", ways=ways, shots=1, test_shots=15, meta_test=True, download=True) meta_model = get_cnn_miniimagenet(32, ways).to(device) else: raise NotImplementedError( "DATASET NOT IMPLEMENTED! only omniglot and miniimagenet ") dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, **kwargs) test_dataloader = BatchMetaDataLoader(test_dataset, batch_size=batch_size, **kwargs) outer_opt = torch.optim.Adam(params=meta_model.parameters()) # outer_opt = torch.optim.SGD(lr=0.1, params=meta_model.parameters()) inner_opt_class = hg.GradientDescent inner_opt_kwargs = {'step_size': inner_lr} def get_inner_opt(train_loss): return inner_opt_class(train_loss, **inner_opt_kwargs) for k, batch in enumerate(dataloader): start_time = time.time() meta_model.train() tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to( device) tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to( device) outer_opt.zero_grad() val_loss, val_acc = 0, 0 forward_time, backward_time = 0, 0 for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)): start_time_task = time.time() # single task set up task = Task(reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y), batch_size=tr_xs.shape[0]) inner_opt = get_inner_opt(task.train_loss_f) # single task inner loop params = [ p.detach().clone().requires_grad_(True) for p in meta_model.parameters() ] last_param = inner_loop(meta_model.parameters(), params, inner_opt, T, log_interval=inner_log_interval)[-1] forward_time_task = time.time() - start_time_task # single task hypergradient computation if args.hg_mode == 'CG': # This is the approximation used in the paper CG stands for conjugate gradient cg_fp_map = hg.GradientDescent(loss_f=task.train_loss_f, step_size=1.) hg.CG(last_param, list(meta_model.parameters()), K=K, fp_map=cg_fp_map, outer_loss=task.val_loss_f) elif args.hg_mode == 'fixed_point': hg.fixed_point(last_param, list(meta_model.parameters()), K=K, fp_map=inner_opt, outer_loss=task.val_loss_f) backward_time_task = time.time( ) - start_time_task - forward_time_task val_loss += task.val_loss val_acc += task.val_acc / task.batch_size forward_time += forward_time_task backward_time += backward_time_task outer_opt.step() step_time = time.time() - start_time if k % log_interval == 0: print( 'MT k={} ({:.3f}s F: {:.3f}s, B: {:.3f}s) Val Loss: {:.2e}, Val Acc: {:.2f}.' .format(k, step_time, forward_time, backward_time, val_loss, 100. * val_acc)) if k % eval_interval == 0: test_losses, test_accs = evaluate( n_tasks_test, test_dataloader, meta_model, T_test, get_inner_opt, reg_param, log_interval=inner_log_interval_test) print( "Test loss {:.2e} +- {:.2e}: Test acc: {:.2f} +- {:.2e} (mean +- std over {} tasks)." .format(test_losses.mean(), test_losses.std(), 100. * test_accs.mean(), 100. * test_accs.std(), len(test_losses)))
import csv from torchmeta.datasets.helpers import miniimagenet from torchmeta.utils.data import BatchMetaDataLoader from maml import MAML from train import adaptation, test import pickle device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True # dataset trainset = miniimagenet("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True) trainloader = BatchMetaDataLoader(trainset, batch_size=2, num_workers=4, shuffle=True) testset = miniimagenet("data", ways=5, shots=5, test_shots=15, meta_test=True, download=True) testloader = BatchMetaDataLoader(testset, batch_size=2,
def generate_dataloaders(args): # Create dataset if args.dataset == "omniglot": dataset_train = omniglot( folder=args.data_path, shots=args.k_shot, ways=args.n_way, shuffle=True, test_shots=args.k_query, meta_split=args.base_dataset_train, seed=args.seed, download=True, # Only downloads if not in data_path ) dataloader_train = BatchMetaDataLoader( dataset_train, batch_size=args.tasks_per_metaupdate, shuffle=True, num_workers=args.n_workers, ) dataset_val = omniglot( folder=args.data_path, shots=args.k_shot, ways=args.n_way, shuffle=True, test_shots=args.k_query, meta_split=args.base_dataset_val, seed=args.seed, download=True, ) dataloader_val = BatchMetaDataLoader( dataset_val, batch_size=args.tasks_per_metaupdate, shuffle=True, num_workers=args.n_workers, ) dataset_test = omniglot( folder=args.data_path, shots=args.k_shot, ways=args.n_way, shuffle=True, test_shots=args.k_query, meta_split=args.base_dataset_test, download=True, ) dataloader_test = BatchMetaDataLoader( dataset_test, batch_size=args.tasks_per_metaupdate, shuffle=True, num_workers=args.n_workers, ) elif args.dataset == "miniimagenet": dataset_train = miniimagenet( folder=args.data_path, shots=args.k_shot, ways=args.n_way, shuffle=True, test_shots=args.k_query, meta_split=args.base_dataset_train, seed=args.seed, download=True, # Only downloads if not in data_path ) dataloader_train = BatchMetaDataLoader( dataset_train, batch_size=args.tasks_per_metaupdate, shuffle=True, num_workers=args.n_workers, ) dataset_val = miniimagenet( folder=args.data_path, shots=args.k_shot, ways=args.n_way, shuffle=True, test_shots=args.k_query, meta_split=args.base_dataset_val, seed=args.seed, download=True, ) dataloader_val = BatchMetaDataLoader( dataset_val, batch_size=args.tasks_per_metaupdate, shuffle=True, num_workers=args.n_workers, ) dataset_test = miniimagenet( folder=args.data_path, shots=args.k_shot, ways=args.n_way, shuffle=True, test_shots=args.k_query, meta_split=args.base_dataset_test, download=True, ) dataloader_test = BatchMetaDataLoader( dataset_test, batch_size=args.tasks_per_metaupdate, shuffle=True, num_workers=args.n_workers, ) else: raise Exception("Dataset {} not implemented".format(args.dataset)) return dataloader_train, dataloader_val, dataloader_test
from torchmeta.datasets.helpers import miniimagenet from torchmeta.utils.data import BatchMetaDataLoader from meta_dataloader import get_meta_loader data_path = 'few_data/' dataset = miniimagenet(data_path, ways=5, shots=5, test_shots=15, meta_train=True, download=True) dataloader = BatchMetaDataLoader(dataset, batch_size=1, num_workers=4) # for i, batch in enumerate(dataloader): # train_inputs, train_targets = batch["train"] # print('Train inputs shape: {0}'.format(train_inputs.shape)) # (16, 25, 1, 28, 28) # print('Train targets shape: {0}'.format(train_targets.shape)) # (16, 25) # # test_inputs, test_targets = batch["test"] # print('Test inputs shape: {0}'.format(test_inputs.shape)) # (16, 75, 1, 28, 28) # print('Test targets shape: {0}'.format(test_targets.shape)) # (16, 75) # # print(train_targets) # if i > 5: # break metadataloader = get_meta_loader(data_path, 'miniimagenet', ways=5, shots=5,
def load_dataset(args, mode): folder = args.folder ways = args.num_ways shots = args.num_shots test_shots = 15 download = args.download shuffle = True if mode == 'meta_train': args.meta_train = True args.meta_val = False args.meta_test = False elif mode == 'meta_valid': args.meta_train = False args.meta_val = True args.meta_test = False elif mode == 'meta_test': args.meta_train = False args.meta_val = False args.meta_test = True if args.dataset == 'miniimagenet': dataset = miniimagenet(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'tieredimagenet': dataset = tieredimagenet(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'cifar_fs': dataset = cifar_fs(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'fc100': dataset = fc100(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'cub': dataset = cub(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'vgg_flower': dataset = vgg_flower(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'aircraft': dataset = aircraft(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'traffic_sign': dataset = traffic_sign(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'svhn': dataset = svhn(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) elif args.dataset == 'cars': dataset = cars(folder=folder, shots=shots, ways=ways, shuffle=shuffle, test_shots=test_shots, meta_train=args.meta_train, meta_val=args.meta_val, meta_test=args.meta_test, download=download) return dataset
def train(args): dataset = miniimagenet(args.folder, shots=args.num_shots, ways=args.num_ways, shuffle=True, test_shots=15, meta_train=True, download=args.download) dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) model = ConvolutionalNeuralNetwork(3, 84, args.num_ways, hidden_size=args.hidden_size) model.to(device=args.device) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop with tqdm(dataloader, total=args.num_batches) as pbar: for batch_idx, batch in enumerate(pbar): model.zero_grad() train_inputs, train_targets = batch['train'] train_inputs = train_inputs.to(device=args.device) train_targets = train_targets.to(device=args.device) test_inputs, test_targets = batch['test'] test_inputs = test_inputs.to(device=args.device) test_targets = test_targets.to(device=args.device) outer_loss = torch.tensor(0., device=args.device) accuracy = torch.tensor(0., device=args.device) for task_idx, (train_input, train_target, test_input, test_target) in enumerate( zip(train_inputs, train_targets, test_inputs, test_targets)): train_logit = model(train_input) inner_loss = F.cross_entropy(train_logit, train_target) # writer.add_scalar('Loss/inner_loss', np.random.random(), task_idx) grid = torchvision.utils.make_grid(train_input) writer.add_image('images', grid, 0) writer.add_graph(model, train_input) model.zero_grad() params = update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_input, params=params) outer_loss += F.cross_entropy(test_logit, test_target) # writer.add_scalar('Loss/outer_loss', np.random.random(), n_iter) for name, grads in model.meta_named_parameters(): writer.add_histogram(name, grads, batch_idx) with torch.no_grad(): accuracy += get_accuracy(test_logit, test_target) writer.add_histogram('meta parameters', grads, batch_idx) outer_loss.div_(args.batch_size) accuracy.div_(args.batch_size) outer_loss.backward() meta_optimizer.step() pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) writer.add_scalar('Accuracy/test', accuracy.item(), batch_idx) if batch_idx >= args.num_batches: break writer.close() # Save model if args.output_folder is not None: filename = os.path.join( args.output_folder, 'maml_omniglot_' '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways)) with open(filename, 'wb') as f: state_dict = model.state_dict() torch.save(state_dict, f)
def train(args): dataset = miniimagenet(args.folder, shots=args.num_shots, ways=args.num_ways, shuffle=True, test_shots=15, meta_train=True, download=args.download) dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) model = ConvolutionalNeuralNetwork(3, args.num_ways, hidden_size=args.hidden_size, fc_in_size=32 * 5 * 5, conv_kernel=[3, 3, 3, 2]) model.to(device=args.device) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop max_acc = 0 with tqdm(dataloader, total=args.num_batches) as pbar: for batch_idx, batch in enumerate(pbar): model.zero_grad() train_inputs, train_targets = batch['train'] train_inputs = train_inputs.to(device=args.device) train_targets = train_targets.to(device=args.device) test_inputs, test_targets = batch['test'] test_inputs = test_inputs.to(device=args.device) test_targets = test_targets.to(device=args.device) outer_loss = torch.tensor(0., device=args.device) accuracy = torch.tensor(0., device=args.device) for task_idx, (train_input, train_target, test_input, test_target) in \ enumerate(zip(train_inputs, train_targets, test_inputs, test_targets)): train_logit = model(train_input) inner_loss = F.cross_entropy(train_logit, train_target) model.zero_grad() params = update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_input, params=params) outer_loss += F.cross_entropy(test_logit, test_target) with torch.no_grad(): accuracy += get_accuracy(test_logit, test_target) outer_loss.div_(args.batch_size) accuracy.div_(args.batch_size) outer_loss.backward() meta_optimizer.step() pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) max_acc = max(max_acc, accuracy.item()) if batch_idx >= args.num_batches: break print('max acc during training is: ', max_acc) # Save model if args.output_folder is not None: filename = os.path.join( args.output_folder, 'maml_omniglot_' '{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways)) with open(filename, 'wb') as f: state_dict = model.state_dict() torch.save(state_dict, f)