Ejemplo n.º 1
0
def evaluate(novel_loader, n_way=5, n_support=5):
    iter_num = len(novel_loader)
    acc_all = []
    # Model
    if params.method == 'MatchingNet':
        model = MatchingNet(model_dict[params.model],
                            n_way=n_way,
                            n_support=n_support).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=n_way,
                            n_support=n_support).cuda()
    elif params.method == 'ProtoNet':
        model = ProtoNet(model_dict[params.model],
                         n_way=n_way,
                         n_support=n_support).cuda()
    elif params.method == 'GNN':
        model = GnnNet(model_dict[params.model],
                       n_way=n_way,
                       n_support=n_support).cuda()
    elif params.method == 'TPN':
        model = TPN(model_dict[params.model], n_way=n_way,
                    n_support=n_support).cuda()
    else:
        print("Please specify the method!")
        assert (False)
    # Update model
    checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir,
                                                           params.name)
    state = torch.load(checkpoint_dir)['state']
    if 'FWT' in params.name:
        model_params = model.state_dict()
        pretrained_dict = {k: v for k, v in state.items() if k in model_params}
        model_params.update(pretrained_dict)
        model.load_state_dict(model_params)
    else:
        model.load_state_dict(state)

    # For TPN model, we compute Batch Norm statistics from the test-time support set, not the exponential moving averages.
    if params.method != 'TPN':
        model.eval()
    for ti, (x, _) in enumerate(novel_loader):  # x:(5, 20, 3, 224, 224)
        x = x.cuda()
        n_query = x.size(1) - n_support
        model.n_query = n_query
        yq = np.repeat(range(n_way), n_query)
        with torch.no_grad():
            scores = model.set_forward(x)  # (80, 5)
            _, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()  # (80, 1)
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Ejemplo n.º 2
0
def prepare_model(args):
    if args.method == 'protonet':
        from methods.protonet import ProtoNet
        model = ProtoNet(args)
    elif args.method == 'relationnet':
        from methods.relationnet import RelationNet
        model = RelationNet(args)
    return model
Ejemplo n.º 3
0
def loadmodel(hook_fn):
    if settings.MODEL_FILE is None:
        model = torchvision.models.__dict__[settings.MODEL](pretrained=True)
    else:
        checkpoint = torch.load(settings.MODEL_FILE)
        if type(checkpoint).__name__ == 'OrderedDict' or type(
                checkpoint).__name__ == 'dict':
            if settings.MODEL == 'genotype':
                backbone.geneC = settings.CHANNELS  #params.channels
                backbone.geneLayers = settings.LAYERS  #params.layers
                backbone.geneName = settings.GENE_NAME  #params.gene_name
                model = ProtoNet(Genotype, settings.N_WAY, settings.N_SHOT)
                state_dict = checkpoint['state']
                model.load_state_dict(state_dict)
            else:
                model = torchvision.models.__dict__[settings.MODEL](
                    num_classes=settings.NUM_CLASSES)
                if settings.MODEL_PARALLEL:
                    state_dict = {
                        str.replace(k, 'module.', ''): v
                        for k, v in checkpoint['state_dict'].items()
                    }  # the data parallel layer will add 'module' before each layer name
                else:
                    state_dict = checkpoint
                model.load_state_dict(state_dict)
        else:
            model = checkpoint
    for name in settings.FEATURE_NAMES:
        if settings.MODEL == 'genotype':
            model._modules.get('feature')._modules.get('cells')._modules.get(
                '%d' % (settings.LAYER - 1)).register_forward_hook(hook_fn)
        else:
            model._modules.get(name).register_forward_hook(hook_fn)
    if settings.GPU:
        model.cuda()
    model.eval()
    return model
Ejemplo n.º 4
0
    def __init__(self, params):
        super(LFTNet, self).__init__()
        backbone.FeatureWiseTransformation2d_fw.feature_augment = True
        backbone.ConvBlock.FWT = True
        backbone.SimpleBlock.FWT = True
        backbone.ResNet.FWT = True

        if params.method == 'ProtoNet':
            model = ProtoNet(model_dict[params.model],
                             n_way=params.train_n_way,
                             n_support=params.n_shot)
        elif params.method == 'MatchingNet':
            backbone.LSTMCell.FWT = True
            model = MatchingNet(model_dict[params.model],
                                n_way=params.train_n_way,
                                n_support=params.n_shot)
        elif params.method == 'RelationNet':
            relationnet.RelationConvBlock.FWT = True
            relationnet.RelationModule.FWT = True
            model = relationnet.RelationNet(model_dict[params.model],
                                            n_way=params.train_n_way,
                                            n_support=params.n_shot)
        elif params.method == 'GNN':
            gnnnet.GnnNet.FWT = True
            gnn.Gconv.FWT = True
            gnn.Wcompute.FWT = True
            model = gnnnet.GnnNet(model_dict[params.model],
                                  n_way=params.train_n_way,
                                  n_support=params.n_shot)
        elif params.method == 'TPN':
            tpn.RelationNetwork.FWT = True
            model = tpn.TPN(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
        else:
            raise ValueError('Unknown method')
        self.model = model
        print('\ttrain with {} framework'.format(params.method))

        # optimizer
        model_params = self.split_model_parameters()
        self.model_optim = torch.optim.Adam(model_params)

        # total epochs
        self.total_epoch = params.stop_epoch
Ejemplo n.º 5
0
            model = backbone.Conv4NP()
        elif params.model == 'Conv6':
            model = backbone.Conv6NP()
        elif params.model == 'Conv4S':
            model = backbone.Conv4SNP()
        else:
            model = model_dict[params.model](flatten=False)
    elif params.method in ['maml', 'maml_approx']:
        raise ValueError('MAML do not support save feature')
    else:
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)
        if params.method == 'protonet':
            print("USE BN:", not params.no_bn)
            model = ProtoNet(model_dict[params.model],
                             **train_few_shot_params,
                             use_bn=(not params.no_bn))
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        else:  # baseline and baseline++
            if isinstance(model_dict[params.model], str):
                if model_dict[params.model] == 'resnet18':
                    model = ResidualNet('ImageNet', 18, 1000, None)
            else:
                model = model_dict[params.model]()

    model = model.cuda()
    if params.method != 'baseline':
        model.feature = model.feature.cuda()
