def test(test_data, class_names, optG, model, criterion, args, test_epoch, verbose=True): ''' Evaluate the model on a bag of sampled tasks. Return the mean accuracy and its std. ''' # model['G'].train() acc = [] for ep in range(test_epoch): sampled_classes, source_classes = task_sampler(test_data, args) train_gen = SerialSampler(test_data, args, sampled_classes, source_classes, 1) sampled_tasks = train_gen.get_epoch() for task in sampled_tasks: if task is None: break q_acc = test_one(task, class_names, model, optG, criterion, args, grad={}) acc.append(q_acc.cpu().item()) acc = np.array(acc) if verbose: if args.embedding != 'mlada': print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) else: print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) return np.mean(acc), np.std(acc)
def train(train_data, val_data, test_data, model, class_names, criterion, args): ''' Train the model Use val_data to do early stopping ''' # creating a tmp directory to save the models out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_acc = 0 sub_cycle = 0 best_path = None if args.STS == True: classes_sample_p, example_prob_metrix = pre_calculate( train_data, class_names, model['G'], args) else: classes_sample_p, example_prob_metrix = None, None optG = torch.optim.Adam(grad_param(model, ['G']), lr=args.meta_lr, weight_decay=args.weight_decay) # optG2 = torch.optim.Adam(grad_param(model, ['G2']), lr=args.task_lr) # optCLF = torch.optim.Adam(grad_param(model, ['clf']), lr=args.task_lr) if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau( optG, 'max', patience=args.patience // 2, factor=0.1, verbose=True) # schedulerCLF = torch.optim.lr_scheduler.ReduceLROnPlateau( # optCLF, 'max', patience=args.patience // 2, factor=0.1, verbose=True) elif args.lr_scheduler == 'ExponentialLR': schedulerG = torch.optim.lr_scheduler.ExponentialLR( optG, gamma=args.ExponentialLR_gamma) # schedulerCLF = torch.optim.lr_scheduler.ExponentialLR(optCLF, gamma=args.ExponentialLR_gamma) print("{}, Start training".format(datetime.datetime.now()), flush=True) # sampled_classes, source_classes = task_sampler(train_data, args) acc = 0 loss = 0 for ep in range(args.train_epochs): ep_loss = 0 for _ in range(args.train_episodes): sampled_classes, source_classes = task_sampler( train_data, args, classes_sample_p) train_gen = SerialSampler(train_data, args, sampled_classes, source_classes, 1, example_prob_metrix) sampled_tasks = train_gen.get_epoch() grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_loss, q_acc = train_one(task, class_names, model, optG, criterion, args, grad) acc += q_acc loss = loss + q_loss ep_loss = ep_loss + q_loss ep_loss = ep_loss / args.train_episodes optG.zero_grad() ep_loss.backward() optG.step() if ep % 100 == 0: print("{}:".format(colored('--------[TRAIN] ep', 'blue')) + str(ep) + ", loss:" + str(q_loss.item()) + ", acc:" + str(q_acc.item()) + "-----------") test_count = 100 # if (ep % test_count == 0) and (ep != 0): if (ep % test_count == 0): acc = acc / args.train_episodes / test_count loss = loss / args.train_episodes / test_count print("{}:".format(colored('--------[TRAIN] ep', 'blue')) + str(ep) + ", mean_loss:" + str(loss.item()) + ", mean_acc:" + str(acc.item()) + "-----------") net = copy.deepcopy(model) # acc, std = test(train_data, class_names, optG, net, criterion, args, args.test_epochs, False) # print("[TRAIN] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format( # datetime.datetime.now(), # "ep", ep, # colored("train", "red"), # colored("acc:", "blue"), acc, std, # ), flush=True) acc = 0 loss = 0 # Evaluate test accuracy cur_acc, cur_std = test(test_data, class_names, optG, net, criterion, args, args.test_epochs, False) print( ("[TEST] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "). format( datetime.datetime.now(), "ep", ep, colored("test ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, # colored("train stats", "cyan"), # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Evaluate validation accuracy cur_acc, cur_std = test(val_data, class_names, optG, net, criterion, args, args.test_epochs, False) print( ("[EVAL] {}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "). format( datetime.datetime.now(), "ep", ep, colored("val ", "cyan"), colored("acc:", "blue"), cur_acc, cur_std, # colored("train stats", "cyan"), # colored("G_grad:", "blue"), np.mean(np.array(grad['G'])), # colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])), ), flush=True) # Update the current best model if val acc is better if cur_acc > best_acc: best_acc = cur_acc best_path = os.path.join(out_dir, str(ep)) # save current model print("{}, Save cur best model to {}".format( datetime.datetime.now(), best_path)) torch.save(model['G'].state_dict(), best_path + '.G') # torch.save(model['G2'].state_dict(), best_path + '.G2') # torch.save(model['clf'].state_dict(), best_path + '.clf') sub_cycle = 0 else: sub_cycle += 1 # Break if the val acc hasn't improved in the past patience epochs if sub_cycle == args.patience: break if args.lr_scheduler == 'ReduceLROnPlateau': schedulerG.step(cur_acc) # schedulerCLF.step(cur_acc) elif args.lr_scheduler == 'ExponentialLR': schedulerG.step() # schedulerCLF.step() print("{}, End of training. Restore the best weights".format( datetime.datetime.now()), flush=True) # restore the best saved model model['G'].load_state_dict(torch.load(best_path + '.G')) # model['G2'].load_state_dict(torch.load(best_path + '.G2')) # model['clf'].load_state_dict(torch.load(best_path + '.clf')) if args.save: # save the current model out_dir = os.path.abspath( os.path.join(os.path.curdir, "saved-runs", str(int(time.time() * 1e7)))) if not os.path.exists(out_dir): os.makedirs(out_dir) best_path = os.path.join(out_dir, 'best') print("{}, Save best model to {}".format(datetime.datetime.now(), best_path), flush=True) torch.save(model['G'].state_dict(), best_path + '.G') # torch.save(model['clf'].state_dict(), best_path + '.clf') with open(best_path + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr, value)) return optG
def test(test_data, class_names, optG, model, criterion, args, num_episodes, verbose=True): ''' Evaluate the model on a bag of sampled tasks. Return the mean accuracy and its std. ''' model['G'].train() acc = [] for ep in range(num_episodes): if ep % 100 == 0: print(ep) # if args.embedding == 'mlada': # acc1, d_acc1, sentence_ebd, avg_sentence_ebd, sentence_label, word_weight, query_data, x_hat = test_one(task, model, args) # if count < 20: # if all_sentence_ebd is None: # all_sentence_ebd = sentence_ebd # all_avg_sentence_ebd = avg_sentence_ebd # all_sentence_label = sentence_label # all_word_weight = word_weight # all_query_data = query_data # all_x_hat = x_hat # else: # all_sentence_ebd = np.concatenate((all_sentence_ebd, sentence_ebd), 0) # all_avg_sentence_ebd = np.concatenate((all_avg_sentence_ebd, avg_sentence_ebd), 0) # all_sentence_label = np.concatenate((all_sentence_label, sentence_label)) # all_word_weight = np.concatenate((all_word_weight, word_weight), 0) # all_query_data = np.concatenate((all_query_data, query_data), 0) # all_x_hat = np.concatenate((all_x_hat, x_hat), 0) # count = count + 1 # acc.append(acc1) # d_acc.append(d_acc1) # else: # acc.append(test_one(task, model, args)) sampled_classes, source_classes = task_sampler(test_data, args) # class_names_dict = {} # class_names_dict['label'] = class_names['label'][sampled_classes] # class_names_dict['text'] = class_names['text'][sampled_classes] # class_names_dict['text_len'] = class_names['text_len'][sampled_classes] # class_names_dict['is_support'] = False train_gen = SerialSampler(test_data, args, sampled_classes, source_classes, args.train_episodes) sampled_tasks = train_gen.get_epoch() # class_names_dict = utils.to_tensor(class_names_dict, args.cuda, exclude_keys=['is_support']) grad = {'clf': [], 'G': []} if not args.notqdm: sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes, ncols=80, leave=False, desc=colored('Training on train', 'yellow')) for task in sampled_tasks: if task is None: break q_acc = test_one(task, class_names, model, optG, criterion, args, grad) acc.append(q_acc.cpu().item()) acc = np.array(acc) if verbose: if args.embedding != 'mlada': print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) else: print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now(), colored("test acc mean", "blue"), np.mean(acc), colored("test std", "blue"), np.std(acc), ), flush=True) return np.mean(acc), np.std(acc)