def train(model: nn.Module, loader: DataLoader, class_loss: nn.Module, optimizer: Optimizer, scheduler: _LRScheduler, epoch: int, callback: VisdomLogger, freq: int, ex: Experiment = None) -> None: model.train() device = next(model.parameters()).device to_device = lambda x: x.to(device, non_blocking=True) loader_length = len(loader) train_losses = AverageMeter(device=device, length=loader_length) train_accs = AverageMeter(device=device, length=loader_length) pbar = tqdm(loader, ncols=80, desc='Training [{:03d}]'.format(epoch)) for i, (batch, labels, indices) in enumerate(pbar): batch, labels, indices = map(to_device, (batch, labels, indices)) logits, features = model(batch) loss = class_loss(logits, labels).mean() acc = (logits.detach().argmax(1) == labels).float().mean() optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() train_losses.append(loss) train_accs.append(acc) if callback is not None and not (i + 1) % freq: step = epoch + i / loader_length callback.scalar('xent', step, train_losses.last_avg, title='Train Losses') callback.scalar('train_acc', step, train_accs.last_avg, title='Train Acc') if ex is not None: for i, (loss, acc) in enumerate( zip(train_losses.values_list, train_accs.values_list)): step = epoch + i / loader_length ex.log_scalar('train.loss', loss, step=step) ex.log_scalar('train.acc', acc, step=step)
def episodic_validate( args: argparse.Namespace, val_loader: torch.utils.data.DataLoader, model: DDP, use_callback: bool, suffix: str = 'test') -> Tuple[torch.tensor, torch.tensor]: print('==> Start testing') model.eval() nb_episodes = int(args.test_num / args.batch_size_val) # ========== Metrics initialization ========== H, W = args.image_size, args.image_size c = model.module.bottleneck_dim h = model.module.feature_res[0] w = model.module.feature_res[1] runtimes = torch.zeros(args.n_runs) deltas_init = torch.zeros((args.n_runs, nb_episodes, args.batch_size_val)) deltas_final = torch.zeros((args.n_runs, nb_episodes, args.batch_size_val)) val_IoUs = np.zeros(args.n_runs) val_losses = np.zeros(args.n_runs) # ========== Perform the runs ========== for run in tqdm(range(args.n_runs)): # =============== Initialize the metric dictionaries =============== loss_meter = AverageMeter() iter_num = 0 cls_intersection = defaultdict(int) # Default value is 0 cls_union = defaultdict(int) IoU = defaultdict(int) # =============== episode = group of tasks =============== runtime = 0 for e in tqdm(range(nb_episodes)): t0 = time.time() features_s = torch.zeros(args.batch_size_val, args.shot, c, h, w).to(dist.get_rank()) features_q = torch.zeros(args.batch_size_val, 1, c, h, w).to(dist.get_rank()) gt_s = 255 * torch.ones(args.batch_size_val, args.shot, args.image_size, args.image_size).long().to(dist.get_rank()) gt_q = 255 * torch.ones(args.batch_size_val, 1, args.image_size, args.image_size).long().to(dist.get_rank()) n_shots = torch.zeros(args.batch_size_val).to(dist.get_rank()) classes = [] # All classes considered in the tasks # =========== Generate tasks and extract features for each task =============== for i in range(args.batch_size_val): try: qry_img, q_label, spprt_imgs, s_label, subcls, _, _ = iter_loader.next( ) except: iter_loader = iter(val_loader) qry_img, q_label, spprt_imgs, s_label, subcls, _, _ = iter_loader.next( ) iter_num += 1 q_label = q_label.to(dist.get_rank(), non_blocking=True) spprt_imgs = spprt_imgs.to(dist.get_rank(), non_blocking=True) s_label = s_label.to(dist.get_rank(), non_blocking=True) qry_img = qry_img.to(dist.get_rank(), non_blocking=True) f_s = model.module.extract_features(spprt_imgs.squeeze(0)) f_q = model.module.extract_features(qry_img) shot = f_s.size(0) n_shots[i] = shot features_s[i, :shot] = f_s.detach() features_q[i] = f_q.detach() gt_s[i, :shot] = s_label gt_q[i, 0] = q_label classes.append([class_.item() for class_ in subcls]) # =========== Normalize features along channel dimension =============== if args.norm_feat: features_s = F.normalize(features_s, dim=2) features_q = F.normalize(features_q, dim=2) # =========== Create a callback is args.visdom_port != -1 =============== callback = VisdomLogger( port=args.visdom_port) if use_callback else None # =========== Initialize the classifier + prototypes + F/B parameter Π =============== classifier = Classifier(args) classifier.init_prototypes(features_s, features_q, gt_s, gt_q, classes, callback) batch_deltas = classifier.compute_FB_param(features_q=features_q, gt_q=gt_q) deltas_init[run, e, :] = batch_deltas.cpu() # =========== Perform RePRI inference =============== batch_deltas = classifier.RePRI(features_s, features_q, gt_s, gt_q, classes, n_shots, callback) deltas_final[run, e, :] = batch_deltas t1 = time.time() runtime += t1 - t0 logits = classifier.get_logits(features_q) # [n_tasks, shot, h, w] logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True) probas = classifier.get_probas(logits).detach() intersection, union, _ = batch_intersectionAndUnionGPU( probas, gt_q, 2) # [n_tasks, shot, num_class] intersection, union = intersection.cpu(), union.cpu() # ================== Log metrics ================== one_hot_gt = to_one_hot(gt_q, 2) valid_pixels = gt_q != 255 loss = classifier.get_ce(probas, valid_pixels, one_hot_gt, reduction='mean') loss_meter.update(loss.item()) for i, task_classes in enumerate(classes): for j, class_ in enumerate(task_classes): cls_intersection[class_] += intersection[ i, 0, j + 1] # Do not count background cls_union[class_] += union[i, 0, j + 1] for class_ in cls_union: IoU[class_] = cls_intersection[class_] / (cls_union[class_] + 1e-10) if (iter_num % 200 == 0): mIoU = np.mean([IoU[i] for i in IoU]) print( 'Test: [{}/{}] ' 'mIoU {:.4f} ' 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '.format( iter_num, args.test_num, mIoU, loss_meter=loss_meter, )) runtimes[run] = runtime mIoU = np.mean(list(IoU.values())) print('mIoU---Val result: mIoU {:.4f}.'.format(mIoU)) for class_ in cls_union: print("Class {} : {:.4f}".format(class_, IoU[class_])) val_IoUs[run] = mIoU val_losses[run] = loss_meter.avg # ================== Save metrics ================== if args.save_oracle: root = os.path.join('plots', 'oracle') os.makedirs(root, exist_ok=True) np.save(os.path.join(root, 'delta_init.npy'), deltas_init.numpy()) np.save(os.path.join(root, 'delta_final.npy'), deltas_final.numpy()) print('Average mIoU over {} runs --- {:.4f}.'.format( args.n_runs, val_IoUs.mean())) print('Average runtime / run --- {:.4f}.'.format(runtimes.mean())) return val_IoUs.mean(), val_losses.mean()
download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=0) model_pgp = alexnet().to(cuda) epochs = 170 prune_rate = 0.05 remove_ratio = 0.5 optimizer_pgp = optim.SGD(model_pgp.parameters(), lr=lr, momentum=momentum) zero_initializer = functools.partial(torch.nn.init.constant_, val=0) logger = VisdomLogger(port=10999) logger = LoggerForSacred(logger) if not os.path.exists("temp_target.p"): loss_acc, model_architecture = pgp.pgp(epochs, prune_rate, remove_ratio, testloader, gng_criterion, cuda=cuda, model=model_pgp, train_loader=trainloader, train_ratio=1, prune_once=False, initializer_fn=zero_initializer,
def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, label_smoothing, temperature): device = torch.device( 'cuda:0' if torch.cuda.is_available() and not cpu else 'cpu') callback = VisdomLogger(port=visdom_port) if visdom_port else None if cudnn_flag == 'deterministic': setattr(cudnn, cudnn_flag, True) torch.manual_seed(seed) loaders, recall_ks = get_loaders() torch.manual_seed(seed) model = get_model(num_classes=loaders.num_classes) class_loss = SmoothCrossEntropy(epsilon=label_smoothing, temperature=temperature) model.to(device) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) parameters = [] if no_bias_decay: parameters.append( {'params': [par for par in model.parameters() if par.dim() != 1]}) parameters.append({ 'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0 }) else: parameters.append({'params': model.parameters()}) optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len( loaders.train)) # setup partial function to simplify call eval_function = partial(evaluate, model=model, recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery) # setup best validation logger metrics = eval_function() if callback is not None: callback.scalars( ['l2', 'cosine'], 0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall@1') pprint(metrics.recall) best_val = (0, metrics.recall, deepcopy(model.state_dict())) torch.manual_seed(seed) for epoch in range(epochs): if cudnn_flag == 'benchmark': setattr(cudnn, cudnn_flag, True) train(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, epoch=epoch, callback=callback, freq=visdom_freq, ex=ex) # validation if cudnn_flag == 'benchmark': setattr(cudnn, cudnn_flag, False) metrics = eval_function() print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall) ex.log_scalar('val.recall_l2@1', metrics.recall['l2'][1], step=epoch + 1) ex.log_scalar('val.recall_cosine@1', metrics.recall['cosine'][1], step=epoch + 1) if callback is not None: callback.scalars( ['l2', 'cosine'], epoch + 1, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall') # save model dict if the chosen validation metric is better if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]: best_val = (epoch + 1, metrics.recall, deepcopy(model.state_dict())) # logging ex.info['recall'] = best_val[1] # saving save_name = os.path.join( temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'], ex.current_run.config['dataset']['name'])) torch.save(state_dict_to_cpu(best_val[2]), save_name) ex.add_artifact(save_name) if callback is not None: save_name = os.path.join(temp_dir, 'visdom_data.pt') callback.save(save_name) ex.add_artifact(save_name) return best_val[1]['cosine'][1]
def main_worker(rank: int, world_size: int, args: argparse.Namespace) -> None: print(f"==> Running process rank {rank}.") setup(args, rank, world_size) if args.manual_seed is not None: cudnn.benchmark = False cudnn.deterministic = True torch.cuda.manual_seed(args.manual_seed + rank) np.random.seed(args.manual_seed + rank) torch.manual_seed(args.manual_seed + rank) torch.cuda.manual_seed_all(args.manual_seed + rank) random.seed(args.manual_seed + rank) callback = None if args.visdom_port == -1 else VisdomLogger( port=args.visdom_port) # ========== Model + Optimizer ========== model = get_model(args).to(rank) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.ppm, model.bottleneck, model.classifier] params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.lr)) for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.lr * args.scale_lr)) optimizer = get_optimizer(args, params_list) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[rank]) savedir = get_model_dir(args) # ========== Validation ================== validate_fn = episodic_validate if args.episodic_val else standard_validate # ========== Data ===================== train_loader, train_sampler = get_train_loader(args) val_loader, _ = get_val_loader( args ) # mode='train' means that we will validate on images from validation set, but with the bases classes # ========== Scheduler ================ scheduler = get_scheduler(args, optimizer, len(train_loader)) # ========== Metrics initialization ==== max_val_mIoU = 0. if args.debug: iter_per_epoch = 5 else: iter_per_epoch = len(train_loader) log_iter = int(iter_per_epoch / args.log_freq) + 1 metrics: Dict[str, Tensor] = { "val_mIou": torch.zeros((args.epochs, 1)).type(torch.float32), "val_loss": torch.zeros((args.epochs, 1)).type(torch.float32), "train_mIou": torch.zeros((args.epochs, log_iter)).type(torch.float32), "train_loss": torch.zeros((args.epochs, log_iter)).type(torch.float32), } # ========== Training ================= for epoch in tqdm(range(args.epochs)): if args.distributed: train_sampler.set_epoch(epoch) train_mIou, train_loss = do_epoch(args=args, train_loader=train_loader, iter_per_epoch=iter_per_epoch, model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, callback=callback, log_iter=log_iter) val_mIou, val_loss = validate_fn(args=args, val_loader=val_loader, model=model, use_callback=False, suffix=f'train_{epoch}') if args.distributed: dist.all_reduce(val_mIou), dist.all_reduce(val_loss) val_mIou /= world_size val_loss /= world_size if main_process(args): # Live plot if desired with visdom if callback is not None: callback.scalar('val_loss', epoch, val_loss, title='Validiation Loss') callback.scalar('mIoU_val', epoch, val_mIou, title='Val metrics') # Model selection if val_mIou.item() > max_val_mIoU: max_val_mIoU = val_mIou.item() os.makedirs(savedir, exist_ok=True) filename = os.path.join(savedir, f'best.pth') if args.save_models: print('Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename) print("=> Max_mIoU = {:.3f}".format(max_val_mIoU)) # Sort and save the metrics for k in metrics: metrics[k][epoch] = eval(k) for k, e in metrics.items(): path = os.path.join(savedir, f"{k}.npy") np.save(path, e.cpu().numpy()) if args.save_models and main_process(args): filename = os.path.join(savedir, 'final.pth') print(f'Saving checkpoint to: {filename}') torch.save( { 'epoch': args.epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename) cleanup()
def do_epoch(args: argparse.Namespace, train_loader: torch.utils.data.DataLoader, model: DDP, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, epoch: int, callback: VisdomLogger, iter_per_epoch: int, log_iter: int) -> Tuple[torch.tensor, torch.tensor]: loss_meter = AverageMeter() train_losses = torch.zeros(log_iter).to(dist.get_rank()) train_mIous = torch.zeros(log_iter).to(dist.get_rank()) iterable_train_loader = iter(train_loader) if main_process(args): bar = tqdm(range(iter_per_epoch)) else: bar = range(iter_per_epoch) for i in bar: model.train() current_iter = epoch * len(train_loader) + i + 1 images, gt = iterable_train_loader.next() images = images.to(dist.get_rank(), non_blocking=True) gt = gt.to(dist.get_rank(), non_blocking=True) loss = compute_loss( args=args, model=model, images=images, targets=gt.long(), num_classes=args.num_classes_tr, ) optimizer.zero_grad() loss.backward() optimizer.step() if args.scheduler == 'cosine': scheduler.step() if i % args.log_freq == 0: model.eval() logits = model(images) intersection, union, target = intersectionAndUnionGPU( logits.argmax(1), gt, args.num_classes_tr, 255) if args.distributed: dist.all_reduce(loss) dist.all_reduce(intersection) dist.all_reduce(union) dist.all_reduce(target) allAcc = (intersection.sum() / (target.sum() + 1e-10)) # scalar mAcc = (intersection / (target + 1e-10)).mean() mIoU = (intersection / (union + 1e-10)).mean() loss_meter.update(loss.item() / dist.get_world_size()) if main_process(args): if callback is not None: t = current_iter / len(train_loader) callback.scalar('loss_train_batch', t, loss_meter.avg, title='Loss') callback.scalars(['mIoU', 'mAcc', 'allAcc'], t, [mIoU, mAcc, allAcc], title='Training metrics') for index, param_group in enumerate( optimizer.param_groups): lr = param_group['lr'] callback.scalar('lr', t, lr, title='Learning rate') break train_losses[int(i / args.log_freq)] = loss_meter.avg train_mIous[int(i / args.log_freq)] = mIoU if args.scheduler != 'cosine': scheduler.step() return train_mIous, train_losses
def main(seed, pretrain, resume, evaluate, print_runtime, epochs, disable_tqdm, visdom_port, ckpt_path, make_plot, cuda): device = torch.device("cuda" if cuda else "cpu") callback = None if visdom_port is None else VisdomLogger(port=visdom_port) if seed is not None: random.seed(seed) torch.manual_seed(seed) cudnn.deterministic = True torch.cuda.set_device(0) # create model print("=> Creating model '{}'".format( ex.current_run.config['model']['arch'])) model = torch.nn.DataParallel(get_model()).cuda() print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) optimizer = get_optimizer(model) if pretrain: pretrain = os.path.join(pretrain, 'checkpoint.pth.tar') if os.path.isfile(pretrain): print("=> loading pretrained weight '{}'".format(pretrain)) checkpoint = torch.load(pretrain) model_dict = model.state_dict() params = checkpoint['state_dict'] params = {k: v for k, v in params.items() if k in model_dict} model_dict.update(params) model.load_state_dict(model_dict) else: print( '[Warning]: Did not find pretrained model {}'.format(pretrain)) if resume: resume_path = ckpt_path + '/checkpoint.pth.tar' if os.path.isfile(resume_path): print("=> loading checkpoint '{}'".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] # scheduler.load_state_dict(checkpoint['scheduler']) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( resume_path, checkpoint['epoch'])) else: print('[Warning]: Did not find checkpoint {}'.format(resume_path)) else: start_epoch = 0 best_prec1 = -1 cudnn.benchmark = True # Data loading code evaluator = Evaluator(device=device, ex=ex) if evaluate: print("Evaluating") results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path, callback=callback) #MYMOD #,model_tag='best', #shots=[5], #method="tim-gd") return results # If this line is reached, then training the model trainer = Trainer(device=device, ex=ex) scheduler = get_scheduler(optimizer=optimizer, num_batches=len(trainer.train_loader), epochs=epochs) tqdm_loop = warp_tqdm(list(range(start_epoch, epochs)), disable_tqdm=disable_tqdm) for epoch in tqdm_loop: # Do one epoch trainer.do_epoch(model=model, optimizer=optimizer, epoch=epoch, scheduler=scheduler, disable_tqdm=disable_tqdm, callback=callback) # Evaluation on validation set prec1 = trainer.meta_val(model=model, disable_tqdm=disable_tqdm, epoch=epoch, callback=callback) print('Meta Val {}: {}'.format(epoch, prec1)) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if not disable_tqdm: tqdm_loop.set_description('Best Acc {:.2f}'.format(best_prec1 * 100.)) # Save checkpoint save_checkpoint(state={ 'epoch': epoch + 1, 'arch': ex.current_run.config['model']['arch'], 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict() }, is_best=is_best, folder=ckpt_path) if scheduler is not None: scheduler.step() # Final evaluation on test set results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path) return results
def main(args): rng = np.random.RandomState(args.seed) if args.test: assert args.checkpoint is not None, 'Please inform the checkpoint (trained model)' if args.logdir is None: logdir = get_logdir(args) else: logdir = pathlib.Path(args.logdir) if not logdir.exists(): logdir.mkdir() print('Writing logs to {}'.format(logdir)) device = torch.device( 'cuda', args.gpu_idx) if torch.cuda.is_available() else torch.device('cpu') if args.port is not None: logger = VisdomLogger(port=args.port) else: logger = None print('Loading Data') x, y, yforg, usermapping, filenames = load_dataset(args.dataset_path) dev_users = range(args.dev_users[0], args.dev_users[1]) if args.devset_size is not None: # Randomly select users from the dev set dev_users = rng.choice(dev_users, args.devset_size, replace=False) if args.devset_sk_size is not None: assert args.devset_sk_size <= len( dev_users), 'devset-sk-size should be smaller than devset-size' # Randomly select users from the dev set to have skilled forgeries (others don't) dev_sk_users = set( rng.choice(dev_users, args.devset_sk_size, replace=False)) else: dev_sk_users = set(dev_users) print('{} users in dev set; {} users with skilled forgeries'.format( len(dev_users), len(dev_sk_users))) if args.exp_users is not None: val_users = range(args.exp_users[0], args.exp_users[1]) print('Testing with users from {} to {}'.format( args.exp_users[0], args.exp_users[1])) elif args.use_testset: val_users = range(0, 300) print('Testing with Exploitation set') else: val_users = range(300, 350) print('Initializing model') base_model = models.available_models[args.model]().to(device) weights = base_model.build_weights(device) maml = MAML(base_model, args.num_updates, args.num_updates, args.train_lr, args.meta_lr, args.meta_min_lr, args.epochs, args.learn_task_lr, weights, device, logger, loss_function=balanced_binary_cross_entropy, is_classification=True) if args.checkpoint: params = torch.load(args.checkpoint) maml.load(params) if args.test: test_and_save(args, device, logdir, maml, val_users, x, y, yforg) return # Pretraining if args.pretrain_epochs > 0: print('Pre-training') data = util.get_subset((x, y, yforg), subset=range(350, 881)) wrapped_model = PretrainWrapper(base_model, weights) if not args.pretrain_forg: data = util.remove_forgeries(data, forg_idx=2) train_loader, val_loader = pretrain.setup_data_loaders( data, 32, args.input_size) n_classes = len(np.unique(y)) classification_layer = nn.Linear(base_model.feature_space_size, n_classes).to(device) if args.pretrain_forg: forg_layer = nn.Linear(base_model.feature_space_size, 1).to(device) else: forg_layer = nn.Module() # Stub module with no parameters pretrain_args = argparse.Namespace(lr=0.01, lr_decay=0.1, lr_decay_times=1, momentum=0.9, weight_decay=0.001, forg=args.pretrain_forg, lamb=args.pretrain_forg_lambda, epochs=args.pretrain_epochs) print(pretrain_args) pretrain.train(wrapped_model, classification_layer, forg_layer, train_loader, val_loader, device, logger, pretrain_args, logdir=None) # MAML training trainset = MAMLDataSet(data=(x, y, yforg), subset=dev_users, sk_subset=dev_sk_users, num_gen_train=args.num_gen, num_rf_train=args.num_rf, num_gen_test=args.num_gen_test, num_rf_test=args.num_rf_test, num_sk_test=args.num_sk_test, input_shape=args.input_size, test=False, rng=np.random.RandomState(args.seed)) val_set = MAMLDataSet(data=(x, y, yforg), subset=val_users, num_gen_train=args.num_gen, num_rf_train=args.num_rf, num_gen_test=args.num_gen_test, num_rf_test=args.num_rf_test, num_sk_test=args.num_sk_test, input_shape=args.input_size, test=True, rng=np.random.RandomState(args.seed)) loader = DataLoader(trainset, batch_size=args.meta_batch_size, shuffle=True, num_workers=2, collate_fn=trainset.collate_fn) print('Training') best_val_acc = 0 with tqdm(initial=0, total=len(loader) * args.epochs) as pbar: if args.checkpoint is not None: postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch( maml, val_set, device, args.num_updates) if logger: for i in range(args.num_updates): logger.scalar('val_postupdate_loss_{}'.format(i), 0, np.mean(postupdate_losses, axis=0)[i]) logger.scalar('val_postupdate_acc_{}'.format(i), 0, np.mean(postupdate_accs, axis=0)[i]) for epoch in range(args.epochs): loss_weights = get_per_step_loss_importance_vector( args.num_updates, args.msl_epochs, epoch) n_batches = len(loader) for step, item in enumerate(loader): item = move_to_gpu(*item, device=device) maml.meta_learning_step((item[0], item[1]), (item[2], item[3]), loss_weights, epoch + step / n_batches) pbar.update(1) maml.scheduler.step() postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch( maml, val_set, device, args.num_updates) if logger: for i in range(args.num_updates): logger.scalar('val_postupdate_loss_{}'.format(i), epoch + 1, np.mean(postupdate_losses, axis=0)[i]) logger.scalar('val_postupdate_acc_{}'.format(i), epoch + 1, np.mean(postupdate_accs, axis=0)[i]) logger.save(logdir / 'train_curves.pickle') this_val_loss = np.mean(postupdate_losses, axis=0)[-1] this_val_acc = np.mean(postupdate_accs, axis=0)[-1] if this_val_acc > best_val_acc: best_val_acc = this_val_acc torch.save(maml.parameters, logdir / 'best_model.pth') print('Epoch {}. Val loss: {:.4f}. Val Acc: {:.2f}%'.format( epoch, this_val_loss, this_val_acc * 100)) # Re-load best parameters and test with 10 folds params = torch.load(logdir / 'best_model.pth') maml.load(params) test_and_save(args, device, logdir, maml, val_users, x, y, yforg)