def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
Ejemplo n.º 7
0
def meta_test(novel_loader,
              n_query=15,
              pretrained_dataset='miniImageNet',
              freeze_backbone=False,
              n_pseudo=100,
              n_way=5,
              n_support=5):
    #few_shot_params={"n_way":5, "n_support":5}
    #pretrained_dataset = "miniImageNet"
    #n_pseudo=100
    #n_way=5 # five class
    #n_support=5 # each class contain 5 support images. Thus, 25 query images in total
    #freeze_backbone=True
    #n_query=15 # each class contains 15 query images. Thus, 75 query images in total
    correct = 0
    count = 0

    iter_num = len(novel_loader)  #600

    acc_all = []
    for ti, (x, y) in enumerate(novel_loader):  #600 "ti"mes

        ###############################################################################################
        # load pretrained model on miniImageNet
        if params.method == 'protonet':
            pretrained_model = ProtoNet(model_dict[params.model],
                                        n_way=n_way,
                                        n_support=n_support)
        elif 'mytpn' in params.method:
            pretrained_model = MyTPN(model_dict[params.model],
                                     n_way=n_way,
                                     n_support=n_support)

        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        checkpoint_dir += '_5way_5shot'

        params.save_iter = -1
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        print(
            "load from %s" % (modelfile)
        )  #logs/checkpoints/miniImageNet/ResNet10_protonet_aug_5way_5shot/best_model.tar
        tmp = torch.load(modelfile)
        state = tmp['state']
        pretrained_model.load_state_dict(state)  #load checkpoints to model
        pretrained_model.cuda()
        ###############################################################################################
        # split data into support set and query set
        n_query = x.size(1) - n_support  #20-5=15

        x = x.cuda()  ##torch.Size([5, 20, 3, 224, 224])
        x_var = Variable(x)

        support_size = n_way * n_support  #25

        y_a_i = Variable(torch.from_numpy(np.repeat(
            range(n_way), n_support))).cuda()  # (25,)

        x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
            n_way * n_query,
            *x.size()[2:])  # query set (75,3,224,224)
        x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
            n_way * n_support,
            *x.size()[2:])  # support set (25,3,224,224)

        if freeze_backbone == False:
            ###############################################################################################
            # Finetune components initialization
            pseudo_q_genrator = PseudoQeuryGenerator(n_way, n_support,
                                                     n_pseudo)
            delta_opt = torch.optim.Adam(
                filter(lambda p: p.requires_grad,
                       pretrained_model.parameters()))

            ###############################################################################################
            # finetune process
            finetune_epoch = 100

            fine_tune_n_query = n_pseudo // n_way  # 100//5 =20
            pretrained_model.n_query = fine_tune_n_query  #20
            pretrained_model.train()

            z_support = x_a_i.view(n_way, n_support,
                                   *x_a_i.size()[1:])  #(5,5,3,224,224)

            for epoch in range(finetune_epoch):  #100 EPOCH
                delta_opt.zero_grad()  #clear feature extractor gradient

                # generate pseudo query images
                psedo_query_set, _ = pseudo_q_genrator.generate(x_a_i)
                psedo_query_set = psedo_query_set.cuda().view(
                    n_way, fine_tune_n_query,
                    *x_a_i.size()[1:])  #(5,20,3,224,224)

                x = torch.cat((z_support, psedo_query_set), dim=1)

                loss = pretrained_model.set_forward_loss(x)
                loss.backward()
                delta_opt.step()

        ###############################################################################################
        # inference

        pretrained_model.eval()

        pretrained_model.n_query = n_query  #15
        with torch.no_grad():
            scores = pretrained_model.set_forward(
                x_var.cuda())  #set_forward in protonet.py

        y_query = np.repeat(range(n_way),
                            n_query)  #[0,...0, ...4,...4] with shape (75)
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        #the 1st argument means return top-1
        #the 2nd argument dim=1 means return the value row-wisely
        #the 3rd arguemtn is largest=True
        #the 4th argument is sorted=True
        topk_ind = topk_labels.cpu().numpy()

        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)

        acc_all.append((correct_this / count_this * 100))
        print("Task %d : %4.2f%%  Now avg: %4.2f%%" %
              (ti, correct_this / count_this * 100, np.mean(acc_all)))
        ###############################################################################################
    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %
          (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Ejemplo n.º 8
0
            # tx, _ = next(tl)
            base_loader = [base_loader, target_loader]

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot_test)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if params.method == 'protonet':
            if params.adversarial:
                model = ProtoNet(model_dict[params.model],
                                 params.test_n_way,
                                 params.n_shot,
                                 discriminator=backbone.Disc_model(
                                     params.train_n_way),
                                 cosine=params.cosine)
            elif params.adaptFinetune:
                assert (params.adversarial == False)
                model = ProtoNet(
                    model_dict[params.model],
                    params.test_n_way,
                    params.n_shot,
                    adaptive_classifier=backbone.Adaptive_Classifier(
                        params.train_n_way),
                    cosine=params.cosine)
            else:
                model = ProtoNet(model_dict[params.model],
                                 params.test_n_way,
                                 params.n_shot,
Ejemplo n.º 9
0
        base_loader_l = base_datamgr_l.get_data_loader(base_file,
                                                       aug=params.train_aug)

        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params,
                                     isAircraft=isAircraft,
                                     grey=params.grey)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model],
                             **train_few_shot_params,
                             use_bn=(not params.no_bn),
                             pretrain=params.pretrain)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
