Ejemplo n.º 1
0
def evaluate(novel_loader, n_way=5, n_support=5):
    iter_num = len(novel_loader)
    acc_all = []
    # Model
    if params.method == 'MatchingNet':
        model = MatchingNet(model_dict[params.model],
                            n_way=n_way,
                            n_support=n_support).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=n_way,
                            n_support=n_support).cuda()
    elif params.method == 'ProtoNet':
        model = ProtoNet(model_dict[params.model],
                         n_way=n_way,
                         n_support=n_support).cuda()
    elif params.method == 'GNN':
        model = GnnNet(model_dict[params.model],
                       n_way=n_way,
                       n_support=n_support).cuda()
    elif params.method == 'TPN':
        model = TPN(model_dict[params.model], n_way=n_way,
                    n_support=n_support).cuda()
    else:
        print("Please specify the method!")
        assert (False)
    # Update model
    checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir,
                                                           params.name)
    state = torch.load(checkpoint_dir)['state']
    if 'FWT' in params.name:
        model_params = model.state_dict()
        pretrained_dict = {k: v for k, v in state.items() if k in model_params}
        model_params.update(pretrained_dict)
        model.load_state_dict(model_params)
    else:
        model.load_state_dict(state)

    # For TPN model, we compute Batch Norm statistics from the test-time support set, not the exponential moving averages.
    if params.method != 'TPN':
        model.eval()
    for ti, (x, _) in enumerate(novel_loader):  # x:(5, 20, 3, 224, 224)
        x = x.cuda()
        n_query = x.size(1) - n_support
        model.n_query = n_query
        yq = np.repeat(range(n_way), n_query)
        with torch.no_grad():
            scores = model.set_forward(x)  # (80, 5)
            _, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()  # (80, 1)
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Ejemplo n.º 2
0
        model = TPN(model_dict[params.model],
                    n_way=params.train_n_way,
                    n_support=params.n_shot).cuda()
    else:
        print("Please specify the method!")
        assert (False)

    # load model
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.resume_epoch > 0:
        resume_file = os.path.join(params.checkpoint_dir,
                                   '{:d}.tar'.format(params.resume_epoch))
        tmp = torch.load(resume_file)
        start_epoch = tmp['epoch'] + 1
        model.load_state_dict(tmp['state'])
        print('\tResume the training weight at {} epoch.'.format(start_epoch))
    else:
        path = '%s/checkpoints/%s/399.tar' % (params.save_dir,
                                              params.resume_dir)
        state = torch.load(path)['state']
        model_params = model.state_dict()
        pretrained_dict = {k: v for k, v in state.items() if k in model_params}
        print(pretrained_dict.keys())
        model_params.update(pretrained_dict)
        model.load_state_dict(model_params)

    # training
    print('\n--- start the training ---')
    model = train(base_loader, val_loader, model, start_epoch, stop_epoch,
                  params)
Ejemplo n.º 3
0
def finetune(novel_loader, n_pseudo=75, n_way=5, n_support=5):
    iter_num = len(novel_loader)
    acc_all = []

    checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir,
                                                           params.name)
    state = torch.load(checkpoint_dir)['state']
    for ti, (x, _) in enumerate(novel_loader):  # x:(5, 20, 3, 224, 224)
        # Model
        if params.method == 'MatchingNet':
            model = MatchingNet(model_dict[params.model],
                                n_way=n_way,
                                n_support=n_support).cuda()
        elif params.method == 'RelationNet':
            model = RelationNet(model_dict[params.model],
                                n_way=n_way,
                                n_support=n_support).cuda()
        elif params.method == 'ProtoNet':
            model = ProtoNet(model_dict[params.model],
                             n_way=n_way,
                             n_support=n_support).cuda()
        elif params.method == 'GNN':
            model = GnnNet(model_dict[params.model],
                           n_way=n_way,
                           n_support=n_support).cuda()
        elif params.method == 'TPN':
            model = TPN(model_dict[params.model],
                        n_way=n_way,
                        n_support=n_support).cuda()
        else:
            print("Please specify the method!")
            assert (False)
        # Update model
        if 'FWT' in params.name:
            model_params = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in state.items() if k in model_params
            }
            model_params.update(pretrained_dict)
            model.load_state_dict(model_params)
        else:
            model.load_state_dict(state)

        x = x.cuda()
        # Finetune components initialization
        xs = x[:, :n_support].reshape(-1, *x.size()[2:])  # (25, 3, 224, 224)
        pseudo_q_genrator = PseudoSampleGenerator(n_way, n_support, n_pseudo)
        loss_fun = nn.CrossEntropyLoss().cuda()
        opt = torch.optim.Adam(model.parameters())
        # Finetune process
        n_query = n_pseudo // n_way
        pseudo_set_y = torch.from_numpy(np.repeat(range(n_way),
                                                  n_query)).cuda()
        model.n_query = n_query
        model.train()
        for epoch in range(params.finetune_epoch):
            opt.zero_grad()
            pseudo_set = pseudo_q_genrator.generate(
                xs)  # (5, n_support+n_query, 3, 224, 224)
            scores = model.set_forward(pseudo_set)  # (5*n_query, 5)
            loss = loss_fun(scores, pseudo_set_y)
            loss.backward()
            opt.step()
            del pseudo_set, scores, loss
        torch.cuda.empty_cache()

        # Inference process
        n_query = x.size(1) - n_support
        model.n_query = n_query
        yq = np.repeat(range(n_way), n_query)
        with torch.no_grad():
            scores = model.set_forward(x)  # (80, 5)
            _, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()  # (80, 1)
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        del scores, topk_labels
        torch.cuda.empty_cache()
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))