def load_presaved_model_for_train(model, params): """ """ start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task #maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) del tmp elif params.warmup: #We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = 'checkpoints/%s/%s_%s' % ( params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') if params.loadfile != '': print('Loading model from: ' + params.loadfile) checkpoint = torch.load(params.loadfile) ## remove last layer for baseline pretrained_dict = { k: v for k, v in checkpoint['state'].items() if 'classifier' not in k and 'loss_fn' not in k } print('Load model from:', params.loadfile) model.load_state_dict(pretrained_dict, strict=False) return model, start_epoch, stop_epoch
def load_states(self, checkpoint_dir): resume_file = get_resume_file(checkpoint_dir) tmp = torch.load(resume_file) for key in tmp: if key in ['epoch', 'state']: continue state = tmp[key] self.losses_engines[key].load_state_dict(state)
def load_weight_file_for_test(model, params): """ choose the weight file for test process """ if params.loadfile != '': modelfile = params.loadfile checkpoint_dir = params.loadfile else: checkpoint_dir = params.checkpoint_dir # checkpoint path if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) elif params.method in ['baseline', 'baseline++']: modelfile = get_resume_file(checkpoint_dir) else: modelfile = get_best_file( checkpoint_dir) # return the best.tar file assert modelfile, "can not find model weight file in {}".format( checkpoint_dir) print("use model weight file: ", modelfile) if params.method in ['maml', 'maml_approx']: if modelfile is not None: tmp = torch.load(modelfile) state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(tmp['state']) else: ## eg: for Protonet and others tmp = torch.load(modelfile) state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) ## for protonets model.feature.load_state_dict(state) model.eval() model = model.cuda() model.eval() return model
model = wrn_mixup_model.wrn28_10( num_classes=params.num_classes, dct_status=params.dct_status) elif params.model == 'ResNet18': model = res_mixup_model.resnet18( num_classes=params.num_classes) if params.method == 'baseline++': if use_gpu: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel( model, device_ids=range(torch.cuda.device_count())) model.cuda() if params.resume: resume_file = get_resume_file(params.checkpoint_dir) tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 state = tmp['state'] model.load_state_dict(state) model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True optimization = 'Adam' model = train_baseline(base_loader, base_loader_test, val_loader, model, start_epoch, start_epoch + stop_epoch, params, {}) elif params.method == 'S2M2_R': if use_gpu: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(
def test_loop(novel_loader, return_std=False, loss_type="softmax", n_query=15, models_to_use=[], finetune_each_model=False, n_way=5, n_support=5): #overwrite parrent function correct = 0 count = 0 iter_num = len(novel_loader) acc_all = [] for _, (x, y) in enumerate(novel_loader): ############################################################################################### pretrained_models = [] for _ in range(len(models_to_use)): pretrained_models.append(model_dict[params.model]()) ############################################################################################### for idx, dataset_name in enumerate(models_to_use): checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, models_to_use[idx], params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' params.save_iter = -1 if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) elif params.method in ['baseline', 'baseline++']: modelfile = get_resume_file(checkpoint_dir) else: modelfile = get_best_file(checkpoint_dir) tmp = torch.load(modelfile) state = tmp['state'] state_keys = list(state.keys()) for _, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) pretrained_models[idx].load_state_dict(state) ############################################################################################### n_query = x.size(1) - n_support x = x.cuda() x_var = Variable(x) batch_size = 4 support_size = n_way * n_support ################################################################################## if finetune_each_model: for idx, model_name in enumerate(pretrained_models): pretrained_models[idx].cuda() pretrained_models[idx].train() x_a_i = x_var[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) # (25, 3, 224, 224) loss_fn = nn.CrossEntropyLoss().cuda() cnet = Classifier(pretrained_models[idx].final_feat_dim, n_way).cuda() classifier_opt = torch.optim.SGD(cnet.parameters(), lr=0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) feature_opt = torch.optim.SGD( pretrained_models[idx].parameters(), lr=0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) x_a_i = Variable(x_a_i).cuda() y_a_i = Variable( torch.from_numpy(np.repeat(range(n_way), n_support))).cuda() # (25,) train_size = support_size batch_size = 4 for epoch in range(100): rand_id = np.random.permutation(train_size) for j in range(0, train_size, batch_size): classifier_opt.zero_grad() feature_opt.zero_grad() ##################################### selected_id = torch.from_numpy( rand_id[j:min(j + batch_size, train_size)]).cuda() z_batch = x_a_i[selected_id] y_batch = y_a_i[selected_id] ##################################### outputs = pretrained_models[idx](z_batch) outputs = cnet(outputs) ##################################### loss = loss_fn(outputs, y_batch) loss.backward() for k, param in enumerate( pretrained_models[idx].parameters()): param.grad[torch.lt( torch.abs(param.grad), torch.abs(param.grad).median())] = 0.0 classifier_opt.step() feature_opt.step() ############################################################################################### for idx, model_name in enumerate(pretrained_models): pretrained_models[idx].cuda() pretrained_models[idx].eval() ############################################################################################### all_embeddings_train = [] for idx, model_name in enumerate(pretrained_models): model_embeddings = [] x_a_i = x_var[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) # (25, 3, 224, 224) for idx, module in enumerate(pretrained_models[idx].trunk): x_a_i = module(x_a_i) if len(list(x_a_i.size())) == 4: embedding = F.adaptive_avg_pool2d(x_a_i, (1, 1)).squeeze() model_embeddings.append(embedding.detach()) if params.model == "ResNet10" or params.model == "ResNet18": model_embeddings = model_embeddings[4:-1] elif params.model == "Conv4": model_embeddings = model_embeddings all_embeddings_train.append(model_embeddings) ########################################################## y_a_i = np.repeat(range(n_way), n_support) embeddings_idx_of_each, embeddings_idx_model, embeddings_train, embeddings_best_of_each = train_selection( all_embeddings_train, y_a_i, support_size, n_support, n_way, with_replacement=True) ########################################################## all_embeddings_test = [] for idx, model_name in enumerate(pretrained_models): model_embeddings = [] x_b_i = x_var[:, n_support:, :, :, :].contiguous().view( n_way * n_query, *x.size()[2:]) for idx, module in enumerate(pretrained_models[idx].trunk): x_b_i = module(x_b_i) if len(list(x_b_i.size())) == 4: embedding = F.adaptive_avg_pool2d(x_b_i, (1, 1)).squeeze() model_embeddings.append(embedding.detach()) if params.model == "ResNet10" or params.model == "ResNet18": model_embeddings = model_embeddings[4:-1] elif params.model == "Conv4": model_embeddings = model_embeddings all_embeddings_test.append(model_embeddings) ############################################################################################ embeddings_test = [] for index in embeddings_idx_model: embeddings_test.append( all_embeddings_test[index][embeddings_idx_of_each[index]]) embeddings_test = torch.cat(embeddings_test, 1) ############################################################################################ y_a_i = Variable(torch.from_numpy(np.repeat( range(n_way), n_support))).cuda() # (25,) net = Classifier(embeddings_test.size()[1], n_way).cuda() loss_fn = nn.CrossEntropyLoss().cuda() classifier_opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) total_epoch = 100 embeddings_train = Variable(embeddings_train.cuda()) net.train() for epoch in range(total_epoch): rand_id = np.random.permutation(support_size) for j in range(0, support_size, batch_size): classifier_opt.zero_grad() ##################################### selected_id = torch.from_numpy( rand_id[j:min(j + batch_size, support_size)]).cuda() z_batch = embeddings_train[selected_id] y_batch = y_a_i[selected_id] ##################################### outputs = net(z_batch) ##################################### loss = loss_fn(outputs, y_batch) loss.backward() classifier_opt.step() embeddings_test = Variable(embeddings_test.cuda()) scores = net(embeddings_test) y_query = np.repeat(range(n_way), n_query) topk_scores, topk_labels = scores.data.topk(1, 1, True, True) topk_ind = topk_labels.cpu().numpy() top1_correct = np.sum(topk_ind[:, 0] == y_query) correct_this, count_this = float(top1_correct), len(y_query) print(correct_this / count_this * 100) acc_all.append((correct_this / count_this * 100)) ############################################################################################### acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
def __init__(self, module): super(WrappedModel, self).__init__() self.module = module # that I actually define. def forward(self, x): return self.module(x) model = backbone.WideResNet28_10( flatten = True, beta_value = 50.) checkpoint_dir = './checkpoints/%s/%s_%s_%s' %('cifar', 'WideResNet28_10', 'art' , 'cifar') model = WrappedModel(model) print("resuming" , checkpoint_dir) resume_file = get_resume_file(checkpoint_dir) if resume_file is not None: print("resume_file" , resume_file) tmp = torch.load(resume_file) model.load_state_dict(tmp['state']) else: print("error no file found") exit() model = model.cuda() model.eval() def normalize(x): mean = torch.tensor([0.4914, 0.4822, 0.4465]) std = torch.tensor([0.2023, 0.1994, 0.2010])
aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) model = SSL_Train(model_dict[params.model], params.num_classes) else: raise ValueError('Unknown method') model = model.cuda() #Prepare checkpoint_dir params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: params.checkpoint_dir += '_aug' if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) print('checkpoint_dir', params.checkpoint_dir) if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) start_epoch = params.start_epoch stop_epoch = params.stop_epoch model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params)
lsl=args.lsl, language_model=lang_model, lang_supervision=args.lang_supervision, l3=args.l3, l3_model=l3_model, l3_n_infer=args.l3_n_infer) model = model.cuda() os.makedirs(args.checkpoint_dir, exist_ok=True) start_epoch = args.start_epoch stop_epoch = args.stop_epoch if args.resume: resume_file = get_resume_file(args.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp["epoch"] + 1 model.load_state_dict(tmp["state"]) metrics_fname = "metrics_{}.json".format(args.n) train( base_loader, val_loader, model, start_epoch, stop_epoch, args, metrics_fname=metrics_fname,
def finetune(novel_loader, n_query = 15, pretrained_dataset='miniImageNet', freeze_backbone = False, n_way = 5, n_support = 5): correct = 0 count = 0 iter_num = len(novel_loader) acc_all = [] for _, (x, y) in enumerate(novel_loader): ############################################################################################### # load pretrained model on miniImageNet pretrained_model = model_dict[params.model]() checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, pretrained_dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' params.save_iter = -1 if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) elif params.method in ['baseline', 'baseline++'] : modelfile = get_resume_file(checkpoint_dir) else: modelfile = get_best_file(checkpoint_dir) tmp = torch.load(modelfile) state = tmp['state'] state_keys = list(state.keys()) for _, key in enumerate(state_keys): if "feature." in key: newkey = key.replace("feature.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) pretrained_model.load_state_dict(state) ############################################################################################### classifier = Classifier(pretrained_model.final_feat_dim, n_way) ############################################################################################### n_query = x.size(1) - n_support x = x.cuda() x_var = Variable(x) batch_size = 4 support_size = n_way * n_support y_a_i = Variable( torch.from_numpy( np.repeat(range( n_way ), n_support ) )).cuda() # (25,) x_b_i = x_var[:, n_support:,:,:,:].contiguous().view( n_way* n_query, *x.size()[2:]) x_a_i = x_var[:,:n_support,:,:,:].contiguous().view( n_way* n_support, *x.size()[2:]) # (25, 3, 224, 224) ############################################################################################### loss_fn = nn.CrossEntropyLoss().cuda() classifier_opt = torch.optim.SGD(classifier.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) if freeze_backbone is False: delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, pretrained_model.parameters()), lr = 0.01) pretrained_model.cuda() classifier.cuda() ############################################################################################### total_epoch = 100 if freeze_backbone is False: pretrained_model.train() else: pretrained_model.eval() classifier.train() for epoch in range(total_epoch): rand_id = np.random.permutation(support_size) for j in range(0, support_size, batch_size): classifier_opt.zero_grad() if freeze_backbone is False: delta_opt.zero_grad() ##################################### selected_id = torch.from_numpy( rand_id[j: min(j+batch_size, support_size)]).cuda() z_batch = x_a_i[selected_id] y_batch = y_a_i[selected_id] ##################################### output = pretrained_model(z_batch) output = classifier(output) loss = loss_fn(output, y_batch) ##################################### loss.backward() classifier_opt.step() if freeze_backbone is False: delta_opt.step() pretrained_model.eval() classifier.eval() output = pretrained_model(x_b_i.cuda()) scores = classifier(output) y_query = np.repeat(range( n_way ), n_query ) topk_scores, topk_labels = scores.data.topk(1, 1, True, True) topk_ind = topk_labels.cpu().numpy() top1_correct = np.sum(topk_ind[:,0] == y_query) correct_this, count_this = float(top1_correct), len(y_query) print (correct_this/ count_this *100) acc_all.append((correct_this/ count_this *100)) ############################################################################################### acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
def run(params): if params.dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'all.json' val_file = configs.data_dir['CUB'] + 'val.json' elif params.dataset == 'cross_char': base_file = configs.data_dir['omniglot'] + 'noLatin.json' val_file = configs.data_dir['emnist'] + 'val.json' else: if params.base_json: print(f'Using base classes from {params.base_json}') base_file = params.base_json else: base_file = configs.data_dir[params.dataset] + 'base.json' val_file = configs.data_dir[params.dataset] + 'val.json' image_size = get_image_size(params) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' optimization = 'Adam' if params.stop_epoch == -1: if params.method in ['baseline', 'baseline++']: if params.dataset in ['omniglot', 'cross_char']: params.stop_epoch = 5 elif params.dataset in ['CUB']: params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting elif params.dataset in ['miniImagenet', 'cross']: params.stop_epoch = 400 else: params.stop_epoch = 400 #default else: #meta-learning methods if params.n_shot == 1: params.stop_epoch = 600 # elif params.n_shot == 5: # params.stop_epoch = 400 else: params.stop_epoch = 600 #default for 5-shot if params.method in ['baseline', 'baseline++']: base_datamgr = SimpleDataManager(image_size, batch_size=16) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(val_file, aug=False) if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': model = BaselineTrain(model_dict[params.model], params.num_classes) elif params.method == 'baseline++': model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type='dist') elif params.method in [ 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small if 'n_episode' not in params: if params.stop_epoch >= 500: params.n_episode = 1000 params.stop_epoch = int(params.stop_epoch / 10) else: params.n_episode = 100 print(f'| Using {params.n_episode} n_episode for trainloader...') print(f'| Using Stop epoch {params.stop_epoch}') train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size, n_query=n_query, n_episode=params.n_episode, **train_few_shot_params) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) test_few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader(val_file, aug=False) #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor # n_way: 5, n_support: 5, n_query: 16 if params.method == 'protonet': model = ProtoNet(model_dict[params.model], **train_few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **train_few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **train_few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **train_few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() print(model) if hasattr(params, 'logdir'): params.checkpoint_dir = params.logdir else: params.checkpoint_dir = '%s/ckpts/%s/%s_%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method, params.base_json) if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task #maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: #We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/ckpts/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) model.feature.load_state_dict(state) else: raise ValueError('No warm_up file') model = train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params)
def meta_test(novel_loader, n_query=15, pretrained_dataset='miniImageNet', freeze_backbone=False, n_way=5, n_support=5): #novel_loader has 600 dataloaders #n_query=15 #pretrained_dataset=miniImageNet #freeze_backbone=True #n_way=5 #n_support = 5 correct = 0 count = 0 iter_num = len(novel_loader) #600 acc_all = [] for ti, (x, y) in enumerate(novel_loader): ############################################################################################### # load pretrained model on miniImageNet pretrained_model = model_dict[params.model]() checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, pretrained_dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' params.save_iter = -1 if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) elif params.method in ['baseline', 'baseline++']: modelfile = get_resume_file(checkpoint_dir) else: modelfile = get_best_file(checkpoint_dir) print( "load from %s" % (modelfile) ) #"./logs/checkpoints/miniImagenet/ResNet10_baseline_aug/399.pth" tmp = torch.load(modelfile) state = tmp['state'] state_keys = list(state.keys()) for _, key in enumerate(state_keys): if "feature." in key: newkey = key.replace("feature.", "") state[newkey] = state.pop(key) #replace key name else: state.pop(key) #remove classifier pretrained_model.load_state_dict(state) #load checkpoints # train a new linear classifier classifier = Classifier( pretrained_model.final_feat_dim, n_way) #initializ only classifier with shape (512,5) for each task ############################################################################################### # split data into support set(5) and query set(15) n_query = x.size(1) - n_support #print(x.size())#torch.Size([5, 20, 3, 224, 224]) #print(n_support)#5 #print("n_query:%d"%(n_query))#15 x = x.cuda() x_var = Variable(x) #print(x_var.data.shape)#torch.Size([5, 20, 3, 224, 224]) # number of dataloaders is 5 and the real input is (20,3,224,224) #print(y)#however, y is useless and its shape is (5,20) => batch=5 and label=20 batch_size = 4 support_size = n_way * n_support #5*5=25 (maybe 5-way and each way contains 5 samples) y_a_i = Variable(torch.from_numpy(np.repeat(range(n_way), n_support))).cuda() #np.repeat(range( n_way ), n_support )=[0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,3,3,3,3,3,4,4,4,4,4] #print(y_a_i.data.shape)#torch.Size([25]) #n_way=5 and n_query=15, view(75,3,224,224) #x_var[:, n_support:,:,:,:].shape=(5,15,3,224,224) => sample 5 loaders, where each contains a batch of images with shape (15,3,224,224) x_b_i = x_var[:, n_support:, :, :, :].contiguous().view( n_way * n_query, *x.size()[2:]) # query set #print(x_b_i.shape)#(75,3,224,224) # 5 class loaders in total. Thus, batch size = 15*5 =75 #x_b_i.shape=75,3,224,224 #n_way * n_query ... (maybe 5-way and each way contains 15 samples) #n_way=5 and n_support=5, view(25,3,224,224) #x_var[:, :n_support,:,:,:].shape=(5,5,3,224,224) x_a_i = x_var[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) # support set #x_a_u.shape=25,3,224,224 ################################################################################################ # loss function and optimizer setting loss_fn = nn.CrossEntropyLoss().cuda() classifier_opt = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) if freeze_backbone is False: #for finetune use delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, pretrained_model.parameters()), lr=0.01) pretrained_model.cuda( ) #pretrained on "mini-ImageNet" instead of "ImageNet" classifier.cuda() ############################################################################################### # fine-tuning #In the fine-tuning or meta-testing stage for all methods, we average the results over 600 experiments. #In each experiment, we randomly sample 5 classes from novel classes, and in each class, we also #pick k instances for the support set and 16 for the query set. #For Baseline and Baseline++, we use the entire support set to train a new classifier for 100 iterations with a batch size of 4. #For meta-learning methods, we obtain the classification model conditioned on the support set total_epoch = 100 if freeze_backbone is False: #for finetune use pretrained_model.train() else: # if you don't want finetune pretrained_model.eval() classifier.train( ) #classifier should be dependent on task. Thus, we should update the classifier weights for epoch in range(total_epoch): #train classifier 100 epoch rand_id = np.random.permutation(support_size) #rand_id.shape=25 #support_size=25 #batch_size=4 # using "support set" to train the classifier (and fine-tune the backbone). for j in range(0, support_size, batch_size): #support_size=25, batch_size=4 classifier_opt.zero_grad() #clear classifier optimizer if freeze_backbone is False: #for finetune use delta_opt.zero_grad() #update feature extractor selected_id = torch.from_numpy( rand_id[j:min(j + batch_size, support_size)]).cuda( ) #fetch only 4 elements #x_a_i.shape=25,3,224,224 #y_a_i.shape=25 z_batch = x_a_i[ selected_id] #sample 4 inputs randomly from support set data #z_batch.shape=4,3,224,224 #y_a_i=[0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,3,3,3,3,3,4,4,4,4,4] y_batch = y_a_i[ selected_id] #sample 4 labels randomly from support set label #y_batch.shape=4 output = pretrained_model(z_batch) #feature output = classifier(output) #predictions loss = loss_fn(output, y_batch) loss.backward() classifier_opt.step() #update classifier optimizer if freeze_backbone is False: #for finetune use delta_opt.step() #update extractor ############################################################################################## # inference pretrained_model.eval() classifier.eval() output = pretrained_model(x_b_i.cuda()) #features scores = classifier(output) #predictions y_query = np.repeat(range(n_way), n_query) #shape=(75) #y_query=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, # 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, # 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, # 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] topk_scores, topk_labels = scores.data.topk(1, 1, True, True) #the 1st argument means return top-1 #the 2nd argument dim=1 means return the value row-wisely #the 3rd arguemtn is largest=True #the 4th argument is sorted=True #topk_labels=[[1],[1], ..., [0],[0]] with shape (75,1) cuz batch=75 topk_ind = topk_labels.cpu().numpy() top1_correct = np.sum(topk_ind[:, 0] == y_query) correct_this, count_this = float(top1_correct), len(y_query) acc_all.append((correct_this / count_this * 100)) print("Task %d : %4.2f%% Now avg: %4.2f%%" % (ti, correct_this / count_this * 100, np.mean(acc_all))) ############################################################################################### acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
batch_size=params.bs, shuffle=False, num_workers=0) config = { 'epsilon': 8.0 / 255, 'num_steps': 5, 'step_size': 2.0 / 255, 'random_start': True, 'loss_func': 'xent', } teacher = model.Model(net='_'.join(params.teacher.split('_')[:-1]), num_classes=params.num_classes) teacher_dir = '%s/%s/teacher/%s' % (SAVE_DIR, params.dataset, params.teacher) teacher_file = get_resume_file(teacher_dir) print('Teacher file:', teacher_file) tmp = torch.load(teacher_file) teacher.feature.load_state_dict(tmp['feature']) teacher.classifier.load_state_dict(tmp['classifier']) teacher.eval() model = model.Model(net=params.model, num_classes=params.num_classes) optimization = 'Adam' params.checkpoint_dir = '%s/%s/student2/%s_%s_%s' % ( SAVE_DIR, params.dataset, params.model, params.method, params.teacher) if params.exp != 'gbp': params.checkpoint_dir += '_%s' % (params.exp) if params.e != 8.0: params.checkpoint_dir += '_eps{}'.format(params.e)
def train_s2m2(base_loader, base_loader_test, model, params, tmp): def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) criterion = nn.CrossEntropyLoss() rotate_classifier = nn.Sequential(nn.Linear(640, 4)) rotate_classifier.to(device) model.to(device) if 'rotate' in tmp: print("loading rotate model") rotate_classifier.load_state_dict(tmp['rotate']) optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': rotate_classifier.parameters() }]) start_epoch, stop_epoch = params.start_epoch, params.start_epoch + params.stop_epoch if params.resume: checkpoint = get_resume_file(params.checkpoint_dir) print('resumefile: {}'.format(checkpoint)) checkpoint = torch.load(checkpoint, map_location=device) model.load_state_dict(checkpoint['state']) start_epoch = checkpoint['epoch'] print('Model loaded') print("stop_epoch", start_epoch, stop_epoch) for epoch in range(start_epoch, stop_epoch): print('\nEpoch: %d' % epoch) model.train() train_loss, rotate_loss = 0, 0 correct, total = 0, 0 torch.cuda.empty_cache() for batch_idx, (inputs, targets) in enumerate(base_loader): inputs, targets = inputs.to(device), targets.to(device) lam = np.random.beta(params.alpha, params.alpha) f, outputs, target_a, target_b = model(inputs, targets, mixup_hidden=True, mixup_alpha=params.alpha, lam=lam) loss = mixup_criterion(criterion, outputs, target_a, target_b, lam) train_loss += loss.data.item() optimizer.zero_grad() loss.backward() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += ( lam * predicted.eq(target_a.data).cpu().sum().float() + (1 - lam) * predicted.eq(target_b.data).cpu().sum().float()) bs = inputs.size(0) inputs_, targets_, a_ = [], [], [] indices = np.arange(bs) np.random.shuffle(indices) split_size = int(bs / 4) for j in indices[0:split_size]: x90 = inputs[j].transpose(2, 1).flip(1) x180 = x90.transpose(2, 1).flip(1) x270 = x180.transpose(2, 1).flip(1) inputs_ += [inputs[j], x90, x180, x270] targets_ += [targets[j] for _ in range(4)] a_ += [ torch.tensor(0), torch.tensor(1), torch.tensor(2), torch.tensor(3) ] inputs = Variable(torch.stack(inputs_, 0)) targets = Variable(torch.stack(targets_, 0)) a_ = Variable(torch.stack(a_, 0)) inputs, targets, a_ = inputs.to(device), targets.to(device), a_.to( device) rf, outputs = model(inputs) rotate_outputs = rotate_classifier(rf) rloss = criterion(rotate_outputs, a_) closs = criterion(outputs, targets) loss = (rloss + closs) / 2.0 rotate_loss += rloss.data.item() loss.backward() optimizer.step() if (batch_idx + 1) % 50 == 0: print( '{0}/{1}'.format(batch_idx, len(base_loader)), 'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f ' % (train_loss / (batch_idx + 1), 100. * correct / total, rotate_loss / (batch_idx + 1))) if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) print('Model saved') test_s2m2(base_loader_test, model, criterion) return model
if params.train_aug: params.checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.method == 'maml' or params.method == 'maml_approx': stop_epoch = params.stop_epoch * model.n_task #maml use multiple tasks in one update if params.resume: resume_file = get_resume_file(params.checkpoint_dir) if resume_file is not None: tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 model.load_state_dict(tmp['state']) elif params.warmup: #We also support warmup from pretrained baseline feature, but we never used in our paper baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, 'baseline') if params.train_aug: baseline_checkpoint_dir += '_aug' warmup_resume_file = get_resume_file(baseline_checkpoint_dir) tmp = torch.load(warmup_resume_file) if tmp is not None: state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys):
def finetune(novel_loader, n_query=15, freeze_backbone=False, n_way=5, n_support=5, loadpath='', adaptation=False, pretrained_dataset='miniImagenet', proto_init=False): correct = 0 count = 0 iter_num = len(novel_loader) acc_all = [] with tqdm(enumerate(novel_loader), total=len(novel_loader)) as pbar: for _, (x, y) in pbar: #, position=1, #leave=False): ############################################################################################### # load pretrained model on miniImageNet pretrained_model = model_dict[params.model]() checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s%s_%s%s' % ( configs.save_dir, params.dataset, params.model, params.method, params.n_support, "s" if params.no_aug_support else "s_aug", params.n_query, "q" if params.no_aug_query else "q_aug") checkpoint_dir += "_bs{}".format(params.batch_size) if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) elif params.method in ['baseline', 'baseline++']: modelfile = get_resume_file(checkpoint_dir) else: modelfile = get_best_file(checkpoint_dir) tmp = torch.load(modelfile) state = tmp['state'] state_keys = list(state.keys()) for _, key in enumerate(state_keys): if "feature." in key: newkey = key.replace( "feature.", "" ) # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' state[newkey] = state.pop(key) else: state.pop(key) pretrained_model.load_state_dict(state) pretrained_model.cuda() pretrained_model.train() ############################################################################################### if adaptation: classifier = Classifier(pretrained_model.final_feat_dim, n_way) classifier.cuda() classifier.train() else: classifier = ProtoClassifier(n_way, n_support, n_query) ############################################################################################### n_query = x.size(1) - n_support x = x.cuda() x_var = Variable(x) batch_size = n_way support_size = n_way * n_support y_a_i = Variable( torch.from_numpy(np.repeat(range(n_way), n_support))).cuda() # (25,) x_b_i = x_var[:, n_support:, :, :, :].contiguous().view( n_way * n_query, *x.size()[2:]) x_a_i = x_var[:, :n_support, :, :, :].contiguous().view( n_way * n_support, *x.size()[2:]) # (25, 3, 224, 224) pretrained_model.eval() z_a_i = pretrained_model(x_a_i.cuda()) pretrained_model.train() ############################################################################################### loss_fn = nn.CrossEntropyLoss().cuda() if adaptation: inner_lr = params.lr_rate if proto_init: # Initialise as distance classifer (distance to prototypes) classifier.init_params_from_prototypes( z_a_i, n_way, n_support) #classifier_opt = torch.optim.SGD(classifier.parameters(), lr = inner_lr, momentum=0.9, dampening=0.9, weight_decay=0.001) classifier_opt = torch.optim.Adam(classifier.parameters(), lr=inner_lr) if freeze_backbone is False: delta_opt = torch.optim.Adam(filter( lambda p: p.requires_grad, pretrained_model.parameters()), lr=inner_lr) total_epoch = params.ft_steps if freeze_backbone is False: pretrained_model.train() else: pretrained_model.eval() classifier.train() #for epoch in range(total_epoch): for epoch in tqdm(range(total_epoch), total=total_epoch, leave=False): rand_id = np.random.permutation(support_size) for j in range(0, support_size, batch_size): classifier_opt.zero_grad() if freeze_backbone is False: delta_opt.zero_grad() ##################################### selected_id = torch.from_numpy( rand_id[j:min(j + batch_size, support_size)]).cuda() z_batch = x_a_i[selected_id] y_batch = y_a_i[selected_id] ##################################### output = pretrained_model(z_batch) output = classifier(output) loss = loss_fn(output, y_batch) ##################################### loss.backward() classifier_opt.step() if freeze_backbone is False: delta_opt.step() classifier.eval() pretrained_model.eval() output = pretrained_model(x_b_i.cuda()) if adaptation: scores = classifier(output) else: scores = classifier(z_a_i, y_a_i, output) y_query = np.repeat(range(n_way), n_query) topk_scores, topk_labels = scores.data.topk(1, 1, True, True) topk_ind = topk_labels.cpu().numpy() top1_correct = np.sum(topk_ind[:, 0] == y_query) correct_this, count_this = float(top1_correct), len(y_query) #print (correct_this/ count_this *100) acc_all.append((correct_this / count_this * 100)) ############################################################################################### pbar.set_postfix(avg_acc=np.mean(np.asarray(acc_all))) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)))
loss_type=loss_type, **few_shot_params) elif params.method in ["dampnet_full_class"]: model = dampnet_full_class.DampNet(model_dict[params.model], **few_shot_params) elif params.method == "baseline": checkpoint_dir_b = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, pretrained_dataset, params.model, "baseline") if params.train_aug: checkpoint_dir_b += '_aug' if params.save_iter != -1: modelfile_b = get_assigned_file(checkpoint_dir_b, 400) elif params.method in ['baseline', 'baseline++']: modelfile_b = get_resume_file(checkpoint_dir_b) else: modelfile_b = get_best_file(checkpoint_dir_b) tmp_b = torch.load(modelfile_b) state_b = tmp_b['state'] elif params.method == "all": #model = ProtoNet( model_dict[params.model], **few_shot_params ) checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, 'miniImageNet', params.model, "protonet") model_2 = GnnNet(model_dict[params.model], **few_shot_params) checkpoint_dir2 = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, 'miniImageNet', params.model, "gnnnet") #model_3 = dampnet_full_class.DampNet( model_dict[params.model], **few_shot_params ) checkpoint_dir3 = '%s/checkpoints/%s/%s_%s' % (
def __init__(self, params): np.random.seed(10) if params.train_dataset == 'cross': base_file = configs.data_dir['miniImagenet'] + 'all.json' val_file = configs.data_dir['CUB'] + 'val.json' elif params.train_dataset == 'cross_char': base_file = configs.data_dir['omniglot'] + 'noLatin.json' val_file = configs.data_dir['emnist'] + 'val.json' else: base_file = configs.data_dir[params.train_dataset] + 'base.json' val_file = configs.data_dir[params.train_dataset] + 'val.json' if 'Conv' in params.model: if params.train_dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 if params.train_dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.train_dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.train_dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' params.train_num_query = max( 1, int(params.test_num_query * params.test_num_way / params.train_num_way)) if params.episodic: train_few_shot_params = dict(n_way=params.train_num_way, n_support=params.train_num_shot, n_query=params.train_num_query) base_datamgr = SetDataManager(image_size, **train_few_shot_params) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) else: base_datamgr = SimpleDataManager(image_size, batch_size=32) base_loader = base_datamgr.get_data_loader(base_file, aug=params.train_aug) if params.test_dataset == 'cross': novel_file = configs.data_dir['CUB'] + 'novel.json' elif params.test_dataset == 'cross_char': novel_file = configs.data_dir['emnist'] + 'novel.json' else: novel_file = configs.data_dir[params.test_dataset] + 'novel.json' val_datamgr = SimpleDataManager(image_size, batch_size=64) val_loader = val_datamgr.get_data_loader(novel_file, aug=False) novel_datamgr = SimpleDataManager(image_size, batch_size=64) novel_loader = novel_datamgr.get_data_loader(novel_file, aug=False) optimizer = params.optimizer if params.stop_epoch == -1: if params.train_dataset in ['omniglot', 'cross_char']: params.stop_epoch = 5 elif params.train_dataset in ['CUB']: params.stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting elif params.train_dataset in ['miniImagenet', 'cross']: params.stop_epoch = 300 else: params.stop_epoch = 300 shake_config = { 'shake_forward': params.shake_forward, 'shake_backward': params.shake_backward, 'shake_picture': params.shake_picture } train_param = { 'loss_type': params.train_loss_type, 'temperature': params.train_temperature, 'margin': params.train_margin, 'lr': params.train_lr, 'shake': params.shake, 'shake_config': shake_config, 'episodic': params.episodic, 'num_way': params.train_num_way, 'num_shot': params.train_num_shot, 'num_query': params.train_num_query, 'num_classes': params.num_classes } test_param = { 'loss_type': params.test_loss_type, 'temperature': params.test_temperature, 'margin': params.test_margin, 'lr': params.test_lr, 'num_way': params.test_num_way, 'num_shot': params.test_num_shot, 'num_query': params.test_num_query } model = Baseline(model_dict[params.model], params.entropy, train_param, test_param) model = model.cuda() key = params.tag writer = SummaryWriter(log_dir=os.path.join(params.vis_log, key)) params.checkpoint_dir = '%s/checkpoints/%s/%s' % ( configs.save_dir, params.train_dataset, params.checkpoint_dir) if not os.path.isdir(params.vis_log): os.makedirs(params.vis_log) outfile_template = os.path.join( params.checkpoint_dir.replace("checkpoints", "features"), "%s.hdf5") if params.mode == 'train' and not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if params.resume or params.mode == 'test': if params.mode == 'test': self.feature_model = model_dict[params.model]().cuda() resume_file = get_best_file(params.checkpoint_dir) tmp = torch.load(resume_file) state = tmp['state'] state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key: newkey = key.replace("feature.", "") state[newkey] = state.pop(key) else: state.pop(key) self.feature_model.load_state_dict(state) self.feature_model.eval() else: resume_file = get_resume_file(params.checkpoint_dir) tmp = torch.load(resume_file) state = tmp['state'] model.load_state_dict(state) params.start_epoch = tmp['epoch'] + 1 print('Info: Model loaded!!!') self.params = params self.val_file = val_file self.base_file = base_file self.image_size = image_size self.optimizer = optimizer self.outfile_template = outfile_template self.novel_loader = novel_loader self.base_loader = base_loader self.val_loader = val_loader self.writer = writer self.model = model self.key = key