Ejemplo n.º 10
0
                                    n_query=n_query,
                                    **train_few_shot_params)
    target_loader = target_datamgr.get_data_loader(novel_file, aug=False)
    base_loader = [base_loader, target_loader]

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot_test)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.method == 'protonet':
        modelT = ProtoNet(model_dict[params.model],
                          params.test_n_way,
                          params.n_shot,
                          discriminator=backbone.Disc_model(
                              params.train_n_way))
        modelS = ProtoNet(model_dict[params.model], params.test_n_way,
                          params.n_shot)
    # pre_train or warm start
    if params.load_modelpth:
        modelS = load_model(modelS, params.load_modelpth)
        modelT = load_model(modelT, params.load_modelpth)
        print('preloading: ', params.load_modelpth)

    featexS = modelS.feature.to(DEVICE1)
    modelT = modelT.to(DEVICE2)

    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
def calculate_dist(n_way, n_shot, task_num, episode, batch_size=64):

    model = ProtoNet(model_dict["ResNet18"], n_way, n_shot)

    device = "cuda:0"
    with torch.cuda.device(device):
        model = model.cuda()

    resume_file = "/home/takumi/research/CloserLookFewShot/instance_selection/feature_space/{0}.tar".format(
        episode)
    tmp = torch.load(resume_file)
    model.load_state_dict(tmp['state'])

    base_file = "/home/takumi/research/CloserLookFewShot/filelists/full_Imagenet_except_testclass.json"

    candidate_data_manager = SimpleDataManager(224, batch_size)

    candidate_data_loader = candidate_data_manager.get_data_loader(
        base_file, aug=False, shuffle=False)

    task_file = "/home/takumi/research/CloserLookFewShot/instance_selection/task/few_shot_task{0}.json".format(
        task_num)

    few_image_list = task_train_reader(task_file, n_shot, n_way)

    few_image_feature_list = []
    for images in few_image_list:
        images = images.to(device)
        features2 = model.feature(images)
        features = torch.mean(features2, dim=0)
        few_image_feature_list.append(features)
        #print(euclidean_dist(features2, features2))

    dim = few_image_feature_list[0].size()[0]
    features_ave = torch.zeros(0, dim).to(device)
    for x in few_image_feature_list:
        features_ave = torch.cat([features_ave, x.unsqueeze(0)], dim=0)

    with open(base_file) as f:
        base = json.load(f)

    image_dists = []

    image_id = 0
    for x, labels in tqdm.tqdm(candidate_data_loader):
        x = x.to(device)
        y = model.feature(x)
        #print(euclidean_dist(y, y))
        #print(euclidean_dist(features_ave, features_ave))
        dist = euclidean_dist(y, features_ave)
        #print(dist)
        dist_min, _ = torch.min(dist, dim=1)

        dist_min = dist_min.tolist()
        labels = labels.tolist()

        for i in range(len(dist_min)):
            image_dists.append(
                [dist_min[i], labels[i], base["image_names"][image_id]])
            image_id += 1

    image_dists.sort()

    task_image_dist = dict()

    task_image_dist["label_names"] = copy.deepcopy(base["label_names"])
    task_image_dist["image_names"] = []
    task_image_dist["image_labels"] = []
    task_image_dist["distance"] = []

    for i in range(len(image_dists)):
        task_image_dist["image_names"].append(image_dists[i][2])
        task_image_dist["image_labels"].append(image_dists[i][1])
        task_image_dist["distance"].append(image_dists[i][0])

    with open(
            "/home/takumi/research/CloserLookFewShot/instance_selection/task/task{0}_dataset_dist_{1}.json"
            .format(task_num, episode), "w") as f:
        json.dump(task_image_dist, f)

    print(len(task_image_dist["image_names"]),
          len(task_image_dist["image_labels"]),
          len(task_image_dist["distance"]))
