def test(test_data, model, args, num_episodes, verbose=True, sampled_tasks=None): ''' Evaluate the model on a bag of sampled tasks. Return the mean accuracy and its std. ''' # clone the original model fast_model = { 'ebd': copy.deepcopy(model['ebd']), 'clf': copy.deepcopy(model['clf']), } if sampled_tasks is None: sampled_tasks = ParallelSampler(test_data, args, num_episodes).get_epoch() acc = [] sampled_tasks = enumerate(sampled_tasks) if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=num_episodes, ncols=80, leave=False, desc=colored('Testing on val', 'yellow')) for i, task in sampled_tasks: if i == num_episodes and not args.notqdm: sampled_tasks.close() break _copy_weights(model['ebd'], fast_model['ebd']) _copy_weights(model['clf'], fast_model['clf']) acc.append(test_one(task, fast_model, args)) acc = np.array(acc) if verbose: print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), colored("acc mean", "blue"), np.mean(acc), colored("std", "blue"), np.std(acc), )) return np.mean(acc), np.std(acc)
def test(test_data, model, args, num_episodes, verbose=True, sampled_tasks=None): ''' Evaluate the model on a bag of sampled tasks. Return the mean accuracy and its std. ''' model['ebd'].eval() model['clf'].eval() if sampled_tasks is None: sampled_tasks = ParallelSampler(test_data, args, num_episodes).get_epoch() acc = [] if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=num_episodes, ncols=80, leave=False, desc=colored('Testing on val', 'yellow')) for task in sampled_tasks: acc.append(test_one(task, model, args)) acc = np.array(acc) if verbose: print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), colored("acc mean", "blue"), np.mean(acc), colored("std", "blue"), np.std(acc), ), flush=True) return np.mean(acc), np.std(acc)
def train(train_data, val_data, model, args): ''' Train the model (obviously~) ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None opt = torch.optim.Adam(grad_param(model, ['ebd', 'clf']), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'max', patience=5, factor=0.1, verbose=True) # clone the original model fast_model = { 'ebd': copy.deepcopy(model['ebd']), 'clf': copy.deepcopy(model['clf']), } print("{}, Start training".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'))) train_gen = ParallelSampler(train_data, args, args.train_episodes * args.maml_batchsize) val_gen = ParallelSampler(val_data, args, args.val_episodes) for ep in range(args.train_epochs): sampled_tasks = train_gen.get_epoch() meta_grad_dict = {'clf': [], 'ebd': []} train_episodes = range(args.train_episodes) if not args.notqdm: train_episodes = tqdm(train_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for _ in train_episodes: # update the initialization based on a batch of tasks total_grad = {'ebd': [], 'clf': []} for _ in range(args.maml_batchsize): task = next(sampled_tasks) # clone the current initialization _copy_weights(model['ebd'], fast_model['ebd']) _copy_weights(model['clf'], fast_model['clf']) # get the meta gradient train_one(task, fast_model, args, total_grad) ebd_grad, clf_grad = _meta_update(model, total_grad, opt, task, args.maml_batchsize, args.clip_grad) meta_grad_dict['ebd'].append(ebd_grad) meta_grad_dict['clf'].append(clf_grad) # evaluate training accuracy if ep % 10 == 0: acc, std = test(train_data, model, args, args.train_episodes * args.maml_batchsize, False, train_gen.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # evaluate validation accuracy cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("ebd_grad:", "blue"), np.mean(np.array(meta_grad_dict['ebd'])), colored("clf_grad:", "blue"), np.mean(np.array(meta_grad_dict['clf']))), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path)) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break print("{}, End of training. Restore the best weights".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'))) # restore the best saved model model['ebd'].load_state_dict(torch.load(best_path + '.ebd')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path), flush=True) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return
def train(train_data, val_data, model, args): ''' Train the model Use val_data to do early stopping Args: model (dict): {'ebd': embedding, 'clf': classifier} ''' # creating a tmp directory to save the models out_dir = os.path.abspath(os.path.join( os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) # Write results # write_acc_tr = 'acc_base.csv' # init_csv(write_acc_tr) # write_acc_val = 'val_acc_base.csv' # init_csv(write_acc_val) best_acc = 0 sub_cycle = 0 best_path = None # grad_param generates the learnable parameters from the classifier params_to_opt = grad_param(model, ['ebd', 'clf']) opt = torch.optim.Adam(params_to_opt, lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, 'max', patience=args.patience//2, factor=0.1, verbose=True) print("{}, Start training".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) train_gen = ParallelSampler(train_data, args, args.train_episodes) train_gen_val = ParallelSampler(train_data, args, args.val_episodes) val_gen = ParallelSampler(val_data, args, args.val_episodes) for ep in range(args.train_epochs): sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'ebd': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break train_one(task, model, opt, args, grad) if ep % 10 == 0: acc, std = test(train_data, model, args, args.val_episodes, False, train_gen_val.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # write_csv(write_acc_tr, acc, std, ep) # Evaluate validation accuracy cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("ebd_grad:", "blue"), np.mean(np.array(grad['ebd'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # if ep % 10 == 0: write_csv(write_acc_val, cur_acc, cur_std, ep) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path)) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break print("{}, End of training. Restore the best weights".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) # restore the best saved model model['ebd'].load_state_dict(torch.load(best_path + '.ebd')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath(os.path.join( os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), best_path), flush=True) torch.save(model['ebd'].state_dict(), best_path + '.ebd') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return