Exemplo n.º 1
0
def evaluate(novel_loader, n_way=5, n_support=5):
    iter_num = len(novel_loader)
    acc_all = []
    # Model
    if params.method == 'MatchingNet':
        model = MatchingNet(model_dict[params.model],
                            n_way=n_way,
                            n_support=n_support).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=n_way,
                            n_support=n_support).cuda()
    elif params.method == 'ProtoNet':
        model = ProtoNet(model_dict[params.model],
                         n_way=n_way,
                         n_support=n_support).cuda()
    elif params.method == 'GNN':
        model = GnnNet(model_dict[params.model],
                       n_way=n_way,
                       n_support=n_support).cuda()
    elif params.method == 'TPN':
        model = TPN(model_dict[params.model], n_way=n_way,
                    n_support=n_support).cuda()
    else:
        print("Please specify the method!")
        assert (False)
    # Update model
    checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir,
                                                           params.name)
    state = torch.load(checkpoint_dir)['state']
    if 'FWT' in params.name:
        model_params = model.state_dict()
        pretrained_dict = {k: v for k, v in state.items() if k in model_params}
        model_params.update(pretrained_dict)
        model.load_state_dict(model_params)
    else:
        model.load_state_dict(state)

    # For TPN model, we compute Batch Norm statistics from the test-time support set, not the exponential moving averages.
    if params.method != 'TPN':
        model.eval()
    for ti, (x, _) in enumerate(novel_loader):  # x:(5, 20, 3, 224, 224)
        x = x.cuda()
        n_query = x.size(1) - n_support
        model.n_query = n_query
        yq = np.repeat(range(n_way), n_query)
        with torch.no_grad():
            scores = model.set_forward(x)  # (80, 5)
            _, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()  # (80, 1)
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
Exemplo n.º 2
0
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNet':
        model = RelationNet(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
    elif params.method == 'RelationNetLRP':
        model = RelationNetLRP(model_dict[params.model],
                               n_way=params.train_n_way,
                               n_support=params.n_shot).cuda()
    elif params.method == 'ProtoNet':
        model = ProtoNet(model_dict[params.model],
                         n_way=params.train_n_way,
                         n_support=params.n_shot).cuda()
    elif params.method == 'GNN':
        model = GnnNet(model_dict[params.model],
                       n_way=params.train_n_way,
                       n_support=params.n_shot).cuda()
    elif params.method == 'GNNLRP':
        model = GnnNetLRP(model_dict[params.model],
                          n_way=params.train_n_way,
                          n_support=params.n_shot).cuda()
    elif params.method == 'TPN':
        model = TPN(model_dict[params.model],
                    n_way=params.train_n_way,
                    n_support=params.n_shot).cuda()
    else:
        print("Please specify the method!")
        assert (False)

    # load model
    start_epoch = params.start_epoch
Exemplo n.º 3
0
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.method == 'protonet':
            model = ProtoNet(model_dict[params.model],
                             tf_path=params.tf_dir,
                             **train_few_shot_params)
        elif params.method == 'gnnnet':
            model = GnnNet(model_dict[params.model],
                           tf_path=params.tf_dir,
                           **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                tf_path=params.tf_dir,
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            else:
                feature_model = model_dict[params.model]
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
            model = RelationNet(feature_model,
                                loss_type=loss_type,
Exemplo n.º 4
0
            feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method == 'maml':
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'metaoptnet':
            model = MetaOptNet(model_dict[params.model],
                               **train_few_shot_params)
        elif params.method == 'gnnnet':
            if params.n_shot != 50:
                model = GnnNet(model_dict[params.model],
                               **train_few_shot_params)
            else:
                model = gnnnet_copy.GnnNet(model_dict[params.model],
                                           **train_few_shot_params)

        elif params.method == 'gnnnet_maml':
            gnnnet.GnnNet.maml = True
            gnn.Gconv.maml = True
            gnn.Wcompute.maml = True
            model = gnnnet.GnnNet(model_dict[params.model],
                                  **train_few_shot_params)
            print(model.maml)
        elif params.method == 'gnnnet_neg_margin':
            model = gnnnet_neg_margin.GnnNet(model_dict[params.model],
                                             **train_few_shot_params)
        elif params.method == 'gnnnet_normalized':
Exemplo n.º 5
0
    np.random.seed(10)
    params = parse_args('train')

    ##################################################################
    image_size = 224
    iter_num = 600
    pretrained_dataset = "miniImageNet"
    ds = False

    n_query = max(
        1, int(16 * params.test_n_way / params.train_n_way)
    )  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)

    if params.method in ["gnnnet", "gnnnet_maml"]:
        model = GnnNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'protonet':
        model = ProtoNet(model_dict[params.model], **few_shot_params)
    elif params.method == 'relationnet':
        feature_model = lambda: model_dict[params.model](flatten=False)
        loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)

    elif params.method in ["dampnet_full_class"]:
        model = dampnet_full_class.DampNet(model_dict[params.model],
                                           **few_shot_params)
    elif params.method == "baseline":
        checkpoint_dir_b = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, "baseline")
