Ejemplo n.º 1
0
                               delimiter='\t')
                    sys.exit(0)

    split = args.split
    if split == 'attributes' and args.method != 'protonet':
        raise NotImplementedError
    if args.method in ['maml', 'maml_approx'
                       ]:  #maml do not support testing with feature
        if 'Conv' in args.model:
            image_size = 84
        else:
            image_size = 224

        datamgr = SetDataManager(image_size,
                                 n_episode=iter_num,
                                 n_query=15,
                                 **few_shot_args,
                                 args=args)

        if args.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        else:
            loadfile = configs.data_dir[args.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if args.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
Ejemplo n.º 2
0
 if params.method in ['baseline++', 'S2M2_R', 'rotation']:
     if params.dct_status:
         base_datamgr = SimpleDataManager(image_size_dct,
                                          batch_size=params.batch_size)
         base_loader = base_datamgr.get_data_loader_dct(
             base_file,
             aug=params.train_aug,
             filter_size=params.filter_size)
         base_datamgr_test = SimpleDataManager(
             image_size_dct, batch_size=params.test_batch_size)
         base_loader_test = base_datamgr_test.get_data_loader_dct(
             base_file, aug=False, filter_size=params.filter_size)
         test_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
         val_datamgr = SetDataManager(image_size_dct,
                                      n_query=15,
                                      **test_few_shot_params)
         val_loader = val_datamgr.get_data_loader_dct(
             val_file, aug=False, filter_size=params.filter_size)
     else:
         base_datamgr = SimpleDataManager(image_size,
                                          batch_size=params.batch_size)
         base_loader = base_datamgr.get_data_loader(base_file,
                                                    aug=params.train_aug)
         base_datamgr_test = SimpleDataManager(
             image_size, batch_size=params.test_batch_size)
         base_loader_test = base_datamgr_test.get_data_loader(base_file,
                                                              aug=False)
         test_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
         val_datamgr = SetDataManager(image_size,
def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
Ejemplo n.º 4
0
                                           rotation=params.rotation,
                                           isAircraft=isAircraft,
                                           grey=params.grey,
                                           shuffle=False)
        if params.dataset_unlabel is not None:
            base_loader_u = base_datamgr_u.get_data_loader(
                base_file_unlabel, aug=params.train_aug)
        else:
            base_loader_u = base_datamgr_u.get_data_loader(
                base_file, aug=params.train_aug)

        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        base_datamgr_l = SetDataManager(image_size,
                                        n_query=n_query,
                                        **train_few_shot_params,
                                        isAircraft=isAircraft,
                                        grey=params.grey)
        base_loader_l = base_datamgr_l.get_data_loader(base_file,
                                                       aug=params.train_aug)

        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params,
                                     isAircraft=isAircraft,
                                     grey=params.grey)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
Ejemplo n.º 5
0
                                                   params.name)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    # dataloader
    print('\n--- Prepare dataloader ---')
    print('\ttrain with seen domain {}'.format(params.dataset))
    print('\tval with seen domain {}'.format(params.testset))
    base_file = os.path.join(params.data_dir, params.dataset, 'base.json')
    val_file = os.path.join(params.data_dir, params.testset, 'val.json')

    # model
    image_size = 224
    n_query = max(1, int(16 * params.test_n_way / params.train_n_way))
    base_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  n_way=params.train_n_way,
                                  n_support=params.n_shot)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    val_datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 n_way=params.test_n_way,
                                 n_support=params.n_shot)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.method == 'MatchingNet':
        model = MatchingNet(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=params.train_n_way,
Ejemplo n.º 6
0
                num_layers=args.rnn_num_layers,
                dropout=args.rnn_dropout,
            )
            l3_model = l3_model.cuda()

        embedding_model = embedding_model.cuda()
        lang_model = lang_model.cuda()

    # if test_n_way is smaller than train_n_way, reduce n_query to keep batch
    # size small
    n_query = max(1, int(16 * args.test_n_way / args.train_n_way))

    train_few_shot_args = dict(n_way=args.train_n_way, n_support=args.n_shot)
    base_datamgr = SetDataManager("CUB",
                                  84,
                                  n_query=n_query,
                                  **train_few_shot_args,
                                  args=args)
    print("Loading train data")

    base_loader = base_datamgr.get_data_loader(
        base_file,
        aug=True,
        lang_dir=constants.LANG_DIR,
        normalize=True,
        vocab=vocab,
        # Maximum training data restrictions only apply at train time
        max_class=args.max_class,
        max_img_per_class=args.max_img_per_class,
        max_lang_per_class=args.max_lang_per_class,
    )
Ejemplo n.º 7
0
    base_file = configs.data_dir[params.dataset] + 'base.json'
    val_file = configs.data_dir[params.dataset] + 'val.json'
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)
    base_datamgr_test = SimpleDataManager(image_size,
                                          batch_size=params.test_batch_size)
    base_loader_test = base_datamgr_test.get_data_loader(base_file, aug=False)
    test_few_shot_params = dict(n_way=5, n_support=1)
    val_datamgr = SetDataManager(image_size,
                                 n_query=15,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.method == 'manifold_mixup':
        print(params.num_classes)
        model = wrn_mixup_model.wrn28_10(params.num_classes)
    elif params.method == 'S2M2_R':
        model = wrn_mixup_model.wrn28_10(params.num_classes)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model],
                              params.num_classes,
                              loss_type='dist')

    if params.method == 'S2M2_R':
        if use_gpu:
