def load_mask_detector(): ############### Change path to path of model ################### # path_model='/content/drive/MyDrive/frinks/models/fastai_resnet101' path_model='/content/drive/MyDrive/frinks/fewShots/CloserLookFewShot/checkpoints_masks_Conv4_baseline_aug/20.tar' path_data='/content/drive/MyDrive/frinks/Faces/data' if flag is 'torch': if not flag_fewShots: model = torch.load(path_model) if flag_fewShots: # import pdb; pdb.set_trace() model = BaselineTrain(backbone.Conv4, 4) model_dict = torch.load(path_model) model.load_state_dict(model_dict['state']) model=model.cuda() elif flag is 'fastai': data = ImageDataBunch.from_folder(path_data, valid_pct=0.2, size = 120) model = cnn_learner(data, models.resnet101, metrics=error_rate) model.load(path_model) else: model = LoadModel(path_model) return model
def get_model(params): if params.method in ['baseline', 'baseline++']: if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') else: raise ValueError('Unknown method') if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) params.start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: #We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') return model
val_datamgr = SimpleDataManager(image_size) val_loader = val_datamgr.get_data_loader(val_file, batch_size=64, aug=False) if params.dataset == "omniglot": assert ( params.num_classes >= 4112 ), "class number need to be larger than max label id in base class" if params.dataset == "cross_char": assert ( params.num_classes >= 1597 ), "class number need to be larger than max label id in base class" if params.method == "baseline": model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == "baseline++": model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type="dist") elif params.method in [ "protonet", "matchingnet", "relationnet", "relationnet_softmax", "maml", "maml_approx", ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way)
base_datamgr = caltech256_few_shot.SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(aug=False) params.num_classes = 257 elif params.dataset == "DTD": base_datamgr = DTD_few_shot.SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(aug=True) else: raise ValueError('Unknown dataset') #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #print(device) model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method in [ 'dampnet_full_class', 'dampnet_full_sparse', 'protonet_damp', 'maml', 'relationnet', 'dampnet_full', 'dampnet', 'protonet', 'gnnnet', 'gnnnet_maml', 'metaoptnet', 'gnnnet_normalized', 'gnnnet_neg_margin' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
image_size = 224 optimization = 'Adam' if params.method in ['baseline']: if params.dataset == "miniImageNet": base_file = configs.data_dir['miniImagenet'] + 'base.json' base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) else: raise ValueError('Unknown dataset') model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method in ['protonet']: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset == "miniImageNet": base_file = configs.data_dir['miniImagenet'] + 'base.json' base_datamgr = SetDataManager(image_size, n_query=n_query,
image_size = 224 optimization = 'Adam' if params.method in ['baseline', 'myModel']: if params.dataset == "miniImageNet": datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size=180) base_loader = datamgr.get_data_loader(aug=params.train_aug) val_loader = None else: raise ValueError('Unknown dataset') if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) else: model = MyModelTrain(model_dict[params.model], params.num_classes, params.margin, params.embed_dim, params.logit_scale) elif params.method in ['protonet', 'myprotonet']: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset == "miniImageNet":
def main_train(params): _set_seed(params) results_logger = ResultsLogger(params) if params.dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'all.json' val_file = configs.data_dir['CUB'] + 'val.json' elif params.dataset == 'cross_char': base_file = configs.data_dir['omniglot'] + 'noLatin.json' val_file = configs.data_dir['emnist'] + 'val.json' else: base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + 'val.json' if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' optimization = 'Adam' if params.stop_epoch == -1: if params.method in ['baseline', 'baseline++']: if params.dataset in ['omniglot', 'cross_char']: params.stop_epoch = 5 elif params.dataset in ['CUB']: params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting elif params.dataset in ['miniImagenet', 'cross']: params.stop_epoch = 400 else: params.stop_epoch = 400 # default else: # meta-learning methods if params.n_shot == 1: params.stop_epoch = 600 elif params.n_shot == 5: params.stop_epoch = 400 else: params.stop_epoch = 600 # default if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'DKT', 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) # n_eposide=100 base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor if (params.method == 'DKT'): model = DKT(model_dict[params.model], **train_few_shot_params) model.init_summary() elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in [ 'omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task # maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: # We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, results_logger) results_logger.save()
else: print(f'{key} will be removed') del state[key] msg = model_moco.load_state_dict(state, strict=False) assert len(msg.missing_keys) == 0 and len(msg.unexpected_keys) == 0, "loading model is wrong" # get bottom of ResNet encoder = moco.ResNetBottom(model_moco.encoder_q) if params.method in ['baseline', 'baseline++'] : base_datamgr = SimpleDataManager(image_size, batch_size = 64) base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) val_datamgr = SimpleDataManager(image_size, batch_size = 256) val_loader = val_datamgr.get_data_loader( val_file, aug = False) if params.method == 'baseline': model = BaselineTrain( model_dict[params.model], params.num_classes) elif params.method == 'baseline++': # model = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist') model = BaselineTrain( encoder, params.num_classes, loss_type = 'dist') elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']: n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader( val_file, aug = False) #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
params.stop_epoch = 40 #default else: #meta-learning methods params.stop_epoch = 60 #default if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(batch_size=16) base_loader = base_datamgr.get_data_loader( root='./filelists/tabula_muris', mode='train') val_datamgr = SimpleDataManager(batch_size=64) val_loader = val_datamgr.get_data_loader( root='./filelists/tabula_muris', mode='val') x_dim = base_loader.dataset.get_dim() if params.method == 'baseline': model = BaselineTrain(backbone.FCNet(x_dim), params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(backbone.FCNet(x_dim), params.num_classes, loss_type='dist') elif params.method in [ 'protonet', 'comet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(n_query=n_query, **train_few_shot_params)
base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.method == 'manifold_mixup': model = wrn_mixup_model.wrn28_10(64) elif params.method == 'S2M2_R': model = ProtoNet(model_dict[params.model], params.train_n_way, params.n_shot) elif params.method == 'rotation': model = BaselineTrain(model_dict[params.model], 64, loss_type='dist') if params.method == 'S2M2_R': if use_gpu: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, device_ids=range( torch.cuda.device_count())) model.cuda() if params.resume: resume_file = get_resume_file(params.checkpoint_dir) print("resume_file", resume_file) tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1
datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size=16) print( "datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 16)" ) base_loader = datamgr.get_data_loader( aug=params.train_aug) #waste lots of time print( "base_loader = datamgr.get_data_loader(aug = params.train_aug )" ) val_loader = None print("load miniIMageNet [END]") else: raise ValueError('Unknown dataset') print("load model [START]") model = BaselineTrain(model_dict[params.model], params.num_classes) print("load model [END]") elif params.method in ['protonet']: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset == "miniImageNet": datamgr = miniImageNet_few_shot.SetDataManager( image_size,
start_epoch = params.start_epoch stop_epoch = params.stop_epoch base_datamgr = SimpleDataManager(image_size, batch_size=params.batch_size) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=params.test_batch_size) val_loader = base_datamgr.get_data_loader(base_file, aug=False) if params.method == 'manifold_mixup': model = wrn_mixup_model.wrn28_10(64, 0.9) elif params.method == 'S2M2_R': model = wrn_mixup_model.wrn28_10(64, 0.9) elif params.method == 'rotation': model = BaselineTrain(model_dict[params.model], 64, dropRate=0.9, loss_type='dist') if params.method == 'S2M2_R': if use_gpu: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, device_ids=range( torch.cuda.device_count())) model.cuda() if params.resume: resume_file = get_resume_file(params.checkpoint_dir) print("resume_file", resume_file) tmp = torch.load(resume_file)
def select_model(params): """ select which model to use based on params """ if params.method in ['baseline', 'baseline++']: if params.dataset == 'CUB': params.num_classes = 200 elif params.dataset == 'cars': params.num_classes = 196 elif params.dataset == 'aircrafts': params.num_classes = 100 elif params.dataset == 'dogs': params.num_classes = 120 elif params.dataset == 'flowers': params.num_classes = 102 elif params.dataset == 'miniImagenet': params.num_classes = 100 elif params.dataset == 'tieredImagenet': params.num_classes = 608 if params.method == 'baseline': model = BaselineTrain( model_dict[params.model], params.num_classes, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) elif params.method == 'baseline++': model = BaselineTrain( model_dict[params.model], params.num_classes, \ loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation) if params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params, use_bn=(not params.no_bn), pretrain=params.pretrain, tracking=params.tracking) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True BasicBlock.maml = True Bottleneck.maml = True ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) else: raise ValueError('Unknown method') return model
if 'Conv' in params.model: image_size = 84 else: image_size = 224 if params.method in ['baseline', 'baseline++']: print(' pre-training the feature encoder {} using method {}'.format( params.model, params.method)) base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes, tf_path=params.tf_dir) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist', tf_path=params.tf_dir) elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'gnnnet' ]: print( ' baseline training the model {} with feature encoder {}'.format( params.method, params.model))
print('Running up to {} epochs'.format(params.stop_epoch)) device = torch.device("cuda:0" if torch.cuda.device_count() > 0 else "cpu") if params.method in ['baseline', 'baseline++'] : base_datamgr = SimpleDataManager(image_size, batch_size = 16) base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) val_datamgr = SimpleDataManager(image_size, batch_size = 64) val_loader = val_datamgr.get_data_loader( val_file, aug = False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain( model_dict[params.model], params.num_classes, device=device) elif params.method == 'baseline++': model = BaselineTrain( model_dict[params.model], params.num_classes, loss_type = 'dist', init_orthogonal = params.ortho, ortho_reg = params.ortho_reg, device=device) elif params.method in ['protonet','matchingnet','relationnet', 'relationnet_softmax', 'maml', 'maml_approx']: n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug ) test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader( val_file, aug = False) #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
params.num_classes = 200 elif params.dataset == 'cars_original': params.num_classes = 196 elif params.dataset == 'aircrafts_original': params.num_classes = 100 elif params.dataset == 'dogs_original': params.num_classes = 120 elif params.dataset == 'flowers_original': params.num_classes = 102 elif params.dataset == 'miniImagenet': params.num_classes = 100 elif params.dataset == 'tieredImagenet': params.num_classes = 608 if params.method == 'baseline': model = BaselineTrain( model_dict[params.model], params.num_classes, \ jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) elif params.method == 'baseline++': model = BaselineTrain( model_dict[params.model], params.num_classes, \ loss_type = 'dist', jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking) elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(params.n_query * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small print('n_query:', n_query) base_datamgr_u = SimpleDataManager(image_size, batch_size=params.bs,
def main(): timer = Timer() args, writer = init() train_file = args.dataset_dir + 'train.json' val_file = args.dataset_dir + 'val.json' few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query) n_episode = 10 if args.debug else 100 if args.method_type is Method_type.baseline: train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64) train_loader = train_datamgr.get_data_loader(aug = True) else: train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size, n_episode=n_episode, mode='train', **few_shot_params) train_loader = train_datamgr.get_data_loader(aug=True) val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size, n_episode=n_episode, mode='val', **few_shot_params) val_loader = val_datamgr.get_data_loader(aug=False) if args.model_type is Model_type.ConvNet: pass elif args.model_type is Model_type.ResNet12: from methods.backbone import ResNet12 encoder = ResNet12() else: raise ValueError('') if args.method_type is Method_type.baseline: from methods.baselinetrain import BaselineTrain model = BaselineTrain(encoder, args) elif args.method_type is Method_type.protonet: from methods.protonet import ProtoNet model = ProtoNet(encoder, args) else: raise ValueError('') from torch.optim import SGD,lr_scheduler if args.method_type is Method_type.baseline: optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1) else: optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) args.ngpu = torch.cuda.device_count() torch.backends.cudnn.benchmark = True model = model.cuda() label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query)) label = label.cuda() if args.test: test(model, label, args, few_shot_params) return if args.resume: resume_OK = resume_model(model, optimizer, args, scheduler) else: resume_OK = False if (not resume_OK) and (args.warmup is not None): load_pretrained_weights(model, args) if args.debug: args.max_epoch = args.start_epoch + 1 for epoch in range(args.start_epoch, args.max_epoch): train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch) scheduler.step() vl, va = val(model, args, val_loader, label) if writer is not None: writer.add_scalar('data/val_acc', float(va), epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va)) if va >= args.max_acc: args.max_acc = va print('saving the best model! acc={:.4f}'.format(va)) save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler) save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler) if epoch != 0: print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) if writer is not None: writer.close() test(model, label, args, few_shot_params)
base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) base_datamgr_test = SimpleDataManager( image_size, batch_size=params.test_batch_size) base_loader_test = base_datamgr_test.get_data_loader(base_file, aug=False) test_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=15, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method == 'manifold_mixup': if params.model == 'WideResNet28_10': model = wrn_mixup_model.wrn28_10(params.num_classes) elif params.model == 'ResNet18': model = res_mixup_model.resnet18( num_classes=params.num_classes) elif params.method == 'S2M2_R' or 'rotation': if params.model == 'WideResNet28_10': model = wrn_mixup_model.wrn28_10( num_classes=params.num_classes, dct_status=params.dct_status) elif params.model == 'ResNet18':
def get_model(params, mode): ''' Args: params: argparse params mode: (str), 'train', 'test' ''' print('get_model() start...') # few_shot_params_d = get_few_shot_params(params, None) # few_shot_params = few_shot_params_d[mode] few_shot_params = get_few_shot_params(params, mode) if 'omniglot' in params.dataset or 'cross_char' in params.dataset: # if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]: # assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation' assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation' params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop if params.recons_decoder is not None: if 'ConvS' not in params.recons_decoder: raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.') # if mode == 'train': # params.num_classes = n_base_class_map[params.dataset] if params.method in ['baseline', 'baseline++'] and mode=='train': assert params.num_classes >= n_base_classes[params.dataset] # if params.dataset == 'omniglot': # 4112/688/1692 # assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char': # 1597/31/31 # assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char_half': # 758/31/31 # assert params.num_classes >= 758, 'class number need to be larger than max label id in base class' # if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31 # assert params.num_classes >= 350, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char_base3lang': # 69/31/31 # assert params.num_classes >= 69, 'class number need to be larger than max label id in base class' # if params.dataset == 'miniImagenet': # 64/16/20 # assert params.num_classes >= 64, 'class number need to be larger than max label id in base class' # if params.dataset == 'CUB': # 100/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross': # 64+16+20/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_base80cl': # 80/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' if params.recons_decoder == None: print('params.recons_decoder == None') recons_decoder = None else: recons_decoder = decoder_dict[params.recons_decoder] print('recons_decoder:\n',recons_decoder) backbone_func = get_backbone_func(params) if 'baseline' in params.method: loss_types = { 'baseline':'softmax', 'baseline++':'dist', } loss_type = loss_types[params.method] if recons_decoder is None and params.min_gram is None: # default baseline/baseline++ if mode == 'train': model = BaselineTrain( model_func = backbone_func, loss_type = loss_type, num_class = params.num_classes, **few_shot_params) elif mode == 'test': model = BaselineFinetune( model_func = backbone_func, loss_type = loss_type, **few_shot_params, finetune_dropout_p = params.finetune_dropout_p) else: # other settings for baseline if params.min_gram is not None: min_gram_params = { 'min_gram':params.min_gram, 'lambda_gram':params.lambda_gram, } if mode == 'train': model = BaselineTrainMinGram( model_func = backbone_func, loss_type = loss_type, num_class = params.num_classes, **few_shot_params, **min_gram_params) elif mode == 'test': model = BaselineFinetune( model_func = backbone_func, loss_type = loss_type, **few_shot_params, finetune_dropout_p = params.finetune_dropout_p) # model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params) elif params.method == 'protonet': # default ProtoNet if recons_decoder is None and params.min_gram is None: model = ProtoNet( backbone_func, **few_shot_params ) else: # other settings if params.min_gram is not None: min_gram_params = { 'min_gram':params.min_gram, 'lambda_gram':params.lambda_gram, } model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params) if params.recons_decoder is not None: if 'Hidden' in params.recons_decoder: if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS' model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2) elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS' model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False) elif params.recons_decoder == 'HiddenRes10': model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6) elif params.recons_decoder == 'HiddenRes18': model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8) else: if 'ConvS' in params.recons_decoder: model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False) else: model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True) elif params.method == 'matchingnet': model = MatchingNet( backbone_func, **few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: # if params.model == 'Conv4': # feature_model = backbone.Conv4NP # elif params.model == 'Conv6': # feature_model = backbone.Conv6NP # elif params.model == 'Conv4S': # feature_model = backbone.Conv4SNP # else: # feature_model = lambda: model_dict[params.model]( flatten = False ) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params ) if 'omniglot' in params.dataset or 'cross_char' in params.dataset: # if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unexpected params.method: %s'%(params.method)) print('get_model() finished.') return model
params.stop_epoch = 600 #default if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot)