def loadmodel(hook_fn): if settings.MODEL_FILE is None: model = torchvision.models.__dict__[settings.MODEL](pretrained=True) else: checkpoint = torch.load(settings.MODEL_FILE) if type(checkpoint).__name__ == 'OrderedDict' or type( checkpoint).__name__ == 'dict': if settings.MODEL == 'genotype': backbone.geneC = settings.CHANNELS #params.channels backbone.geneLayers = settings.LAYERS #params.layers backbone.geneName = settings.GENE_NAME #params.gene_name model = ProtoNet(Genotype, settings.N_WAY, settings.N_SHOT) state_dict = checkpoint['state'] model.load_state_dict(state_dict) else: model = torchvision.models.__dict__[settings.MODEL]( num_classes=settings.NUM_CLASSES) if settings.MODEL_PARALLEL: state_dict = { str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items() } # the data parallel layer will add 'module' before each layer name else: state_dict = checkpoint model.load_state_dict(state_dict) else: model = checkpoint for name in settings.FEATURE_NAMES: if settings.MODEL == 'genotype': model._modules.get('feature')._modules.get('cells')._modules.get( '%d' % (settings.LAYER - 1)).register_forward_hook(hook_fn) else: model._modules.get(name).register_forward_hook(hook_fn) if settings.GPU: model.cuda() model.eval() return model
def meta_test(novel_loader, n_query=15, pretrained_dataset='miniImageNet', freeze_backbone=False, n_pseudo=100, n_way=5, n_support=5): #few_shot_params={"n_way":5, "n_support":5} #pretrained_dataset = "miniImageNet" #n_pseudo=100 #n_way=5 # five class #n_support=5 # each class contain 5 support images. Thus, 25 query images in total #freeze_backbone=True #n_query=15 # each class contains 15 query images. Thus, 75 query images in total correct = 0 count = 0 iter_num = len(novel_loader) #600 acc_all = [] for ti, (x, y) in enumerate(novel_loader): #600 "ti"mes ############################################################################################### # load pretrained model on miniImageNet if params.method == 'protonet': pretrained_model = ProtoNet(model_dict[params.model], n_way=n_way, n_support=n_support) elif 'mytpn' in params.method: pretrained_model = MyTPN(model_dict[params.model], n_way=n_way, n_support=n_support) checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, pretrained_dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' checkpoint_dir += '_5way_5shot' params.save_iter = -1 if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) print( "load from %s" % (modelfile) ) #logs/checkpoints/miniImageNet/ResNet10_protonet_aug_5way_5shot/best_model.tar tmp = torch.load(modelfile) state = tmp['state'] pretrained_model.load_state_dict(state) #load checkpoints to model pretrained_model.cuda() ############################################################################################### # split data into support set and query set n_query = x.size(1) - n_support #20-5=15 x = x.cuda() ##torch.Size([5, 20, 3, 224, 224]) x_var = Variable(x) support_size = n_way * n_support #25 y_a_i = Variable(torch.from_numpy(np.repeat( range(n_way), n_support))).cuda() # (25,) x_b_i = x_var[:, n_support:, :, :, :].contiguous().view( n_way * n_query, *x.size()[2:]) # query set (75,3,224,224) x_a_i = x_var[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) # support set (25,3,224,224) if freeze_backbone == False: ############################################################################################### # Finetune components initialization pseudo_q_genrator = PseudoQeuryGenerator(n_way, n_support, n_pseudo) delta_opt = torch.optim.Adam( filter(lambda p: p.requires_grad, pretrained_model.parameters())) ############################################################################################### # finetune process finetune_epoch = 100 fine_tune_n_query = n_pseudo // n_way # 100//5 =20 pretrained_model.n_query = fine_tune_n_query #20 pretrained_model.train() z_support = x_a_i.view(n_way, n_support, *x_a_i.size()[1:]) #(5,5,3,224,224) for epoch in range(finetune_epoch): #100 EPOCH delta_opt.zero_grad() #clear feature extractor gradient # generate pseudo query images psedo_query_set, _ = pseudo_q_genrator.generate(x_a_i) psedo_query_set = psedo_query_set.cuda().view( n_way, fine_tune_n_query, *x_a_i.size()[1:]) #(5,20,3,224,224) x = torch.cat((z_support, psedo_query_set), dim=1) loss = pretrained_model.set_forward_loss(x) loss.backward() delta_opt.step() ############################################################################################### # inference pretrained_model.eval() pretrained_model.n_query = n_query #15 with torch.no_grad(): scores = pretrained_model.set_forward( x_var.cuda()) #set_forward in protonet.py y_query = np.repeat(range(n_way), n_query) #[0,...0, ...4,...4] with shape (75) topk_scores, topk_labels = scores.data.topk(1, 1, True, True) #the 1st argument means return top-1 #the 2nd argument dim=1 means return the value row-wisely #the 3rd arguemtn is largest=True #the 4th argument is sorted=True topk_ind = topk_labels.cpu().numpy() top1_correct = np.sum(topk_ind[:, 0] == y_query) correct_this, count_this = float(top1_correct), len(y_query) acc_all.append((correct_this / count_this * 100)) print("Task %d : %4.2f%% Now avg: %4.2f%%" % (ti, correct_this / count_this * 100, np.mean(acc_all))) ############################################################################################### acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
def meta_test(novel_loader, n_query=15, pretrained_dataset='miniImageNet', freeze_backbone=False, n_pseudo=100, n_way=5, n_support=5): correct = 0 count = 0 iter_num = len(novel_loader) acc_all = [] for ti, (x, y) in enumerate(novel_loader): ############################################################################################### # load pretrained model on miniImageNet if params.method == 'protonet': pretrained_model = ProtoNet(model_dict[params.model], n_way=n_way, n_support=n_support) checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, pretrained_dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' checkpoint_dir += '_5way_5shot' params.save_iter = -1 if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) tmp = torch.load(modelfile) state = tmp['state'] pretrained_model.load_state_dict(state) pretrained_model.cuda() ############################################################################################### # split data into support set and query set n_query = x.size(1) - n_support x = x.cuda() x_var = Variable(x) support_size = n_way * n_support y_a_i = Variable(torch.from_numpy(np.repeat( range(n_way), n_support))).cuda() # (25,) x_b_i = x_var[:, n_support:, :, :, :].contiguous().view( n_way * n_query, *x.size()[2:]) # query set x_a_i = x_var[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) # support set if freeze_backbone == False: ############################################################################################### # Finetune components initialization pseudo_q_genrator = PseudoQeuryGenerator(n_way, n_support, n_pseudo) delta_opt = torch.optim.Adam( filter(lambda p: p.requires_grad, pretrained_model.parameters())) ############################################################################################### # finetune process finetune_epoch = 100 fine_tune_n_query = n_pseudo // n_way pretrained_model.n_query = fine_tune_n_query pretrained_model.train() z_support = x_a_i.view(n_way, n_support, *x_a_i.size()[1:]) for epoch in range(finetune_epoch): delta_opt.zero_grad() # generate pseudo query images psedo_query_set, _ = pseudo_q_genrator.generate(x_a_i) psedo_query_set = psedo_query_set.cuda().view( n_way, fine_tune_n_query, *x_a_i.size()[1:]) x = torch.cat((z_support, psedo_query_set), dim=1) loss = pretrained_model.set_forward_loss(x) loss.backward() delta_opt.step() ############################################################################################### # inference pretrained_model.eval() pretrained_model.n_query = n_query with torch.no_grad(): scores = pretrained_model.set_forward(x_var.cuda()) y_query = np.repeat(range(n_way), n_query) topk_scores, topk_labels = scores.data.topk(1, 1, True, True) topk_ind = topk_labels.cpu().numpy() top1_correct = np.sum(topk_ind[:, 0] == y_query) correct_this, count_this = float(top1_correct), len(y_query) acc_all.append((correct_this / count_this * 100)) print("Task %d : %4.2f%% Now avg: %4.2f%%" % (ti, correct_this / count_this * 100, np.mean(acc_all))) ############################################################################################### acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
def calculate_dist(n_way, n_shot, task_num, episode, batch_size=64): model = ProtoNet(model_dict["ResNet18"], n_way, n_shot) device = "cuda:0" with torch.cuda.device(device): model = model.cuda() resume_file = "/home/takumi/research/CloserLookFewShot/instance_selection/feature_space/{0}.tar".format( episode) tmp = torch.load(resume_file) model.load_state_dict(tmp['state']) base_file = "/home/takumi/research/CloserLookFewShot/filelists/full_Imagenet_except_testclass.json" candidate_data_manager = SimpleDataManager(224, batch_size) candidate_data_loader = candidate_data_manager.get_data_loader( base_file, aug=False, shuffle=False) task_file = "/home/takumi/research/CloserLookFewShot/instance_selection/task/few_shot_task{0}.json".format( task_num) few_image_list = task_train_reader(task_file, n_shot, n_way) few_image_feature_list = [] for images in few_image_list: images = images.to(device) features2 = model.feature(images) features = torch.mean(features2, dim=0) few_image_feature_list.append(features) #print(euclidean_dist(features2, features2)) dim = few_image_feature_list[0].size()[0] features_ave = torch.zeros(0, dim).to(device) for x in few_image_feature_list: features_ave = torch.cat([features_ave, x.unsqueeze(0)], dim=0) with open(base_file) as f: base = json.load(f) image_dists = [] image_id = 0 for x, labels in tqdm.tqdm(candidate_data_loader): x = x.to(device) y = model.feature(x) #print(euclidean_dist(y, y)) #print(euclidean_dist(features_ave, features_ave)) dist = euclidean_dist(y, features_ave) #print(dist) dist_min, _ = torch.min(dist, dim=1) dist_min = dist_min.tolist() labels = labels.tolist() for i in range(len(dist_min)): image_dists.append( [dist_min[i], labels[i], base["image_names"][image_id]]) image_id += 1 image_dists.sort() task_image_dist = dict() task_image_dist["label_names"] = copy.deepcopy(base["label_names"]) task_image_dist["image_names"] = [] task_image_dist["image_labels"] = [] task_image_dist["distance"] = [] for i in range(len(image_dists)): task_image_dist["image_names"].append(image_dists[i][2]) task_image_dist["image_labels"].append(image_dists[i][1]) task_image_dist["distance"].append(image_dists[i][0]) with open( "/home/takumi/research/CloserLookFewShot/instance_selection/task/task{0}_dataset_dist_{1}.json" .format(task_num, episode), "w") as f: json.dump(task_image_dist, f) print(len(task_image_dist["image_names"]), len(task_image_dist["image_labels"]), len(task_image_dist["distance"]))
n_query=params.n_query, **train_few_shot_params, num_views=params.num_views) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=params.n_query, **test_few_shot_params, num_views=params.num_views) val_loader = val_datamgr.get_data_loader(val_file, aug=False) backbone = model_dict[params.model] model = ProtoNet(backbone, params.num_views, **train_few_shot_params) model = model.cuda() # model = torch.nn.DataParallel(model).cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' params.checkpoint_dir += '_%dway_%dshot_%dviews_lr%f' % ( params.train_n_way, params.n_shot, params.num_views, params.lr) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file)
def explain_gnnnet(): params = options.parse_args('test') feature_model = backbone.model_dict['ResNet10'] params.method = 'gnnnet' params.dataset = 'miniImagenet' # name relationnet --testset miniImagenet params.name = 'gnn' params.testset = 'miniImagenet' params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/' params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output' if 'Conv' in params.model: image_size = 84 else: image_size = 224 split = params.split n_query = 1 loadfile = os.path.join(params.data_dir, params.testset, split + '.json') few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) data_datamgr = SetDataManager(image_size, n_query=n_query, **few_shot_params) data_loader = data_datamgr.get_data_loader(loadfile, aug=False) # model print(' build metric-based model') if params.method == 'protonet': model = ProtoNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'gnnnet': model = GnnNet(backbone.model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP else: feature_model = backbone.model_dict[params.model] loss_type = 'LRP' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) else: raise ValueError('Unknown method') checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name) # print(checkpoint_dir) if params.save_epoch != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_epoch) else: modelfile = get_best_file(checkpoint_dir) # print(modelfile) if modelfile is not None: tmp = torch.load(modelfile) try: model.load_state_dict(tmp['state']) print('loaded pretrained model') except RuntimeError: print('warning! RuntimeError when load_state_dict()!') model.load_state_dict(tmp['state'], strict=False) except KeyError: for k in tmp['model_state']: ##### revise latter if 'running' in k: tmp['model_state'][k] = tmp['model_state'][k].squeeze() model.load_state_dict(tmp['model_state'], strict=False) except: raise model = model.cuda() model.eval() model.n_query = n_query # for module in model.modules(): # print(type(module)) lrp_preset = lrp_presets.SequentialPresetA() feature_model = model.feature fc_encoder = model.fc gnn_net = model.gnn lrp_wrapper.add_lrp(fc_encoder, lrp_preset) # lrp_wrapper.add_lrp(feature_model, lrp_preset) # lrp_wrapper.add_lrp(fc_encoder,lrp_preset) # lrp_wrapper.add_lrp(feature_model, lrp_preset) # acc = 0 # count = 0 # tested the forward pass is correct by observing the accuracy # for i, (x, _, _) in enumerate(data_loader): # x = x.cuda() # support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1) # support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way) # support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1) # support_label = support_label.view(1, -1, model.n_way) # support_label = support_label.cuda() # x = x.view(-1, *x.size()[2:]) # # x_feature = feature_model(x) # x_fc_encoded = fc_encoder(x_feature) # z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1)) # gnn_feature = [ # torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) # for i in range(model.n_query)] # gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0) # scores = gnn_net(gnn_nodes) # scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0, # 2).contiguous().view( # -1, model.n_way) # pred = scores.data.cpu().numpy().argmax(axis=1) # y = np.repeat(range(model.n_way), n_query) # acc += np.sum(pred == y) # count += len(y) # # print(1.0*acc/count) # print(1.0*acc/count) with open( '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json', 'r') as f: class_to_readable = json.load(f) explanation_save_dir = os.path.join(params.save_dir, 'explanations', params.name) if not os.path.isdir(explanation_save_dir): os.makedirs(explanation_save_dir) for batch_idx, (x, y, p) in enumerate(data_loader): print(p) label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label( p, class_to_readable, model.n_query) x = x.cuda() support_label = torch.from_numpy( np.repeat(range(model.n_way), model.n_support)).unsqueeze(1) #torch.Size([25, 1]) support_label = torch.zeros(model.n_way * model.n_support, model.n_way).scatter(1, support_label, 1).view( model.n_way, model.n_support, model.n_way) support_label = torch.cat( [support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1) support_label = support_label.view(1, -1, model.n_way) support_label = support_label.cuda() #torch.Size([1, 30, 5]) x = x.contiguous() x = x.view(-1, *x.size()[2:]) #torch.Size([30, 3, 224, 224]) x_feature = feature_model(x) #torch.Size([30, 512]) x_fc_encoded = fc_encoder(x_feature) #torch.Size([30, 128]) z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1)) # (5,6,128) gnn_feature = [ torch.cat([ z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1] ], dim=1).view(1, -1, z.size(2)) for i in range(model.n_query) ] # model.n_query is the number of query images for each class # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features. # print(len(gnn_feature), gnn_feature[0].shape) gnn_nodes = torch.cat( [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0 ) # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero # perform gnn_net step by step # the first iteration print('x', gnn_nodes.shape) W_init = torch.eye( gnn_nodes.size(1), device=gnn_nodes.device ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze( 3 ) # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1) # print(W_init.shape) W1 = gnn_net._modules['layer_w{}'.format(0)]( gnn_nodes, W_init ) # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2) # print(Wi.shape) x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)]( [W1, gnn_nodes])[1]) # (num_querry, num_support + 1, num_outputs) # print(x_new1.shape) #torch.Size([1, 30, 48]) gnn_nodes_1 = torch.cat([gnn_nodes, x_new1], 2) # (concat more features) # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181]) # the second iteration W2 = gnn_net._modules['layer_w{}'.format(1)]( gnn_nodes_1, W_init ) # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2) x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)]( [W2, gnn_nodes_1])[1]) # (num_querry, num_support + 1, num_outputs) # print(x_new2.shape) gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2], 2) # (concat more features) # print('gn2', gnn_nodes_2.shape) #torch.Size([1, 30, 229]) Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init) # print(Wl.shape) #torch.Size([1, 30, 30, 2]) scores = gnn_net.layer_last( [Wl, gnn_nodes_2])[1] # (num_querry, num_support + 1, num_way) print(scores.shape) scores_sf = torch.softmax(scores, dim=-1) # print(scores_sf) gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf / (1 - scores_sf)) gnn_logits = gnn_logits.view(-1, model.n_way, model.n_support + n_query, model.n_way) # print(gnn_logits) query_scores = scores.view( model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0, 2).contiguous().view(-1, model.n_way) preds = query_scores.data.cpu().numpy().argmax(axis=-1) # print(preds.shape) for k in range(model.n_way): mask = torch.zeros(5).cuda() mask[k] = 1 gnn_logits_cls = gnn_logits.clone() gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask # print(gnn_logits_cls) # print(gnn_logits_cls.shape) gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way) relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls, gnn_net.layer_last, Wl, gnn_nodes_2) relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48) # relevance_gnn_nodes = relevance_gnn_nodes_2 relevance_gnn_nodes_1 = explain_Gconv( relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2, gnn_nodes_1) relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48) relevance_gnn_nodes = explain_Gconv( relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1, gnn_nodes) relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128) print(relevance_gnn_features.shape) relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128) relevance_gnn_features += relevance_gnn_nodes_2.narrow( -1, 0, 128) #[2, 30, 128] relevance_gnn_features = relevance_gnn_features.view( n_query, model.n_way, model.n_support + 1, 128) for i in range(n_query): query_i = relevance_gnn_features[i][:, model. n_support:model.n_support + 1] if i == 0: relevance_z = query_i else: relevance_z = torch.cat((relevance_z, query_i), 1) relevance_z = relevance_z.view(-1, 128) query_feature = x_feature.view(model.n_way, -1, 512)[:, model.n_support:] # print(query_feature.shape) query_feature = query_feature.contiguous() query_feature = query_feature.view(n_query * model.n_way, 512) # print(query_feature.shape) relevance_query_features = fc_encoder.compute_lrp( query_feature, target=relevance_z) # print(relevance_query_features.shape) # print(relevance_gnn_features.shape) # explain the fc layer and the image encoder query_images = x.view(model.n_way, -1, *x.size()[1:])[:, model.n_support:] query_images = query_images.contiguous() query_images = query_images.view(-1, *x.size()[1:]).detach() # print(query_images.shape) lrp_wrapper.add_lrp(feature_model, lrp_preset) relevance_query_images = feature_model.compute_lrp( query_images, target=relevance_query_features) print(relevance_query_images.shape) for j in range(n_query * model.n_way): predict_class = label_to_readableclass[preds[j]] true_class = query_gt_class[int(j % model.n_way)][int( j // model.n_way)] explain_class = label_to_readableclass[k] img_name = query_img_path[int(j % model.n_way)][int( j // model.n_way)].split('/')[-1] if not os.path.isdir( os.path.join(explanation_save_dir, 'episode' + str(batch_idx), img_name.strip('.jpg'))): os.makedirs( os.path.join(explanation_save_dir, 'episode' + str(batch_idx), img_name.strip('.jpg'))) save_path = os.path.join(explanation_save_dir, 'episode' + str(batch_idx), img_name.strip('.jpg')) if not os.path.exists( os.path.join( save_path, true_class + '_' + predict_class + img_name)): original_img = Image.fromarray( np.uint8( project(query_images[j].permute( 1, 2, 0).detach().cpu().numpy()))) original_img.save( os.path.join( save_path, true_class + '_' + predict_class + img_name)) img_relevance = relevance_query_images.narrow(0, j, 1) print(predict_class, true_class, explain_class) # assert relevance_querry_cls[j].sum() != 0 # assert img_relevance.sum()!=0 hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy() hm = LRPutil.gamma(hm) hm = LRPutil.heatmap(hm)[0] hm = project(hm) hp_img = Image.fromarray(np.uint8(hm)) hp_img.save( os.path.join( save_path, true_class + '_' + explain_class + '_lrp_hm.jpg')) break
def explain_relationnet(): # print(sys.path) params = options.parse_args('test') feature_model = backbone.model_dict['ResNet10'] params.method = 'relationnet' params.dataset = 'miniImagenet' # name relationnet --testset miniImagenet params.name = 'relationnet' params.testset = 'miniImagenet' params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/' params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output' if 'Conv' in params.model: image_size = 84 else: image_size = 224 split = params.split n_query = 1 loadfile = os.path.join(params.data_dir, params.testset, split + '.json') few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) data_datamgr = SetDataManager(image_size, n_query=n_query, **few_shot_params) data_loader = data_datamgr.get_data_loader(loadfile, aug=False) acc_all = [] iter_num = 1000 # model print(' build metric-based model') if params.method == 'protonet': model = ProtoNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'gnnnet': model = GnnNet(backbone.model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP else: feature_model = backbone.model_dict[params.model] loss_type = 'LRPmse' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) else: raise ValueError('Unknown method') checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name) # print(checkpoint_dir) if params.save_epoch != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_epoch) else: modelfile = get_best_file(checkpoint_dir) # print(modelfile) if modelfile is not None: tmp = torch.load(modelfile) try: model.load_state_dict(tmp['state']) except RuntimeError: print('warning! RuntimeError when load_state_dict()!') model.load_state_dict(tmp['state'], strict=False) except KeyError: for k in tmp['model_state']: ##### revise latter if 'running' in k: tmp['model_state'][k] = tmp['model_state'][k].squeeze() model.load_state_dict(tmp['model_state'], strict=False) except: raise model = model.cuda() model.eval() model.n_query = n_query # ---test the accuracy on the test set to verify the model is loaded---- acc = 0 count = 0 # for i, (x, y) in enumerate(data_loader): # scores = model.set_forward(x) # pred = scores.data.cpu().numpy().argmax(axis=1) # y = np.repeat(range(model.n_way), n_query) # acc += np.sum(pred == y) # count += len(y) # # print(1.0*acc/count) # print(1.0*acc/count) preset = lrp_presets.SequentialPresetA() feature_model = copy.deepcopy(model.feature) lrp_wrapper.add_lrp(feature_model, preset=preset) relation_model = copy.deepcopy(model.relation_module) # print(relation_model) lrp_wrapper.add_lrp(relation_model, preset=preset) with open( '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json', 'r') as f: class_to_readable = json.load(f) explanation_save_dir = os.path.join(params.save_dir, 'explanations', params.name) if not os.path.isdir(explanation_save_dir): os.makedirs(explanation_save_dir) for i, (x, y, p) in enumerate(data_loader): '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size y is the global labels of the images with shape as (n_way, n_support + n_query) p is the image path as a list of tuples, length is n_query+n_support, each tuple element is with length n_way''' if i >= 3: break label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label( p, class_to_readable, model.n_query) z_support, z_query = model.parse_feature(x, is_feature=False) z_support = z_support.contiguous() z_proto = z_support.view(model.n_way, model.n_support, *model.feat_dim).mean(1) # print(z_proto.shape) z_query = z_query.contiguous().view(model.n_way * model.n_query, *model.feat_dim) # print(z_query.shape) # get relations with metric function z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way, 1, 1, 1, 1) # print(z_proto_ext.shape) z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1) z_query_ext = torch.transpose(z_query_ext, 0, 1) # print(z_query_ext.shape) extend_final_feat_dim = model.feat_dim.copy() extend_final_feat_dim[0] *= 2 relation_pairs = torch.cat((z_proto_ext, z_query_ext), 2).view(-1, *extend_final_feat_dim) # print(relation_pairs.shape) relations = relation_model(relation_pairs) # print(relations) scores = relations.view(-1, model.n_way) preds = scores.data.cpu().numpy().argmax(axis=1) # print(preds.shape) relations = relations.view(-1, model.n_way) # print(relations) relations_sf = torch.softmax(relations, dim=-1) # print(relations_sf) relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf / (1 - relations_sf)) # print(relations_logits) # print(preds) relations_logits = relations_logits.view(-1, 1) relevance_relations = relation_model.compute_lrp( relation_pairs, target=relations_logits) # print(relevance_relations.shape) # print(model.feat_dim) relevance_z_query = torch.narrow(relevance_relations, 1, model.feat_dim[0], model.feat_dim[0]) # print(relevance_z_query.shape) relevance_z_query = relevance_z_query.view( model.n_query * model.n_way, model.n_way, *relevance_z_query.size()[1:]) # print(relevance_z_query.shape) query_img = x.narrow(1, model.n_support, model.n_query).view(model.n_way * model.n_query, *x.size()[2:]) # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:]) # print(query_img.shape) for k in range(model.n_way): relevance_querry_cls = torch.narrow(relevance_z_query, 1, k, 1).squeeze(1) # print(relevance_querry_cls.shape) relevance_querry_img = feature_model.compute_lrp( query_img.cuda(), target=relevance_querry_cls) # print(relevance_querry_img.max(), relevance_querry_img.min()) # print(relevance_querry_img.shape) for j in range(model.n_query * model.n_way): predict_class = label_to_readableclass[preds[j]] true_class = query_gt_class[int(j % model.n_way)][int( j // model.n_way)] explain_class = label_to_readableclass[k] img_name = query_img_path[int(j % model.n_way)][int( j // model.n_way)].split('/')[-1] if not os.path.isdir( os.path.join(explanation_save_dir, 'episode' + str(i), img_name.strip('.jpg'))): os.makedirs( os.path.join(explanation_save_dir, 'episode' + str(i), img_name.strip('.jpg'))) save_path = os.path.join(explanation_save_dir, 'episode' + str(i), img_name.strip('.jpg')) if not os.path.exists( os.path.join( save_path, true_class + '_' + predict_class + img_name)): original_img = Image.fromarray( np.uint8( project(query_img[j].permute(1, 2, 0).cpu().numpy()))) original_img.save( os.path.join( save_path, true_class + '_' + predict_class + img_name)) img_relevance = relevance_querry_img.narrow(0, j, 1) print(predict_class, true_class, explain_class) # assert relevance_querry_cls[j].sum() != 0 # assert img_relevance.sum()!=0 hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy() hm = LRPutil.gamma(hm) hm = LRPutil.heatmap(hm)[0] hm = project(hm) hp_img = Image.fromarray(np.uint8(hm)) hp_img.save( os.path.join( save_path, true_class + '_' + explain_class + '_lrp_hm.jpg'))
if params.method == 'manifold_mixup': model = wrn_mixup_model.wrn28_10(64) elif params.method == 'S2M2_R': model = ProtoNet(model_dict[params.model], params.train_n_way, params.n_shot) elif params.method == 'rotation': model = BaselineTrain(model_dict[params.model], 64, loss_type='dist') if params.method == 'S2M2_R': if use_gpu: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, device_ids=range( torch.cuda.device_count())) model.cuda() if params.resume: resume_file = get_resume_file(params.checkpoint_dir) print("resume_file", resume_file) tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 print("restored epoch is", tmp['epoch']) state = tmp['state'] model.load_state_dict(state) else: resume_file = get_resume_file(params.checkpoint_dir) # resume_file = './checkpoints/cifar/ProtoNet_from_S2M2_R_SGD_lr0.0001/best_model.tar' print("resume_file", resume_file) tmp = torch.load(resume_file)