Exemplo n.º 6
0
def explain_gnnnet():
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'gnnnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'gnn'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRP'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
            print('loaded pretrained model')
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # for module in model.modules():
    #   print(type(module))
    lrp_preset = lrp_presets.SequentialPresetA()
    feature_model = model.feature
    fc_encoder = model.fc
    gnn_net = model.gnn
    lrp_wrapper.add_lrp(fc_encoder, lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)
    # lrp_wrapper.add_lrp(fc_encoder,lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)

    # acc = 0
    # count = 0
    # tested the forward pass is correct by observing the accuracy
    # for i, (x, _, _) in enumerate(data_loader):
    #   x = x.cuda()
    #   support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1)
    #   support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way)
    #   support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1)
    #   support_label = support_label.view(1, -1, model.n_way)
    #   support_label = support_label.cuda()
    #   x = x.view(-1, *x.size()[2:])
    #
    #   x_feature = feature_model(x)
    #   x_fc_encoded = fc_encoder(x_feature)
    #   z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1))
    #   gnn_feature = [
    #     torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2))
    #     for i in range(model.n_query)]
    #   gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0)
    #   scores = gnn_net(gnn_nodes)
    #   scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0,
    #                                                                                                    2).contiguous().view(
    #     -1, model.n_way)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for batch_idx, (x, y, p) in enumerate(data_loader):
        print(p)
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        x = x.cuda()
        support_label = torch.from_numpy(
            np.repeat(range(model.n_way),
                      model.n_support)).unsqueeze(1)  #torch.Size([25, 1])
        support_label = torch.zeros(model.n_way * model.n_support,
                                    model.n_way).scatter(1, support_label,
                                                         1).view(
                                                             model.n_way,
                                                             model.n_support,
                                                             model.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(model.n_way, 1, model.n_way)], dim=1)
        support_label = support_label.view(1, -1, model.n_way)
        support_label = support_label.cuda()  #torch.Size([1, 30, 5])
        x = x.contiguous()
        x = x.view(-1, *x.size()[2:])  #torch.Size([30, 3, 224, 224])
        x_feature = feature_model(x)  #torch.Size([30, 512])
        x_fc_encoded = fc_encoder(x_feature)  #torch.Size([30, 128])
        z = x_fc_encoded.view(model.n_way, -1,
                              x_fc_encoded.size(1))  # (5,6,128)
        gnn_feature = [
            torch.cat([
                z[:, :model.n_support],
                z[:, model.n_support + i:model.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(model.n_query)
        ]  # model.n_query is the number of query images for each class
        # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features.
        # print(len(gnn_feature), gnn_feature[0].shape)
        gnn_nodes = torch.cat(
            [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0
        )  # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero

        #  perform gnn_net step by step
        #  the first iteration
        print('x', gnn_nodes.shape)
        W_init = torch.eye(
            gnn_nodes.size(1), device=gnn_nodes.device
        ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze(
            3
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1)
        # print(W_init.shape)

        W1 = gnn_net._modules['layer_w{}'.format(0)](
            gnn_nodes, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        # print(Wi.shape)
        x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)](
            [W1, gnn_nodes])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new1.shape)  #torch.Size([1, 30, 48])
        gnn_nodes_1 = torch.cat([gnn_nodes, x_new1],
                                2)  # (concat more features)
        # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181])

        #  the second iteration
        W2 = gnn_net._modules['layer_w{}'.format(1)](
            gnn_nodes_1, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)](
            [W2,
             gnn_nodes_1])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new2.shape)
        gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2],
                                2)  # (concat more features)
        # print('gn2', gnn_nodes_2.shape)  #torch.Size([1, 30, 229])

        Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init)
        # print(Wl.shape)  #torch.Size([1, 30, 30, 2])
        scores = gnn_net.layer_last(
            [Wl, gnn_nodes_2])[1]  # (num_querry, num_support + 1, num_way)
        print(scores.shape)

        scores_sf = torch.softmax(scores, dim=-1)
        # print(scores_sf)

        gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf /
                               (1 - scores_sf))
        gnn_logits = gnn_logits.view(-1, model.n_way,
                                     model.n_support + n_query, model.n_way)
        # print(gnn_logits)
        query_scores = scores.view(
            model.n_query, model.n_way, model.n_support + 1,
            model.n_way)[:, :,
                         -1].permute(1, 0,
                                     2).contiguous().view(-1, model.n_way)
        preds = query_scores.data.cpu().numpy().argmax(axis=-1)
        # print(preds.shape)
        for k in range(model.n_way):
            mask = torch.zeros(5).cuda()
            mask[k] = 1
            gnn_logits_cls = gnn_logits.clone()
            gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask
            # print(gnn_logits_cls)
            # print(gnn_logits_cls.shape)
            gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way)
            relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls,
                                                  gnn_net.layer_last, Wl,
                                                  gnn_nodes_2)
            relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48)
            # relevance_gnn_nodes = relevance_gnn_nodes_2
            relevance_gnn_nodes_1 = explain_Gconv(
                relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2,
                gnn_nodes_1)
            relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48)
            relevance_gnn_nodes = explain_Gconv(
                relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1,
                gnn_nodes)
            relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128)
            print(relevance_gnn_features.shape)
            relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128)
            relevance_gnn_features += relevance_gnn_nodes_2.narrow(
                -1, 0, 128)  #[2, 30, 128]
            relevance_gnn_features = relevance_gnn_features.view(
                n_query, model.n_way, model.n_support + 1, 128)
            for i in range(n_query):
                query_i = relevance_gnn_features[i][:, model.
                                                    n_support:model.n_support +
                                                    1]
                if i == 0:
                    relevance_z = query_i
                else:
                    relevance_z = torch.cat((relevance_z, query_i), 1)
            relevance_z = relevance_z.view(-1, 128)
            query_feature = x_feature.view(model.n_way, -1,
                                           512)[:, model.n_support:]
            # print(query_feature.shape)
            query_feature = query_feature.contiguous()
            query_feature = query_feature.view(n_query * model.n_way, 512)
            # print(query_feature.shape)
            relevance_query_features = fc_encoder.compute_lrp(
                query_feature, target=relevance_z)
            # print(relevance_query_features.shape)
            # print(relevance_gnn_features.shape)
            # explain the fc layer and the image encoder
            query_images = x.view(model.n_way, -1,
                                  *x.size()[1:])[:, model.n_support:]
            query_images = query_images.contiguous()

            query_images = query_images.view(-1, *x.size()[1:]).detach()
            # print(query_images.shape)
            lrp_wrapper.add_lrp(feature_model, lrp_preset)
            relevance_query_images = feature_model.compute_lrp(
                query_images, target=relevance_query_features)
            print(relevance_query_images.shape)

            for j in range(n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' +
                                     str(batch_idx), img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir,
                                     'episode' + str(batch_idx),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(batch_idx),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_images[j].permute(
                                1, 2, 0).detach().cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_query_images.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))

        break
