elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = to_cuda(model) params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug'
def single_test(params, results_logger): acc_all = [] iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'DKT': model = DKT(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) # modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++']: if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) else: print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir)) split = params.split if params.save_iter != -1: split_str = split + "_" + str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx', 'DKT']: # maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split + '.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split + '.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader(loadfile, aug=False) if params.adaptation: model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints", "features"), split_str + ".hdf5" ) # defaut split = novel, but you can also test base or val classes cl_data_file = feat_loader.init_loader(novel_file) for i in range(iter_num): acc = feature_evaluation(cl_data_file, model, n_query=15, adaptation=params.adaptation, **few_shot_params) acc_all.append(acc) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) with open('record/results.txt', 'a') as f: timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) aug_str = '-aug' if params.train_aug else '' aug_str += '-adapted' if params.adaptation else '' if params.method in ['baseline', 'baseline++']: exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.test_n_way) else: exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.train_n_way, params.test_n_way) acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % ( iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)) f.write('Time: %s, Setting: %s, Acc: %s \n' % (timestamp, exp_setting, acc_str)) results_logger.log("single_test_acc", acc_mean) results_logger.log("single_test_acc_std", 1.96 * acc_std / np.sqrt(iter_num)) results_logger.log("time", timestamp) results_logger.log("exp_setting", exp_setting) results_logger.log("acc_str", acc_str) return acc_mean
def main_train(params): _set_seed(params) results_logger = ResultsLogger(params) if params.dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'all.json' val_file = configs.data_dir['CUB'] + 'val.json' elif params.dataset == 'cross_char': base_file = configs.data_dir['omniglot'] + 'noLatin.json' val_file = configs.data_dir['emnist'] + 'val.json' else: base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + 'val.json' if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' optimization = 'Adam' if params.stop_epoch == -1: if params.method in ['baseline', 'baseline++']: if params.dataset in ['omniglot', 'cross_char']: params.stop_epoch = 5 elif params.dataset in ['CUB']: params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting elif params.dataset in ['miniImagenet', 'cross']: params.stop_epoch = 400 else: params.stop_epoch = 400 # default else: # meta-learning methods if params.n_shot == 1: params.stop_epoch = 600 elif params.n_shot == 5: params.stop_epoch = 400 else: params.stop_epoch = 600 # default if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'DKT', 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) # n_eposide=100 base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor if (params.method == 'DKT'): model = DKT(model_dict[params.model], **train_few_shot_params) model.init_summary() elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in [ 'omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task # maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: # We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, results_logger) results_logger.save()
feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), test_n_support=params.test_n_shot, **few_shot_params) # model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), test_n_support=params.tes,**few_shot_params) model.task_update_num = 5 if params.dataset in ['omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 elif params.method == 'Ours': model = Ours( model_dict[params.model], # test_n_support=1, **few_shot_params) test_n_support=params.test_n_shot, **few_shot_params) else:
def get_model(params, mode): ''' Args: params: argparse params mode: (str), 'train', 'test' ''' print('get_model() start...') # few_shot_params_d = get_few_shot_params(params, None) # few_shot_params = few_shot_params_d[mode] few_shot_params = get_few_shot_params(params, mode) if 'omniglot' in params.dataset or 'cross_char' in params.dataset: # if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]: # assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation' assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation' params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop if params.recons_decoder is not None: if 'ConvS' not in params.recons_decoder: raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.') # if mode == 'train': # params.num_classes = n_base_class_map[params.dataset] if params.method in ['baseline', 'baseline++'] and mode=='train': assert params.num_classes >= n_base_classes[params.dataset] # if params.dataset == 'omniglot': # 4112/688/1692 # assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char': # 1597/31/31 # assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char_half': # 758/31/31 # assert params.num_classes >= 758, 'class number need to be larger than max label id in base class' # if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31 # assert params.num_classes >= 350, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char_base3lang': # 69/31/31 # assert params.num_classes >= 69, 'class number need to be larger than max label id in base class' # if params.dataset == 'miniImagenet': # 64/16/20 # assert params.num_classes >= 64, 'class number need to be larger than max label id in base class' # if params.dataset == 'CUB': # 100/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross': # 64+16+20/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_base80cl': # 80/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' if params.recons_decoder == None: print('params.recons_decoder == None') recons_decoder = None else: recons_decoder = decoder_dict[params.recons_decoder] print('recons_decoder:\n',recons_decoder) backbone_func = get_backbone_func(params) if 'baseline' in params.method: loss_types = { 'baseline':'softmax', 'baseline++':'dist', } loss_type = loss_types[params.method] if recons_decoder is None and params.min_gram is None: # default baseline/baseline++ if mode == 'train': model = BaselineTrain( model_func = backbone_func, loss_type = loss_type, num_class = params.num_classes, **few_shot_params) elif mode == 'test': model = BaselineFinetune( model_func = backbone_func, loss_type = loss_type, **few_shot_params, finetune_dropout_p = params.finetune_dropout_p) else: # other settings for baseline if params.min_gram is not None: min_gram_params = { 'min_gram':params.min_gram, 'lambda_gram':params.lambda_gram, } if mode == 'train': model = BaselineTrainMinGram( model_func = backbone_func, loss_type = loss_type, num_class = params.num_classes, **few_shot_params, **min_gram_params) elif mode == 'test': model = BaselineFinetune( model_func = backbone_func, loss_type = loss_type, **few_shot_params, finetune_dropout_p = params.finetune_dropout_p) # model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params) elif params.method == 'protonet': # default ProtoNet if recons_decoder is None and params.min_gram is None: model = ProtoNet( backbone_func, **few_shot_params ) else: # other settings if params.min_gram is not None: min_gram_params = { 'min_gram':params.min_gram, 'lambda_gram':params.lambda_gram, } model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params) if params.recons_decoder is not None: if 'Hidden' in params.recons_decoder: if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS' model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2) elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS' model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False) elif params.recons_decoder == 'HiddenRes10': model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6) elif params.recons_decoder == 'HiddenRes18': model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8) else: if 'ConvS' in params.recons_decoder: model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False) else: model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True) elif params.method == 'matchingnet': model = MatchingNet( backbone_func, **few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: # if params.model == 'Conv4': # feature_model = backbone.Conv4NP # elif params.model == 'Conv6': # feature_model = backbone.Conv6NP # elif params.model == 'Conv4S': # feature_model = backbone.Conv4SNP # else: # feature_model = lambda: model_dict[params.model]( flatten = False ) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params ) if 'omniglot' in params.dataset or 'cross_char' in params.dataset: # if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unexpected params.method: %s'%(params.method)) print('get_model() finished.') return model
**test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) train_few_shot_params['dropout_method'] = params.dropout_method train_few_shot_params['dropout_rate'] = params.dropout_rate train_few_shot_params['dropout_schedule'] = params.dropout_schedule # prepare model print('--- prepare model {} (backbone {}) ---'.format( params.method, params.model)) if params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), tf_path=params.tf_dir, **train_few_shot_params) else: raise ValueError('Unknown method') model = model.cuda() max_acc = 0 total_it = 0 start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method not in ['baseline', 'baseline++']: stop_epoch = params.stop_epoch * model.batch_size # resume/warmup or not if params.resume != '': resume_file = get_resume_file('%s/checkpoints/%s' %
print('\n--- build MAML model ---') print(' train with model: %s'%params.model) if 'Conv' in params.model: image_size = 84 else: image_size = 224 n_query = max(1, int(16* params.test_n_way/params.train_n_way)) train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) base_datamgr = SetDataManager(image_size, n_query = n_query, n_eposide = 100, **train_few_shot_params) test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) val_datamgr = SetDataManager(image_size, n_query = n_query, n_eposide = 100, **test_few_shot_params) val_loader = val_datamgr.get_data_loader( val_file, aug = False) val_loader_nd = val_datamgr.get_data_loader( val_file, aug = False) model = MAML(params, tf_path=params.tf_dir) model.cuda() # resume training start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.resume != '': resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch) if resume_file is not None: start_epoch = model.resume(resume_file) print(' resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume)) else: raise ValueError('No resume file') # load pre-trained feature encoder else: if params.warmup == 'scratch':
if __name__ == '__main__': params = parse_args('test') print(params) print('test: {}'.format(params.name)) acc_all = [] iter_num = 100 # create model few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) else: raise ValueError('Unknown method') model = model.cuda() # load model checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name) if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state'])
def get_logits_targets(params): acc_all = [] iter_num = 600 few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune( model_dict[params.model], **few_shot_params ) elif params.method == 'baseline++': model = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params ) elif params.method == 'protonet': model = ProtoNet( model_dict[params.model], **few_shot_params ) elif params.method == 'DKT': model = DKT(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet( model_dict[params.model], **few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model]( flatten = False ) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( feature_model, loss_type = loss_type , **few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params ) if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++'] : checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot) #modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++'] : if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir,params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) else: print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir)) split = params.split if params.save_iter != -1: split_str = split + "_" +str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split +'.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split +'.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader( loadfile, aug = False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() logits_list = list() targets_list = list() for i, (x,_) in enumerate(novel_loader): logits = model.get_logits(x).detach() targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda() logits_list.append(logits) #.cpu().detach().numpy()) targets_list.append(targets) #.cpu().detach().numpy()) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5") cl_data_file = feat_loader.init_loader(novel_file) logits_list = list() targets_list = list() n_query = 15 n_way = few_shot_params['n_way'] n_support = few_shot_params['n_support'] class_list = cl_data_file.keys() for i in range(iter_num): #---------------------- select_class = random.sample(class_list,n_way) z_all = [] for cl in select_class: img_feat = cl_data_file[cl] perm_ids = np.random.permutation(len(img_feat)).tolist() z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] ) # stack each batch z_all = torch.from_numpy(np.array(z_all)) model.n_query = n_query logits = model.set_forward(z_all, is_feature = True).detach() targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda() logits_list.append(logits) targets_list.append(targets) #---------------------- return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
model = ProtoNet(backbone.FCNet(x_dim), **train_few_shot_params) elif params.method == 'comet': model = COMET(backbone.EnFCNet(x_dim, go_mask), **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(backbone.FCNet(x_dim), **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(backbone.FCNet(x_dim), loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: model = MAML(backbone.FCNet(x_dim), approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method, params.exp_str) if params.train_aug:
feature_model = backbone.Conv4NP elif args.model == 'Conv6': feature_model = backbone.Conv6NP else: feature_model = lambda: model_dict[args.model](flatten=False) loss_type = 'mse' if args.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_args) elif args.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[args.model], approx=(args.method == 'maml_approx'), **few_shot_args) else: raise ValueError('Unknown method') model = model.cuda() if args.checkpoint_dir is None: checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, args.dataset, args.model, args.method) if args.train_aug: checkpoint_dir += '_aug' if not args.method in ['baseline', 'baseline++']: checkpoint_dir += '_%dway_%dshot' % (args.train_n_way, args.n_shot) else: checkpoint_dir = args.checkpoint_dir
def select_model(params): """ select which model to use based on params """ if params.method in ['baseline', 'baseline++']: if params.dataset == 'CUB': params.num_classes = 200 elif params.dataset == 'cars': params.num_classes = 196 elif params.dataset == 'aircrafts': params.num_classes = 100 elif params.dataset == 'dogs': params.num_classes = 120 elif params.dataset == 'flowers': params.num_classes = 102 elif params.dataset == 'miniImagenet': params.num_classes = 100 elif params.dataset == 'tieredImagenet': params.num_classes = 608 if params.method == 'baseline': model = BaselineTrain( model_dict[params.model], params.num_classes, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) elif params.method == 'baseline++': model = BaselineTrain( model_dict[params.model], params.num_classes, \ loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation) if params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params, use_bn=(not params.no_bn), pretrain=params.pretrain, tracking=params.tracking) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True BasicBlock.maml = True Bottleneck.maml = True ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) else: raise ValueError('Unknown method') return model
feature_model = backbone.Conv6NP elif params.model == "Conv4S": feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = "mse" if params.method == "relationnet" else "softmax" model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ["maml", "maml_approx"]: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == "maml_approx"), **few_shot_params) if params.dataset in [ "omniglot", "cross_char", ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError("Unknown method") model = model.cuda() checkpoint_dir = "%s/checkpoints/%s/%s_%s" % ( configs.save_dir,
n_support=params.n_shot) datamgr = SetDataManager(image_size, n_query=n_query, n_eposide=n_task, **test_few_shot_params) # model print('\n--- build MAML model ---') print(' test with model: %s' % params.model) params.tf_dir = '%s/log/%s' % (params.save_dir, params.name) params.checkpoint_dir = '%s/checkpoints/%s' % (params.save_dir, params.name) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) model = MAML(params, tf_path=params.tf_dir) model.cuda() print('\nStage 2: evaluate') # resume model if params.save_epoch != -1: modelfile = get_assigned_file(params.checkpoint_dir, params.save_epoch) else: modelfile = get_best_file(params.checkpoint_dir) print(" load model: %s" % modelfile) # start evaluate print('\n--- start the testing ---') n_exp = params.n_exp n_iter = params.n_iter tf_path = '%s/log_test/%s_iter_%s_%s' % (params.save_dir, params.name,