delimiter='\t') sys.exit(0) split = args.split if split == 'attributes' and args.method != 'protonet': raise NotImplementedError if args.method in ['maml', 'maml_approx' ]: #maml do not support testing with feature if 'Conv' in args.model: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_episode=iter_num, n_query=15, **few_shot_args, args=args) if args.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split + '.json' else: loadfile = configs.data_dir[args.dataset] + split + '.json' novel_loader = datamgr.get_data_loader(loadfile, aug=False) if args.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval()
if params.method in ['baseline++', 'S2M2_R', 'rotation']: if params.dct_status: base_datamgr = SimpleDataManager(image_size_dct, batch_size=params.batch_size) base_loader = base_datamgr.get_data_loader_dct( base_file, aug=params.train_aug, filter_size=params.filter_size) base_datamgr_test = SimpleDataManager( image_size_dct, batch_size=params.test_batch_size) base_loader_test = base_datamgr_test.get_data_loader_dct( base_file, aug=False, filter_size=params.filter_size) test_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size_dct, n_query=15, **test_few_shot_params) val_loader = val_datamgr.get_data_loader_dct( val_file, aug=False, filter_size=params.filter_size) else: base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) base_datamgr_test = SimpleDataManager( image_size, batch_size=params.test_batch_size) base_loader_test = base_datamgr_test.get_data_loader(base_file, aug=False) test_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size,
def main_train(params): _set_seed(params) results_logger = ResultsLogger(params) if params.dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'all.json' val_file = configs.data_dir['CUB'] + 'val.json' elif params.dataset == 'cross_char': base_file = configs.data_dir['omniglot'] + 'noLatin.json' val_file = configs.data_dir['emnist'] + 'val.json' else: base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + 'val.json' if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' optimization = 'Adam' if params.stop_epoch == -1: if params.method in ['baseline', 'baseline++']: if params.dataset in ['omniglot', 'cross_char']: params.stop_epoch = 5 elif params.dataset in ['CUB']: params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting elif params.dataset in ['miniImagenet', 'cross']: params.stop_epoch = 400 else: params.stop_epoch = 400 # default else: # meta-learning methods if params.n_shot == 1: params.stop_epoch = 600 elif params.n_shot == 5: params.stop_epoch = 400 else: params.stop_epoch = 600 # default if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'DKT', 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) # n_eposide=100 base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor if (params.method == 'DKT'): model = DKT(model_dict[params.model], **train_few_shot_params) model.init_summary() elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in [ 'omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task # maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: # We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, results_logger) results_logger.save()
rotation=params.rotation, isAircraft=isAircraft, grey=params.grey, shuffle=False) if params.dataset_unlabel is not None: base_loader_u = base_datamgr_u.get_data_loader( base_file_unlabel, aug=params.train_aug) else: base_loader_u = base_datamgr_u.get_data_loader( base_file, aug=params.train_aug) train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation) base_datamgr_l = SetDataManager(image_size, n_query=n_query, **train_few_shot_params, isAircraft=isAircraft, grey=params.grey) base_loader_l = base_datamgr_l.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params, isAircraft=isAircraft, grey=params.grey) val_loader = val_datamgr.get_data_loader(val_file, aug=False) #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
params.name) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) # dataloader print('\n--- Prepare dataloader ---') print('\ttrain with seen domain {}'.format(params.dataset)) print('\tval with seen domain {}'.format(params.testset)) base_file = os.path.join(params.data_dir, params.dataset, 'base.json') val_file = os.path.join(params.data_dir, params.testset, 'val.json') # model image_size = 224 n_query = max(1, int(16 * params.test_n_way / params.train_n_way)) base_datamgr = SetDataManager(image_size, n_query=n_query, n_way=params.train_n_way, n_support=params.n_shot) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SetDataManager(image_size, n_query=n_query, n_way=params.test_n_way, n_support=params.n_shot) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.method == 'MatchingNet': model = MatchingNet(model_dict[params.model], n_way=params.train_n_way, n_support=params.n_shot).cuda() elif params.method == 'RelationNet': model = RelationNet(model_dict[params.model], n_way=params.train_n_way,
num_layers=args.rnn_num_layers, dropout=args.rnn_dropout, ) l3_model = l3_model.cuda() embedding_model = embedding_model.cuda() lang_model = lang_model.cuda() # if test_n_way is smaller than train_n_way, reduce n_query to keep batch # size small n_query = max(1, int(16 * args.test_n_way / args.train_n_way)) train_few_shot_args = dict(n_way=args.train_n_way, n_support=args.n_shot) base_datamgr = SetDataManager("CUB", 84, n_query=n_query, **train_few_shot_args, args=args) print("Loading train data") base_loader = base_datamgr.get_data_loader( base_file, aug=True, lang_dir=constants.LANG_DIR, normalize=True, vocab=vocab, # Maximum training data restrictions only apply at train time max_class=args.max_class, max_img_per_class=args.max_img_per_class, max_lang_per_class=args.max_lang_per_class, )
base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + 'val.json' params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) start_epoch = params.start_epoch stop_epoch = params.stop_epoch base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) base_datamgr_test = SimpleDataManager(image_size, batch_size=params.test_batch_size) base_loader_test = base_datamgr_test.get_data_loader(base_file, aug=False) test_few_shot_params = dict(n_way=5, n_support=1) val_datamgr = SetDataManager(image_size, n_query=15, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.method == 'manifold_mixup': print(params.num_classes) model = wrn_mixup_model.wrn28_10(params.num_classes) elif params.method == 'S2M2_R': model = wrn_mixup_model.wrn28_10(params.num_classes) elif params.method == 'rotation': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') if params.method == 'S2M2_R': if use_gpu:
def explain_gnnnet(): params = options.parse_args('test') feature_model = backbone.model_dict['ResNet10'] params.method = 'gnnnet' params.dataset = 'miniImagenet' # name relationnet --testset miniImagenet params.name = 'gnn' params.testset = 'miniImagenet' params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/' params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output' if 'Conv' in params.model: image_size = 84 else: image_size = 224 split = params.split n_query = 1 loadfile = os.path.join(params.data_dir, params.testset, split + '.json') few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) data_datamgr = SetDataManager(image_size, n_query=n_query, **few_shot_params) data_loader = data_datamgr.get_data_loader(loadfile, aug=False) # model print(' build metric-based model') if params.method == 'protonet': model = ProtoNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'gnnnet': model = GnnNet(backbone.model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP else: feature_model = backbone.model_dict[params.model] loss_type = 'LRP' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) else: raise ValueError('Unknown method') checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name) # print(checkpoint_dir) if params.save_epoch != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_epoch) else: modelfile = get_best_file(checkpoint_dir) # print(modelfile) if modelfile is not None: tmp = torch.load(modelfile) try: model.load_state_dict(tmp['state']) print('loaded pretrained model') except RuntimeError: print('warning! RuntimeError when load_state_dict()!') model.load_state_dict(tmp['state'], strict=False) except KeyError: for k in tmp['model_state']: ##### revise latter if 'running' in k: tmp['model_state'][k] = tmp['model_state'][k].squeeze() model.load_state_dict(tmp['model_state'], strict=False) except: raise model = model.cuda() model.eval() model.n_query = n_query # for module in model.modules(): # print(type(module)) lrp_preset = lrp_presets.SequentialPresetA() feature_model = model.feature fc_encoder = model.fc gnn_net = model.gnn lrp_wrapper.add_lrp(fc_encoder, lrp_preset) # lrp_wrapper.add_lrp(feature_model, lrp_preset) # lrp_wrapper.add_lrp(fc_encoder,lrp_preset) # lrp_wrapper.add_lrp(feature_model, lrp_preset) # acc = 0 # count = 0 # tested the forward pass is correct by observing the accuracy # for i, (x, _, _) in enumerate(data_loader): # x = x.cuda() # support_label = torch.from_numpy(np.repeat(range(model.n_way), model.n_support)).unsqueeze(1) # support_label = torch.zeros(model.n_way*model.n_support, model.n_way).scatter(1, support_label, 1).view(model.n_way, model.n_support, model.n_way) # support_label = torch.cat([support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1) # support_label = support_label.view(1, -1, model.n_way) # support_label = support_label.cuda() # x = x.view(-1, *x.size()[2:]) # # x_feature = feature_model(x) # x_fc_encoded = fc_encoder(x_feature) # z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1)) # gnn_feature = [ # torch.cat([z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) # for i in range(model.n_query)] # gnn_nodes = torch.cat([torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0) # scores = gnn_net(gnn_nodes) # scores = scores.view(model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0, # 2).contiguous().view( # -1, model.n_way) # pred = scores.data.cpu().numpy().argmax(axis=1) # y = np.repeat(range(model.n_way), n_query) # acc += np.sum(pred == y) # count += len(y) # # print(1.0*acc/count) # print(1.0*acc/count) with open( '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json', 'r') as f: class_to_readable = json.load(f) explanation_save_dir = os.path.join(params.save_dir, 'explanations', params.name) if not os.path.isdir(explanation_save_dir): os.makedirs(explanation_save_dir) for batch_idx, (x, y, p) in enumerate(data_loader): print(p) label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label( p, class_to_readable, model.n_query) x = x.cuda() support_label = torch.from_numpy( np.repeat(range(model.n_way), model.n_support)).unsqueeze(1) #torch.Size([25, 1]) support_label = torch.zeros(model.n_way * model.n_support, model.n_way).scatter(1, support_label, 1).view( model.n_way, model.n_support, model.n_way) support_label = torch.cat( [support_label, torch.zeros(model.n_way, 1, model.n_way)], dim=1) support_label = support_label.view(1, -1, model.n_way) support_label = support_label.cuda() #torch.Size([1, 30, 5]) x = x.contiguous() x = x.view(-1, *x.size()[2:]) #torch.Size([30, 3, 224, 224]) x_feature = feature_model(x) #torch.Size([30, 512]) x_fc_encoded = fc_encoder(x_feature) #torch.Size([30, 128]) z = x_fc_encoded.view(model.n_way, -1, x_fc_encoded.size(1)) # (5,6,128) gnn_feature = [ torch.cat([ z[:, :model.n_support], z[:, model.n_support + i:model.n_support + i + 1] ], dim=1).view(1, -1, z.size(2)) for i in range(model.n_query) ] # model.n_query is the number of query images for each class # gnn_feature is grouped into n_query groups. each group contains the support image features concatenated with one query image features. # print(len(gnn_feature), gnn_feature[0].shape) gnn_nodes = torch.cat( [torch.cat([z, support_label], dim=2) for z in gnn_feature], dim=0 ) # the features are concatenated with the one hot label. for the unknow image the one hot label is all zero # perform gnn_net step by step # the first iteration print('x', gnn_nodes.shape) W_init = torch.eye( gnn_nodes.size(1), device=gnn_nodes.device ).unsqueeze(0).repeat(gnn_nodes.size(0), 1, 1).unsqueeze( 3 ) # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 1) # print(W_init.shape) W1 = gnn_net._modules['layer_w{}'.format(0)]( gnn_nodes, W_init ) # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2) # print(Wi.shape) x_new1 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(0)]( [W1, gnn_nodes])[1]) # (num_querry, num_support + 1, num_outputs) # print(x_new1.shape) #torch.Size([1, 30, 48]) gnn_nodes_1 = torch.cat([gnn_nodes, x_new1], 2) # (concat more features) # print('gn1',gnn_nodes_1.shape) #torch.Size([1, 30, 181]) # the second iteration W2 = gnn_net._modules['layer_w{}'.format(1)]( gnn_nodes_1, W_init ) # (n_querry, n_way*(num_support + 1), n_way*(num_support + 1), 2) x_new2 = F.leaky_relu(gnn_net._modules['layer_l{}'.format(1)]( [W2, gnn_nodes_1])[1]) # (num_querry, num_support + 1, num_outputs) # print(x_new2.shape) gnn_nodes_2 = torch.cat([gnn_nodes_1, x_new2], 2) # (concat more features) # print('gn2', gnn_nodes_2.shape) #torch.Size([1, 30, 229]) Wl = gnn_net.w_comp_last(gnn_nodes_2, W_init) # print(Wl.shape) #torch.Size([1, 30, 30, 2]) scores = gnn_net.layer_last( [Wl, gnn_nodes_2])[1] # (num_querry, num_support + 1, num_way) print(scores.shape) scores_sf = torch.softmax(scores, dim=-1) # print(scores_sf) gnn_logits = torch.log(LRPutil.LOGIT_BETA * scores_sf / (1 - scores_sf)) gnn_logits = gnn_logits.view(-1, model.n_way, model.n_support + n_query, model.n_way) # print(gnn_logits) query_scores = scores.view( model.n_query, model.n_way, model.n_support + 1, model.n_way)[:, :, -1].permute(1, 0, 2).contiguous().view(-1, model.n_way) preds = query_scores.data.cpu().numpy().argmax(axis=-1) # print(preds.shape) for k in range(model.n_way): mask = torch.zeros(5).cuda() mask[k] = 1 gnn_logits_cls = gnn_logits.clone() gnn_logits_cls[:, :, -1] = gnn_logits_cls[:, :, -1] * mask # print(gnn_logits_cls) # print(gnn_logits_cls.shape) gnn_logits_cls = gnn_logits_cls.view(-1, model.n_way) relevance_gnn_nodes_2 = explain_Gconv(gnn_logits_cls, gnn_net.layer_last, Wl, gnn_nodes_2) relevance_x_new2 = relevance_gnn_nodes_2.narrow(-1, 181, 48) # relevance_gnn_nodes = relevance_gnn_nodes_2 relevance_gnn_nodes_1 = explain_Gconv( relevance_x_new2, gnn_net._modules['layer_l{}'.format(1)], W2, gnn_nodes_1) relevance_x_new1 = relevance_gnn_nodes_1.narrow(-1, 133, 48) relevance_gnn_nodes = explain_Gconv( relevance_x_new1, gnn_net._modules['layer_l{}'.format(0)], W1, gnn_nodes) relevance_gnn_features = relevance_gnn_nodes.narrow(-1, 0, 128) print(relevance_gnn_features.shape) relevance_gnn_features += relevance_gnn_nodes_1.narrow(-1, 0, 128) relevance_gnn_features += relevance_gnn_nodes_2.narrow( -1, 0, 128) #[2, 30, 128] relevance_gnn_features = relevance_gnn_features.view( n_query, model.n_way, model.n_support + 1, 128) for i in range(n_query): query_i = relevance_gnn_features[i][:, model. n_support:model.n_support + 1] if i == 0: relevance_z = query_i else: relevance_z = torch.cat((relevance_z, query_i), 1) relevance_z = relevance_z.view(-1, 128) query_feature = x_feature.view(model.n_way, -1, 512)[:, model.n_support:] # print(query_feature.shape) query_feature = query_feature.contiguous() query_feature = query_feature.view(n_query * model.n_way, 512) # print(query_feature.shape) relevance_query_features = fc_encoder.compute_lrp( query_feature, target=relevance_z) # print(relevance_query_features.shape) # print(relevance_gnn_features.shape) # explain the fc layer and the image encoder query_images = x.view(model.n_way, -1, *x.size()[1:])[:, model.n_support:] query_images = query_images.contiguous() query_images = query_images.view(-1, *x.size()[1:]).detach() # print(query_images.shape) lrp_wrapper.add_lrp(feature_model, lrp_preset) relevance_query_images = feature_model.compute_lrp( query_images, target=relevance_query_features) print(relevance_query_images.shape) for j in range(n_query * model.n_way): predict_class = label_to_readableclass[preds[j]] true_class = query_gt_class[int(j % model.n_way)][int( j // model.n_way)] explain_class = label_to_readableclass[k] img_name = query_img_path[int(j % model.n_way)][int( j // model.n_way)].split('/')[-1] if not os.path.isdir( os.path.join(explanation_save_dir, 'episode' + str(batch_idx), img_name.strip('.jpg'))): os.makedirs( os.path.join(explanation_save_dir, 'episode' + str(batch_idx), img_name.strip('.jpg'))) save_path = os.path.join(explanation_save_dir, 'episode' + str(batch_idx), img_name.strip('.jpg')) if not os.path.exists( os.path.join( save_path, true_class + '_' + predict_class + img_name)): original_img = Image.fromarray( np.uint8( project(query_images[j].permute( 1, 2, 0).detach().cpu().numpy()))) original_img.save( os.path.join( save_path, true_class + '_' + predict_class + img_name)) img_relevance = relevance_query_images.narrow(0, j, 1) print(predict_class, true_class, explain_class) # assert relevance_querry_cls[j].sum() != 0 # assert img_relevance.sum()!=0 hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy() hm = LRPutil.gamma(hm) hm = LRPutil.heatmap(hm)[0] hm = project(hm) hp_img = Image.fromarray(np.uint8(hm)) hp_img.save( os.path.join( save_path, true_class + '_' + explain_class + '_lrp_hm.jpg')) break
top1_correct = np.sum(topk_ind[:, 0] == yq) acc = top1_correct * 100. / (n_way * n_query) acc_all.append(acc) print('Task %d : %4.2f%%' % (ti, acc)) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('Test Acc = %4.2f +- %4.2f%%' % (acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) if __name__ == '__main__': np.random.seed(10) params = parse_args() image_size = 224 iter_num = 2000 n_query = 16 print('Loading target dataset!') novel_file = os.path.join(params.data_dir, params.dataset, 'novel.json') datamgr = SetDataManager(image_size, n_query=n_query, n_way=params.test_n_way, n_support=params.n_shot, n_eposide=iter_num) novel_loader = datamgr.get_data_loader(novel_file, aug=False) evaluate(novel_loader, n_way=params.test_n_way, n_support=params.n_shot)
pred = scores.data.cpu().numpy().argmax(axis=1) y = np.repeat(range(n_way), n_query) acc = np.mean(pred == y) * 100 return acc if __name__ == '__main__': params = parse_args('test') acc_all = [] iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) datamgr = SetDataManager(n_eposide=iter_num, n_query=15, **few_shot_params) novel_loader = datamgr.get_data_loader(root='./filelists/tabula_muris', mode='test') x_dim = novel_loader.dataset.get_dim() go_mask = novel_loader.dataset.go_mask if params.method == 'baseline': model = BaselineFinetune(backbone.FCNet(x_dim), **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(backbone.FCNet(x_dim), loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(backbone.FCNet(x_dim), **few_shot_params) elif params.method == 'comet':
def explain_relationnet(): # print(sys.path) params = options.parse_args('test') feature_model = backbone.model_dict['ResNet10'] params.method = 'relationnet' params.dataset = 'miniImagenet' # name relationnet --testset miniImagenet params.name = 'relationnet' params.testset = 'miniImagenet' params.data_dir = '/home/sunjiamei/work/fewshotlearning/dataset/' params.save_dir = '/home/sunjiamei/work/fewshotlearning/CrossDomainFewShot-master/output' if 'Conv' in params.model: image_size = 84 else: image_size = 224 split = params.split n_query = 1 loadfile = os.path.join(params.data_dir, params.testset, split + '.json') few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) data_datamgr = SetDataManager(image_size, n_query=n_query, **few_shot_params) data_loader = data_datamgr.get_data_loader(loadfile, aug=False) acc_all = [] iter_num = 1000 # model print(' build metric-based model') if params.method == 'protonet': model = ProtoNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(backbone.model_dict[params.model], **few_shot_params) elif params.method == 'gnnnet': model = GnnNet(backbone.model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP else: feature_model = backbone.model_dict[params.model] loss_type = 'LRPmse' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) else: raise ValueError('Unknown method') checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name) # print(checkpoint_dir) if params.save_epoch != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_epoch) else: modelfile = get_best_file(checkpoint_dir) # print(modelfile) if modelfile is not None: tmp = torch.load(modelfile) try: model.load_state_dict(tmp['state']) except RuntimeError: print('warning! RuntimeError when load_state_dict()!') model.load_state_dict(tmp['state'], strict=False) except KeyError: for k in tmp['model_state']: ##### revise latter if 'running' in k: tmp['model_state'][k] = tmp['model_state'][k].squeeze() model.load_state_dict(tmp['model_state'], strict=False) except: raise model = model.cuda() model.eval() model.n_query = n_query # ---test the accuracy on the test set to verify the model is loaded---- acc = 0 count = 0 # for i, (x, y) in enumerate(data_loader): # scores = model.set_forward(x) # pred = scores.data.cpu().numpy().argmax(axis=1) # y = np.repeat(range(model.n_way), n_query) # acc += np.sum(pred == y) # count += len(y) # # print(1.0*acc/count) # print(1.0*acc/count) preset = lrp_presets.SequentialPresetA() feature_model = copy.deepcopy(model.feature) lrp_wrapper.add_lrp(feature_model, preset=preset) relation_model = copy.deepcopy(model.relation_module) # print(relation_model) lrp_wrapper.add_lrp(relation_model, preset=preset) with open( '/home/sunjiamei/work/fewshotlearning/dataset/miniImagenet/class_to_readablelabel.json', 'r') as f: class_to_readable = json.load(f) explanation_save_dir = os.path.join(params.save_dir, 'explanations', params.name) if not os.path.isdir(explanation_save_dir): os.makedirs(explanation_save_dir) for i, (x, y, p) in enumerate(data_loader): '''x is the images with shape as n_way, n_support + n_querry, 3, img_size, img_size y is the global labels of the images with shape as (n_way, n_support + n_query) p is the image path as a list of tuples, length is n_query+n_support, each tuple element is with length n_way''' if i >= 3: break label_to_readableclass, query_img_path, query_gt_class = LRPutil.get_class_label( p, class_to_readable, model.n_query) z_support, z_query = model.parse_feature(x, is_feature=False) z_support = z_support.contiguous() z_proto = z_support.view(model.n_way, model.n_support, *model.feat_dim).mean(1) # print(z_proto.shape) z_query = z_query.contiguous().view(model.n_way * model.n_query, *model.feat_dim) # print(z_query.shape) # get relations with metric function z_proto_ext = z_proto.unsqueeze(0).repeat(model.n_query * model.n_way, 1, 1, 1, 1) # print(z_proto_ext.shape) z_query_ext = z_query.unsqueeze(0).repeat(model.n_way, 1, 1, 1, 1) z_query_ext = torch.transpose(z_query_ext, 0, 1) # print(z_query_ext.shape) extend_final_feat_dim = model.feat_dim.copy() extend_final_feat_dim[0] *= 2 relation_pairs = torch.cat((z_proto_ext, z_query_ext), 2).view(-1, *extend_final_feat_dim) # print(relation_pairs.shape) relations = relation_model(relation_pairs) # print(relations) scores = relations.view(-1, model.n_way) preds = scores.data.cpu().numpy().argmax(axis=1) # print(preds.shape) relations = relations.view(-1, model.n_way) # print(relations) relations_sf = torch.softmax(relations, dim=-1) # print(relations_sf) relations_logits = torch.log(LRPutil.LOGIT_BETA * relations_sf / (1 - relations_sf)) # print(relations_logits) # print(preds) relations_logits = relations_logits.view(-1, 1) relevance_relations = relation_model.compute_lrp( relation_pairs, target=relations_logits) # print(relevance_relations.shape) # print(model.feat_dim) relevance_z_query = torch.narrow(relevance_relations, 1, model.feat_dim[0], model.feat_dim[0]) # print(relevance_z_query.shape) relevance_z_query = relevance_z_query.view( model.n_query * model.n_way, model.n_way, *relevance_z_query.size()[1:]) # print(relevance_z_query.shape) query_img = x.narrow(1, model.n_support, model.n_query).view(model.n_way * model.n_query, *x.size()[2:]) # query_img_copy = query_img.view(model.n_way, model.n_query, *x.size()[2:]) # print(query_img.shape) for k in range(model.n_way): relevance_querry_cls = torch.narrow(relevance_z_query, 1, k, 1).squeeze(1) # print(relevance_querry_cls.shape) relevance_querry_img = feature_model.compute_lrp( query_img.cuda(), target=relevance_querry_cls) # print(relevance_querry_img.max(), relevance_querry_img.min()) # print(relevance_querry_img.shape) for j in range(model.n_query * model.n_way): predict_class = label_to_readableclass[preds[j]] true_class = query_gt_class[int(j % model.n_way)][int( j // model.n_way)] explain_class = label_to_readableclass[k] img_name = query_img_path[int(j % model.n_way)][int( j // model.n_way)].split('/')[-1] if not os.path.isdir( os.path.join(explanation_save_dir, 'episode' + str(i), img_name.strip('.jpg'))): os.makedirs( os.path.join(explanation_save_dir, 'episode' + str(i), img_name.strip('.jpg'))) save_path = os.path.join(explanation_save_dir, 'episode' + str(i), img_name.strip('.jpg')) if not os.path.exists( os.path.join( save_path, true_class + '_' + predict_class + img_name)): original_img = Image.fromarray( np.uint8( project(query_img[j].permute(1, 2, 0).cpu().numpy()))) original_img.save( os.path.join( save_path, true_class + '_' + predict_class + img_name)) img_relevance = relevance_querry_img.narrow(0, j, 1) print(predict_class, true_class, explain_class) # assert relevance_querry_cls[j].sum() != 0 # assert img_relevance.sum()!=0 hm = img_relevance.permute(0, 2, 3, 1).cpu().detach().numpy() hm = LRPutil.gamma(hm) hm = LRPutil.heatmap(hm)[0] hm = project(hm) hp_img = Image.fromarray(np.uint8(hm)) hp_img.save( os.path.join( save_path, true_class + '_' + explain_class + '_lrp_hm.jpg'))
split = 'novel' if params.save_iter != -1: split_str = split + "_" + str(params.save_iter) else: split_str = split iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) acc_all = [] model = load_weight_file_for_test(model, params) print_model_params(model, params) if params.method in ['maml', 'maml_approx']: datamgr = SetDataManager(params.image_size, n_eposide=iter_num, n_query=15, **few_shot_params, isAircraft=(params.dataset == 'aircrafts')) loadfile = os.path.join('filelists', params.test_dataset, 'novel.json') novel_loader = datamgr.get_data_loader(loadfile, aug=False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) params = parse_args('mytest') loadfile = os.path.join('filelists', params.test_dataset, 'novel.json') else: if "recognition36" in params.test_dataset: loadfile = os.path.join('filelists', params.test_dataset,
if params.n_shot == 1: params.stop_epoch = 600 elif params.n_shot == 5: params.stop_epoch = 400 else: params.stop_epoch = 600 if params.method in ['tcmaml', 'tcmaml_approx']: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) train_datamgr = SetDataManager( image_size, n_query=n_query, **train_few_shot_params ) # default number of episodes (tasks) is 100 per epoch train_loader = train_datamgr.get_data_loader(train_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.ResNet.maml = True
# elif params.method == 'baseline++': # model = BaselineTrain( model_dict[params.model], params.num_classes, \ # loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) if params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(params.n_query * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params, isAircraft=isAircraft) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) base_loader1 = copy.deepcopy(base_loader) images = torch.empty(0, 3, 224, 224) ### [total_images, 3, 224, 224] for i, inputs in enumerate(base_loader1): print(i) x = inputs[0] ### [5,21,3,224,224] x = x.view(105, *x.size()[2:]) ### [105,3,224,224] # print(x.size()) images = torch.cat([images, x], dim=0) print(len(images)) dataset = JigsawDataset(images)
print(' train with single seen domain {}'.format(params.dataset)) base_file = os.path.join(params.data_dir, params.dataset, 'base.json') val_file = os.path.join(params.data_dir, params.dataset, 'val.json') # model print('\n--- build model ---') if 'Conv' in params.model: image_size = 84 else: image_size = 224 if params.method in ['maml_baseline'] : print(' training the {} with backbone {}'.format(params.method, params.model)) n_query = 15 train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, n_eposide=100, **train_few_shot_params) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, n_eposide=100, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) # base_datamgr = SimpleDataManager(image_size, batch_size=16) # base_loader = base_datamgr.get_data_loader(base_file , aug=params.train_aug ) # val_datamgr = SimpleDataManager(image_size, batch_size=64) # val_loader = val_datamgr.get_data_loader(val_file, aug=False) else: raise ValueError('Unknown method') model = MAMLBaseline(params, tf_path=params.tf_dir) model.cuda()
def main(): timer = Timer() args, writer = init() train_file = args.dataset_dir + 'train.json' val_file = args.dataset_dir + 'val.json' few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query) n_episode = 10 if args.debug else 100 if args.method_type is Method_type.baseline: train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64) train_loader = train_datamgr.get_data_loader(aug = True) else: train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size, n_episode=n_episode, mode='train', **few_shot_params) train_loader = train_datamgr.get_data_loader(aug=True) val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size, n_episode=n_episode, mode='val', **few_shot_params) val_loader = val_datamgr.get_data_loader(aug=False) if args.model_type is Model_type.ConvNet: pass elif args.model_type is Model_type.ResNet12: from methods.backbone import ResNet12 encoder = ResNet12() else: raise ValueError('') if args.method_type is Method_type.baseline: from methods.baselinetrain import BaselineTrain model = BaselineTrain(encoder, args) elif args.method_type is Method_type.protonet: from methods.protonet import ProtoNet model = ProtoNet(encoder, args) else: raise ValueError('') from torch.optim import SGD,lr_scheduler if args.method_type is Method_type.baseline: optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1) else: optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) args.ngpu = torch.cuda.device_count() torch.backends.cudnn.benchmark = True model = model.cuda() label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query)) label = label.cuda() if args.test: test(model, label, args, few_shot_params) return if args.resume: resume_OK = resume_model(model, optimizer, args, scheduler) else: resume_OK = False if (not resume_OK) and (args.warmup is not None): load_pretrained_weights(model, args) if args.debug: args.max_epoch = args.start_epoch + 1 for epoch in range(args.start_epoch, args.max_epoch): train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch) scheduler.step() vl, va = val(model, args, val_loader, label) if writer is not None: writer.add_scalar('data/val_acc', float(va), epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va)) if va >= args.max_acc: args.max_acc = va print('saving the best model! acc={:.4f}'.format(va)) save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler) save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler) if epoch != 0: print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) if writer is not None: writer.close() test(model, label, args, few_shot_params)
assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: base_datamgr = SetDataManager(image_size, params.batchsize) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug, ifshuffle=True) #val_datamgr = SetDataManager(image_size, params.batchsize) #val_loader = val_datamgr.get_data_loader( val_file, aug = False,ifshuffle = False) #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor if params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4':
def get_logits_targets(params): acc_all = [] iter_num = 600 few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune( model_dict[params.model], **few_shot_params ) elif params.method == 'baseline++': model = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params ) elif params.method == 'protonet': model = ProtoNet( model_dict[params.model], **few_shot_params ) elif params.method == 'DKT': model = DKT(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet( model_dict[params.model], **few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model]( flatten = False ) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( feature_model, loss_type = loss_type , **few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params ) if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++'] : checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot) #modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++'] : if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir,params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) else: print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir)) split = params.split if params.save_iter != -1: split_str = split + "_" +str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split +'.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split +'.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader( loadfile, aug = False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() logits_list = list() targets_list = list() for i, (x,_) in enumerate(novel_loader): logits = model.get_logits(x).detach() targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda() logits_list.append(logits) #.cpu().detach().numpy()) targets_list.append(targets) #.cpu().detach().numpy()) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5") cl_data_file = feat_loader.init_loader(novel_file) logits_list = list() targets_list = list() n_query = 15 n_way = few_shot_params['n_way'] n_support = few_shot_params['n_support'] class_list = cl_data_file.keys() for i in range(iter_num): #---------------------- select_class = random.sample(class_list,n_way) z_all = [] for cl in select_class: img_feat = cl_data_file[cl] perm_ids = np.random.permutation(len(img_feat)).tolist() z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] ) # stack each batch z_all = torch.from_numpy(np.array(z_all)) model.n_query = n_query logits = model.set_forward(z_all, is_feature = True).detach() targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda() logits_list.append(logits) targets_list.append(targets) #---------------------- return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
vocab_srt = [v[0] for v in vocab_srt] with open(args.embeddings_file, "w") as fout: fout.write("\n".join(vocab_srt)) fout.write("\n") np.savetxt(args.embeddings_metadata, weights, fmt="%f", delimiter="\t") sys.exit(0) # Run the test loop for 600 iterations ITER_NUM = 600 N_QUERY = 15 test_datamgr = SetDataManager( "CUB", 84, n_query=N_QUERY, n_way=args.test_n_way, n_support=args.n_shot, n_episode=ITER_NUM, args=args, ) test_loader = test_datamgr.get_data_loader( os.path.join(constants.DATA_DIR, f"{args.split}.json"), aug=False, lang_dir=constants.LANG_DIR, normalize=False, vocab=vocab, ) normalizer = TransformLoader(84).get_normalize() model.eval()
def meta_train(self, config, method, descriptor_str, debug=True, use_test=False, require_pretrain=False, metric="acc"): config["meta_training"] = True params = self.params params.save_freq = 10 params.n_query = max(1, int(16 * params.test_n_way / params.train_n_way)) params.dataset = config["dataset"] params.model = config["model"] params.method = config["method"] params.n_shot = config["n_shot"] train_episodes = config["train_episodes"] val_episodes = config["val_episodes"] end_epoch = config["end_epoch"] if "weight_decay" in config: weight_decay = config["weight_decay"] else: weight_decay = 0 result_dir = "results/meta/%s" % (params.dataset) if not os.path.isdir(result_dir): os.makedirs(result_dir) result_file = os.path.join( result_dir, "%s_%s_%s.txt" % (params.method, params.model, descriptor_str)) self.few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot, n_query=self.params.n_query) params.checkpoint_dir = '%s/checkpoints/%s/%s/%s_%s_%s' % ( configs.save_dir, params.dataset, params.method, params.model, descriptor_str, params.n_shot) params.stop_epoch = 100 self.initialize(params, False) image_size = self.image_size pretrain = PretrainedModel(self.params) if use_test: file_name = "novel.json" else: file_name = "val.json" if params.dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'base.json' val_file = configs.data_dir['CUB'] + file_name else: base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + file_name n_query = max(1, int(16 * params.test_n_way / params.train_n_way)) train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params, n_eposide=train_episodes) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug, debug=debug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params, n_eposide=val_episodes) val_loader = val_datamgr.get_data_loader(val_file, aug=False, debug=debug) backbone = self.get_backbone() if "params" in config: model_params = config["params"] else: model_params = {} if require_pretrain: model_params["pretrain"] = pretrain model = method(backbone, **train_few_shot_params, **model_params) model = model.cuda() # Freeze backbone model.feature = None if not require_pretrain: model.pretrain = pretrain optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay) max_acc = 0 for epoch in range(0, end_epoch): model.epoch = epoch model.train() model.train_loop( epoch, base_loader, optimizer) # model are called by reference, no need to return model.eval() if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) acc = model.test_loop(val_loader, metric=metric) message = "Epoch: %d, Validation accuracy: %.3f, Best validation accuracy: %.3f" % ( epoch, acc, max_acc) print(message) append_to_file(result_file, message) if acc > max_acc: print("best model! save...") max_acc = acc outfile = os.path.join(params.checkpoint_dir, 'best_model.tar') torch.save({ 'epoch': epoch, 'state': model.state_dict() }, outfile) if (epoch % params.save_freq == 0) or (epoch == params.stop_epoch - 1): outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) torch.save({ 'epoch': epoch, 'state': model.state_dict() }, outfile) self.meta_test(config, method, descriptor_str, debug, require_pretrain)
#code from mvcnn, add logging later #parse_args # num_models = 1000 #max number of models to use per class, add this functionality later # n_models_train = num_models*num_views # if params.num_views and params.num_views >=5: # n_query = max(1, int(8* params.test_n_way/params.train_n_way)) #why is this required? # else: # n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #why is this required? train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=params.n_query, **train_few_shot_params, num_views=params.num_views) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=params.n_query, **test_few_shot_params, num_views=params.num_views) val_loader = val_datamgr.get_data_loader(val_file, aug=False) backbone = model_dict[params.model] model = ProtoNet(backbone, params.num_views, **train_few_shot_params) model = model.cuda() # model = torch.nn.DataParallel(model).cuda()
def meta_test(self, config, method, descriptor_str, debug=True, require_pretrain=False, metric="acc"): config["meta_training"] = True params = self.params params.save_freq = 50 params.n_query = max(1, int(16 * params.test_n_way / params.train_n_way)) params.dataset = config["dataset"] params.model = config["model"] params.method = config["method"] params.n_shot = config["n_shot"] self.few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot, n_query=self.params.n_query) params.checkpoint_dir = '%s/checkpoints/%s/%s/%s_%s_%s' % ( configs.save_dir, params.dataset, params.method, params.model, descriptor_str, params.n_shot) params.stop_epoch = 2000 self.initialize(params, False) image_size = self.image_size pretrain = PretrainedModel(self.params) if params.dataset == 'cross': test_file = configs.data_dir['CUB'] + 'novel.json' else: test_file = configs.data_dir[params.dataset] + 'novel.json' n_query = 15 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) datamgr = SetDataManager(image_size, n_query=n_query, **few_shot_params, n_eposide=params.stop_epoch) loader = datamgr.get_data_loader(test_file, aug=False, debug=debug) if "params" in config: model_params = config["params"] else: model_params = {} if require_pretrain: model_params["pretrain"] = pretrain backbone = self.get_backbone() model = method(backbone, **few_shot_params, **model_params) model = model.cuda() model_file = os.path.join(params.checkpoint_dir, "best_model.tar") # Load model state_dict = model.state_dict() saved_states = torch.load(model_file)["state"] state_dict.update(saved_states) model.load_state_dict(state_dict) # Freeze backbone model.feature = None model.pretrain = pretrain model.eval() acc = model.test_loop(loader, metric=metric) print(acc)
model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) if params.n_shot_test == -1: # modify val loader support params.n_shot_test = params.n_shot else: # modify target loader support train_few_shot_params['n_support'] = params.n_shot_test if params.adversarial or params.adaptFinetune: target_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) target_loader = target_datamgr.get_data_loader(novel_file, aug=False) # ipdb.set_trace() # bl, tl = iter(base_loader), iter(target_loader)
else: raise print('train_aug is wrong') print('Testing! {} shots on {} dataset with {} epochs of {}({})'.format( params.n_shot, params.testset, params.save_epoch, name, params.method)) # dataset print(' build dataset') if 'Conv' in params.model: image_size = 84 else: image_size = 224 split = params.split loadfile = os.path.join(params.data_dir, params.testset, split + '.json') test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=params.n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(loadfile, aug=False) datasets = params.dataset datasets.remove(params.testset) base_loaders = [ val_datamgr.get_data_loader(os.path.join(params.data_dir, dataset, 'base.json'), aug=False) for dataset in datasets ] print(' build feature encoder') # feature encoder checkpoint_dir = '%s/checkmodels/%s' % (params.save_dir, name)
if params.save_iter != -1: split_str = split + "_" + str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx' ]: #maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split + '.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split + '.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json'
image_size = 224 base_datamgr = SimpleDataManager(image_size, batch_size=batch_size) base_loader = base_datamgr.get_data_loader(base_file, aug=True) base_jigsaw_datamgr = JigsawDataManger( image_size, batch_size=batch_size, max_replace_block_num=params.jig_replace_num_train) base_jigsaw_loader = base_jigsaw_datamgr.get_data_loader(base_file, aug=False) extra_data = 15 # extra_unlabeled data val_datamgr = SetDataManager(image_size, n_way=params.test_n_way, n_support=params.n_shot, n_query=params.n_query + extra_data, n_eposide=50) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == "miniImagenet": num_class = 64 elif params.dataset == "tieredImagenet": num_class = 351 elif params.dataset == "caltech256": num_class = 257 elif params.dataset == "CUB": num_class = 200 # set to 200 since the label range 0~199 even though there are only 100 classes to be trained else: raise ValueError('Unknown dataset')
def single_test(params, results_logger): acc_all = [] iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'DKT': model = DKT(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) # modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++']: if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) else: print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir)) split = params.split if params.save_iter != -1: split_str = split + "_" + str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx', 'DKT']: # maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split + '.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split + '.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader(loadfile, aug=False) if params.adaptation: model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints", "features"), split_str + ".hdf5" ) # defaut split = novel, but you can also test base or val classes cl_data_file = feat_loader.init_loader(novel_file) for i in range(iter_num): acc = feature_evaluation(cl_data_file, model, n_query=15, adaptation=params.adaptation, **few_shot_params) acc_all.append(acc) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) with open('record/results.txt', 'a') as f: timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) aug_str = '-aug' if params.train_aug else '' aug_str += '-adapted' if params.adaptation else '' if params.method in ['baseline', 'baseline++']: exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.test_n_way) else: exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.train_n_way, params.test_n_way) acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % ( iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)) f.write('Time: %s, Setting: %s, Acc: %s \n' % (timestamp, exp_setting, acc_str)) results_logger.log("single_test_acc", acc_mean) results_logger.log("single_test_acc_std", 1.96 * acc_std / np.sqrt(iter_num)) results_logger.log("time", timestamp) results_logger.log("exp_setting", exp_setting) results_logger.log("acc_str", acc_str) return acc_mean
def get_train_val_loader(params, source_val): # to prevent circular import from data.datamgr import SimpleDataManager, SetDataManager, AugSetDataManager, VAESetDataManager image_size = get_img_size(params) base_file, val_file = get_train_val_filename(params) if source_val: source_val_file = get_source_val_filename(params) if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) # val_datamgr = SimpleDataManager(image_size, batch_size = 64) # val_loader = val_datamgr.get_data_loader( val_file, aug = False) # to do fine-tune when validation n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small val_few_shot_params = get_few_shot_params(params, 'val') val_datamgr = SetDataManager(image_size, n_query=n_query, **val_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if source_val: source_val_datamgr = SetDataManager(image_size, n_query=n_query, **val_few_shot_params) source_val_loader = val_datamgr.get_data_loader(source_val_file, aug=False) elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small # train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) # val_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) train_few_shot_params = get_few_shot_params(params, 'train') val_few_shot_params = get_few_shot_params(params, 'val') if params.vaegan_exp is not None: # TODO is_training = False vaegan = restore_vaegan(params.dataset, params.vaegan_exp, params.vaegan_step, is_training=is_training) base_datamgr = VAESetDataManager( image_size, n_query=n_query, vaegan_exp=params.vaegan_exp, vaegan_step=params.vaegan_step, vaegan_is_train=params.vaegan_is_train, lambda_zlogvar=params.zvar_lambda, fake_prob=params.fake_prob, **train_few_shot_params) # train_val or val??? val_datamgr = SetDataManager(image_size, n_query=n_query, **val_few_shot_params) elif params.aug_target is None: # Common Case assert params.aug_type is None base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) val_datamgr = SetDataManager(image_size, n_query=n_query, **val_few_shot_params) if source_val: source_val_datamgr = SetDataManager(image_size, n_query=n_query, **val_few_shot_params) else: aug_type = params.aug_type assert aug_type is not None base_datamgr = AugSetDataManager(image_size, n_query=n_query, aug_type=aug_type, aug_target=params.aug_target, **train_few_shot_params) val_datamgr = AugSetDataManager(image_size, n_query=n_query, aug_type=aug_type, aug_target='test-sample', **val_few_shot_params) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if source_val: source_val_loader = val_datamgr.get_data_loader(source_val_file, aug=False) #a batch for SetDataManager: a [n_way, n_support + n_query, n_channel, w, h] tensor else: raise ValueError('Unknown method') if source_val: return base_loader, val_loader, source_val_loader else: return base_loader, val_loader
if params.method in ['baseline', 'baseline++'] : base_datamgr = SimpleDataManager(image_size, batch_size = 16) base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) val_datamgr = SimpleDataManager(image_size, batch_size = 64) val_loader = val_datamgr.get_data_loader( val_file, aug = False) if params.method == 'baseline': model = BaselineTrain( model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist') elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']: n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader( val_file, aug = False) #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor if params.method == 'protonet': model = ProtoNet( model_dict[params.model], **train_few_shot_params ) elif params.method == 'matchingnet': model = MatchingNet( model_dict[params.model], **train_few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6':
def exp_test(params, n_episodes, should_del_features=False):#, show_data=False): start_time = datetime.datetime.now() print('exp_test() started at',start_time) set_random_seed(0) # successfully reproduce "normal" testing. if params.gpu_id: set_gpu_id(params.gpu_id) # acc_all = [] model = get_model(params, 'test') ########## get settings ########## n_shot = params.test_n_shot if params.test_n_shot is not None else params.n_shot few_shot_params = dict(n_way = params.test_n_way , n_support = n_shot) if params.gpu_id: model = model.cuda() else: model = to_device(model) checkpoint_dir = get_checkpoint_dir(params) print('loading from:',checkpoint_dir) if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) ########## load model ########## if modelfile is not None: if params.gpu_id is None: tmp = torch.load(modelfile) else: # TODO: figure out WTF is going on here print('params.gpu_id =', params.gpu_id) map_location = 'cuda:0' # gpu_str = 'cuda:' + '0'#str(params.gpu_id) # map_location = {'cuda:1':gpu_str, 'cuda:0':gpu_str} # see here: https://hackmd.io/koKAo6kURn2YBqjoXXDhaw#RuntimeError-CUDA-error-invalid-device-ordinal tmp = torch.load(modelfile, map_location=map_location) # tmp = torch.load(modelfile) if not params.method in ['baseline', 'baseline++'] : # if 'baseline' or 'baseline++' then NO NEED to load model !!! model.load_state_dict(tmp['state']) print('Model successfully loaded.') else: print('No need to load model for baseline/baseline++ when testing.') load_epoch = int(tmp['epoch']) ########## testing ########## if params.method in ['maml', 'maml_approx']: #maml do not support testing with feature image_size = get_img_size(params) load_file = get_loadfile_path(params, params.split) datamgr = SetDataManager(image_size, n_episode = n_episodes, n_query = 15 , **few_shot_params) novel_loader = datamgr.get_data_loader( loadfile, aug = False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop( novel_loader, return_std = True) ########## last record and post-process ########## torch.cuda.empty_cache() timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) # TODO afterward: compute this acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes)) # writing settings into csv acc_mean_str = '%4.2f' % (acc_mean) acc_std_str = '%4.2f' %(acc_std) # record beyond params extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch} if should_del_features: del_features(params) end_time = datetime.datetime.now() print('exp_test() start at', start_time, ', end at', end_time, '.\n') print('exp_test() totally took:', end_time-start_time) return extra_record, task_datas else: # not MAML acc_all = [] # # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc. # task_datas = [None]*n_episodes # list of dict # directly use extracted features all_feature_files = get_all_feature_files(params) if params.n_test_candidates is None: # common setting (no candidate) # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc. task_datas = [None]*n_episodes # list of dict feature_file = all_feature_files[0] cl_feature, cl_filepath = feat_loader.init_loader(feature_file, return_path=True) cl_feature_single = [cl_feature] for i in tqdm(range(n_episodes)): # TODO afterward: fix data list? can only fix class list? task_data = feature_evaluation( cl_feature_single, model, params=params, n_query=15, **few_shot_params, cl_filepath=cl_filepath, ) acc = task_data['acc'] acc_all.append(acc) task_datas[i] = task_data acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('loaded from %d epoch model.' %(load_epoch)) print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes))) ########## last record and post-process ########## torch.cuda.empty_cache() timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) # TODO afterward: compute this acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes)) # writing settings into csv acc_mean_str = '%4.2f' % (acc_mean) acc_std_str = '%4.2f' %(acc_std) # record beyond params extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch} if should_del_features: del_features(params) end_time = datetime.datetime.now() print('exp_test() start at', start_time, ', end at', end_time, '.\n') print('exp_test() totally took:', end_time-start_time) return extra_record, task_datas else: # n_test_candidates settings candidate_cl_feature = [] # features of each class of each candidates print('Loading features of %s candidates into dictionaries...' %(params.n_test_candidates)) for n in tqdm(range(params.n_test_candidates)): nth_feature_file = all_feature_files[n] cl_feature, cl_filepath = feat_loader.init_loader(nth_feature_file, return_path=True) candidate_cl_feature.append(cl_feature) print('Evaluating...') # TODO: frac_acc_all is_single_exp = not isinstance(params.frac_ensemble, list) if is_single_exp: # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc. task_datas = [None]*n_episodes # list of dict ########## test and record acc ########## for i in tqdm(range(n_episodes)): # TODO afterward: fix data list? can only fix class list? task_data = feature_evaluation( candidate_cl_feature, model, params=params, n_query=15, **few_shot_params, cl_filepath=cl_filepath, ) acc = task_data['acc'] acc_all.append(acc) task_datas[i] = task_data collected = gc.collect() # print("Garbage collector: collected %d objects." % (collected)) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('loaded from %d epoch model.' %(load_epoch)) print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes))) collected = gc.collect() print("garbage collector: collected %d objects." % (collected)) ########## last record and post-process ########## torch.cuda.empty_cache() timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) # TODO afterward: compute this acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes)) # writing settings into csv acc_mean_str = '%4.2f' % (acc_mean) acc_std_str = '%4.2f' %(acc_std) # record beyond params extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch} if should_del_features: del_features(params) end_time = datetime.datetime.now() print('exp_test() start at', start_time, ', end at', end_time, '.\n') print('exp_test() totally took:', end_time-start_time) return extra_record, task_datas else: ########## multi-frac_ensemble exps ########## ########## (haven't modified) test and record acc ########## n_fracs = len(params.frac_ensemble) ##### initialize frac_data ##### frac_acc_alls = [[0]*n_episodes for _ in range(n_fracs)] frac_acc_means = [None]*n_fracs frac_acc_stds = [None]*n_fracs # draw_task: initialize task acc(actually can replace acc_all), img_path, img_is_correct, etc. ep_task_data_each_frac = [[None]*n_episodes for _ in range(n_fracs)] # list of list of dict for ep_id in tqdm(range(n_episodes)): # TODO afterward: fix data list? can only fix class list? # TODO my_utils.py: feature_eval return frac_task_data frac_task_data = feature_evaluation( candidate_cl_feature, model, params=params, n_query=15, **few_shot_params, cl_filepath=cl_filepath, ) for frac_id in range(n_fracs): task_data = frac_task_data[frac_id] # TODO: i think here's the problem??? acc = task_data['acc'] frac_acc_alls[frac_id][ep_id] = acc ep_task_data_each_frac[frac_id][ep_id] = task_data collected = gc.collect() # print("Garbage collector: collected %d objects." % (collected)) collected = gc.collect() # print("Garbage collector: collected %d objects." % (collected)) ### debug # print('frac_acc_alls:', frac_acc_alls) # yee for frac_id in range(n_fracs): frac_acc_alls[frac_id] = np.asarray(frac_acc_alls[frac_id]) acc_all = frac_acc_alls[frac_id] acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) frac_acc_means[frac_id] = acc_mean frac_acc_stds[frac_id] = acc_std print('loaded from %d epoch model, frac_ensemble:'%(load_epoch), params.frac_ensemble[frac_id]) print('%d episodes, Test Acc = %4.2f%% +- %4.2f%%' %(n_episodes, acc_mean, 1.96* acc_std/np.sqrt(n_episodes))) ########## (haven't modified) last record and post-process ########## torch.cuda.empty_cache() timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) # TODO afterward: compute this # acc_str = '%4.2f%% +- %4.2f%%' % (acc_mean, 1.96* acc_std/np.sqrt(n_episodes)) frac_extra_records = [] for frac_id in range(n_fracs): # writing settings into csv acc_mean = frac_acc_means[frac_id] acc_std = frac_acc_stds[frac_id] acc_mean_str = '%4.2f' % (acc_mean) acc_std_str = '%4.2f' %(acc_std) # record beyond params extra_record = {'time':timestamp, 'acc_mean':acc_mean_str, 'acc_std':acc_std_str, 'epoch':load_epoch} frac_extra_records.append(extra_record) if should_del_features: del_features(params) end_time = datetime.datetime.now() print('exp_test() start at', start_time, ', end at', end_time, '.\n') print('exp_test() totally took:', end_time-start_time) return frac_extra_records, ep_task_data_each_frac