Ejemplo n.º 12
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
Ejemplo n.º 13
0
def explain_gnnnet():
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'gnnnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'gnn'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRP'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
            print('loaded pretrained model')
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # for module in model.modules():
    #   print(type(module))
    lrp_preset = lrp_presets.SequentialPresetA()
    feature_model = model.feature
    fc_encoder = model.fc
    gnn_net = model.gnn
    lrp_wrapper.add_lrp(fc_encoder, lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)
    # lrp_wrapper.add_lrp(fc_encoder,lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)

    # acc = 0
    # count = 0
    # tested the forward pass is correct by observing the accuracy
    # for i, (x, _, _) in enumerate(data_loader):
    #   x = x.cuda()
    #   support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1)
    #   support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way)
    #   support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1)
    #   support_label = support_label.view(1, -1, model.n_way)
    #   support_label = support_label.cuda()
    #   x = x.view(-1, *x.size()[2:])
    #
    #   x_feature = feature_model(x)
    #   x_fc_encoded = fc_encoder(x_feature)
    #   z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1))
    #   gnn_feature = [
    #     torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2))
    #     for i in range(model.n_query)]
    #   gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0)
    #   scores = gnn_net(gnn_nodes)
    #   scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0,
    #                                                                                                    2).contiguous().view(
    #     -1, model.n_way)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for batch_idx, (x, y, p) in enumerate(data_loader):
        print(p)
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        x = x.cuda()
        support_label = torch.from_numpy(
            np.repeat(range(model.n_way),
                      model.n_support)).unsqueeze(1)  #torch.Size([25, 1])
        support_label = torch.zeros(model.n_way * model.n_support,
                                    model.n_way).scatter(1, support_label,
                                                         1).view(
                                                             model.n_way,
                                                             model.n_support,
                                                             model.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(model.n_way, 1, model.n_way)], dim=1)
        support_label = support_label.view(1, -1, model.n_way)
        support_label = support_label.cuda()  #torch.Size([1, 30, 5])
        x = x.contiguous()
        x = x.view(-1, *x.size()[2:])  #torch.Size([30, 3, 224, 224])
        x_feature = feature_model(x)  #torch.Size([30, 512])
        x_fc_encoded = fc_encoder(x_feature)  #torch.Size([30, 128])
        z = x_fc_encoded.view(model.n_way, -1,
                              x_fc_encoded.size(1))  # (5,6,128)
        gnn_feature = [
            torch.cat([
                z[:, :model.n_support],
                z[:, model.n_support + i:model.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(model.n_query)
        ]  # model.n_query is the number of query images for each class
        # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features.
        # print(len(gnn_feature), gnn_feature[0].shape)
        gnn_nodes = torch.cat(
            [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0
        )  # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero

        #  perform gnn_net step by step
        #  the first iteration
        print('x', gnn_nodes.shape)
        W_init = torch.eye(
            gnn_nodes.size(1), device=gnn_nodes.device
        ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze(
            3
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1)
        # print(W_init.shape)

        W1 = gnn_net._modules['layer_w{}'.format(0)](
            gnn_nodes, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        # print(Wi.shape)
        x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)](
            [W1, gnn_nodes])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new1.shape)  #torch.Size([1, 30, 48])
        gnn_nodes_1 = torch.cat([gnn_nodes, x_new1],
                                2)  # (concat more features)
        # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181])

        #  the second iteration
        W2 = gnn_net._modules['layer_w{}'.format(1)](
            gnn_nodes_1, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)](
            [W2,
             gnn_nodes_1])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new2.shape)
        gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2],
                                2)  # (concat more features)
        # print('gn2', gnn_nodes_2.shape)  #torch.Size([1, 30, 229])

        Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init)
        # print(Wl.shape)  #torch.Size([1, 30, 30, 2])
        scores = gnn_net.layer_last(
            [Wl, gnn_nodes_2])[1]  # (num_querry, num_support + 1, num_way)
        print(scores.shape)

        scores_sf = torch.softmax(scores, dim=-1)
        # print(scores_sf)

        gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf /
                               (1 - scores_sf))
        gnn_logits = gnn_logits.view(-1, model.n_way,
                                     model.n_support + n_query, model.n_way)
        # print(gnn_logits)
        query_scores = scores.view(
            model.n_query, model.n_way, model.n_support + 1,
            model.n_way)[:, :,
                         -1].permute(1, 0,
                                     2).contiguous().view(-1, model.n_way)
        preds = query_scores.data.cpu().numpy().argmax(axis=-1)
        # print(preds.shape)
        for k in range(model.n_way):
            mask = torch.zeros(5).cuda()
            mask[k] = 1
            gnn_logits_cls = gnn_logits.clone()
            gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask
            # print(gnn_logits_cls)
            # print(gnn_logits_cls.shape)
            gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way)
            relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls,
                                                  gnn_net.layer_last, Wl,
                                                  gnn_nodes_2)
            relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48)
            # relevance_gnn_nodes = relevance_gnn_nodes_2
            relevance_gnn_nodes_1 = explain_Gconv(
                relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2,
                gnn_nodes_1)
            relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48)
            relevance_gnn_nodes = explain_Gconv(
                relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1,
                gnn_nodes)
            relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128)
            print(relevance_gnn_features.shape)
            relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128)
            relevance_gnn_features += relevance_gnn_nodes_2.narrow(
                -1, 0, 128)  #[2, 30, 128]
            relevance_gnn_features = relevance_gnn_features.view(
                n_query, model.n_way, model.n_support + 1, 128)
            for i in range(n_query):
                query_i = relevance_gnn_features[i][:, model.
                                                    n_support:model.n_support +
                                                    1]
                if i == 0:
                    relevance_z = query_i
                else:
                    relevance_z = torch.cat((relevance_z, query_i), 1)
            relevance_z = relevance_z.view(-1, 128)
            query_feature = x_feature.view(model.n_way, -1,
                                           512)[:, model.n_support:]
            # print(query_feature.shape)
            query_feature = query_feature.contiguous()
            query_feature = query_feature.view(n_query * model.n_way, 512)
            # print(query_feature.shape)
            relevance_query_features = fc_encoder.compute_lrp(
                query_feature, target=relevance_z)
            # print(relevance_query_features.shape)
            # print(relevance_gnn_features.shape)
            # explain the fc layer and the image encoder
            query_images = x.view(model.n_way, -1,
                                  *x.size()[1:])[:, model.n_support:]
            query_images = query_images.contiguous()

            query_images = query_images.view(-1, *x.size()[1:]).detach()
            # print(query_images.shape)
            lrp_wrapper.add_lrp(feature_model, lrp_preset)
            relevance_query_images = feature_model.compute_lrp(
                query_images, target=relevance_query_features)
            print(relevance_query_images.shape)

            for j in range(n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' +
                                     str(batch_idx), img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir,
                                     'episode' + str(batch_idx),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(batch_idx),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_images[j].permute(
                                1, 2, 0).detach().cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_query_images.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))

        break
