def test(test_data, class_names, optCLF, model, args, num_episodes, verbose=True): ''' Evaluate the model on a bag of sampled tasks. Return the mean accuracy and its std. ''' model['G'].train() model['clf'].train() acc = [] for ep in range(num_episodes): # if args.embedding == 'mlada': # acc1, d_acc1, sentence_ebd, avg_sentence_ebd, sentence_label, word_weight, query_data, x_hat = test_one(task, model, args) # if count < 20: # if all_sentence_ebd is None: # all_sentence_ebd = sentence_ebd # all_avg_sentence_ebd = avg_sentence_ebd # all_sentence_label = sentence_label # all_word_weight = word_weight # all_query_data = query_data # all_x_hat = x_hat # else: # all_sentence_ebd = np.concatenate((all_sentence_ebd, sentence_ebd), 0) # all_avg_sentence_ebd = np.concatenate((all_avg_sentence_ebd, avg_sentence_ebd), 0) # all_sentence_label = np.concatenate((all_sentence_label, sentence_label)) # all_word_weight = np.concatenate((all_word_weight, word_weight), 0) # all_query_data = np.concatenate((all_query_data, query_data), 0) # all_x_hat = np.concatenate((all_x_hat, x_hat), 0) # count = count + 1 # acc.append(acc1) # d_acc.append(d_acc1) # else: # acc.append(test_one(task, model, args)) sampled_classes, source_classes = task_sampler(test_data, args) # class_names_dict = {} # class_names_dict['label'] = class_names['label'][sampled_classes] # class_names_dict['text'] = class_names['text'][sampled_classes] # class_names_dict['text_len'] = class_names['text_len'][sampled_classes] # class_names_dict['is_support'] = False train_gen = ParallelSampler(test_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() # class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_acc = test_one(task, class_names, model, optCLF, args, grad) acc.append(q_acc.cpu().item()) acc = np.array(acc) if verbose: if args.embedding != 'mlada': print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) else: print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) return np.mean(acc), np.std(acc)
def train(train_data, val_data, model, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath(os.path.join( os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None optG = torch.optim.Adam(grad_param(model, ['G', 'clf']), lr=args.lr_g) optD = torch.optim.Adam(grad_param(model, ['D']), lr=args.lr_d) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience//2, factor=0.1, verbose=True) schedulerD = torch.optim.lr_scheduler.ReduceLROnPlateau( optD, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR(optG, gamma=args.ExponentialLR_gamma) schedulerD = torch.optim.lr_scheduler.ExponentialLR(optD, gamma=args.ExponentialLR_gamma) print("{}, Start training".format( datetime.datetime.now()), flush=True) # train_gen = ParallelSampler(train_data, args, args.train_episodes) train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes) val_gen = ParallelSampler_Test(val_data, args, args.val_episodes) # sampled_classes, source_classes = task_sampler(train_data, args) for ep in range(args.train_epochs): sampled_classes, source_classes = task_sampler(train_data, args) train_gen = ParallelSampler(train_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'G': [], 'D': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) d_acc = 0 for task in sampled_tasks: if task is None: break d_acc += train_one(task, model, optG, optD, args, grad) d_acc = d_acc / args.train_episodes print("---------------ep:" + str(ep) + " d_acc:" + str(d_acc) + "-----------") if ep % 10 == 0: acc, std, _ = test(train_data, model, args, args.val_episodes, False, train_gen_val.get_epoch()) print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now(), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) # Evaluate validation accuracy cur_acc, cur_std, _ = test(val_data, model, args, args.val_episodes, False, val_gen.get_epoch()) print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['D'].state_dict(), best_path + '.D') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) schedulerD.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() schedulerD.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) model['D'].load_state_dict(torch.load(best_path + '.D')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath(os.path.join( os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format( datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['D'].state_dict(), best_path + '.D') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return
def train(train_data, val_data, model, class_names, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr) optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr) optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True) schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau( optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR( optG, gamma=args.ExponentialLR_gamma) schedulerCLF = torch.optim.lr_scheduler.ExponentialLR( optCLF, gamma=args.ExponentialLR_gamma) print("{}, Start training".format(datetime.datetime.now()), flush=True) # train_gen = ParallelSampler(train_data, args, args.train_episodes) # train_gen_val = ParallelSampler_Test(train_data, args, args.val_episodes) # val_gen = ParallelSampler_Test(val_data, args, args.val_episodes) # sampled_classes, source_classes = task_sampler(train_data, args) acc = 0 loss = 0 for ep in range(args.train_epochs): sampled_classes, source_classes = task_sampler(train_data, args) class_names_dict = {} class_names_dict['label'] = class_names['label'][sampled_classes] class_names_dict['text'] = class_names['text'][sampled_classes] class_names_dict['text_len'] = class_names['text_len'][sampled_classes] class_names_dict['is_support'] = False train_gen = ParallelSampler(train_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_loss, q_acc = train_one(task, class_names_dict, model, optG, optG2, optCLF, args, grad) acc += q_acc loss += q_loss if ep % 100 == 0: print("--------[TRAIN] ep:" + str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" + str(q_acc.item()) + "-----------") if (ep % 200 == 0) and (ep != 0): acc = acc / args.train_episodes / 200 loss = loss / args.train_episodes / 200 print("--------[TRAIN] ep:" + str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" + str(acc.item()) + "-----------") net = copy.deepcopy(model) acc, std = test(train_data, class_names, optG, optCLF, net, args, args.test_epochs, False) print( "[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( datetime.datetime.now(), "ep", ep, colored("train", "red"), colored("acc:", "blue"), acc, std, ), flush=True) acc = 0 loss = 0 # Evaluate validation accuracy cur_acc, cur_std = test(val_data, class_names, optG, optCLF, net, args, args.test_epochs, False) print(("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, colored("train stats", "cyan"), colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['G2'].state_dict(), best_path + '.G2') torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) schedulerCLF.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() schedulerCLF.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) model['G2'].load_state_dict(torch.load(best_path + '.G2')) model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format(datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return optG, optCLF