Exemplo n.º 1
0
def loadmodel(hook_fn):
    if settings.MODEL_FILE is None:
        model = torchvision.models.__dict__[settings.MODEL](pretrained=True)
    else:
        checkpoint = torch.load(settings.MODEL_FILE)
        if type(checkpoint).__name__ == 'OrderedDict' or type(
                checkpoint).__name__ == 'dict':
            if settings.MODEL == 'genotype':
                backbone.geneC = settings.CHANNELS  #params.channels
                backbone.geneLayers = settings.LAYERS  #params.layers
                backbone.geneName = settings.GENE_NAME  #params.gene_name
                model = ProtoNet(Genotype, settings.N_WAY, settings.N_SHOT)
                state_dict = checkpoint['state']
                model.load_state_dict(state_dict)
            else:
                model = torchvision.models.__dict__[settings.MODEL](
                    num_classes=settings.NUM_CLASSES)
                if settings.MODEL_PARALLEL:
                    state_dict = {
                        str.replace(k, 'module.', ''): v
                        for k, v in checkpoint['state_dict'].items()
                    }  # the data parallel layer will add 'module' before each layer name
                else:
                    state_dict = checkpoint
                model.load_state_dict(state_dict)
        else:
            model = checkpoint
    for name in settings.FEATURE_NAMES:
        if settings.MODEL == 'genotype':
            model._modules.get('feature')._modules.get('cells')._modules.get(
                '%d' % (settings.LAYER - 1)).register_forward_hook(hook_fn)
        else:
            model._modules.get(name).register_forward_hook(hook_fn)
    if settings.GPU:
        model.cuda()
    model.eval()
    return model
