def run_test(params): print('Testing ...') acc_all = [] if hasattr(params, 'iter_num'): iter_num = params.iter_num else: iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() if hasattr(params, 'checkpoint_dir'): checkpoint_dir = params.checkpoint_dir else: checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) #modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++']: if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) split = params.split if params.save_iter != -1: split_str = split + "_" + str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx' ]: #maml do not support testing with feature image_size = get_image_size(params) datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split + '.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split + '.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader(loadfile, aug=False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints", "features"), split_str + ".hdf5" ) #defaut split = novel, but you can also test base or val classes cl_data_file = feat_loader.init_loader(novel_file) for i in range(iter_num): acc = feature_evaluation(cl_data_file, model, n_query=15, adaptation=params.adaptation, **few_shot_params) acc_all.append(acc) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) err = 1.96 * acc_std / np.sqrt(iter_num) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, err)) os.remove(novel_file) with open(os.path.join(checkpoint_dir, 'results.txt'), 'a') as f: timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) aug_str = '-aug' if params.train_aug else '' aug_str += '-adapted' if params.adaptation else '' if params.method in ['baseline', 'baseline++']: exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.test_n_way) else: exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.train_n_way, params.test_n_way) acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, err) f.write('Time: %s, Setting: %s, Acc: %s \n' % (timestamp, exp_setting, acc_str)) res = {params.n_shot: (acc_mean, err)} torch.save(res, os.path.join(checkpoint_dir, 'result.pth')) return res
few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True
def finetune(novel_loader, n_pseudo=75, n_way=5, n_support=5): iter_num = len(novel_loader) acc_all = [] checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir, params.name) state = torch.load(checkpoint_dir)['state'] for ti, (x, _) in enumerate(novel_loader): # x:(5, 20, 3, 224, 224) # Model if params.method == 'MatchingNet': model = MatchingNet(model_dict[params.model], n_way=n_way, n_support=n_support).cuda() elif params.method == 'RelationNet': model = RelationNet(model_dict[params.model], n_way=n_way, n_support=n_support).cuda() elif params.method == 'ProtoNet': model = ProtoNet(model_dict[params.model], n_way=n_way, n_support=n_support).cuda() elif params.method == 'GNN': model = GnnNet(model_dict[params.model], n_way=n_way, n_support=n_support).cuda() elif params.method == 'TPN': model = TPN(model_dict[params.model], n_way=n_way, n_support=n_support).cuda() else: print("Please specify the method!") assert (False) # Update model if 'FWT' in params.name: model_params = model.state_dict() pretrained_dict = { k: v for k, v in state.items() if k in model_params } model_params.update(pretrained_dict) model.load_state_dict(model_params) else: model.load_state_dict(state) x = x.cuda() # Finetune components initialization xs = x[:, :n_support].reshape(-1, *x.size()[2:]) # (25, 3, 224, 224) pseudo_q_genrator = PseudoSampleGenerator(n_way, n_support, n_pseudo) loss_fun = nn.CrossEntropyLoss().cuda() opt = torch.optim.Adam(model.parameters()) # Finetune process n_query = n_pseudo // n_way pseudo_set_y = torch.from_numpy(np.repeat(range(n_way), n_query)).cuda() model.n_query = n_query model.train() for epoch in range(params.finetune_epoch): opt.zero_grad() pseudo_set = pseudo_q_genrator.generate( xs) # (5, n_support+n_query, 3, 224, 224) scores = model.set_forward(pseudo_set) # (5*n_query, 5) loss = loss_fun(scores, pseudo_set_y) loss.backward() opt.step() del pseudo_set, scores, loss torch.cuda.empty_cache() # Inference process n_query = x.size(1) - n_support model.n_query = n_query yq = np.repeat(range(n_way), n_query) with torch.no_grad(): scores = model.set_forward(x) # (80, 5) _, topk_labels = scores.data.topk(1, 1, True, True) topk_ind = topk_labels.cpu().numpy() # (80, 1) top1_correct = np.sum(topk_ind[:, 0] == yq) acc = top1_correct * 100. / (n_way * n_query) acc_all.append(acc) del scores, topk_labels torch.cuda.empty_cache() print('Task %d : %4.2f%%' % (ti, acc)) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('Test Acc = %4.2f +- %4.2f%%' % (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))