Ejemplo n.º 8
0
def explain_gnnnet():
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'gnnnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'gnn'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRP'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
            print('loaded pretrained model')
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # for module in model.modules():
    #   print(type(module))
    lrp_preset = lrp_presets.SequentialPresetA()
    feature_model = model.feature
    fc_encoder = model.fc
    gnn_net = model.gnn
    lrp_wrapper.add_lrp(fc_encoder, lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)
    # lrp_wrapper.add_lrp(fc_encoder,lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)

    # acc = 0
    # count = 0
    # tested the forward pass is correct by observing the accuracy
    # for i, (x, _, _) in enumerate(data_loader):
    #   x = x.cuda()
    #   support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1)
    #   support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way)
    #   support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1)
    #   support_label = support_label.view(1, -1, model.n_way)
    #   support_label = support_label.cuda()
    #   x = x.view(-1, *x.size()[2:])
    #
    #   x_feature = feature_model(x)
    #   x_fc_encoded = fc_encoder(x_feature)
    #   z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1))
    #   gnn_feature = [
    #     torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2))
    #     for i in range(model.n_query)]
    #   gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0)
    #   scores = gnn_net(gnn_nodes)
    #   scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0,
    #                                                                                                    2).contiguous().view(
    #     -1, model.n_way)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for batch_idx, (x, y, p) in enumerate(data_loader):
        print(p)
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        x = x.cuda()
        support_label = torch.from_numpy(
            np.repeat(range(model.n_way),
                      model.n_support)).unsqueeze(1)  #torch.Size([25, 1])
        support_label = torch.zeros(model.n_way * model.n_support,
                                    model.n_way).scatter(1, support_label,
                                                         1).view(
                                                             model.n_way,
                                                             model.n_support,
                                                             model.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(model.n_way, 1, model.n_way)], dim=1)
        support_label = support_label.view(1, -1, model.n_way)
        support_label = support_label.cuda()  #torch.Size([1, 30, 5])
        x = x.contiguous()
        x = x.view(-1, *x.size()[2:])  #torch.Size([30, 3, 224, 224])
        x_feature = feature_model(x)  #torch.Size([30, 512])
        x_fc_encoded = fc_encoder(x_feature)  #torch.Size([30, 128])
        z = x_fc_encoded.view(model.n_way, -1,
                              x_fc_encoded.size(1))  # (5,6,128)
        gnn_feature = [
            torch.cat([
                z[:, :model.n_support],
                z[:, model.n_support + i:model.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(model.n_query)
        ]  # model.n_query is the number of query images for each class
        # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features.
        # print(len(gnn_feature), gnn_feature[0].shape)
        gnn_nodes = torch.cat(
            [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0
        )  # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero

        #  perform gnn_net step by step
        #  the first iteration
        print('x', gnn_nodes.shape)
        W_init = torch.eye(
            gnn_nodes.size(1), device=gnn_nodes.device
        ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze(
            3
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1)
        # print(W_init.shape)

        W1 = gnn_net._modules['layer_w{}'.format(0)](
            gnn_nodes, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        # print(Wi.shape)
        x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)](
            [W1, gnn_nodes])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new1.shape)  #torch.Size([1, 30, 48])
        gnn_nodes_1 = torch.cat([gnn_nodes, x_new1],
                                2)  # (concat more features)
        # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181])

        #  the second iteration
        W2 = gnn_net._modules['layer_w{}'.format(1)](
            gnn_nodes_1, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)](
            [W2,
             gnn_nodes_1])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new2.shape)
        gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2],
                                2)  # (concat more features)
        # print('gn2', gnn_nodes_2.shape)  #torch.Size([1, 30, 229])

        Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init)
        # print(Wl.shape)  #torch.Size([1, 30, 30, 2])
        scores = gnn_net.layer_last(
            [Wl, gnn_nodes_2])[1]  # (num_querry, num_support + 1, num_way)
        print(scores.shape)

        scores_sf = torch.softmax(scores, dim=-1)
        # print(scores_sf)

        gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf /
                               (1 - scores_sf))
        gnn_logits = gnn_logits.view(-1, model.n_way,
                                     model.n_support + n_query, model.n_way)
        # print(gnn_logits)
        query_scores = scores.view(
            model.n_query, model.n_way, model.n_support + 1,
            model.n_way)[:, :,
                         -1].permute(1, 0,
                                     2).contiguous().view(-1, model.n_way)
        preds = query_scores.data.cpu().numpy().argmax(axis=-1)
        # print(preds.shape)
        for k in range(model.n_way):
            mask = torch.zeros(5).cuda()
            mask[k] = 1
            gnn_logits_cls = gnn_logits.clone()
            gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask
            # print(gnn_logits_cls)
            # print(gnn_logits_cls.shape)
            gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way)
            relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls,
                                                  gnn_net.layer_last, Wl,
                                                  gnn_nodes_2)
            relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48)
            # relevance_gnn_nodes = relevance_gnn_nodes_2
            relevance_gnn_nodes_1 = explain_Gconv(
                relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2,
                gnn_nodes_1)
            relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48)
            relevance_gnn_nodes = explain_Gconv(
                relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1,
                gnn_nodes)
            relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128)
            print(relevance_gnn_features.shape)
            relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128)
            relevance_gnn_features += relevance_gnn_nodes_2.narrow(
                -1, 0, 128)  #[2, 30, 128]
            relevance_gnn_features = relevance_gnn_features.view(
                n_query, model.n_way, model.n_support + 1, 128)
            for i in range(n_query):
                query_i = relevance_gnn_features[i][:, model.
                                                    n_support:model.n_support +
                                                    1]
                if i == 0:
                    relevance_z = query_i
                else:
                    relevance_z = torch.cat((relevance_z, query_i), 1)
            relevance_z = relevance_z.view(-1, 128)
            query_feature = x_feature.view(model.n_way, -1,
                                           512)[:, model.n_support:]
            # print(query_feature.shape)
            query_feature = query_feature.contiguous()
            query_feature = query_feature.view(n_query * model.n_way, 512)
            # print(query_feature.shape)
            relevance_query_features = fc_encoder.compute_lrp(
                query_feature, target=relevance_z)
            # print(relevance_query_features.shape)
            # print(relevance_gnn_features.shape)
            # explain the fc layer and the image encoder
            query_images = x.view(model.n_way, -1,
                                  *x.size()[1:])[:, model.n_support:]
            query_images = query_images.contiguous()

            query_images = query_images.view(-1, *x.size()[1:]).detach()
            # print(query_images.shape)
            lrp_wrapper.add_lrp(feature_model, lrp_preset)
            relevance_query_images = feature_model.compute_lrp(
                query_images, target=relevance_query_features)
            print(relevance_query_images.shape)

            for j in range(n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' +
                                     str(batch_idx), img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir,
                                     'episode' + str(batch_idx),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(batch_idx),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_images[j].permute(
                                1, 2, 0).detach().cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_query_images.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))

        break
Ejemplo n.º 9
0
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))


if __name__ == '__main__':
    np.random.seed(10)
    params = parse_args()

    image_size = 224
    iter_num = 2000
    n_query = 16

    print('Loading target dataset!')
    novel_file = os.path.join(params.data_dir, params.dataset, 'novel.json')
    datamgr = SetDataManager(image_size,
                             n_query=n_query,
                             n_way=params.test_n_way,
                             n_support=params.n_shot,
                             n_eposide=iter_num)
    novel_loader = datamgr.get_data_loader(novel_file, aug=False)

    evaluate(novel_loader, n_way=params.test_n_way, n_support=params.n_shot)
