def single_test(params, results_logger): acc_all = [] iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'DKT': model = DKT(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model](flatten=False) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet(feature_model, loss_type=loss_type, **few_shot_params) elif params.method in ['maml', 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML(model_dict[params.model], approx=(params.method == 'maml_approx'), **few_shot_params) if params.dataset in ['omniglot', 'cross_char' ]: # maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++']: checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) # modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++']: if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) else: print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir)) split = params.split if params.save_iter != -1: split_str = split + "_" + str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx', 'DKT']: # maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split + '.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split + '.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader(loadfile, aug=False) if params.adaptation: model.task_update_num = 100 # We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints", "features"), split_str + ".hdf5" ) # defaut split = novel, but you can also test base or val classes cl_data_file = feat_loader.init_loader(novel_file) for i in range(iter_num): acc = feature_evaluation(cl_data_file, model, n_query=15, adaptation=params.adaptation, **few_shot_params) acc_all.append(acc) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) with open('record/results.txt', 'a') as f: timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) aug_str = '-aug' if params.train_aug else '' aug_str += '-adapted' if params.adaptation else '' if params.method in ['baseline', 'baseline++']: exp_setting = '%s-%s-%s-%s%s %sshot %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.test_n_way) else: exp_setting = '%s-%s-%s-%s%s %sshot %sway_train %sway_test' % ( params.dataset, split_str, params.model, params.method, aug_str, params.n_shot, params.train_n_way, params.test_n_way) acc_str = '%d Test Acc = %4.2f%% +- %4.2f%%' % ( iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num)) f.write('Time: %s, Setting: %s, Acc: %s \n' % (timestamp, exp_setting, acc_str)) results_logger.log("single_test_acc", acc_mean) results_logger.log("single_test_acc_std", 1.96 * acc_std / np.sqrt(iter_num)) results_logger.log("time", timestamp) results_logger.log("exp_setting", exp_setting) results_logger.log("acc_str", acc_str) return acc_mean
def train_baseline(base_loader, base_loader_test, val_loader, model, start_epoch, stop_epoch, params, tmp): if params.dct_status: channels = params.channels else: channels = 3 val_acc_best = 0.0 if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'): loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset + '.pt') else: loader = [] for ii, (x, _) in enumerate(val_loader): loader.append(x) #print("head of train_dct: ", x.shape) torch.save(loader, params.checkpoint_dir + '/val_' + params.dataset + '.pt') criterion = nn.CrossEntropyLoss().cuda() optimizer = optim.Adam(model.parameters()) print("stop_epoch", start_epoch, stop_epoch) for epoch in range(start_epoch, stop_epoch): print('\nEpoch: %d' % epoch) model.train() train_loss = 0 reg_loss = 0 correct = 0 correct1 = 0.0 total = 0 for batch_idx, (input_var, target_var) in enumerate(base_loader): if use_gpu: input_var, target_var = input_var.cuda(), target_var.cuda() input_dct_var, target_var = Variable(input_var), Variable( target_var) f, outputs = model.forward(input_dct_var) loss = criterion(outputs, target_var) train_loss += loss.data.item() _, predicted = torch.max(outputs.data, 1) total += target_var.size(0) correct += predicted.eq(target_var.data).cpu().sum() optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 50 == 0: print( '{0}/{1}'.format(batch_idx, len(base_loader)), 'Loss: %.3f | Acc: %.3f%% ' % (train_loss / (batch_idx + 1), 100. * correct / total)) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) model.eval() with torch.no_grad(): test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(base_loader_test): if use_gpu: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs), Variable(targets) f, outputs = model.forward(inputs) loss = criterion(outputs, targets) test_loss += loss.data.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() print('Loss: %.3f | Acc: %.3f%%' % (test_loss / (batch_idx + 1), 100. * correct / total)) torch.cuda.empty_cache() valmodel = BaselineFinetune(model_dict[params.model], params.train_n_way, params.n_shot, loss_type='dist') valmodel.n_query = 15 acc_all1, acc_all2, acc_all3 = [], [], [] for i, x in enumerate(loader): # print("len of loader: ",len(loader)) # print("shape of x: ",x.shape) if params.dct_status: x = x.view(-1, channels, image_size_dct, image_size_dct) else: x = x.view(-1, channels, image_size, image_size) if use_gpu: x = x.cuda() with torch.no_grad(): f, scores = model(x) f = f.view(params.train_n_way, params.n_shot + valmodel.n_query, -1) scores = valmodel.set_forward_adaptation(f.cpu()) acc = [] for each_score in scores: pred = each_score.data.cpu().numpy().argmax(axis=1) y = np.repeat(range(5), 15) acc.append(np.mean(pred == y) * 100) acc_all1.append(acc[0]) acc_all2.append(acc[1]) acc_all3.append(acc[2]) print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1))) print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2))) print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3))) if np.mean(acc_all3) > val_acc_best: val_acc_best = np.mean(acc_all3) bestfile = os.path.join(params.checkpoint_dir, 'best.tar') torch.save({'epoch': epoch, 'state': model.state_dict()}, bestfile) return model
def train_s2m2(base_loader, base_loader_test, val_loader, model, start_epoch, stop_epoch, params, tmp): if params.dct_status: channels = params.channels else: channels = 3 val_acc_best = 0.0 if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if path.exists(params.checkpoint_dir + '/val_' + params.dataset + '.pt'): loader = torch.load(params.checkpoint_dir + '/val_' + params.dataset + '.pt') else: loader = [] for _, (x, _) in enumerate(val_loader): loader.append(x) torch.save(loader, params.checkpoint_dir + '/val_' + params.dataset + '.pt') def mixup_criterion(criterion, pred, y_a, y_b, lam): return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) criterion = nn.CrossEntropyLoss() if params.model == 'WideResNet28_10': rotate_classifier = nn.Sequential(nn.Linear(640, 4)) elif params.model == 'ResNet18': rotate_classifier = nn.Sequential(nn.Linear(512, 4)) rotate_classifier.cuda() if 'rotate' in tmp: print("loading rotate model") rotate_classifier.load_state_dict(tmp['rotate']) optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': rotate_classifier.parameters() }]) print("stop_epoch", start_epoch, stop_epoch) for epoch in range(start_epoch, stop_epoch): print('\nEpoch: %d' % epoch) model.train() train_loss = 0 rotate_loss = 0 correct = 0 total = 0 torch.cuda.empty_cache() print("inside base_loader: ", len(base_loader)) for batch_idx, (inputs, targets) in enumerate(base_loader): if use_gpu: inputs, targets = inputs.cuda(), targets.cuda() #print("shape of input: ", inputs.shape) lam = np.random.beta(params.alpha, params.alpha) f, outputs, target_a, target_b = model(inputs, targets, mixup_hidden=True, mixup_alpha=params.alpha, lam=lam) loss = mixup_criterion(criterion, outputs, target_a, target_b, lam) train_loss += loss.data.item() optimizer.zero_grad() loss.backward() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += ( lam * predicted.eq(target_a.data).cpu().sum().float() + (1 - lam) * predicted.eq(target_b.data).cpu().sum().float()) bs = inputs.size(0) inputs_ = [] targets_ = [] a_ = [] indices = np.arange(bs) np.random.shuffle(indices) split_size = int(bs / 4) for j in indices[0:split_size]: x90 = inputs[j].transpose(2, 1).flip(1) x180 = x90.transpose(2, 1).flip(1) x270 = x180.transpose(2, 1).flip(1) inputs_ += [inputs[j], x90, x180, x270] targets_ += [targets[j] for _ in range(4)] a_ += [ torch.tensor(0), torch.tensor(1), torch.tensor(2), torch.tensor(3) ] inputs = Variable(torch.stack(inputs_, 0)) targets = Variable(torch.stack(targets_, 0)) a_ = Variable(torch.stack(a_, 0)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() a_ = a_.cuda() rf, outputs = model(inputs) rotate_outputs = rotate_classifier(rf) rloss = criterion(rotate_outputs, a_) closs = criterion(outputs, targets) loss = (rloss + closs) / 2.0 rotate_loss += rloss.data.item() loss.backward() optimizer.step() if batch_idx % 50 == 0: print( '{0}/{1}'.format(batch_idx, len(base_loader)), 'Loss: %.3f | Acc: %.3f%% | RotLoss: %.3f ' % (train_loss / (batch_idx + 1), 100. * correct / total, rotate_loss / (batch_idx + 1))) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) model.eval() with torch.no_grad(): test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(base_loader_test): if use_gpu: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs), Variable(targets) f, outputs = model.forward(inputs) loss = criterion(outputs, targets) test_loss += loss.data.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() print('Loss: %.3f | Acc: %.3f%%' % (test_loss / (batch_idx + 1), 100. * correct / total)) if params.dct_status: valmodel = BaselineFinetune(model_dict[params.model + '_dct'], params.train_n_way, params.n_shot, loss_type='dist') else: valmodel = BaselineFinetune(model_dict[params.model], params.train_n_way, params.n_shot, loss_type='dist') valmodel.n_query = 15 acc_all1, acc_all2, acc_all3 = [], [], [] for i, x in enumerate(loader): if params.dct_status: x = x.view(-1, channels, image_size_dct, image_size_dct) else: x = x.view(-1, channels, image_size, image_size) if use_gpu: x = x.cuda() with torch.no_grad(): f, scores = model(x) f = f.view(params.train_n_way, params.n_shot + valmodel.n_query, -1) scores = valmodel.set_forward_adaptation(f.cpu()) acc = [] for each_score in scores: pred = each_score.data.cpu().numpy().argmax(axis=1) y = np.repeat(range(5), 15) acc.append(np.mean(pred == y) * 100) acc_all1.append(acc[0]) acc_all2.append(acc[1]) acc_all3.append(acc[2]) print('Test Acc at 100= %4.2f%%' % (np.mean(acc_all1))) print('Test Acc at 200= %4.2f%%' % (np.mean(acc_all2))) print('Test Acc at 300= %4.2f%%' % (np.mean(acc_all3))) if np.mean(acc_all3) > val_acc_best: val_acc_best = np.mean(acc_all3) bestfile = os.path.join(params.checkpoint_dir, 'best.tar') torch.save( { 'epoch': epoch, 'state': model.state_dict(), 'rotate': rotate_classifier.state_dict() }, bestfile) return model
datamgr = SetDataManager(image_size, n_eposide=iter_num, n_query=15, **few_shot_params, isAircraft=isAircraft, grey=params.grey) loadfile = os.path.join('filelists', params.dataset, 'novel.json') novel_loader = datamgr.get_data_loader(loadfile, aug=False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) else: if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) if params.save_iter != -1: outfile = os.path.join( checkpoint_dir.replace("checkpoints", "features"), split + "_" + str(params.save_iter) + ".hdf5") else: outfile = os.path.join( checkpoint_dir.replace("checkpoints", "features"), split + ".hdf5") datamgr = SimpleDataManager(image_size,
if __name__ == '__main__': params = parse_args('test') set_cuda(params.cuda) acc_all = [] iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'baseline++': model = BaselineFinetune(model_dict[params.model], loss_type='dist', **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP
def get_model(params, mode): ''' Args: params: argparse params mode: (str), 'train', 'test' ''' print('get_model() start...') # few_shot_params_d = get_few_shot_params(params, None) # few_shot_params = few_shot_params_d[mode] few_shot_params = get_few_shot_params(params, mode) if 'omniglot' in params.dataset or 'cross_char' in params.dataset: # if params.dataset in ['omniglot', 'cross_char', 'cross_char_half', 'cross_char_quarter', ...]: # assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation' assert 'Conv4' in params.model and not params.train_aug ,'omniglot/cross_char only support Conv4 without augmentation' params.model = params.model.replace('Conv4', 'Conv4S') # because Conv4Drop should also be Conv4SDrop if params.recons_decoder is not None: if 'ConvS' not in params.recons_decoder: raise ValueError('omniglot / cross_char should use ConvS/HiddenConvS decoder.') # if mode == 'train': # params.num_classes = n_base_class_map[params.dataset] if params.method in ['baseline', 'baseline++'] and mode=='train': assert params.num_classes >= n_base_classes[params.dataset] # if params.dataset == 'omniglot': # 4112/688/1692 # assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char': # 1597/31/31 # assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char_half': # 758/31/31 # assert params.num_classes >= 758, 'class number need to be larger than max label id in base class' # if params.dataset in ['cross_char_quarter', 'cross_char_quarter_10shot']: # 350/31/31 # assert params.num_classes >= 350, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_char_base3lang': # 69/31/31 # assert params.num_classes >= 69, 'class number need to be larger than max label id in base class' # if params.dataset == 'miniImagenet': # 64/16/20 # assert params.num_classes >= 64, 'class number need to be larger than max label id in base class' # if params.dataset == 'CUB': # 100/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross': # 64+16+20/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' # if params.dataset == 'cross_base80cl': # 80/50/50 # assert params.num_classes >= 100, 'class number need to be larger than max label id in base class' if params.recons_decoder == None: print('params.recons_decoder == None') recons_decoder = None else: recons_decoder = decoder_dict[params.recons_decoder] print('recons_decoder:\n',recons_decoder) backbone_func = get_backbone_func(params) if 'baseline' in params.method: loss_types = { 'baseline':'softmax', 'baseline++':'dist', } loss_type = loss_types[params.method] if recons_decoder is None and params.min_gram is None: # default baseline/baseline++ if mode == 'train': model = BaselineTrain( model_func = backbone_func, loss_type = loss_type, num_class = params.num_classes, **few_shot_params) elif mode == 'test': model = BaselineFinetune( model_func = backbone_func, loss_type = loss_type, **few_shot_params, finetune_dropout_p = params.finetune_dropout_p) else: # other settings for baseline if params.min_gram is not None: min_gram_params = { 'min_gram':params.min_gram, 'lambda_gram':params.lambda_gram, } if mode == 'train': model = BaselineTrainMinGram( model_func = backbone_func, loss_type = loss_type, num_class = params.num_classes, **few_shot_params, **min_gram_params) elif mode == 'test': model = BaselineFinetune( model_func = backbone_func, loss_type = loss_type, **few_shot_params, finetune_dropout_p = params.finetune_dropout_p) # model = BaselineFinetuneMinGram(backbone_func, loss_type = loss_type, **few_shot_params, **min_gram_params) elif params.method == 'protonet': # default ProtoNet if recons_decoder is None and params.min_gram is None: model = ProtoNet( backbone_func, **few_shot_params ) else: # other settings if params.min_gram is not None: min_gram_params = { 'min_gram':params.min_gram, 'lambda_gram':params.lambda_gram, } model = ProtoNetMinGram(backbone_func, **few_shot_params, **min_gram_params) if params.recons_decoder is not None: if 'Hidden' in params.recons_decoder: if params.recons_decoder == 'HiddenConv': # 'HiddenConv', 'HiddenConvS' model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2) elif params.recons_decoder == 'HiddenConvS': # 'HiddenConv', 'HiddenConvS' model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 2, is_color=False) elif params.recons_decoder == 'HiddenRes10': model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 6) elif params.recons_decoder == 'HiddenRes18': model = ProtoNetAE2(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, extract_layer = 8) else: if 'ConvS' in params.recons_decoder: model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=False) else: model = ProtoNetAE(backbone_func, **few_shot_params, recons_func=recons_decoder, lambda_d=params.recons_lambda, is_color=True) elif params.method == 'matchingnet': model = MatchingNet( backbone_func, **few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: # if params.model == 'Conv4': # feature_model = backbone.Conv4NP # elif params.model == 'Conv6': # feature_model = backbone.Conv6NP # elif params.model == 'Conv4S': # feature_model = backbone.Conv4SNP # else: # feature_model = lambda: model_dict[params.model]( flatten = False ) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( backbone_func, loss_type = loss_type , **few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( backbone_func, approx = (params.method == 'maml_approx') , **few_shot_params ) if 'omniglot' in params.dataset or 'cross_char' in params.dataset: # if params.dataset in ['omniglot', 'cross_char', 'cross_char_half']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unexpected params.method: %s'%(params.method)) print('get_model() finished.') return model
return acc if __name__ == '__main__': params = parse_args('test') acc_all = [] if params.dataset == 'CUB': iter_num = 600 else: iter_num = 10000 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) model = BaselineFinetune(model_dict[params.model], **few_shot_params) if torch.cuda.is_available(): model = model.cuda() if params.dct_status == False: params.channels = 3 checkpoint_dir = '%s/checkpoints/%s/%s_%s_%sway_%sshot' % ( configs.save_dir, params.dataset, params.model, params.method, params.test_n_way, params.n_shot) if params.train_aug: checkpoint_dir += '_aug' if params.dct_status: checkpoint_dir += '_dct'
def get_logits_targets(params): acc_all = [] iter_num = 600 few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune( model_dict[params.model], **few_shot_params ) elif params.method == 'baseline++': model = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params ) elif params.method == 'protonet': model = ProtoNet( model_dict[params.model], **few_shot_params ) elif params.method == 'DKT': model = DKT(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet( model_dict[params.model], **few_shot_params ) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP else: feature_model = lambda: model_dict[params.model]( flatten = False ) loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( feature_model, loss_type = loss_type , **few_shot_params ) elif params.method in ['maml' , 'maml_approx']: backbone.ConvBlock.maml = True backbone.SimpleBlock.maml = True backbone.BottleneckBlock.maml = True backbone.ResNet.maml = True model = MAML( model_dict[params.model], approx = (params.method == 'maml_approx') , **few_shot_params ) if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot model.n_task = 32 model.task_update_num = 1 model.train_lr = 0.1 else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline', 'baseline++'] : checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot) #modelfile = get_resume_file(checkpoint_dir) if not params.method in ['baseline', 'baseline++'] : if params.save_iter != -1: modelfile = get_assigned_file(checkpoint_dir,params.save_iter) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) model.load_state_dict(tmp['state']) else: print("[WARNING] Cannot find 'best_file.tar' in: " + str(checkpoint_dir)) split = params.split if params.save_iter != -1: split_str = split + "_" +str(params.save_iter) else: split_str = split if params.method in ['maml', 'maml_approx', 'DKT']: #maml do not support testing with feature if 'Conv' in params.model: if params.dataset in ['omniglot', 'cross_char']: image_size = 28 else: image_size = 84 else: image_size = 224 datamgr = SetDataManager(image_size, n_eposide = iter_num, n_query = 15 , **few_shot_params) if params.dataset == 'cross': if split == 'base': loadfile = configs.data_dir['miniImagenet'] + 'all.json' else: loadfile = configs.data_dir['CUB'] + split +'.json' elif params.dataset == 'cross_char': if split == 'base': loadfile = configs.data_dir['omniglot'] + 'noLatin.json' else: loadfile = configs.data_dir['emnist'] + split +'.json' else: loadfile = configs.data_dir[params.dataset] + split + '.json' novel_loader = datamgr.get_data_loader( loadfile, aug = False) if params.adaptation: model.task_update_num = 100 #We perform adaptation on MAML simply by updating more times. model.eval() logits_list = list() targets_list = list() for i, (x,_) in enumerate(novel_loader): logits = model.get_logits(x).detach() targets = torch.tensor(np.repeat(range(params.test_n_way), model.n_query)).cuda() logits_list.append(logits) #.cpu().detach().numpy()) targets_list.append(targets) #.cpu().detach().numpy()) else: novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5") cl_data_file = feat_loader.init_loader(novel_file) logits_list = list() targets_list = list() n_query = 15 n_way = few_shot_params['n_way'] n_support = few_shot_params['n_support'] class_list = cl_data_file.keys() for i in range(iter_num): #---------------------- select_class = random.sample(class_list,n_way) z_all = [] for cl in select_class: img_feat = cl_data_file[cl] perm_ids = np.random.permutation(len(img_feat)).tolist() z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] ) # stack each batch z_all = torch.from_numpy(np.array(z_all)) model.n_query = n_query logits = model.set_forward(z_all, is_feature = True).detach() targets = torch.tensor(np.repeat(range(n_way), n_query)).cuda() logits_list.append(logits) targets_list.append(targets) #---------------------- return torch.cat(logits_list, 0), torch.cat(targets_list, 0)
iter_num = 600 few_shot_args = dict(n_way=args.test_n_way, n_support=args.n_shot) l3_model = None # Always load vocab for compatibility if args.dataset == 'CUB': vocab = lang_utils.load_vocab(configs.lang_dir) else: vocab, *_ = lang_utils.load_scenes_vocab( './filelists/scenes/AbstractScenes_v1.1/Sentences_1002.txt') if args.method == 'baseline': model = BaselineFinetune(model_dict[args.model], **few_shot_args) elif args.method == 'baseline++': model = BaselineFinetune(model_dict[args.model], loss_type='dist', **few_shot_args) elif args.method == 'protonet': lang_model = None if args.ml2 or args.l3: # Add sos/eos tokens to vocabulary embedding_model = nn.Embedding(len(vocab), args.lang_emb_size) lang_input_size = 1600 if args.model.startswith('Conv') else 512 if args.language_task == 'decode': # FIXME: This won't work for all models lang_model_func = lambda inpsize: TextProposal( embedding_model, input_size=inpsize,
if params.dataset == 'omniglot': assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class' if params.dataset == 'cross_char': assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class' if params.method == 'baseline': train_model = BaselineTrain(model_dict[params.model], params.num_classes, loss_type=params.loss_type, margin=params.margin, centered=params.centered, temperature=params.temperature) test_model = BaselineFinetune(model_dict[params.model], params.num_classes, params.n_shot, loss_type=params.loss_type, margin=params.margin, centered=params.centered, temperature=params.temperature) elif params.method in [ 'simplenet', 'protonet', 'matchingnet', 'relationnet', 'relationnet_softmax', 'maml', 'maml_approx' ]: n_query = max( 1, int(16 * params.test_n_way / params.train_n_way) ) # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot) base_datamgr = SetDataManager(image_size,
pred = scores.data.cpu().numpy().argmax(axis=1) y = np.repeat(range(n_way), n_query) acc = np.mean(pred == y) * 100 return acc if __name__ == '__main__': params = parse_args('test') acc_all = [] iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) else: raise ValueError('Unknown method') model = model.cuda() checkpoint_dir = '%s/checkpoints/%s/%s_%s' % ( configs.save_dir, 'miniImageNet', params.model, params.method) if params.train_aug: checkpoint_dir += '_aug' if not params.method in ['baseline']: checkpoint_dir += '_%dway_%dshot' % (params.train_n_way, params.n_shot) if not params.method in ['baseline']:
iter_num = 600 few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot) params.reminder = "%s_TempR-%s_%s_Margin-%s_%s" % ( params.model, params.ratio, params.loss_type, params.large_margin, time.time()) if params.dataset in ['omniglot', 'cross_char']: assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation' params.model = 'Conv4S' if params.method == 'baseline': model = BaselineFinetune(model_dict[params.model], ratio=params.ratio, loss_type=params.loss_type, margin=params.large_margin, centered=params.centered, **few_shot_params) elif params.method == 'simplenet': model = SimpleNetFinetune(model_dict[params.model], **few_shot_params) elif params.method == 'protonet': model = ProtoNet(model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet(model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP elif params.model == 'Conv4S': feature_model = backbone.Conv4SNP