Exemplo n.º 7
0
def explain_relationnet():
    # print(sys.path)
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'relationnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'relationnet'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    acc_all = []
    iter_num = 1000

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRPmse'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # ---test the accuracy on the test set to verify the model is loaded----
    acc = 0
    count = 0
    # for i, (x, y) in enumerate(data_loader):
    #   scores = model.set_forward(x)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    preset = lrp_presets.SequentialPresetA()

    feature_model = copy.deepcopy(model.feature)
    lrp_wrapper.add_lrp(feature_model, preset=preset)
    relation_model = copy.deepcopy(model.relation_module)
    # print(relation_model)
    lrp_wrapper.add_lrp(relation_model, preset=preset)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for i, (x, y, p) in enumerate(data_loader):
        '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size
       y is the global labels of the images with shape as (n_way, n_support + n_query)
       p is the image path as a list of tuples, length is n_query+n_support,  each tuple element is with length n_way'''
        if i >= 3:
            break
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        z_support, z_query = model.parse_feature(x, is_feature=False)
        z_support = z_support.contiguous()
        z_proto = z_support.view(model.n_way, model.n_support,
                                 *model.feat_dim).mean(1)
        # print(z_proto.shape)
        z_query = z_query.contiguous().view(model.n_way * model.n_query,
                                            *model.feat_dim)
        # print(z_query.shape)
        # get relations with metric function
        z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way,
                                                  1, 1, 1, 1)
        # print(z_proto_ext.shape)
        z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1)

        z_query_ext = torch.transpose(z_query_ext, 0, 1)
        # print(z_query_ext.shape)
        extend_final_feat_dim = model.feat_dim.copy()
        extend_final_feat_dim[0] *= 2
        relation_pairs = torch.cat((z_proto_ext, z_query_ext),
                                   2).view(-1, *extend_final_feat_dim)
        # print(relation_pairs.shape)
        relations = relation_model(relation_pairs)
        # print(relations)
        scores = relations.view(-1, model.n_way)
        preds = scores.data.cpu().numpy().argmax(axis=1)
        # print(preds.shape)
        relations = relations.view(-1, model.n_way)
        # print(relations)
        relations_sf = torch.softmax(relations, dim=-1)
        # print(relations_sf)
        relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf /
                                     (1 - relations_sf))
        # print(relations_logits)
        # print(preds)
        relations_logits = relations_logits.view(-1, 1)
        relevance_relations = relation_model.compute_lrp(
            relation_pairs, target=relations_logits)
        # print(relevance_relations.shape)
        # print(model.feat_dim)
        relevance_z_query = torch.narrow(relevance_relations, 1,
                                         model.feat_dim[0], model.feat_dim[0])
        # print(relevance_z_query.shape)
        relevance_z_query = relevance_z_query.view(
            model.n_query * model.n_way, model.n_way,
            *relevance_z_query.size()[1:])
        # print(relevance_z_query.shape)
        query_img = x.narrow(1, model.n_support,
                             model.n_query).view(model.n_way * model.n_query,
                                                 *x.size()[2:])
        # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:])
        # print(query_img.shape)
        for k in range(model.n_way):
            relevance_querry_cls = torch.narrow(relevance_z_query, 1, k,
                                                1).squeeze(1)
            # print(relevance_querry_cls.shape)
            relevance_querry_img = feature_model.compute_lrp(
                query_img.cuda(), target=relevance_querry_cls)
            # print(relevance_querry_img.max(), relevance_querry_img.min())
            # print(relevance_querry_img.shape)
            for j in range(model.n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(i),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_img[j].permute(1, 2,
                                                         0).cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_querry_img.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))
Exemplo n.º 8
0
def finetune(novel_loader, n_pseudo=75, n_way=5, n_support=5):
    iter_num = len(novel_loader)
    acc_all = []

    checkpoint_dir = '%s/checkpoints/%s/best_model.tar' % (params.save_dir,
                                                           params.name)
    state = torch.load(checkpoint_dir)['state']
    for ti, (x, _) in enumerate(novel_loader):  # x:(5, 20, 3, 224, 224)
        # Model
        if params.method == 'MatchingNet':
            model = MatchingNet(model_dict[params.model],
                                n_way=n_way,
                                n_support=n_support).cuda()
        elif params.method == 'RelationNet':
            model = RelationNet(model_dict[params.model],
                                n_way=n_way,
                                n_support=n_support).cuda()
        elif params.method == 'ProtoNet':
            model = ProtoNet(model_dict[params.model],
                             n_way=n_way,
                             n_support=n_support).cuda()
        elif params.method == 'GNN':
            model = GnnNet(model_dict[params.model],
                           n_way=n_way,
                           n_support=n_support).cuda()
        elif params.method == 'TPN':
            model = TPN(model_dict[params.model],
                        n_way=n_way,
                        n_support=n_support).cuda()
        else:
            print("Please specify the method!")
            assert (False)
        # Update model
        if 'FWT' in params.name:
            model_params = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in state.items() if k in model_params
            }
            model_params.update(pretrained_dict)
            model.load_state_dict(model_params)
        else:
            model.load_state_dict(state)

        x = x.cuda()
        # Finetune components initialization
        xs = x[:, :n_support].reshape(-1, *x.size()[2:])  # (25, 3, 224, 224)
        pseudo_q_genrator = PseudoSampleGenerator(n_way, n_support, n_pseudo)
        loss_fun = nn.CrossEntropyLoss().cuda()
        opt = torch.optim.Adam(model.parameters())
        # Finetune process
        n_query = n_pseudo // n_way
        pseudo_set_y = torch.from_numpy(np.repeat(range(n_way),
                                                  n_query)).cuda()
        model.n_query = n_query
        model.train()
        for epoch in range(params.finetune_epoch):
            opt.zero_grad()
            pseudo_set = pseudo_q_genrator.generate(
                xs)  # (5, n_support+n_query, 3, 224, 224)
            scores = model.set_forward(pseudo_set)  # (5*n_query, 5)
            loss = loss_fun(scores, pseudo_set_y)
            loss.backward()
            opt.step()
            del pseudo_set, scores, loss
        torch.cuda.empty_cache()

        # Inference process
        n_query = x.size(1) - n_support
        model.n_query = n_query
        yq = np.repeat(range(n_way), n_query)
        with torch.no_grad():
            scores = model.set_forward(x)  # (80, 5)
            _, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()  # (80, 1)
            top1_correct = np.sum(topk_ind[:, 0] == yq)
            acc = top1_correct * 100. / (n_way * n_query)
            acc_all.append(acc)
        del scores, topk_labels
        torch.cuda.empty_cache()
        print('Task %d : %4.2f%%' % (ti, acc))

    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('Test Acc = %4.2f +- %4.2f%%' %
          (acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))