コード例 #1
0
def run_test(params):
    print('Testing ...')
    acc_all = []

    if hasattr(params, 'iter_num'):
        iter_num = params.iter_num
    else:
        iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    elif params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(model_dict[params.model],
                     approx=(params.method == 'maml_approx'),
                     **few_shot_params)
        if params.dataset in ['omniglot', 'cross_char'
                              ]:  #maml use different parameter in omniglot
            model.n_task = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = model.cuda()

    if hasattr(params, 'checkpoint_dir'):
        checkpoint_dir = params.checkpoint_dir
    else:
        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        if not params.method in ['baseline', 'baseline++']:
            checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                 params.n_shot)

    #modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++']:
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx'
                         ]:  #maml do not support testing with feature
        image_size = get_image_size(params)

        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)

    else:
        novel_file = os.path.join(
            checkpoint_dir.replace("checkpoints",
                                   "features"), split_str + ".hdf5"
        )  #defaut split = novel, but you can also test base or val classes
        cl_data_file = feat_loader.init_loader(novel_file)

        for i in range(iter_num):
            acc = feature_evaluation(cl_data_file,
                                     model,
                                     n_query=15,
                                     adaptation=params.adaptation,
                                     **few_shot_params)
            acc_all.append(acc)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        err = 1.96 * acc_std / np.sqrt(iter_num)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, err))
        os.remove(novel_file)
    with open(os.path.join(checkpoint_dir, 'results.txt'), 'a') as f:
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        aug_str = '-aug' if params.train_aug else ''
        aug_str += '-adapted' if params.adaptation else ''
        if params.method in ['baseline', 'baseline++']:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.test_n_way)
        else:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.train_n_way, params.test_n_way)
        acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean,
                                                        err)
        f.write('Time: %s, Setting: %s, Acc: %s \n' %
                (timestamp, exp_setting, acc_str))

    res = {params.n_shot: (acc_mean, err)}
    torch.save(res, os.path.join(checkpoint_dir, 'result.pth'))
    return res
コード例 #2
0
ファイル: test.py プロジェクト: wangkua1/CloserLookFewShot
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    elif params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
コード例 #3
0
ファイル: finetune_ml.py プロジェクト: Haoqing-Wang/CDFSL-ATA
def finetune(novel_loader, n_pseudo=75, n_way=5, n_support=5):
    iter_num = len(novel_loader)
    acc_all = []

    checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir,
                                                           params.name)
    state = torch.load(checkpoint_dir)['state']
    for ti, (x, _) in enumerate(novel_loader):  # x:(5, 20, 3, 224, 224)
        # Model
        if params.method == 'MatchingNet':
            model = MatchingNet(model_dict[params.model],
                                n_way=n_way,
                                n_support=n_support).cuda()
        elif params.method == 'RelationNet':
            model = RelationNet(model_dict[params.model],
                                n_way=n_way,
                                n_support=n_support).cuda()
        elif params.method == 'ProtoNet':
            model = ProtoNet(model_dict[params.model],
                             n_way=n_way,
                             n_support=n_support).cuda()
        elif params.method == 'GNN':
            model = GnnNet(model_dict[params.model],
                           n_way=n_way,
                           n_support=n_support).cuda()
        elif params.method == 'TPN':
            model = TPN(model_dict[params.model],
                        n_way=n_way,
                        n_support=n_support).cuda()
        else:
            print("Please specify the method!")
            assert (False)
        # Update model
        if 'FWT' in params.name:
            model_params = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in state.items() if k in model_params
            }
            model_params.update(pretrained_dict)
            model.load_state_dict(model_params)
        else:
            model.load_state_dict(state)

        x = x.cuda()
        # Finetune components initialization
        xs = x[:, :n_support].reshape(-1, *x.size()[2:])  # (25, 3, 224, 224)
        pseudo_q_genrator = PseudoSampleGenerator(n_way, n_support, n_pseudo)
        loss_fun = nn.CrossEntropyLoss().cuda()
        opt = torch.optim.Adam(model.parameters())
        # Finetune process
        n_query = n_pseudo // n_way
        pseudo_set_y = torch.from_numpy(np.repeat(range(n_way),
                                                  n_query)).cuda()
        model.n_query = n_query
        model.train()
        for epoch in range(params.finetune_epoch):
            opt.zero_grad()
            pseudo_set = pseudo_q_genrator.generate(
                xs)  # (5, n_support+n_query, 3, 224, 224)
            scores = model.set_forward(pseudo_set)  # (5*n_query, 5)
            loss = loss_fun(scores, pseudo_set_y)
            loss.backward()
            opt.step()
            del pseudo_set, scores, loss
        torch.cuda.empty_cache()

        # Inference process
        n_query = x.size(1) - n_support
        model.n_query = n_query
        yq = np.repeat(range(n_way), n_query)
        with torch.no_grad():
            scores = model.set_forward(x)  # (80, 5)
            _, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()  # (80, 1)
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        del scores, topk_labels
        torch.cuda.empty_cache()
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))