def load_mask_detector(): ############### Change path to path of model ################### # path_model='/content/drive/MyDrive/frinks/models/fastai_resnet101' path_model='/content/drive/MyDrive/frinks/fewShots/CloserLookFewShot/checkpoints_masks_Conv4_baseline_aug/20.tar' path_data='/content/drive/MyDrive/frinks/Faces/data' if flag is 'torch': if not flag_fewShots: model = torch.load(path_model) if flag_fewShots: # import pdb; pdb.set_trace() model = BaselineTrain(backbone.Conv4, 4) model_dict = torch.load(path_model) model.load_state_dict(model_dict['state']) model=model.cuda() elif flag is 'fastai': data = ImageDataBunch.from_folder(path_data, valid_pct=0.2, size = 120) model = cnn_learner(data, models.resnet101, metrics=error_rate) model.load(path_model) else: model = LoadModel(path_model) return model
model = RelationNet( feature_model, loss_type = loss_type , **train_few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( model_dict[params.model], approx = (params.method == 'maml_approx') , **train_few_shot_params ) if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx' : stop_epoch = params.stop_epoch * model.n_task #maml use multiple tasks in one update
def main_train(params): _set_seed(params) results_logger = ResultsLogger(params) if params.dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'all.json' val_file = configs.data_dir['CUB'] + 'val.json' elif params.dataset == 'cross_char': base_file = configs.data_dir['omniglot'] + 'noLatin.json' val_file = configs.data_dir['emnist'] + 'val.json' else: base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + 'val.json' if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' optimization = 'Adam' if params.stop_epoch == -1: if params.method in ['baseline', 'baseline++']: if params.dataset in ['omniglot', 'cross_char']: params.stop_epoch = 5 elif params.dataset in ['CUB']: params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting elif params.dataset in ['miniImagenet', 'cross']: params.stop_epoch = 400 else: params.stop_epoch = 400 # default else: # meta-learning methods if params.n_shot == 1: params.stop_epoch = 600 elif params.n_shot == 5: params.stop_epoch = 400 else: params.stop_epoch = 600 # default if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'DKT', 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) # n_eposide=100 base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor if (params.method == 'DKT'): model = DKT(model_dict[params.model], **train_few_shot_params) model.init_summary() elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in [ 'omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task # maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: # We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params, results_logger) results_logger.save()
elif params.method == 'S2M2_R' or 'rotation': if params.model == 'WideResNet28_10': model = wrn_mixup_model.wrn28_10( num_classes=params.num_classes, dct_status=params.dct_status) elif params.model == 'ResNet18': model = res_mixup_model.resnet18( num_classes=params.num_classes) if params.method == 'baseline++': if use_gpu: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel( model, device_ids=range(torch.cuda.device_count())) model.cuda() if params.resume: resume_file = get_resume_file(params.checkpoint_dir) tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 state = tmp['state'] model.load_state_dict(state) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True optimization = 'Adam' model = train_baseline(base_loader, base_loader_test, val_loader, model, start_epoch, start_epoch + stop_epoch, params, {}) elif params.method == 'S2M2_R':
def main(): timer = Timer() args, writer = init() train_file = args.dataset_dir + 'train.json' val_file = args.dataset_dir + 'val.json' few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query) n_episode = 10 if args.debug else 100 if args.method_type is Method_type.baseline: train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64) train_loader = train_datamgr.get_data_loader(aug = True) else: train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size, n_episode=n_episode, mode='train', **few_shot_params) train_loader = train_datamgr.get_data_loader(aug=True) val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size, n_episode=n_episode, mode='val', **few_shot_params) val_loader = val_datamgr.get_data_loader(aug=False) if args.model_type is Model_type.ConvNet: pass elif args.model_type is Model_type.ResNet12: from methods.backbone import ResNet12 encoder = ResNet12() else: raise ValueError('') if args.method_type is Method_type.baseline: from methods.baselinetrain import BaselineTrain model = BaselineTrain(encoder, args) elif args.method_type is Method_type.protonet: from methods.protonet import ProtoNet model = ProtoNet(encoder, args) else: raise ValueError('') from torch.optim import SGD,lr_scheduler if args.method_type is Method_type.baseline: optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1) else: optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) args.ngpu = torch.cuda.device_count() torch.backends.cudnn.benchmark = True model = model.cuda() label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query)) label = label.cuda() if args.test: test(model, label, args, few_shot_params) return if args.resume: resume_OK = resume_model(model, optimizer, args, scheduler) else: resume_OK = False if (not resume_OK) and (args.warmup is not None): load_pretrained_weights(model, args) if args.debug: args.max_epoch = args.start_epoch + 1 for epoch in range(args.start_epoch, args.max_epoch): train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch) scheduler.step() vl, va = val(model, args, val_loader, label) if writer is not None: writer.add_scalar('data/val_acc', float(va), epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va)) if va >= args.max_acc: args.max_acc = va print('saving the best model! acc={:.4f}'.format(va)) save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler) save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler) if epoch != 0: print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) if writer is not None: writer.close() test(model, label, args, few_shot_params)