def main(epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, label_smoothing, temperature): device = torch.device( 'cuda:0' if torch.cuda.is_available() and not cpu else 'cpu') callback = VisdomLogger(port=visdom_port) if visdom_port else None if cudnn_flag == 'deterministic': setattr(cudnn, cudnn_flag, True) torch.manual_seed(seed) loaders, recall_ks = get_loaders() torch.manual_seed(seed) model = get_model(num_classes=loaders.num_classes) class_loss = SmoothCrossEntropy(epsilon=label_smoothing, temperature=temperature) model.to(device) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) parameters = [] if no_bias_decay: parameters.append( {'params': [par for par in model.parameters() if par.dim() != 1]}) parameters.append({ 'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0 }) else: parameters.append({'params': model.parameters()}) optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len( loaders.train)) # setup partial function to simplify call eval_function = partial(evaluate, model=model, recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery) # setup best validation logger metrics = eval_function() if callback is not None: callback.scalars( ['l2', 'cosine'], 0, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall@1') pprint(metrics.recall) best_val = (0, metrics.recall, deepcopy(model.state_dict())) torch.manual_seed(seed) for epoch in range(epochs): if cudnn_flag == 'benchmark': setattr(cudnn, cudnn_flag, True) train(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, epoch=epoch, callback=callback, freq=visdom_freq, ex=ex) # validation if cudnn_flag == 'benchmark': setattr(cudnn, cudnn_flag, False) metrics = eval_function() print('Validation [{:03d}]'.format(epoch)), pprint(metrics.recall) ex.log_scalar('val.recall_l2@1', metrics.recall['l2'][1], step=epoch + 1) ex.log_scalar('val.recall_cosine@1', metrics.recall['cosine'][1], step=epoch + 1) if callback is not None: callback.scalars( ['l2', 'cosine'], epoch + 1, [metrics.recall['l2'][1], metrics.recall['cosine'][1]], title='Val Recall') # save model dict if the chosen validation metric is better if metrics.recall['cosine'][1] >= best_val[1]['cosine'][1]: best_val = (epoch + 1, metrics.recall, deepcopy(model.state_dict())) # logging ex.info['recall'] = best_val[1] # saving save_name = os.path.join( temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['arch'], ex.current_run.config['dataset']['name'])) torch.save(state_dict_to_cpu(best_val[2]), save_name) ex.add_artifact(save_name) if callback is not None: save_name = os.path.join(temp_dir, 'visdom_data.pt') callback.save(save_name) ex.add_artifact(save_name) return best_val[1]['cosine'][1]
def main(args): rng = np.random.RandomState(args.seed) if args.test: assert args.checkpoint is not None, 'Please inform the checkpoint (trained model)' if args.logdir is None: logdir = get_logdir(args) else: logdir = pathlib.Path(args.logdir) if not logdir.exists(): logdir.mkdir() print('Writing logs to {}'.format(logdir)) device = torch.device( 'cuda', args.gpu_idx) if torch.cuda.is_available() else torch.device('cpu') if args.port is not None: logger = VisdomLogger(port=args.port) else: logger = None print('Loading Data') x, y, yforg, usermapping, filenames = load_dataset(args.dataset_path) dev_users = range(args.dev_users[0], args.dev_users[1]) if args.devset_size is not None: # Randomly select users from the dev set dev_users = rng.choice(dev_users, args.devset_size, replace=False) if args.devset_sk_size is not None: assert args.devset_sk_size <= len( dev_users), 'devset-sk-size should be smaller than devset-size' # Randomly select users from the dev set to have skilled forgeries (others don't) dev_sk_users = set( rng.choice(dev_users, args.devset_sk_size, replace=False)) else: dev_sk_users = set(dev_users) print('{} users in dev set; {} users with skilled forgeries'.format( len(dev_users), len(dev_sk_users))) if args.exp_users is not None: val_users = range(args.exp_users[0], args.exp_users[1]) print('Testing with users from {} to {}'.format( args.exp_users[0], args.exp_users[1])) elif args.use_testset: val_users = range(0, 300) print('Testing with Exploitation set') else: val_users = range(300, 350) print('Initializing model') base_model = models.available_models[args.model]().to(device) weights = base_model.build_weights(device) maml = MAML(base_model, args.num_updates, args.num_updates, args.train_lr, args.meta_lr, args.meta_min_lr, args.epochs, args.learn_task_lr, weights, device, logger, loss_function=balanced_binary_cross_entropy, is_classification=True) if args.checkpoint: params = torch.load(args.checkpoint) maml.load(params) if args.test: test_and_save(args, device, logdir, maml, val_users, x, y, yforg) return # Pretraining if args.pretrain_epochs > 0: print('Pre-training') data = util.get_subset((x, y, yforg), subset=range(350, 881)) wrapped_model = PretrainWrapper(base_model, weights) if not args.pretrain_forg: data = util.remove_forgeries(data, forg_idx=2) train_loader, val_loader = pretrain.setup_data_loaders( data, 32, args.input_size) n_classes = len(np.unique(y)) classification_layer = nn.Linear(base_model.feature_space_size, n_classes).to(device) if args.pretrain_forg: forg_layer = nn.Linear(base_model.feature_space_size, 1).to(device) else: forg_layer = nn.Module() # Stub module with no parameters pretrain_args = argparse.Namespace(lr=0.01, lr_decay=0.1, lr_decay_times=1, momentum=0.9, weight_decay=0.001, forg=args.pretrain_forg, lamb=args.pretrain_forg_lambda, epochs=args.pretrain_epochs) print(pretrain_args) pretrain.train(wrapped_model, classification_layer, forg_layer, train_loader, val_loader, device, logger, pretrain_args, logdir=None) # MAML training trainset = MAMLDataSet(data=(x, y, yforg), subset=dev_users, sk_subset=dev_sk_users, num_gen_train=args.num_gen, num_rf_train=args.num_rf, num_gen_test=args.num_gen_test, num_rf_test=args.num_rf_test, num_sk_test=args.num_sk_test, input_shape=args.input_size, test=False, rng=np.random.RandomState(args.seed)) val_set = MAMLDataSet(data=(x, y, yforg), subset=val_users, num_gen_train=args.num_gen, num_rf_train=args.num_rf, num_gen_test=args.num_gen_test, num_rf_test=args.num_rf_test, num_sk_test=args.num_sk_test, input_shape=args.input_size, test=True, rng=np.random.RandomState(args.seed)) loader = DataLoader(trainset, batch_size=args.meta_batch_size, shuffle=True, num_workers=2, collate_fn=trainset.collate_fn) print('Training') best_val_acc = 0 with tqdm(initial=0, total=len(loader) * args.epochs) as pbar: if args.checkpoint is not None: postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch( maml, val_set, device, args.num_updates) if logger: for i in range(args.num_updates): logger.scalar('val_postupdate_loss_{}'.format(i), 0, np.mean(postupdate_losses, axis=0)[i]) logger.scalar('val_postupdate_acc_{}'.format(i), 0, np.mean(postupdate_accs, axis=0)[i]) for epoch in range(args.epochs): loss_weights = get_per_step_loss_importance_vector( args.num_updates, args.msl_epochs, epoch) n_batches = len(loader) for step, item in enumerate(loader): item = move_to_gpu(*item, device=device) maml.meta_learning_step((item[0], item[1]), (item[2], item[3]), loss_weights, epoch + step / n_batches) pbar.update(1) maml.scheduler.step() postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch( maml, val_set, device, args.num_updates) if logger: for i in range(args.num_updates): logger.scalar('val_postupdate_loss_{}'.format(i), epoch + 1, np.mean(postupdate_losses, axis=0)[i]) logger.scalar('val_postupdate_acc_{}'.format(i), epoch + 1, np.mean(postupdate_accs, axis=0)[i]) logger.save(logdir / 'train_curves.pickle') this_val_loss = np.mean(postupdate_losses, axis=0)[-1] this_val_acc = np.mean(postupdate_accs, axis=0)[-1] if this_val_acc > best_val_acc: best_val_acc = this_val_acc torch.save(maml.parameters, logdir / 'best_model.pth') print('Epoch {}. Val loss: {:.4f}. Val Acc: {:.2f}%'.format( epoch, this_val_loss, this_val_acc * 100)) # Re-load best parameters and test with 10 folds params = torch.load(logdir / 'best_model.pth') maml.load(params) test_and_save(args, device, logdir, maml, val_users, x, y, yforg)