Ejemplo n.º 10
0
    pred = scores.data.cpu().numpy().argmax(axis=1)
    y = np.repeat(range(n_way), n_query)
    acc = np.mean(pred == y) * 100
    return acc


if __name__ == '__main__':
    params = parse_args('test')

    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    datamgr = SetDataManager(n_eposide=iter_num, n_query=15, **few_shot_params)
    novel_loader = datamgr.get_data_loader(root='./filelists/tabula_muris',
                                           mode='test')

    x_dim = novel_loader.dataset.get_dim()
    go_mask = novel_loader.dataset.go_mask

    if params.method == 'baseline':
        model = BaselineFinetune(backbone.FCNet(x_dim), **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(backbone.FCNet(x_dim),
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(backbone.FCNet(x_dim), **few_shot_params)
    elif params.method == 'comet':
Ejemplo n.º 11
0
def explain_relationnet():
    # print(sys.path)
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'relationnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'relationnet'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    acc_all = []
    iter_num = 1000

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRPmse'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # ---test the accuracy on the test set to verify the model is loaded----
    acc = 0
    count = 0
    # for i, (x, y) in enumerate(data_loader):
    #   scores = model.set_forward(x)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    preset = lrp_presets.SequentialPresetA()

    feature_model = copy.deepcopy(model.feature)
    lrp_wrapper.add_lrp(feature_model, preset=preset)
    relation_model = copy.deepcopy(model.relation_module)
    # print(relation_model)
    lrp_wrapper.add_lrp(relation_model, preset=preset)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for i, (x, y, p) in enumerate(data_loader):
        '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size
       y is the global labels of the images with shape as (n_way, n_support + n_query)
       p is the image path as a list of tuples, length is n_query+n_support,  each tuple element is with length n_way'''
        if i >= 3:
            break
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        z_support, z_query = model.parse_feature(x, is_feature=False)
        z_support = z_support.contiguous()
        z_proto = z_support.view(model.n_way, model.n_support,
                                 *model.feat_dim).mean(1)
        # print(z_proto.shape)
        z_query = z_query.contiguous().view(model.n_way * model.n_query,
                                            *model.feat_dim)
        # print(z_query.shape)
        # get relations with metric function
        z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way,
                                                  1, 1, 1, 1)
        # print(z_proto_ext.shape)
        z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1)

        z_query_ext = torch.transpose(z_query_ext, 0, 1)
        # print(z_query_ext.shape)
        extend_final_feat_dim = model.feat_dim.copy()
        extend_final_feat_dim[0] *= 2
        relation_pairs = torch.cat((z_proto_ext, z_query_ext),
                                   2).view(-1, *extend_final_feat_dim)
        # print(relation_pairs.shape)
        relations = relation_model(relation_pairs)
        # print(relations)
        scores = relations.view(-1, model.n_way)
        preds = scores.data.cpu().numpy().argmax(axis=1)
        # print(preds.shape)
        relations = relations.view(-1, model.n_way)
        # print(relations)
        relations_sf = torch.softmax(relations, dim=-1)
        # print(relations_sf)
        relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf /
                                     (1 - relations_sf))
        # print(relations_logits)
        # print(preds)
        relations_logits = relations_logits.view(-1, 1)
        relevance_relations = relation_model.compute_lrp(
            relation_pairs, target=relations_logits)
        # print(relevance_relations.shape)
        # print(model.feat_dim)
        relevance_z_query = torch.narrow(relevance_relations, 1,
                                         model.feat_dim[0], model.feat_dim[0])
        # print(relevance_z_query.shape)
        relevance_z_query = relevance_z_query.view(
            model.n_query * model.n_way, model.n_way,
            *relevance_z_query.size()[1:])
        # print(relevance_z_query.shape)
        query_img = x.narrow(1, model.n_support,
                             model.n_query).view(model.n_way * model.n_query,
                                                 *x.size()[2:])
        # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:])
        # print(query_img.shape)
        for k in range(model.n_way):
            relevance_querry_cls = torch.narrow(relevance_z_query, 1, k,
                                                1).squeeze(1)
            # print(relevance_querry_cls.shape)
            relevance_querry_img = feature_model.compute_lrp(
                query_img.cuda(), target=relevance_querry_cls)
            # print(relevance_querry_img.max(), relevance_querry_img.min())
            # print(relevance_querry_img.shape)
            for j in range(model.n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(i),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_img[j].permute(1, 2,
                                                         0).cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_querry_img.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))
Ejemplo n.º 12
0
    split = 'novel'
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split

    iter_num = 600
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    acc_all = []
    model = load_weight_file_for_test(model, params)
    print_model_params(model, params)

    if params.method in ['maml', 'maml_approx']:
        datamgr = SetDataManager(params.image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params,
                                 isAircraft=(params.dataset == 'aircrafts'))
        loadfile = os.path.join('filelists', params.test_dataset, 'novel.json')
        novel_loader = datamgr.get_data_loader(loadfile, aug=False)

        if params.adaptation:
            model.task_update_num = 100  #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)
        params = parse_args('mytest')
        loadfile = os.path.join('filelists', params.test_dataset, 'novel.json')

    else:
        if "recognition36" in params.test_dataset:
            loadfile = os.path.join('filelists', params.test_dataset,
Ejemplo n.º 13
0
        if params.n_shot == 1:
            params.stop_epoch = 600
        elif params.n_shot == 5:
            params.stop_epoch = 400
        else:
            params.stop_epoch = 600

    if params.method in ['tcmaml', 'tcmaml_approx']:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        train_datamgr = SetDataManager(
            image_size, n_query=n_query, **train_few_shot_params
        )  # default number of episodes (tasks) is 100 per epoch
        train_loader = train_datamgr.get_data_loader(train_file,
                                                     aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.ResNet.maml = True
Ejemplo n.º 14
0
    #    elif params.method == 'baseline++':
    #        model           = BaselineTrain( model_dict[params.model], params.num_classes, \
    #                                        loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)

    if params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(params.n_query * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot, \
                                        jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params,
                                      isAircraft=isAircraft)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        base_loader1 = copy.deepcopy(base_loader)
        images = torch.empty(0, 3, 224, 224)  ### [total_images, 3, 224, 224]
        for i, inputs in enumerate(base_loader1):
            print(i)
            x = inputs[0]  ### [5,21,3,224,224]
            x = x.view(105, *x.size()[2:])  ### [105,3,224,224]
            # print(x.size())
            images = torch.cat([images, x], dim=0)

        print(len(images))
        dataset = JigsawDataset(images)
Ejemplo n.º 15
0
    print('  train with single seen domain {}'.format(params.dataset))
    base_file  = os.path.join(params.data_dir, params.dataset, 'base.json')
    val_file   = os.path.join(params.data_dir, params.dataset, 'val.json')

  # model
  print('\n--- build model ---')
  if 'Conv' in params.model:
    image_size = 84
  else:
    image_size = 224

  if params.method in ['maml_baseline'] :
    print('  training the {} with backbone {}'.format(params.method, params.model))
    n_query = 15
    train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot)
    base_datamgr          = SetDataManager(image_size, n_query=n_query, n_eposide=100,  **train_few_shot_params)
    base_loader           = base_datamgr.get_data_loader(base_file, aug=params.train_aug)

    test_few_shot_params  = dict(n_way=params.test_n_way, n_support=params.n_shot)
    val_datamgr           = SetDataManager(image_size, n_query=n_query, n_eposide=100, **test_few_shot_params)
    val_loader            = val_datamgr.get_data_loader(val_file, aug=False)
  
    # base_datamgr    = SimpleDataManager(image_size, batch_size=16)
    # base_loader     = base_datamgr.get_data_loader(base_file , aug=params.train_aug )
    # val_datamgr     = SimpleDataManager(image_size, batch_size=64)
    # val_loader      = val_datamgr.get_data_loader(val_file, aug=False)
  else:
    raise ValueError('Unknown method')
  model = MAMLBaseline(params, tf_path=params.tf_dir)
  model.cuda()
Ejemplo n.º 16
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)
Ejemplo n.º 17
0
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        base_datamgr = SetDataManager(image_size, params.batchsize)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug,
                                                   ifshuffle=True)

        #val_datamgr             = SetDataManager(image_size, params.batchsize)
        #val_loader              = val_datamgr.get_data_loader( val_file, aug = False,ifshuffle = False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
def get_logits_targets(params):
    acc_all = []
    iter_num = 600
    few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) 

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model           = BaselineFinetune( model_dict[params.model], **few_shot_params )
    elif params.method == 'baseline++':
        model           = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params )
    elif params.method == 'protonet':
        model           = ProtoNet( model_dict[params.model], **few_shot_params )
    elif params.method == 'DKT':
        model           = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model           = MatchingNet( model_dict[params.model], **few_shot_params )
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4': 
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6': 
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S': 
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model]( flatten = False )
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model           = RelationNet( feature_model, loss_type = loss_type , **few_shot_params )
    elif params.method in ['maml' , 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(  model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params )
        if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
            model.n_task     = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++'] :
        checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    #modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++'] : 
        if params.save_iter != -1:
            modelfile   = get_assigned_file(checkpoint_dir,params.save_iter)
        else:
            modelfile   = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" +str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84 
        else:
            image_size = 224

        datamgr         = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params)
        
        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json' 
            else:
                loadfile   = configs.data_dir['CUB'] + split +'.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json' 
            else:
                loadfile  = configs.data_dir['emnist'] + split +'.json' 
        else: 
            loadfile    = configs.data_dir[params.dataset] + split + '.json'

        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()

        logits_list = list()
        targets_list = list()    
        for i, (x,_) in enumerate(novel_loader):
            logits = model.get_logits(x).detach()
            targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda()
            logits_list.append(logits) #.cpu().detach().numpy())
            targets_list.append(targets) #.cpu().detach().numpy())
    else:
        novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5")
        cl_data_file = feat_loader.init_loader(novel_file)
        logits_list = list()
        targets_list = list()
        n_query = 15
        n_way = few_shot_params['n_way']
        n_support = few_shot_params['n_support']
        class_list = cl_data_file.keys()
        for i in range(iter_num):
            #----------------------
            select_class = random.sample(class_list,n_way)
            z_all  = []
            for cl in select_class:
                img_feat = cl_data_file[cl]
                perm_ids = np.random.permutation(len(img_feat)).tolist()
                z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] )     # stack each batch
            z_all = torch.from_numpy(np.array(z_all))
            model.n_query = n_query
            logits  = model.set_forward(z_all, is_feature = True).detach()
            targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda()
            logits_list.append(logits)
            targets_list.append(targets)
            #----------------------
    return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
Ejemplo n.º 19
0
                vocab_srt = [v[0] for v in vocab_srt]
                with open(args.embeddings_file, "w") as fout:
                    fout.write("\n".join(vocab_srt))
                    fout.write("\n")
                np.savetxt(args.embeddings_metadata, weights, fmt="%f", delimiter="\t")
                sys.exit(0)

    # Run the test loop for 600 iterations
    ITER_NUM = 600
    N_QUERY = 15

    test_datamgr = SetDataManager(
        "CUB",
        84,
        n_query=N_QUERY,
        n_way=args.test_n_way,
        n_support=args.n_shot,
        n_episode=ITER_NUM,
        args=args,
    )
    test_loader = test_datamgr.get_data_loader(
        os.path.join(constants.DATA_DIR, f"{args.split}.json"),
        aug=False,
        lang_dir=constants.LANG_DIR,
        normalize=False,
        vocab=vocab,
    )
    normalizer = TransformLoader(84).get_normalize()

    model.eval()
Ejemplo n.º 20
0
    def meta_train(self,
                   config,
                   method,
                   descriptor_str,
                   debug=True,
                   use_test=False,
                   require_pretrain=False,
                   metric="acc"):
        config["meta_training"] = True
        params = self.params
        params.save_freq = 10
        params.n_query = max(1,
                             int(16 * params.test_n_way / params.train_n_way))
        params.dataset = config["dataset"]
        params.model = config["model"]
        params.method = config["method"]
        params.n_shot = config["n_shot"]
        train_episodes = config["train_episodes"]
        val_episodes = config["val_episodes"]
        end_epoch = config["end_epoch"]
        if "weight_decay" in config:
            weight_decay = config["weight_decay"]
        else:
            weight_decay = 0

        result_dir = "results/meta/%s" % (params.dataset)
        if not os.path.isdir(result_dir):
            os.makedirs(result_dir)
        result_file = os.path.join(
            result_dir,
            "%s_%s_%s.txt" % (params.method, params.model, descriptor_str))

        self.few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot,
                                    n_query=self.params.n_query)
        params.checkpoint_dir = '%s/checkpoints/%s/%s/%s_%s_%s' % (
            configs.save_dir, params.dataset, params.method, params.model,
            descriptor_str, params.n_shot)
        params.stop_epoch = 100
        self.initialize(params, False)
        image_size = self.image_size
        pretrain = PretrainedModel(self.params)

        if use_test:
            file_name = "novel.json"
        else:
            file_name = "val.json"
        if params.dataset == 'cross':
            base_file = configs.data_dir['miniImagenet'] + 'base.json'
            val_file = configs.data_dir['CUB'] + file_name
        else:
            base_file = configs.data_dir[params.dataset] + 'base.json'
            val_file = configs.data_dir[params.dataset] + file_name

        n_query = max(1, int(16 * params.test_n_way / params.train_n_way))

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params,
                                      n_eposide=train_episodes)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug,
                                                   debug=debug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params,
                                     n_eposide=val_episodes)
        val_loader = val_datamgr.get_data_loader(val_file,
                                                 aug=False,
                                                 debug=debug)

        backbone = self.get_backbone()
        if "params" in config:
            model_params = config["params"]
        else:
            model_params = {}
        if require_pretrain:
            model_params["pretrain"] = pretrain

        model = method(backbone, **train_few_shot_params, **model_params)
        model = model.cuda()

        # Freeze backbone
        model.feature = None
        if not require_pretrain:
            model.pretrain = pretrain
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=weight_decay)

        max_acc = 0
        for epoch in range(0, end_epoch):
            model.epoch = epoch
            model.train()
            model.train_loop(
                epoch, base_loader,
                optimizer)  # model are called by reference, no need to return
            model.eval()

            if not os.path.isdir(params.checkpoint_dir):
                os.makedirs(params.checkpoint_dir)

            acc = model.test_loop(val_loader, metric=metric)
            message = "Epoch: %d, Validation accuracy: %.3f, Best validation accuracy: %.3f" % (
                epoch, acc, max_acc)
            print(message)
            append_to_file(result_file, message)

            if acc > max_acc:
                print("best model! save...")
                max_acc = acc
                outfile = os.path.join(params.checkpoint_dir, 'best_model.tar')
                torch.save({
                    'epoch': epoch,
                    'state': model.state_dict()
                }, outfile)

            if (epoch % params.save_freq == 0) or (epoch
                                                   == params.stop_epoch - 1):
                outfile = os.path.join(params.checkpoint_dir,
                                       '{:d}.tar'.format(epoch))
                torch.save({
                    'epoch': epoch,
                    'state': model.state_dict()
                }, outfile)
        self.meta_test(config, method, descriptor_str, debug, require_pretrain)
Ejemplo n.º 21
0
    #code from mvcnn, add logging later
    #parse_args

    # num_models = 1000 #max number of models to use per class, add this functionality later
    # n_models_train = num_models*num_views

    #     if params.num_views and params.num_views >=5:
    #         n_query = max(1, int(8* params.test_n_way/params.train_n_way)) #why is this required?
    #     else:
    #         n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #why is this required?

    train_few_shot_params = dict(n_way=params.train_n_way,
                                 n_support=params.n_shot)
    base_datamgr = SetDataManager(image_size,
                                  n_query=params.n_query,
                                  **train_few_shot_params,
                                  num_views=params.num_views)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=params.n_query,
                                 **test_few_shot_params,
                                 num_views=params.num_views)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    backbone = model_dict[params.model]
    model = ProtoNet(backbone, params.num_views, **train_few_shot_params)
    model = model.cuda()
    # model = torch.nn.DataParallel(model).cuda()
Ejemplo n.º 22
0
    def meta_test(self,
                  config,
                  method,
                  descriptor_str,
                  debug=True,
                  require_pretrain=False,
                  metric="acc"):
        config["meta_training"] = True
        params = self.params
        params.save_freq = 50
        params.n_query = max(1,
                             int(16 * params.test_n_way / params.train_n_way))
        params.dataset = config["dataset"]
        params.model = config["model"]
        params.method = config["method"]
        params.n_shot = config["n_shot"]

        self.few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot,
                                    n_query=self.params.n_query)
        params.checkpoint_dir = '%s/checkpoints/%s/%s/%s_%s_%s' % (
            configs.save_dir, params.dataset, params.method, params.model,
            descriptor_str, params.n_shot)
        params.stop_epoch = 2000
        self.initialize(params, False)
        image_size = self.image_size
        pretrain = PretrainedModel(self.params)

        if params.dataset == 'cross':
            test_file = configs.data_dir['CUB'] + 'novel.json'
        else:
            test_file = configs.data_dir[params.dataset] + 'novel.json'

        n_query = 15

        few_shot_params = dict(n_way=params.test_n_way,
                               n_support=params.n_shot)
        datamgr = SetDataManager(image_size,
                                 n_query=n_query,
                                 **few_shot_params,
                                 n_eposide=params.stop_epoch)
        loader = datamgr.get_data_loader(test_file, aug=False, debug=debug)

        if "params" in config:
            model_params = config["params"]
        else:
            model_params = {}
        if require_pretrain:
            model_params["pretrain"] = pretrain

        backbone = self.get_backbone()
        model = method(backbone, **few_shot_params, **model_params)
        model = model.cuda()
        model_file = os.path.join(params.checkpoint_dir, "best_model.tar")
        # Load model
        state_dict = model.state_dict()
        saved_states = torch.load(model_file)["state"]
        state_dict.update(saved_states)
        model.load_state_dict(state_dict)
        # Freeze backbone
        model.feature = None
        model.pretrain = pretrain

        model.eval()
        acc = model.test_loop(loader, metric=metric)
        print(acc)
Ejemplo n.º 23
0
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        if params.n_shot_test == -1:  # modify val loader support
            params.n_shot_test = params.n_shot
        else:  # modify target loader support
            train_few_shot_params['n_support'] = params.n_shot_test
        if params.adversarial or params.adaptFinetune:
            target_datamgr = SetDataManager(image_size,
                                            n_query=n_query,
                                            **train_few_shot_params)
            target_loader = target_datamgr.get_data_loader(novel_file,
                                                           aug=False)
            # ipdb.set_trace()
            # bl, tl = iter(base_loader), iter(target_loader)
Ejemplo n.º 24
0
    else:
        raise print('train_aug is wrong')
    print('Testing! {} shots on {} dataset with {} epochs of {}({})'.format(
        params.n_shot, params.testset, params.save_epoch, name, params.method))
    # dataset
    print('  build dataset')
    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=params.n_query,
                                 **test_few_shot_params)
    val_loader = val_datamgr.get_data_loader(loadfile, aug=False)

    datasets = params.dataset
    datasets.remove(params.testset)

    base_loaders = [
        val_datamgr.get_data_loader(os.path.join(params.data_dir, dataset,
                                                 'base.json'),
                                    aug=False) for dataset in datasets
    ]

    print('  build feature encoder')
    # feature encoder
    checkpoint_dir = '%s/checkmodels/%s' % (params.save_dir, name)
Ejemplo n.º 25
0
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx'
                         ]:  #maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'
Ejemplo n.º 26
0
    image_size = 224
    base_datamgr = SimpleDataManager(image_size, batch_size=batch_size)
    base_loader = base_datamgr.get_data_loader(base_file, aug=True)

    base_jigsaw_datamgr = JigsawDataManger(
        image_size,
        batch_size=batch_size,
        max_replace_block_num=params.jig_replace_num_train)
    base_jigsaw_loader = base_jigsaw_datamgr.get_data_loader(base_file,
                                                             aug=False)

    extra_data = 15  # extra_unlabeled data
    val_datamgr = SetDataManager(image_size,
                                 n_way=params.test_n_way,
                                 n_support=params.n_shot,
                                 n_query=params.n_query + extra_data,
                                 n_eposide=50)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    if params.dataset == "miniImagenet":
        num_class = 64
    elif params.dataset == "tieredImagenet":
        num_class = 351
    elif params.dataset == "caltech256":
        num_class = 257
    elif params.dataset == "CUB":
        num_class = 200  # set to 200 since the label range 0~199 even though there are only 100 classes to be trained
    else:
        raise ValueError('Unknown dataset')
Ejemplo n.º 27
0
def single_test(params, results_logger):
    acc_all = []

    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'

    if params.method == 'baseline':
        model = BaselineFinetune(model_dict[params.model], **few_shot_params)
    elif params.method == 'baseline++':
        model = BaselineFinetune(model_dict[params.model],
                                 loss_type='dist',
                                 **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'DKT':
        model = DKT(model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        elif params.model == 'Conv4S':
            feature_model = backbone.Conv4SNP
        else:
            feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    elif params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(model_dict[params.model],
                     approx=(params.method == 'maml_approx'),
                     **few_shot_params)
        if params.dataset in ['omniglot', 'cross_char'
                              ]:  # maml use different parameter in omniglot
            model.n_task = 32
            model.task_update_num = 1
            model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')

    model = model.cuda()

    checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot)

    # modelfile   = get_resume_file(checkpoint_dir)

    if not params.method in ['baseline', 'baseline++']:
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        if modelfile is not None:
            tmp = torch.load(modelfile)
            model.load_state_dict(tmp['state'])
        else:
            print("[WARNING] Cannot find 'best_file.tar' in: " +
                  str(checkpoint_dir))

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split
    if params.method in ['maml', 'maml_approx',
                         'DKT']:  # maml do not support testing with feature
        if 'Conv' in params.model:
            if params.dataset in ['omniglot', 'cross_char']:
                image_size = 28
            else:
                image_size = 84
        else:
            image_size = 224

        datamgr = SetDataManager(image_size,
                                 n_eposide=iter_num,
                                 n_query=15,
                                 **few_shot_params)

        if params.dataset == 'cross':
            if split == 'base':
                loadfile = configs.data_dir['miniImagenet'] + 'all.json'
            else:
                loadfile = configs.data_dir['CUB'] + split + '.json'
        elif params.dataset == 'cross_char':
            if split == 'base':
                loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
            else:
                loadfile = configs.data_dir['emnist'] + split + '.json'
        else:
            loadfile = configs.data_dir[params.dataset] + split + '.json'

        novel_loader = datamgr.get_data_loader(loadfile, aug=False)
        if params.adaptation:
            model.task_update_num = 100  # We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop(novel_loader, return_std=True)

    else:
        novel_file = os.path.join(
            checkpoint_dir.replace("checkpoints",
                                   "features"), split_str + ".hdf5"
        )  # defaut split = novel, but you can also test base or val classes
        cl_data_file = feat_loader.init_loader(novel_file)

        for i in range(iter_num):
            acc = feature_evaluation(cl_data_file,
                                     model,
                                     n_query=15,
                                     adaptation=params.adaptation,
                                     **few_shot_params)
            acc_all.append(acc)

        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %
              (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
    with open('record/results.txt', 'a') as f:
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        aug_str = '-aug' if params.train_aug else ''
        aug_str += '-adapted' if params.adaptation else ''
        if params.method in ['baseline', 'baseline++']:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.test_n_way)
        else:
            exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % (
                params.dataset, split_str, params.model, params.method,
                aug_str, params.n_shot, params.train_n_way, params.test_n_way)
        acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % (
            iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))
        f.write('Time: %s, Setting: %s, Acc: %s \n' %
                (timestamp, exp_setting, acc_str))
        results_logger.log("single_test_acc", acc_mean)
        results_logger.log("single_test_acc_std",
                           1.96 * acc_std / np.sqrt(iter_num))
        results_logger.log("time", timestamp)
        results_logger.log("exp_setting", exp_setting)
        results_logger.log("acc_str", acc_str)
    return acc_mean
Ejemplo n.º 28
0
def get_train_val_loader(params, source_val):
    # to prevent circular import
    from data.datamgr import SimpleDataManager, SetDataManager, AugSetDataManager, VAESetDataManager

    image_size = get_img_size(params)
    base_file, val_file = get_train_val_filename(params)
    if source_val:
        source_val_file = get_source_val_filename(params)

    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        #         val_datamgr     = SimpleDataManager(image_size, batch_size = 64)
        #         val_loader      = val_datamgr.get_data_loader( val_file, aug = False)

        # to do fine-tune when validation
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
        val_few_shot_params = get_few_shot_params(params, 'val')
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **val_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if source_val:
            source_val_datamgr = SetDataManager(image_size,
                                                n_query=n_query,
                                                **val_few_shot_params)
            source_val_loader = val_datamgr.get_data_loader(source_val_file,
                                                            aug=False)

    elif params.method in [
            'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax',
            'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        #         train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot)
        #         val_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot)
        train_few_shot_params = get_few_shot_params(params, 'train')
        val_few_shot_params = get_few_shot_params(params, 'val')
        if params.vaegan_exp is not None:
            # TODO
            is_training = False
            vaegan = restore_vaegan(params.dataset,
                                    params.vaegan_exp,
                                    params.vaegan_step,
                                    is_training=is_training)

            base_datamgr = VAESetDataManager(
                image_size,
                n_query=n_query,
                vaegan_exp=params.vaegan_exp,
                vaegan_step=params.vaegan_step,
                vaegan_is_train=params.vaegan_is_train,
                lambda_zlogvar=params.zvar_lambda,
                fake_prob=params.fake_prob,
                **train_few_shot_params)
            # train_val or val???
            val_datamgr = SetDataManager(image_size,
                                         n_query=n_query,
                                         **val_few_shot_params)

        elif params.aug_target is None:  # Common Case
            assert params.aug_type is None

            base_datamgr = SetDataManager(image_size,
                                          n_query=n_query,
                                          **train_few_shot_params)
            val_datamgr = SetDataManager(image_size,
                                         n_query=n_query,
                                         **val_few_shot_params)
            if source_val:
                source_val_datamgr = SetDataManager(image_size,
                                                    n_query=n_query,
                                                    **val_few_shot_params)
        else:
            aug_type = params.aug_type
            assert aug_type is not None
            base_datamgr = AugSetDataManager(image_size,
                                             n_query=n_query,
                                             aug_type=aug_type,
                                             aug_target=params.aug_target,
                                             **train_few_shot_params)
            val_datamgr = AugSetDataManager(image_size,
                                            n_query=n_query,
                                            aug_type=aug_type,
                                            aug_target='test-sample',
                                            **val_few_shot_params)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        if source_val:
            source_val_loader = val_datamgr.get_data_loader(source_val_file,
                                                            aug=False)
        #a batch for SetDataManager: a [n_way, n_support + n_query, n_channel, w, h] tensor

    else:
        raise ValueError('Unknown method')

    if source_val:
        return base_loader, val_loader, source_val_loader
    else:
        return base_loader, val_loader
Ejemplo n.º 29
0
    if params.method in ['baseline', 'baseline++'] :
        base_datamgr    = SimpleDataManager(image_size, batch_size = 16)
        base_loader     = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
        val_datamgr     = SimpleDataManager(image_size, batch_size = 64)
        val_loader      = val_datamgr.get_data_loader( val_file, aug = False) 

        if params.method == 'baseline':
            model           = BaselineTrain( model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model           = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist')

    elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']:
        n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
 
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot) 
        base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
        base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
         
        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot) 
        val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
        val_loader              = val_datamgr.get_data_loader( val_file, aug = False) 
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor        

        if params.method == 'protonet':
            model           = ProtoNet( model_dict[params.model], **train_few_shot_params )
        elif params.method == 'matchingnet':
            model           = MatchingNet( model_dict[params.model], **train_few_shot_params )
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4': 
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6': 
Ejemplo n.º 30
0
def exp_test(params, n_episodes, should_del_features=False):#, show_data=False):
    start_time = datetime.datetime.now()
    print('exp_test() started at',start_time)
    
    set_random_seed(0) # successfully reproduce "normal" testing. 
    
    if params.gpu_id:
        set_gpu_id(params.gpu_id)
    
#     acc_all = []

    model = get_model(params, 'test')
    
    ########## get settings ##########
    n_shot = params.test_n_shot if params.test_n_shot is not None else params.n_shot
    few_shot_params = dict(n_way = params.test_n_way , n_support = n_shot)
    if params.gpu_id:
        model = model.cuda()
    else:
        model = to_device(model)
    checkpoint_dir = get_checkpoint_dir(params)
    print('loading from:',checkpoint_dir)
    if params.save_iter != -1:
        modelfile   = get_assigned_file(checkpoint_dir, params.save_iter)
    else:
        modelfile   = get_best_file(checkpoint_dir)
    
    ########## load model ##########
    if modelfile is not None:
        if params.gpu_id is None:
            tmp = torch.load(modelfile)
        else: # TODO: figure out WTF is going on here
            print('params.gpu_id =', params.gpu_id)
            map_location = 'cuda:0'
#             gpu_str = 'cuda:' + '0'#str(params.gpu_id)
#             map_location = {'cuda:1':gpu_str, 'cuda:0':gpu_str} # see here: https://hackmd.io/koKAo6kURn2YBqjoXXDhaw#RuntimeError-CUDA-error-invalid-device-ordinal
            tmp = torch.load(modelfile, map_location=map_location)
#                 tmp = torch.load(modelfile)
        if not params.method in ['baseline', 'baseline++'] : 
            # if 'baseline' or 'baseline++' then NO NEED to load model !!!
            model.load_state_dict(tmp['state'])
            print('Model successfully loaded.')
        else:
            print('No need to load model for baseline/baseline++ when testing.')
        load_epoch = int(tmp['epoch'])
    
    ########## testing ##########
    if params.method in ['maml', 'maml_approx']: #maml do not support testing with feature
        image_size = get_img_size(params)
        load_file = get_loadfile_path(params, params.split)

        datamgr         = SetDataManager(image_size, n_episode = n_episodes, n_query = 15 , **few_shot_params)
        
        novel_loader     = datamgr.get_data_loader( loadfile, aug = False)
        if params.adaptation:
            model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times.
        model.eval()
        acc_mean, acc_std = model.test_loop( novel_loader, return_std = True)
        
        ########## last record and post-process ##########
        torch.cuda.empty_cache()
        timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
        # TODO afterward: compute this
        acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
        # writing settings into csv
        acc_mean_str = '%4.2f' % (acc_mean)
        acc_std_str = '%4.2f' %(acc_std)
        # record beyond params
        extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
        if should_del_features:
            del_features(params)
        end_time = datetime.datetime.now()
        print('exp_test() start at', start_time, ', end at', end_time, '.\n')
        print('exp_test() totally took:', end_time-start_time)
        return extra_record, task_datas

    else: # not MAML
        acc_all = []
#         # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
#         task_datas = [None]*n_episodes # list of dict
        # directly use extracted features
        all_feature_files = get_all_feature_files(params)
        
        if params.n_test_candidates is None: # common setting (no candidate)
            # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
            task_datas = [None]*n_episodes # list of dict
            
            feature_file = all_feature_files[0]
            cl_feature, cl_filepath = feat_loader.init_loader(feature_file, return_path=True)
            cl_feature_single = [cl_feature]
            
            for i in tqdm(range(n_episodes)):
                # TODO afterward: fix data list? can only fix class list?
                task_data = feature_evaluation(
                    cl_feature_single, model, params=params, n_query=15, **few_shot_params, 
                    cl_filepath=cl_filepath,
                )
                acc = task_data['acc']
                acc_all.append(acc)
                task_datas[i] = task_data
            
            acc_all  = np.asarray(acc_all)
            acc_mean = np.mean(acc_all)
            acc_std  = np.std(acc_all)
            print('loaded from %d epoch model.' %(load_epoch))
            print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes)))
            
            ########## last record and post-process ##########
            torch.cuda.empty_cache()
            timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
            # TODO afterward: compute this
            acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
            # writing settings into csv
            acc_mean_str = '%4.2f' % (acc_mean)
            acc_std_str = '%4.2f' %(acc_std)
            # record beyond params
            extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
            if should_del_features:
                del_features(params)
            end_time = datetime.datetime.now()
            print('exp_test() start at', start_time, ', end at', end_time, '.\n')
            print('exp_test() totally took:', end_time-start_time)
            return extra_record, task_datas
        else: # n_test_candidates settings
                
            candidate_cl_feature = [] # features of each class of each candidates
            print('Loading features of %s candidates into dictionaries...' %(params.n_test_candidates))
            for n in tqdm(range(params.n_test_candidates)):
                nth_feature_file = all_feature_files[n]
                cl_feature, cl_filepath = feat_loader.init_loader(nth_feature_file, return_path=True)
                candidate_cl_feature.append(cl_feature)

            print('Evaluating...')
            
            # TODO: frac_acc_all
            is_single_exp = not isinstance(params.frac_ensemble, list)
            if is_single_exp:
                # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
                task_datas = [None]*n_episodes # list of dict
                ########## test and record acc ##########
                for i in tqdm(range(n_episodes)):
                    # TODO afterward: fix data list? can only fix class list?

                    task_data = feature_evaluation(
                        candidate_cl_feature, model, params=params, n_query=15, **few_shot_params, 
                        cl_filepath=cl_filepath,
                    )
                    acc = task_data['acc']
                    acc_all.append(acc)
                    task_datas[i] = task_data
                    
                    collected = gc.collect()
#                     print("Garbage collector: collected %d objects." % (collected))

                acc_all  = np.asarray(acc_all)
                acc_mean = np.mean(acc_all)
                acc_std  = np.std(acc_all)
                print('loaded from %d epoch model.' %(load_epoch))
                print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes)))
                collected = gc.collect()
                print("garbage collector: collected %d objects." % (collected))
                
                ########## last record and post-process ##########
                torch.cuda.empty_cache()
                timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
                # TODO afterward: compute this
                acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
                # writing settings into csv
                acc_mean_str = '%4.2f' % (acc_mean)
                acc_std_str = '%4.2f' %(acc_std)
                # record beyond params
                extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
                if should_del_features:
                    del_features(params)
                end_time = datetime.datetime.now()
                print('exp_test() start at', start_time, ', end at', end_time, '.\n')
                print('exp_test() totally took:', end_time-start_time)
                return extra_record, task_datas
            else: ########## multi-frac_ensemble exps ##########
                
                ########## (haven't modified) test and record acc ##########
                n_fracs = len(params.frac_ensemble)
                
                ##### initialize frac_data #####
                frac_acc_alls = [[0]*n_episodes for _ in range(n_fracs)]
                frac_acc_means = [None]*n_fracs
                frac_acc_stds = [None]*n_fracs
                # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc.
                ep_task_data_each_frac = [[None]*n_episodes for _ in range(n_fracs)] # list of list of dict

                for ep_id in tqdm(range(n_episodes)):
                    # TODO afterward: fix data list? can only fix class list?
                    
                    # TODO my_utils.py: feature_eval return frac_task_data
                    frac_task_data = feature_evaluation(
                        candidate_cl_feature, model, params=params, n_query=15, **few_shot_params, 
                        cl_filepath=cl_filepath,
                    )
                    for frac_id in range(n_fracs):
                        task_data = frac_task_data[frac_id]
                        # TODO: i think here's the problem???
                        acc = task_data['acc']
                        frac_acc_alls[frac_id][ep_id] = acc
                        ep_task_data_each_frac[frac_id][ep_id] = task_data
                        
                        collected = gc.collect()
#                         print("Garbage collector: collected %d objects." % (collected))
                    collected = gc.collect()
#                     print("Garbage collector: collected %d objects." % (collected))
                ### debug
#                 print('frac_acc_alls:', frac_acc_alls)
#                 yee
                for frac_id in range(n_fracs):
                    frac_acc_alls[frac_id]  = np.asarray(frac_acc_alls[frac_id])
                    acc_all = frac_acc_alls[frac_id]
                    acc_mean = np.mean(acc_all)
                    acc_std = np.std(acc_all)
                    frac_acc_means[frac_id] = acc_mean
                    frac_acc_stds[frac_id]  = acc_std
                    print('loaded from %d epoch model, frac_ensemble:'%(load_epoch), params.frac_ensemble[frac_id])
                    print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes)))
                
                ########## (haven't modified) last record and post-process ##########
                torch.cuda.empty_cache()
                timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
                # TODO afterward: compute this
#                 acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes))
                frac_extra_records = []
                for frac_id in range(n_fracs):
                    # writing settings into csv
                    acc_mean = frac_acc_means[frac_id]
                    acc_std = frac_acc_stds[frac_id]
                    acc_mean_str = '%4.2f' % (acc_mean)
                    acc_std_str = '%4.2f' %(acc_std)
                    # record beyond params
                    extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch}
                    frac_extra_records.append(extra_record)
                
                if should_del_features:
                    del_features(params)
                end_time = datetime.datetime.now()
                print('exp_test() start at', start_time, ', end at', end_time, '.\n')
                print('exp_test() totally took:', end_time-start_time)
                
                return frac_extra_records, ep_task_data_each_frac