Exemplo n.º 2
0
def meta_test(novel_loader,
              n_query=15,
              pretrained_dataset='miniImageNet',
              freeze_backbone=False,
              n_pseudo=100,
              n_way=5,
              n_support=5):
    #few_shot_params={"n_way":5, "n_support":5}
    #pretrained_dataset = "miniImageNet"
    #n_pseudo=100
    #n_way=5 # five class
    #n_support=5 # each class contain 5 support images. Thus, 25 query images in total
    #freeze_backbone=True
    #n_query=15 # each class contains 15 query images. Thus, 75 query images in total
    correct = 0
    count = 0

    iter_num = len(novel_loader)  #600

    acc_all = []
    for ti, (x, y) in enumerate(novel_loader):  #600 "ti"mes

        ###############################################################################################
        # load pretrained model on miniImageNet
        if params.method == 'protonet':
            pretrained_model = ProtoNet(model_dict[params.model],
                                        n_way=n_way,
                                        n_support=n_support)
        elif 'mytpn' in params.method:
            pretrained_model = MyTPN(model_dict[params.model],
                                     n_way=n_way,
                                     n_support=n_support)

        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        checkpoint_dir += '_5way_5shot'

        params.save_iter = -1
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)
        print(
            "load from %s" % (modelfile)
        )  #logs/checkpoints/miniImageNet/ResNet10_protonet_aug_5way_5shot/best_model.tar
        tmp = torch.load(modelfile)
        state = tmp['state']
        pretrained_model.load_state_dict(state)  #load checkpoints to model
        pretrained_model.cuda()
        ###############################################################################################
        # split data into support set and query set
        n_query = x.size(1) - n_support  #20-5=15

        x = x.cuda()  ##torch.Size([5, 20, 3, 224, 224])
        x_var = Variable(x)

        support_size = n_way * n_support  #25

        y_a_i = Variable(torch.from_numpy(np.repeat(
            range(n_way), n_support))).cuda()  # (25,)

        x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
            n_way * n_query,
            *x.size()[2:])  # query set (75,3,224,224)
        x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
            n_way * n_support,
            *x.size()[2:])  # support set (25,3,224,224)

        if freeze_backbone == False:
            ###############################################################################################
            # Finetune components initialization
            pseudo_q_genrator = PseudoQeuryGenerator(n_way, n_support,
                                                     n_pseudo)
            delta_opt = torch.optim.Adam(
                filter(lambda p: p.requires_grad,
                       pretrained_model.parameters()))

            ###############################################################################################
            # finetune process
            finetune_epoch = 100

            fine_tune_n_query = n_pseudo // n_way  # 100//5 =20
            pretrained_model.n_query = fine_tune_n_query  #20
            pretrained_model.train()

            z_support = x_a_i.view(n_way, n_support,
                                   *x_a_i.size()[1:])  #(5,5,3,224,224)

            for epoch in range(finetune_epoch):  #100 EPOCH
                delta_opt.zero_grad()  #clear feature extractor gradient

                # generate pseudo query images
                psedo_query_set, _ = pseudo_q_genrator.generate(x_a_i)
                psedo_query_set = psedo_query_set.cuda().view(
                    n_way, fine_tune_n_query,
                    *x_a_i.size()[1:])  #(5,20,3,224,224)

                x = torch.cat((z_support, psedo_query_set), dim=1)

                loss = pretrained_model.set_forward_loss(x)
                loss.backward()
                delta_opt.step()

        ###############################################################################################
        # inference

        pretrained_model.eval()

        pretrained_model.n_query = n_query  #15
        with torch.no_grad():
            scores = pretrained_model.set_forward(
                x_var.cuda())  #set_forward in protonet.py

        y_query = np.repeat(range(n_way),
                            n_query)  #[0,...0, ...4,...4] with shape (75)
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        #the 1st argument means return top-1
        #the 2nd argument dim=1 means return the value row-wisely
        #the 3rd arguemtn is largest=True
        #the 4th argument is sorted=True
        topk_ind = topk_labels.cpu().numpy()

        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)

        acc_all.append((correct_this / count_this * 100))
        print("Task %d : %4.2f%%  Now avg: %4.2f%%" %
              (ti, correct_this / count_this * 100, np.mean(acc_all)))
        ###############################################################################################
    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %
          (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
def meta_test(novel_loader,
              n_query=15,
              pretrained_dataset='miniImageNet',
              freeze_backbone=False,
              n_pseudo=100,
              n_way=5,
              n_support=5):
    correct = 0
    count = 0

    iter_num = len(novel_loader)

    acc_all = []
    for ti, (x, y) in enumerate(novel_loader):

        ###############################################################################################
        # load pretrained model on miniImageNet
        if params.method == 'protonet':
            pretrained_model = ProtoNet(model_dict[params.model],
                                        n_way=n_way,
                                        n_support=n_support)

        checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, pretrained_dataset, params.model, params.method)
        if params.train_aug:
            checkpoint_dir += '_aug'
        checkpoint_dir += '_5way_5shot'

        params.save_iter = -1
        if params.save_iter != -1:
            modelfile = get_assigned_file(checkpoint_dir, params.save_iter)
        else:
            modelfile = get_best_file(checkpoint_dir)

        tmp = torch.load(modelfile)
        state = tmp['state']
        pretrained_model.load_state_dict(state)
        pretrained_model.cuda()
        ###############################################################################################
        # split data into support set and query set
        n_query = x.size(1) - n_support

        x = x.cuda()
        x_var = Variable(x)

        support_size = n_way * n_support

        y_a_i = Variable(torch.from_numpy(np.repeat(
            range(n_way), n_support))).cuda()  # (25,)

        x_b_i = x_var[:, n_support:, :, :, :].contiguous().view(
            n_way * n_query,
            *x.size()[2:])  # query set
        x_a_i = x_var[:, :n_support, :, :, :].contiguous().view(
            n_way * n_support,
            *x.size()[2:])  # support set

        if freeze_backbone == False:
            ###############################################################################################
            # Finetune components initialization
            pseudo_q_genrator = PseudoQeuryGenerator(n_way, n_support,
                                                     n_pseudo)
            delta_opt = torch.optim.Adam(
                filter(lambda p: p.requires_grad,
                       pretrained_model.parameters()))

            ###############################################################################################
            # finetune process
            finetune_epoch = 100

            fine_tune_n_query = n_pseudo // n_way
            pretrained_model.n_query = fine_tune_n_query
            pretrained_model.train()

            z_support = x_a_i.view(n_way, n_support, *x_a_i.size()[1:])

            for epoch in range(finetune_epoch):
                delta_opt.zero_grad()

                # generate pseudo query images
                psedo_query_set, _ = pseudo_q_genrator.generate(x_a_i)
                psedo_query_set = psedo_query_set.cuda().view(
                    n_way, fine_tune_n_query,
                    *x_a_i.size()[1:])

                x = torch.cat((z_support, psedo_query_set), dim=1)

                loss = pretrained_model.set_forward_loss(x)
                loss.backward()
                delta_opt.step()

        ###############################################################################################
        # inference

        pretrained_model.eval()

        pretrained_model.n_query = n_query
        with torch.no_grad():
            scores = pretrained_model.set_forward(x_var.cuda())

        y_query = np.repeat(range(n_way), n_query)
        topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
        topk_ind = topk_labels.cpu().numpy()

        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        correct_this, count_this = float(top1_correct), len(y_query)

        acc_all.append((correct_this / count_this * 100))
        print("Task %d : %4.2f%%  Now avg: %4.2f%%" %
              (ti, correct_this / count_this * 100, np.mean(acc_all)))
        ###############################################################################################
    acc_all = np.asarray(acc_all)
    acc_mean = np.mean(acc_all)
    acc_std = np.std(acc_all)
    print('%d Test Acc = %4.2f%% +- %4.2f%%' %
          (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
def calculate_dist(n_way, n_shot, task_num, episode, batch_size=64):

    model = ProtoNet(model_dict["ResNet18"], n_way, n_shot)

    device = "cuda:0"
    with torch.cuda.device(device):
        model = model.cuda()

    resume_file = "/home/takumi/research/CloserLookFewShot/instance_selection/feature_space/{0}.tar".format(
        episode)
    tmp = torch.load(resume_file)
    model.load_state_dict(tmp['state'])

    base_file = "/home/takumi/research/CloserLookFewShot/filelists/full_Imagenet_except_testclass.json"

    candidate_data_manager = SimpleDataManager(224, batch_size)

    candidate_data_loader = candidate_data_manager.get_data_loader(
        base_file, aug=False, shuffle=False)

    task_file = "/home/takumi/research/CloserLookFewShot/instance_selection/task/few_shot_task{0}.json".format(
        task_num)

    few_image_list = task_train_reader(task_file, n_shot, n_way)

    few_image_feature_list = []
    for images in few_image_list:
        images = images.to(device)
        features2 = model.feature(images)
        features = torch.mean(features2, dim=0)
        few_image_feature_list.append(features)
        #print(euclidean_dist(features2, features2))

    dim = few_image_feature_list[0].size()[0]
    features_ave = torch.zeros(0, dim).to(device)
    for x in few_image_feature_list:
        features_ave = torch.cat([features_ave, x.unsqueeze(0)], dim=0)

    with open(base_file) as f:
        base = json.load(f)

    image_dists = []

    image_id = 0
    for x, labels in tqdm.tqdm(candidate_data_loader):
        x = x.to(device)
        y = model.feature(x)
        #print(euclidean_dist(y, y))
        #print(euclidean_dist(features_ave, features_ave))
        dist = euclidean_dist(y, features_ave)
        #print(dist)
        dist_min, _ = torch.min(dist, dim=1)

        dist_min = dist_min.tolist()
        labels = labels.tolist()

        for i in range(len(dist_min)):
            image_dists.append(
                [dist_min[i], labels[i], base["image_names"][image_id]])
            image_id += 1

    image_dists.sort()

    task_image_dist = dict()

    task_image_dist["label_names"] = copy.deepcopy(base["label_names"])
    task_image_dist["image_names"] = []
    task_image_dist["image_labels"] = []
    task_image_dist["distance"] = []

    for i in range(len(image_dists)):
        task_image_dist["image_names"].append(image_dists[i][2])
        task_image_dist["image_labels"].append(image_dists[i][1])
        task_image_dist["distance"].append(image_dists[i][0])

    with open(
            "/home/takumi/research/CloserLookFewShot/instance_selection/task/task{0}_dataset_dist_{1}.json"
            .format(task_num, episode), "w") as f:
        json.dump(task_image_dist, f)

    print(len(task_image_dist["image_names"]),
          len(task_image_dist["image_labels"]),
          len(task_image_dist["distance"]))
Exemplo n.º 5
0
                                  n_query=params.n_query,
                                  **train_few_shot_params,
                                  num_views=params.num_views)
    base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug)

    test_few_shot_params = dict(n_way=params.test_n_way,
                                n_support=params.n_shot)
    val_datamgr = SetDataManager(image_size,
                                 n_query=params.n_query,
                                 **test_few_shot_params,
                                 num_views=params.num_views)
    val_loader = val_datamgr.get_data_loader(val_file, aug=False)

    backbone = model_dict[params.model]
    model = ProtoNet(backbone, params.num_views, **train_few_shot_params)
    model = model.cuda()
    # model = torch.nn.DataParallel(model).cuda()

    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    params.checkpoint_dir += '_%dway_%dshot_%dviews_lr%f' % (
        params.train_n_way, params.n_shot, params.num_views, params.lr)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
