patience = 100 best_loss = 10000000 stop_count = 0 # Main loop for meta_iteration in range(config.epochs): ## save original weights to make the update weights_original = deepcopy(meta_net.state_dict()) train_loss_before = [] train_loss_meta = [] #loss accumulate from a batch of tasks batch_loss = 0 for _ in range(meta_batch_size): # Get task if config.fix_dialnum_train: train_iter, val_iter = p.get_balanced_loader( persona=tasks_iter.__next__(), batch_size=config.batch_size, split='train') else: train_iter, val_iter = p.get_data_loader( persona=tasks_iter.__next__(), batch_size=config.batch_size, split='train') #before first update v_loss, v_ppl = do_evaluation(meta_net, val_iter) train_loss_before.append(math.exp(v_loss)) # Update fast nets val_loss, v_ppl = do_learning_fix_step( meta_net, train_iter, val_iter, iterations=config.meta_iteration) train_loss_meta.append(math.exp(val_loss.item())) batch_loss += val_loss # log
# clip gradient nn.utils.clip_grad_norm_(meta_net.parameters(), config.max_grad_norm) meta_optimizer.step() print('Meta_iteration:', meta_iteration) val_loss_before = [] val_loss_meta = [] weights_original = deepcopy(meta_net.state_dict()) for idx, per in enumerate(p.get_personas('valid')): #num_of_dialog = p.get_num_of_dialog(persona=per, split='valid') #for dial_i in range(num_of_dialog): if config.fix_dialnum_train: train_iter, val_iter = p.get_balanced_loader( persona=per, batch_size=config.batch_size, split='valid', fold=0) else: train_iter, val_iter = p.get_data_loader( persona=per, batch_size=config.batch_size, split='valid', fold=0) # zero shot result loss, ppl = do_evaluation(meta_net, val_iter) val_loss_before.append(math.exp(loss)) # meta tuning val_loss, val_ppl = do_learning_fix_step( meta_net, train_iter, val_iter, iterations=config.meta_iteration)
# Build model, optimizer, and set states print("Test model", config.model) model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=False) fine_tune = [] iter_per_task = [] iterations = 11 weights_original = deepcopy(model.state_dict()) tasks = p.get_personas('test') for per in tqdm(tasks): num_of_dialog = p.get_num_of_dialog(persona=per, split='test') for val_dial_index in range(num_of_dialog): if config.fix_dialnum_train: train_iter, val_iter = p.get_balanced_loader( persona=per, batch_size=config.batch_size, split='test', fold=val_dial_index, dial_num=config.k_shot) else: train_iter, val_iter = p.get_data_loader( persona=per, batch_size=config.batch_size, split='test', fold=val_dial_index) logger = do_learning(model, train_iter, val_iter, iterations=iterations) fine_tune.append(logger) model.load_state_dict(