Ejemplo n.º 14
0
def explain_relationnet():
    # print(sys.path)
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'relationnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'relationnet'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    acc_all = []
    iter_num = 1000

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRPmse'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # ---test the accuracy on the test set to verify the model is loaded----
    acc = 0
    count = 0
    # for i, (x, y) in enumerate(data_loader):
    #   scores = model.set_forward(x)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    preset = lrp_presets.SequentialPresetA()

    feature_model = copy.deepcopy(model.feature)
    lrp_wrapper.add_lrp(feature_model, preset=preset)
    relation_model = copy.deepcopy(model.relation_module)
    # print(relation_model)
    lrp_wrapper.add_lrp(relation_model, preset=preset)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for i, (x, y, p) in enumerate(data_loader):
        '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size
       y is the global labels of the images with shape as (n_way, n_support + n_query)
       p is the image path as a list of tuples, length is n_query+n_support,  each tuple element is with length n_way'''
        if i >= 3:
            break
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        z_support, z_query = model.parse_feature(x, is_feature=False)
        z_support = z_support.contiguous()
        z_proto = z_support.view(model.n_way, model.n_support,
                                 *model.feat_dim).mean(1)
        # print(z_proto.shape)
        z_query = z_query.contiguous().view(model.n_way * model.n_query,
                                            *model.feat_dim)
        # print(z_query.shape)
        # get relations with metric function
        z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way,
                                                  1, 1, 1, 1)
        # print(z_proto_ext.shape)
        z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1)

        z_query_ext = torch.transpose(z_query_ext, 0, 1)
        # print(z_query_ext.shape)
        extend_final_feat_dim = model.feat_dim.copy()
        extend_final_feat_dim[0] *= 2
        relation_pairs = torch.cat((z_proto_ext, z_query_ext),
                                   2).view(-1, *extend_final_feat_dim)
        # print(relation_pairs.shape)
        relations = relation_model(relation_pairs)
        # print(relations)
        scores = relations.view(-1, model.n_way)
        preds = scores.data.cpu().numpy().argmax(axis=1)
        # print(preds.shape)
        relations = relations.view(-1, model.n_way)
        # print(relations)
        relations_sf = torch.softmax(relations, dim=-1)
        # print(relations_sf)
        relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf /
                                     (1 - relations_sf))
        # print(relations_logits)
        # print(preds)
        relations_logits = relations_logits.view(-1, 1)
        relevance_relations = relation_model.compute_lrp(
            relation_pairs, target=relations_logits)
        # print(relevance_relations.shape)
        # print(model.feat_dim)
        relevance_z_query = torch.narrow(relevance_relations, 1,
                                         model.feat_dim[0], model.feat_dim[0])
        # print(relevance_z_query.shape)
        relevance_z_query = relevance_z_query.view(
            model.n_query * model.n_way, model.n_way,
            *relevance_z_query.size()[1:])
        # print(relevance_z_query.shape)
        query_img = x.narrow(1, model.n_support,
                             model.n_query).view(model.n_way * model.n_query,
                                                 *x.size()[2:])
        # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:])
        # print(query_img.shape)
        for k in range(model.n_way):
            relevance_querry_cls = torch.narrow(relevance_z_query, 1, k,
                                                1).squeeze(1)
            # print(relevance_querry_cls.shape)
            relevance_querry_img = feature_model.compute_lrp(
                query_img.cuda(), target=relevance_querry_cls)
            # print(relevance_querry_img.max(), relevance_querry_img.min())
            # print(relevance_querry_img.shape)
            for j in range(model.n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(i),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_img[j].permute(1, 2,
                                                         0).cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_querry_img.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))
Ejemplo n.º 15
0
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(n_query=n_query, **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(
            root='./filelists/tabula_muris', mode='train')
        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(n_query=n_query, **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(
            root='./filelists/tabula_muris', mode='val')
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
        go_mask = base_loader.dataset.go_mask
        x_dim = base_loader.dataset.get_dim()

        if params.method == 'protonet':
            model = ProtoNet(backbone.FCNet(x_dim), **train_few_shot_params)
        elif params.method == 'comet':
            model = COMET(backbone.EnFCNet(x_dim, go_mask),
                          **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(backbone.FCNet(x_dim), **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:

            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(backbone.FCNet(x_dim),
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            model = MAML(backbone.FCNet(x_dim),
                         approx=(params.method == 'maml_approx'),
Ejemplo n.º 16
0
    base_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **train_few_shot_params)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.method == 'manifold_mixup':
        model = wrn_mixup_model.wrn28_10(64)
    elif params.method == 'S2M2_R':
        model = ProtoNet(model_dict[params.model], params.train_n_way,
                         params.n_shot)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model], 64, loss_type='dist')

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))
            model.cuda()

        if params.resume:
            resume_file = get_resume_file(params.checkpoint_dir)
            print("resume_file", resume_file)
Ejemplo n.º 17
0
    base_datamgr = SetDataManager(image_size,
                                  n_query=params.n_query,
                                  **train_few_shot_params,
                                  num_views=params.num_views)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=params.n_query,
                                 **test_few_shot_params,
                                 num_views=params.num_views)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    backbone = model_dict[params.model]
    model = ProtoNet(backbone, params.num_views, **train_few_shot_params)
    model = model.cuda()
    # model = torch.nn.DataParallel(model).cuda()

    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    params.checkpoint_dir += '_%dway_%dshot_%dviews_lr%f' % (
        params.train_n_way, params.n_shot, params.num_views, params.lr)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
def get_logits_targets(params):
    acc_all = []
    iter_num = 600
    few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) 

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model           = BaselineFinetune( model_dict[params.model], **few_shot_params )
    elif params.method == 'baseline++':
        model           = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params )
    elif params.method == 'protonet':
        model           = ProtoNet( model_dict[params.model], **few_shot_params )
    elif params.method == 'DKT':
        model           = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model           = MatchingNet( model_dict[params.model], **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6': 
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S': 
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model           = RelationNet( feature_model, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(  model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params )
        if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++'] :
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    #modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++'] : 
        if params.save_iter != -1:
            modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
        else:
            modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" +str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84 
        else:
            image_size = 224

        datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
        
        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json' 
            else:
                loadfile   = configs.data_dir['CUB'] + split +'.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json' 
            else:
                loadfile  = configs.data_dir['emnist'] + split +'.json' 
        else: 
            loadfile    = configs.data_dir[params.dataset] + split + '.json'

        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()

        logits_list = list()
        targets_list = list()    
        for i, (x,_) in enumerate(novel_loader):
            logits = model.get_logits(x).detach()
            targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda()
            logits_list.append(logits) #.cpu().detach().numpy())
            targets_list.append(targets) #.cpu().detach().numpy())
    else:
        novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5")
        cl_data_file = feat_loader.init_loader(novel_file)
        logits_list = list()
        targets_list = list()
        n_query = 15
        n_way = few_shot_params['n_way']
        n_support = few_shot_params['n_support']
        class_list = cl_data_file.keys()
        for i in range(iter_num):
            #----------------------
            select_class = random.sample(class_list,n_way)
            z_all  = []
            for cl in select_class:
                img_feat = cl_data_file[cl]
                perm_ids = np.random.permutation(len(img_feat)).tolist()
                z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] )     # stack each batch
            z_all = torch.from_numpy(np.array(z_all))
            model.n_query = n_query
            logits  = model.set_forward(z_all, is_feature = True).detach()
            targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda()
            logits_list.append(logits)
            targets_list.append(targets)
            #----------------------
    return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
def meta_test(novel_loader,
              n_query=15,
              pretrained_dataset='miniImageNet',
              freeze_backbone=False,
              n_pseudo=100,
              n_way=5,
              n_support=5):
    correct = 0
    count = 0

    iter_num = len(novel_loader)

    acc_all = []
    for ti, (x, y) in enumerate(novel_loader):

        ###############################################################################################
        # load pretrained model on miniImageNet
        if params.method == 'protonet':
            pretrained_model = ProtoNet(model_dict[params.model],
                                        n_way=n_way,
                                        n_support=n_support)

        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        checkpoint_dir += '_5way_5shot'

        params.save_iter = -1
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)

        tmp = torch.load(modelfile)
        state = tmp['state']
        pretrained_model.load_state_dict(state)
        pretrained_model.cuda()
        ###############################################################################################
        # split data into support set and query set
        n_query = x.size(1) - n_support

        x = x.cuda()
        x_var = Variable(x)

        support_size = n_way * n_support

        y_a_i = Variable(torch.from_numpy(np.repeat(
            range(n_way), n_support))).cuda()  # (25,)

        x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
            n_way * n_query,
            *x.size()[2:])  # query set
        x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
            n_way * n_support,
            *x.size()[2:])  # support set

        if freeze_backbone == False:
            ###############################################################################################
            # Finetune components initialization
            pseudo_q_genrator = PseudoQeuryGenerator(n_way, n_support,
                                                     n_pseudo)
            delta_opt = torch.optim.Adam(
                filter(lambda p: p.requires_grad,
                       pretrained_model.parameters()))

            ###############################################################################################
            # finetune process
            finetune_epoch = 100

            fine_tune_n_query = n_pseudo // n_way
            pretrained_model.n_query = fine_tune_n_query
            pretrained_model.train()

            z_support = x_a_i.view(n_way, n_support, *x_a_i.size()[1:])

            for epoch in range(finetune_epoch):
                delta_opt.zero_grad()

                # generate pseudo query images
                psedo_query_set, _ = pseudo_q_genrator.generate(x_a_i)
                psedo_query_set = psedo_query_set.cuda().view(
                    n_way, fine_tune_n_query,
                    *x_a_i.size()[1:])

                x = torch.cat((z_support, psedo_query_set), dim=1)

                loss = pretrained_model.set_forward_loss(x)
                loss.backward()
                delta_opt.step()

        ###############################################################################################
        # inference

        pretrained_model.eval()

        pretrained_model.n_query = n_query
        with torch.no_grad():
            scores = pretrained_model.set_forward(x_var.cuda())

        y_query = np.repeat(range(n_way), n_query)
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        topk_ind = topk_labels.cpu().numpy()

        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)

        acc_all.append((correct_this / count_this * 100))
        print("Task %d : %4.2f%%  Now avg: %4.2f%%" %
              (ti, correct_this / count_this * 100, np.mean(acc_all)))
        ###############################################################################################
    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %
          (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Ejemplo n.º 20
0
                                        **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   params.max_length)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot,
                                    n_query=params.n_query)  # (5,5)
        val_datamgr = SetREDataManager(max_length=params.max_length,
                                       **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, params.max_length)

        if params.method == 'protonet':
            # model           = ProtoNet( model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params)
            model = ProtoNet(model_dict[params.model],
                             tf_path=params.tf_dir,
                             **train_few_shot_params,
                             proto_attention=params.proto_attention,
                             distance=params.proto_distance,
                             common=params.proto_common)
        elif params.method == 'gnnnet':
            model = GnnNet(model_dict[params.model],
                           tf_path=params.tf_dir,
                           **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                tf_path=params.tf_dir,
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
Ejemplo n.º 21
0
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model],
                             tf_path=params.tf_dir,
                             **train_few_shot_params)
        elif params.method == 'gnnnet':
            model = GnnNet(model_dict[params.model],
                           tf_path=params.tf_dir,
                           **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                tf_path=params.tf_dir,
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            else:
Ejemplo n.º 22
0
    pred = scores.data.cpu().numpy().argmax(axis = 1)
    y = np.repeat(range( selected_n_way ), n_query )
    acc = np.mean(pred == y)*100 
    return acc

if __name__ == '__main__':
    params = parse_args('test')
    acc_all = []
    iter_num = 1000
    n_query = 100
    few_shot_params = dict(n_way=params.test_n_way, n_query = n_query, max_n_way=params.test_max_way, min_n_way=params.test_min_way, \
            max_shot=params.max_shot, min_shot=params.min_shot, fixed_way=params.fixed_way) 

    
    if params.method == 'protonet':
        model           = ProtoNet( model_dict[params.model], **few_shot_params )
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    #checkpoint_dir = '%s/checkpoints/%s/%s_%s_regularizer' %(configs.save_dir, params.dataset, params.model, params.method)
    checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    
    if params.train_aug:
        checkpoint_dir += '_aug'
    
    if params.train_n_way != -1:
        checkpoint_dir += '_%d-way_' %( params.train_n_way )
    else:
        checkpoint_dir += '_random-way_'
Ejemplo n.º 23
0
            model           = BaselineTrain( encoder, params.num_classes, loss_type = 'dist')

    elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']:
        n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
 
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot) 
        base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
        base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
         
        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot) 
        val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
        val_loader              = val_datamgr.get_data_loader( val_file, aug = False) 
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor        

        if params.method == 'protonet':
            model           = ProtoNet( encoder, **train_few_shot_params )
        elif params.method == 'matchingnet':
            model           = MatchingNet( model_dict[params.model], **train_few_shot_params )
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4': 
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6': 
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S': 
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model]( flatten = False )
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model           = RelationNet( feature_model, loss_type = loss_type , **train_few_shot_params )
        elif params.method in ['maml' , 'maml_approx']:
Ejemplo n.º 24
0
def get_model(params, mode):
    '''
    Args:
        params: argparse params
        mode: (str), 'train', 'test'
    '''
    print('get_model() start...')
#     few_shot_params_d = get_few_shot_params(params, None)
#     few_shot_params = few_shot_params_d[mode]
    few_shot_params = get_few_shot_params(params, mode)
    
    if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#     if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]:
#         assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation'
        params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop
        if params.recons_decoder is not None:
            if 'ConvS' not in params.recons_decoder:
                raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.')
    
#     if mode == 'train':
#         params.num_classes = n_base_class_map[params.dataset]
    if params.method in ['baseline', 'baseline++'] and mode=='train':
        assert params.num_classes >= n_base_classes[params.dataset]
#         if params.dataset == 'omniglot': # 4112/688/1692
#             assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char': # 1597/31/31
#             assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_half': # 758/31/31
#             assert params.num_classes >= 758, 'class number need to be larger than max label id in base class'
#         if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31
#             assert params.num_classes >= 350, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_char_base3lang': # 69/31/31
#             assert params.num_classes >= 69, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'miniImagenet': # 64/16/20
#             assert params.num_classes >= 64, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'CUB': # 100/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross': # 64+16+20/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
#         if params.dataset == 'cross_base80cl': # 80/50/50
#             assert params.num_classes >= 100, 'class number need to be larger than max label id in base class'
        
    
    if params.recons_decoder == None:
        print('params.recons_decoder == None')
        recons_decoder = None
    else:
        recons_decoder = decoder_dict[params.recons_decoder]
        print('recons_decoder:\n',recons_decoder)

    backbone_func = get_backbone_func(params)
    
    if 'baseline' in params.method:
        loss_types = {
            'baseline':'softmax', 
            'baseline++':'dist', 
        }
        loss_type = loss_types[params.method]
        
        if recons_decoder is None and params.min_gram is None: # default baseline/baseline++
            if mode == 'train':
                model = BaselineTrain(
                    model_func = backbone_func, loss_type = loss_type, 
                    num_class = params.num_classes, **few_shot_params)
            elif mode == 'test':
                model = BaselineFinetune(
                    model_func = backbone_func, loss_type = loss_type, 
                    **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
        else: # other settings for baseline
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                if mode == 'train':
                    model = BaselineTrainMinGram(
                        model_func = backbone_func, loss_type = loss_type, 
                        num_class = params.num_classes, **few_shot_params, **min_gram_params)
                elif mode == 'test':
                    model = BaselineFinetune(
                        model_func = backbone_func, loss_type = loss_type, 
                        **few_shot_params, finetune_dropout_p = params.finetune_dropout_p)
#                     model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params)
            
    
    elif params.method == 'protonet':
        # default ProtoNet
        if recons_decoder is None and params.min_gram is None:
            model = ProtoNet( backbone_func, **few_shot_params )
        else: # other settings
            if params.min_gram is not None:
                min_gram_params = {
                    'min_gram':params.min_gram, 
                    'lambda_gram':params.lambda_gram, 
                }
                model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params)

            if params.recons_decoder is not None:
                if 'Hidden' in params.recons_decoder:
                    if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2)
                    elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS'
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False)
                    elif params.recons_decoder == 'HiddenRes10':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6)
                    elif params.recons_decoder == 'HiddenRes18':
                        model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8)
                else:
                    if 'ConvS' in params.recons_decoder:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False)
                    else:
                        model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True)
    elif params.method == 'matchingnet':
        model           = MatchingNet( backbone_func, **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
#         if params.model == 'Conv4': 
#             feature_model = backbone.Conv4NP
#         elif params.model == 'Conv6': 
#             feature_model = backbone.Conv6NP
#         elif params.model == 'Conv4S': 
#             feature_model = backbone.Conv4SNP
#         else:
#             feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

        model           = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model           = MAML(  backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params )
        if 'omniglot' in params.dataset or 'cross_char' in params.dataset:
#         if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unexpected params.method: %s'%(params.method))
    
    print('get_model() finished.')
    return model
Ejemplo n.º 25
0
    if params.method == 'MatchingNet':
        model = MatchingNet(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNetLRP':
        model = RelationNetLRP(model_dict[params.model],
                               n_way=params.train_n_way,
                               n_support=params.n_shot).cuda()
    elif params.method == 'ProtoNet':
        model = ProtoNet(model_dict[params.model],
                         n_way=params.train_n_way,
                         n_support=params.n_shot).cuda()
    elif params.method == 'GNN':
        model = GnnNet(model_dict[params.model],
                       n_way=params.train_n_way,
                       n_support=params.n_shot).cuda()
    elif params.method == 'GNNLRP':
        model = GnnNetLRP(model_dict[params.model],
                          n_way=params.train_n_way,
                          n_support=params.n_shot).cuda()
    elif params.method == 'TPN':
        model = TPN(model_dict[params.model],
                    n_way=params.train_n_way,
                    n_support=params.n_shot).cuda()
    else:
        print("Please specify the method!")
Ejemplo n.º 26
0
    # feature_path=os.path.join(split + ".hdf5")
    f = save_features(model, data_loader, featurefile)

    #f.close()

    print('\nStage 2: evaluate')
    acc_all = []
    iter_num = 1000
    few_shot_params = dict(n_way=params.test_n_way,
                           n_support=params.n_shot)  # 5,5
    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(model_dict[params.model],
                         **few_shot_params,
                         proto_attention=params.proto_attention,
                         distance=params.proto_distance,
                         common=params.proto_common)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = re_backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = re_backbone.Conv6NP
        else:
            feature_model = model_dict[params.model]
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
Ejemplo n.º 27
0
def single_test(params, results_logger):
    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'DKT':
        model = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    elif params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(model_dict[params.model],
                     approx=(params.method == 'maml_approx'),
                     **few_shot_params)
        if params.dataset in ['omniglot', 'cross_char'
                              ]:  # maml use different parameter in omniglot
            model.n_task = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)

    # modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++']:
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " +
                  str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx',
                         'DKT']:  # maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  # We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)

    else:
        novel_file = os.path.join(
            checkpoint_dir.replace("checkpoints",
                                   "features"), split_str + ".hdf5"
        )  # defaut split = novel, but you can also test base or val classes
        cl_data_file = feat_loader.init_loader(novel_file)

        for i in range(iter_num):
            acc = feature_evaluation(cl_data_file,
                                     model,
                                     n_query=15,
                                     adaptation=params.adaptation,
                                     **few_shot_params)
            acc_all.append(acc)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %
              (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
    with open('record/results.txt', 'a') as f:
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        aug_str = '-aug' if params.train_aug else ''
        aug_str += '-adapted' if params.adaptation else ''
        if params.method in ['baseline', 'baseline++']:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.test_n_way)
        else:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.train_n_way, params.test_n_way)
        acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % (
            iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))
        f.write('Time: %s, Setting: %s, Acc: %s \n' %
                (timestamp, exp_setting, acc_str))
        results_logger.log("single_test_acc", acc_mean)
        results_logger.log("single_test_acc_std",
                           1.96 * acc_std / np.sqrt(iter_num))
        results_logger.log("time", timestamp)
        results_logger.log("exp_setting", exp_setting)
        results_logger.log("acc_str", acc_str)
    return acc_mean
Ejemplo n.º 28
0
    else:
        featurefile = os.path.join(
            checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5")
    dirname = os.path.dirname(featurefile)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    save_features(model, data_loader, featurefile)

    print('\nStage 2: evaluate')
    acc_all = []
    iter_num = 1000
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchinenet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = model_dict[params.model]
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
Ejemplo n.º 29
0
    params = parse_args('train')

    ##################################################################
    image_size = 224
    iter_num = 600
    pretrained_dataset = "miniImageNet"

    n_query = max(
        1, int(16 * params.test_n_way / params.train_n_way)
    )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.method in ["gnnnet", "gnnnet_maml"]:
        model = GnnNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'relationnet':
        feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)

    elif params.method in ["dampnet_full_class"]:
        model = dampnet_full_class.DampNet(model_dict[params.model],
                                           **few_shot_params)
    elif params.method == "baseline":
        checkpoint_dir_b = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, "baseline")
        if params.train_aug:
            checkpoint_dir_b += '_aug'
Ejemplo n.º 30
0
                                                   num_workers=32)



        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params,
                                     isAircraft=isAircraft)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model],
                             **train_few_shot_params,
                             use_bn=(not params.no_bn),
                             pretrain=params.pretrain,
                             image_loader=image_loader,
                             len_dataset=len_dataset)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True