Exemplo n.º 6
0
def explain_gnnnet():
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'gnnnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'gnn'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRP'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
            print('loaded pretrained model')
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # for module in model.modules():
    #   print(type(module))
    lrp_preset = lrp_presets.SequentialPresetA()
    feature_model = model.feature
    fc_encoder = model.fc
    gnn_net = model.gnn
    lrp_wrapper.add_lrp(fc_encoder, lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)
    # lrp_wrapper.add_lrp(fc_encoder,lrp_preset)
    # lrp_wrapper.add_lrp(feature_model, lrp_preset)

    # acc = 0
    # count = 0
    # tested the forward pass is correct by observing the accuracy
    # for i, (x, _, _) in enumerate(data_loader):
    #   x = x.cuda()
    #   support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1)
    #   support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way)
    #   support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1)
    #   support_label = support_label.view(1, -1, model.n_way)
    #   support_label = support_label.cuda()
    #   x = x.view(-1, *x.size()[2:])
    #
    #   x_feature = feature_model(x)
    #   x_fc_encoded = fc_encoder(x_feature)
    #   z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1))
    #   gnn_feature = [
    #     torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2))
    #     for i in range(model.n_query)]
    #   gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0)
    #   scores = gnn_net(gnn_nodes)
    #   scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0,
    #                                                                                                    2).contiguous().view(
    #     -1, model.n_way)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for batch_idx, (x, y, p) in enumerate(data_loader):
        print(p)
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        x = x.cuda()
        support_label = torch.from_numpy(
            np.repeat(range(model.n_way),
                      model.n_support)).unsqueeze(1)  #torch.Size([25, 1])
        support_label = torch.zeros(model.n_way * model.n_support,
                                    model.n_way).scatter(1, support_label,
                                                         1).view(
                                                             model.n_way,
                                                             model.n_support,
                                                             model.n_way)
        support_label = torch.cat(
            [support_label,
             torch.zeros(model.n_way, 1, model.n_way)], dim=1)
        support_label = support_label.view(1, -1, model.n_way)
        support_label = support_label.cuda()  #torch.Size([1, 30, 5])
        x = x.contiguous()
        x = x.view(-1, *x.size()[2:])  #torch.Size([30, 3, 224, 224])
        x_feature = feature_model(x)  #torch.Size([30, 512])
        x_fc_encoded = fc_encoder(x_feature)  #torch.Size([30, 128])
        z = x_fc_encoded.view(model.n_way, -1,
                              x_fc_encoded.size(1))  # (5,6,128)
        gnn_feature = [
            torch.cat([
                z[:, :model.n_support],
                z[:, model.n_support + i:model.n_support + i + 1]
            ],
                      dim=1).view(1, -1, z.size(2))
            for i in range(model.n_query)
        ]  # model.n_query is the number of query images for each class
        # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features.
        # print(len(gnn_feature), gnn_feature[0].shape)
        gnn_nodes = torch.cat(
            [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0
        )  # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero

        #  perform gnn_net step by step
        #  the first iteration
        print('x', gnn_nodes.shape)
        W_init = torch.eye(
            gnn_nodes.size(1), device=gnn_nodes.device
        ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze(
            3
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1)
        # print(W_init.shape)

        W1 = gnn_net._modules['layer_w{}'.format(0)](
            gnn_nodes, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        # print(Wi.shape)
        x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)](
            [W1, gnn_nodes])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new1.shape)  #torch.Size([1, 30, 48])
        gnn_nodes_1 = torch.cat([gnn_nodes, x_new1],
                                2)  # (concat more features)
        # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181])

        #  the second iteration
        W2 = gnn_net._modules['layer_w{}'.format(1)](
            gnn_nodes_1, W_init
        )  # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2)
        x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)](
            [W2,
             gnn_nodes_1])[1])  # (num_querry, num_support + 1, num_outputs)
        # print(x_new2.shape)
        gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2],
                                2)  # (concat more features)
        # print('gn2', gnn_nodes_2.shape)  #torch.Size([1, 30, 229])

        Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init)
        # print(Wl.shape)  #torch.Size([1, 30, 30, 2])
        scores = gnn_net.layer_last(
            [Wl, gnn_nodes_2])[1]  # (num_querry, num_support + 1, num_way)
        print(scores.shape)

        scores_sf = torch.softmax(scores, dim=-1)
        # print(scores_sf)

        gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf /
                               (1 - scores_sf))
        gnn_logits = gnn_logits.view(-1, model.n_way,
                                     model.n_support + n_query, model.n_way)
        # print(gnn_logits)
        query_scores = scores.view(
            model.n_query, model.n_way, model.n_support + 1,
            model.n_way)[:, :,
                         -1].permute(1, 0,
                                     2).contiguous().view(-1, model.n_way)
        preds = query_scores.data.cpu().numpy().argmax(axis=-1)
        # print(preds.shape)
        for k in range(model.n_way):
            mask = torch.zeros(5).cuda()
            mask[k] = 1
            gnn_logits_cls = gnn_logits.clone()
            gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask
            # print(gnn_logits_cls)
            # print(gnn_logits_cls.shape)
            gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way)
            relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls,
                                                  gnn_net.layer_last, Wl,
                                                  gnn_nodes_2)
            relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48)
            # relevance_gnn_nodes = relevance_gnn_nodes_2
            relevance_gnn_nodes_1 = explain_Gconv(
                relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2,
                gnn_nodes_1)
            relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48)
            relevance_gnn_nodes = explain_Gconv(
                relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1,
                gnn_nodes)
            relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128)
            print(relevance_gnn_features.shape)
            relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128)
            relevance_gnn_features += relevance_gnn_nodes_2.narrow(
                -1, 0, 128)  #[2, 30, 128]
            relevance_gnn_features = relevance_gnn_features.view(
                n_query, model.n_way, model.n_support + 1, 128)
            for i in range(n_query):
                query_i = relevance_gnn_features[i][:, model.
                                                    n_support:model.n_support +
                                                    1]
                if i == 0:
                    relevance_z = query_i
                else:
                    relevance_z = torch.cat((relevance_z, query_i), 1)
            relevance_z = relevance_z.view(-1, 128)
            query_feature = x_feature.view(model.n_way, -1,
                                           512)[:, model.n_support:]
            # print(query_feature.shape)
            query_feature = query_feature.contiguous()
            query_feature = query_feature.view(n_query * model.n_way, 512)
            # print(query_feature.shape)
            relevance_query_features = fc_encoder.compute_lrp(
                query_feature, target=relevance_z)
            # print(relevance_query_features.shape)
            # print(relevance_gnn_features.shape)
            # explain the fc layer and the image encoder
            query_images = x.view(model.n_way, -1,
                                  *x.size()[1:])[:, model.n_support:]
            query_images = query_images.contiguous()

            query_images = query_images.view(-1, *x.size()[1:]).detach()
            # print(query_images.shape)
            lrp_wrapper.add_lrp(feature_model, lrp_preset)
            relevance_query_images = feature_model.compute_lrp(
                query_images, target=relevance_query_features)
            print(relevance_query_images.shape)

            for j in range(n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' +
                                     str(batch_idx), img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir,
                                     'episode' + str(batch_idx),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(batch_idx),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_images[j].permute(
                                1, 2, 0).detach().cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_query_images.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))

        break
Exemplo n.º 7
0
def explain_relationnet():
    # print(sys.path)
    params = options.parse_args('test')
    feature_model = backbone.model_dict['ResNet10']
    params.method = 'relationnet'
    params.dataset = 'miniImagenet'  # name relationnet --testset miniImagenet
    params.name = 'relationnet'
    params.testset = 'miniImagenet'
    params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/'
    params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output'

    if 'Conv' in params.model:
        image_size = 84
    else:
        image_size = 224
    split = params.split
    n_query = 1
    loadfile = os.path.join(params.data_dir, params.testset, split + '.json')
    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    data_datamgr = SetDataManager(image_size,
                                  n_query=n_query,
                                  **few_shot_params)
    data_loader = data_datamgr.get_data_loader(loadfile, aug=False)

    acc_all = []
    iter_num = 1000

    # model
    print('  build metric-based model')
    if params.method == 'protonet':
        model = ProtoNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method == 'matchingnet':
        model = MatchingNet(backbone.model_dict[params.model],
                            **few_shot_params)
    elif params.method == 'gnnnet':
        model = GnnNet(backbone.model_dict[params.model], **few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
        if params.model == 'Conv4':
            feature_model = backbone.Conv4NP
        elif params.model == 'Conv6':
            feature_model = backbone.Conv6NP
        else:
            feature_model = backbone.model_dict[params.model]
        loss_type = 'LRPmse'
        model = RelationNet(feature_model,
                            loss_type=loss_type,
                            **few_shot_params)
    else:
        raise ValueError('Unknown method')

    checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name)
    # print(checkpoint_dir)
    if params.save_epoch != -1:
        modelfile = get_assigned_file(checkpoint_dir, params.save_epoch)
    else:
        modelfile = get_best_file(checkpoint_dir)
        # print(modelfile)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        try:
            model.load_state_dict(tmp['state'])
        except RuntimeError:
            print('warning! RuntimeError when load_state_dict()!')
            model.load_state_dict(tmp['state'], strict=False)
        except KeyError:
            for k in tmp['model_state']:  ##### revise latter
                if 'running' in k:
                    tmp['model_state'][k] = tmp['model_state'][k].squeeze()
            model.load_state_dict(tmp['model_state'], strict=False)
        except:
            raise

    model = model.cuda()
    model.eval()
    model.n_query = n_query
    # ---test the accuracy on the test set to verify the model is loaded----
    acc = 0
    count = 0
    # for i, (x, y) in enumerate(data_loader):
    #   scores = model.set_forward(x)
    #   pred = scores.data.cpu().numpy().argmax(axis=1)
    #   y = np.repeat(range(model.n_way), n_query)
    #   acc += np.sum(pred == y)
    #   count += len(y)
    #   # print(1.0*acc/count)
    # print(1.0*acc/count)
    preset = lrp_presets.SequentialPresetA()

    feature_model = copy.deepcopy(model.feature)
    lrp_wrapper.add_lrp(feature_model, preset=preset)
    relation_model = copy.deepcopy(model.relation_module)
    # print(relation_model)
    lrp_wrapper.add_lrp(relation_model, preset=preset)
    with open(
            '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json',
            'r') as f:
        class_to_readable = json.load(f)
    explanation_save_dir = os.path.join(params.save_dir, 'explanations',
                                        params.name)
    if not os.path.isdir(explanation_save_dir):
        os.makedirs(explanation_save_dir)
    for i, (x, y, p) in enumerate(data_loader):
        '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size
       y is the global labels of the images with shape as (n_way, n_support + n_query)
       p is the image path as a list of tuples, length is n_query+n_support,  each tuple element is with length n_way'''
        if i >= 3:
            break
        label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label(
            p, class_to_readable, model.n_query)
        z_support, z_query = model.parse_feature(x, is_feature=False)
        z_support = z_support.contiguous()
        z_proto = z_support.view(model.n_way, model.n_support,
                                 *model.feat_dim).mean(1)
        # print(z_proto.shape)
        z_query = z_query.contiguous().view(model.n_way * model.n_query,
                                            *model.feat_dim)
        # print(z_query.shape)
        # get relations with metric function
        z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way,
                                                  1, 1, 1, 1)
        # print(z_proto_ext.shape)
        z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1)

        z_query_ext = torch.transpose(z_query_ext, 0, 1)
        # print(z_query_ext.shape)
        extend_final_feat_dim = model.feat_dim.copy()
        extend_final_feat_dim[0] *= 2
        relation_pairs = torch.cat((z_proto_ext, z_query_ext),
                                   2).view(-1, *extend_final_feat_dim)
        # print(relation_pairs.shape)
        relations = relation_model(relation_pairs)
        # print(relations)
        scores = relations.view(-1, model.n_way)
        preds = scores.data.cpu().numpy().argmax(axis=1)
        # print(preds.shape)
        relations = relations.view(-1, model.n_way)
        # print(relations)
        relations_sf = torch.softmax(relations, dim=-1)
        # print(relations_sf)
        relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf /
                                     (1 - relations_sf))
        # print(relations_logits)
        # print(preds)
        relations_logits = relations_logits.view(-1, 1)
        relevance_relations = relation_model.compute_lrp(
            relation_pairs, target=relations_logits)
        # print(relevance_relations.shape)
        # print(model.feat_dim)
        relevance_z_query = torch.narrow(relevance_relations, 1,
                                         model.feat_dim[0], model.feat_dim[0])
        # print(relevance_z_query.shape)
        relevance_z_query = relevance_z_query.view(
            model.n_query * model.n_way, model.n_way,
            *relevance_z_query.size()[1:])
        # print(relevance_z_query.shape)
        query_img = x.narrow(1, model.n_support,
                             model.n_query).view(model.n_way * model.n_query,
                                                 *x.size()[2:])
        # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:])
        # print(query_img.shape)
        for k in range(model.n_way):
            relevance_querry_cls = torch.narrow(relevance_z_query, 1, k,
                                                1).squeeze(1)
            # print(relevance_querry_cls.shape)
            relevance_querry_img = feature_model.compute_lrp(
                query_img.cuda(), target=relevance_querry_cls)
            # print(relevance_querry_img.max(), relevance_querry_img.min())
            # print(relevance_querry_img.shape)
            for j in range(model.n_query * model.n_way):
                predict_class = label_to_readableclass[preds[j]]
                true_class = query_gt_class[int(j % model.n_way)][int(
                    j // model.n_way)]
                explain_class = label_to_readableclass[k]
                img_name = query_img_path[int(j % model.n_way)][int(
                    j // model.n_way)].split('/')[-1]
                if not os.path.isdir(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg'))):
                    os.makedirs(
                        os.path.join(explanation_save_dir, 'episode' + str(i),
                                     img_name.strip('.jpg')))
                save_path = os.path.join(explanation_save_dir,
                                         'episode' + str(i),
                                         img_name.strip('.jpg'))
                if not os.path.exists(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name)):
                    original_img = Image.fromarray(
                        np.uint8(
                            project(query_img[j].permute(1, 2,
                                                         0).cpu().numpy())))
                    original_img.save(
                        os.path.join(
                            save_path,
                            true_class + '_' + predict_class + img_name))

                img_relevance = relevance_querry_img.narrow(0, j, 1)
                print(predict_class, true_class, explain_class)
                # assert relevance_querry_cls[j].sum() != 0
                # assert img_relevance.sum()!=0
                hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy()
                hm = LRPutil.gamma(hm)
                hm = LRPutil.heatmap(hm)[0]
                hm = project(hm)
                hp_img = Image.fromarray(np.uint8(hm))
                hp_img.save(
                    os.path.join(
                        save_path,
                        true_class + '_' + explain_class + '_lrp_hm.jpg'))
Exemplo n.º 8
0
    if params.method == 'manifold_mixup':
        model = wrn_mixup_model.wrn28_10(64)
    elif params.method == 'S2M2_R':
        model = ProtoNet(model_dict[params.model], params.train_n_way,
                         params.n_shot)
    elif params.method == 'rotation':
        model = BaselineTrain(model_dict[params.model], 64, loss_type='dist')

    if params.method == 'S2M2_R':

        if use_gpu:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model,
                                              device_ids=range(
                                                  torch.cuda.device_count()))
            model.cuda()

        if params.resume:
            resume_file = get_resume_file(params.checkpoint_dir)
            print("resume_file", resume_file)
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            print("restored epoch is", tmp['epoch'])
            state = tmp['state']
            model.load_state_dict(state)

        else:
            resume_file = get_resume_file(params.checkpoint_dir)
            # resume_file = './checkpoints/cifar/ProtoNet_from_S2M2_R_SGD_lr0.0001/best_model.tar'
            print("resume_file", resume_file)
            tmp = torch.load(resume_file)