def __call__(self, eval_desc="reorder"): """ Args: eval_desc: Returns: eval the target bleu and origin bleu for paraphrase model """ training = self.model.training self.model.eval() dev_log_dict = {} inp_examples, inp_ids = batchify_examples(examples=self.eval_set, sort_key=self.sort_key, batch_size=self.batch_size) eval_start = time.time() for batch_examples in inp_examples: batch_ret = self.model.get_loss(examples=batch_examples, return_enc_state=False, train_iter=-1) dev_log_dict = update_tracker(batch_ret, dev_log_dict) use_time = time.time() - eval_start self.model.training = training return { 'ACC': dev_log_dict['Acc'].mean().item(), 'Relax_ACC': dev_log_dict['relax_correct'] * 100.0 / dev_log_dict['count'], 'Relative_ACC': dev_log_dict['relative_correct'] * 100.0 / dev_log_dict['relative_count'], 'Binary_ACC': dev_log_dict['binary_correct'] * 100.0 / dev_log_dict['relative_count'], 'EVAL TIME': use_time, "EVAL SPEED": len(self.eval_set) / use_time }
def eval_elbo(self, eval_desc='vae-elbo', eval_step=None): model = self.model args = self.model.args training = model.training model.eval() step = eval_step if eval_step is not None else 2 * args.x0 ret_track = {} batch_examples, _ = batchify_examples(examples=self.eval_set, batch_size=self.batch_size, sort=False) for batch in batch_examples: ret_loss = model.get_loss(batch, step) ret_track = update_tracker(ret_loss, ret_track) if self.write_down: write_result(ret_track, fname=os.path.join(self.out_dir, eval_desc + ".score")) model.training = training return ret_track
def evaluate_reconstruction(self, examples, eval_desc, eval_step=None, write_down=True): training = self.vae.training self.vae.eval() step = eval_step if eval_step is not None else 2 * self.vae.args.x0 eval_results = evaluate(examples, self.vae, sort_key='src', eval_tgt='src', batch_size=self.eval_batch_size) pred = predict_to_plain(eval_results['predict']) gold = reference_to_plain(eval_results['reference']) dev_set = Dataset(examples) dev_track = {} for dev_examples in dev_set.batch_iter( batch_size=self.train_batch_size, shuffle=False): ret_loss = self.vae.get_loss(dev_examples, step) dev_track = update_tracker(ret_loss, dev_track) if write_down: write_result(pred, fname=os.path.join(self.out_dir, eval_desc + ".pred")) write_result(gold, fname=os.path.join(self.out_dir, eval_desc + ".gold")) write_result(dev_track, fname=os.path.join(self.out_dir, eval_desc + ".score")) with open(os.path.join( self.out_dir, eval_desc + ".score", ), "a") as f: f.write("{}:{}".format(self.vae.args.eval_mode, eval_results['accuracy'])) self.vae.training = training return dev_track, eval_results
def train_vae(main_args, model_args, model=None): para_eval_dir = "/home/user_data/baoy/projects/seq2seq_parser/data/quora-mh/unsupervised" para_eval_list = ["dev.para.txt"] if "task_type" in model_args and model_args.task_type is not None: main_args.task_type = model_args.task_type dir_ret = get_model_info(main_args, model_args) model, optimizer, vocab = build_model(main_args, model_args, model) model_file = dir_ret['model_file'] log_dir = dir_ret['log_dir'] out_dir = dir_ret['out_dir'] train_set, dev_set = load_data(main_args) model, optimizer, vocab = build_model(main_args, model_args, model) evaluator = SyntaxVaeEvaluator( model=model, out_dir=out_dir, train_batch_size=main_args.batch_size, batch_size=model_args.eval_bs, ) writer = SummaryWriter(log_dir) writer.add_text("model", str(model)) writer.add_text("args", str(main_args)) train_iter = main_args.start_iter epoch = num_trial = patience = 0 history_elbo = [] history_bleu = [] max_kl_item = -1 max_kl_weight = None continue_anneal = model_args.peak_anneal if model_args.peak_anneal: model_args.warm_up = 0 memory_temp_count = 0 t_type = torch.Tensor adv_training = model_args.dis_train and model_args.adv_train if adv_training: print("has the adv training process") adv_syn = model_args.adv_syn > 0. or model_args.infer_weight * model_args.inf_sem adv_sem = model_args.adv_sem > 0. or model_args.infer_weight * model_args.inf_syn print(model_args.dev_item.lower()) while True: epoch += 1 train_track = {} for batch_examples in train_set.batch_iter(batch_size=main_args.batch_size, shuffle=True): train_iter += 1 if adv_training: ret_loss = model.get_loss(batch_examples, train_iter, is_dis=True) if adv_syn: dis_syn_loss = ret_loss['dis syn'] optimizer.zero_grad() dis_syn_loss.backward() if main_args.clip_grad > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad) # optimizer.step() if adv_sem: ret_loss = model.get_loss(batch_examples, train_iter, is_dis=True) dis_sem_loss = ret_loss['dis sem'] optimizer.zero_grad() dis_sem_loss.backward() if main_args.clip_grad > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad) # optimizer.step() ret_loss = model.get_loss(batch_examples, train_iter) loss = ret_loss['Loss'] optimizer.zero_grad() loss.backward() if main_args.clip_grad > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad) optimizer.step() train_iter += 1 # tracker = update_track(loss, train_avg_kl, train_avg_nll, tracker) train_track = update_tracker(ret_loss, train_track) if train_iter % main_args.log_every == 0: train_avg_nll = ret_loss['NLL Loss'] train_avg_kl = ret_loss['KL Loss'] _kl_weight = ret_loss['KL Weight'] for key, val in ret_loss.items(): writer.add_scalar( 'Train-Iter/VAE/{}'.format(key), val.item() if isinstance(val, t_type) else val, train_iter ) print("\rTrain-Iter %04d, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f, WD-Drop %6.3f" % (train_iter, loss.item(), train_avg_nll, train_avg_kl, _kl_weight, model.step_unk_rate), end=' ') writer.add_scalar( tag='Optimize/lr', scalar_value=optimizer.param_groups[0]['lr'], global_step=train_iter, ) if train_iter % main_args.dev_every == 0 and train_iter > model_args.warm_up: # dev_track, eval_results = _test_vae(model, dev_set, main_args, train_iter) dev_track, eval_results = evaluator.evaluate_reconstruction(examples=dev_set.examples, eval_desc="dev{}".format(train_iter), eval_step=train_iter, write_down=False) _weight = model.get_kl_weight(step=train_iter) _kl_item = torch.mean(dev_track['KL Item']) # writer.add_scalar("VAE/Valid-Iter/KL Item", _kl_item, train_iter) for key, val in dev_track.items(): writer.add_scalar( 'Valid-Iter/VAE/{}'.format(key), torch.mean(val) if isinstance(val, t_type) else val, train_iter ) if continue_anneal and model.step_kl_weight is None: if _kl_item > max_kl_item: max_kl_item = _kl_item max_kl_weight = _weight else: if (max_kl_item - _kl_item) > model_args.stop_clip_kl: model.step_kl_weight = max_kl_weight writer.add_text(tag='peak_anneal', text_string="fixed the kl weight:{} with kl peak:{} at step:{}".format( max_kl_weight, max_kl_item, train_iter ), global_step=train_iter) continue_anneal = False dev_elbo = torch.mean(dev_track['Model Score']) writer.add_scalar("Evaluation/VAE/Dev Score", dev_elbo, train_iter) # evaluate bleu dev_bleu = eval_results['accuracy'] print() print("Valid-Iter %04d, NLL_Loss:%9.4f, KL_Loss: %9.4f, Sum Score:%9.4f BLEU:%9.4f" % ( train_iter, torch.mean(dev_track['NLL Loss']), torch.mean(dev_track['KL Loss']), dev_elbo, eval_results['accuracy']), file=sys.stderr ) writer.add_scalar( tag='Evaluation/VAE/Iter %s' % model.args.eval_mode, scalar_value=dev_bleu, global_step=train_iter ) if model_args.dev_item == "ELBO" or model_args.dev_item.lower() == "para-elbo" or model_args.dev_item.lower() == "gen-elbo": is_better = history_elbo == [] or dev_elbo < min(history_elbo) elif model_args.dev_item == "BLEU" or model_args.dev_item.lower() == "para-bleu" or model_args.dev_item.lower() == "gen-bleu": is_better = history_bleu == [] or dev_bleu > max(history_bleu) history_elbo.append(dev_elbo) writer.add_scalar("Evaluation/VAE/Best ELBO Score", min(history_elbo), train_iter) history_bleu.append(dev_bleu) writer.add_scalar("Evaluation/VAE/Best BLEU Score", max(history_bleu), train_iter) if is_better: writer.add_scalar( tag='Evaluation/VAE/Best %s' % model.args.eval_mode, scalar_value=dev_bleu, global_step=train_iter ) writer.add_scalar( tag='Evaluation/VAE/Best NLL-LOSS', scalar_value=torch.mean(dev_track['NLL Loss']), global_step=train_iter ) writer.add_scalar( tag='Evaluation/VAE/Best KL-LOSS', scalar_value=torch.mean(dev_track['KL Loss']), global_step=train_iter ) if train_iter * 2 > model_args.x0: memory_temp_count = 3 if model_args.dev_item.lower().startswith("gen") and memory_temp_count > 0: evaluator.evaluate_generation( sample_size=len(dev_set.examples), eval_desc="gen_iter{}".format(train_iter), ) memory_temp_count -= 1 if model_args.dev_item.lower().startswith("para") and memory_temp_count > 0: para_score = evaluator.evaluate_para( eval_dir=para_eval_dir, eval_list=para_eval_list, eval_desc="para_iter{}".format(train_iter) ) if memory_temp_count == 3: writer.add_scalar( tag='Evaluation/VAE/Para Dev Ori-BLEU', scalar_value=para_score[0][0], global_step=train_iter ) writer.add_scalar( tag='Evaluation/VAE/Para Dev Tgt-BLEU', scalar_value=para_score[0][1], global_step=train_iter ) if len(para_score) > 1: writer.add_scalar( tag='Evaluation/VAE/Para Test Ori-BLEU', scalar_value=para_score[1][0], global_step=train_iter ) writer.add_scalar( tag='Evaluation/VAE/Para Test Tgt-BLEU', scalar_value=para_score[1][1], global_step=train_iter ) memory_temp_count -= 1 model, optimizer, num_trial, patience = get_lr_schedule( is_better=is_better, model_file=model_file, main_args=main_args, patience=patience, num_trial=num_trial, model=model, optimizer=optimizer, reload_model=model_args.reload_model, ) model.train() elbo = torch.mean(train_track['Model Score']) print() print("Train-Epoch %02d, Score %9.4f" % (epoch, elbo)) for key, val in train_track.items(): writer.add_scalar( 'Train-Epoch/VAE/{}'.format(key), torch.mean(val) if isinstance(val, t_type) else val, epoch )
def training(main_args, model_args, model=None): if "task_type" in model_args and model_args.task_type is not None: main_args.task_type = model_args.task_type dir_ret = get_model_info(main_args=main_args, model_args=model_args) model, optimizer, vocab = build_model(main_args=main_args, model_args=model_args, model=model) train_set, dev_set = load_data(main_args) model_file = dir_ret['model_file'] log_dir = dir_ret['log_dir'] out_dir = dir_ret['out_dir'] writer = SummaryWriter(log_dir) GlobalOps.writer_ops = writer writer.add_text("main_args", str(main_args)) writer.add_text("model_args", str(model_args)) print("...... Start Training ......") train_iter = main_args.start_iter train_nums, train_loss = 0., 0. epoch, num_trial, patience, = 0, 0, 0 history_scores = [] task_type = main_args.task_type eval_select = eval_key_dict[task_type.lower()] sort_key = sort_key_dict[ model_args.sort_key] if "sort_key" in model_args else sort_key_dict[ model_args.enc_type] evaluator = get_evaluator(eval_choice=eval_select, model=model, eval_set=dev_set.examples, eval_lists=main_args.eval_lists, sort_key=sort_key, eval_tgt="tgt", batch_size=model_args.eval_bs, out_dir=out_dir, write_down=True) print("Dev ITEM: ", evaluator.score_item) adv_training = "adv_train" in model_args and model_args.adv_train adv_syn, adv_sem = False, False def hyper_init(): adv_syn_ = (model_args.adv_syn + model_args.infer_weight * model_args.inf_sem) > 0. adv_sem_ = (model_args.adv_sem + model_args.infer_weight * model_args.inf_syn) > 0. return adv_syn_, adv_sem_ if adv_training: adv_syn, adv_sem = hyper_init() def normal_training(): optimizer.zero_grad() batch_ret_ = model.get_loss(examples=batch_examples, return_enc_state=False, train_iter=train_iter) batch_loss_ = batch_ret_['Loss'] return batch_ret_ def universal_training(): if adv_training: ret_loss = model.get_loss(examples=batch_examples, train_iter=train_iter, is_dis=True) if adv_syn: dis_syn_loss = ret_loss['dis syn'] optimizer.zero_grad() dis_syn_loss.backward() if main_args.clip_grad > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad) optimizer.step() if adv_sem: ret_loss = model.get_loss(examples=batch_examples, train_iter=train_iter, is_dis=True) dis_sem_loss = ret_loss['dis sem'] optimizer.zero_grad() dis_sem_loss.backward() if main_args.clip_grad > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad) optimizer.step() return normal_training() while True: epoch += 1 epoch_begin = time.time() train_log_dict = {} for batch_examples in train_set.batch_iter( batch_size=main_args.batch_size, shuffle=True): train_iter += 1 train_nums += len(batch_examples) # batch_ret = model.get_loss(examples=batch_examples, return_enc_state=False, train_iter=train_iter) batch_ret = universal_training() batch_loss = batch_ret['Loss'] train_loss += batch_loss.sum().item() torch.mean(batch_loss).backward() if main_args.clip_grad > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad) optimizer.step() train_log_dict = update_tracker(batch_ret, train_log_dict) if train_iter % main_args.log_every == 0: print('\r[Iter %d] Train loss=%.5f' % (train_iter, train_loss / train_nums), file=sys.stdout, end=" ") for key, val in train_log_dict.items(): if isinstance(val, torch.Tensor): writer.add_scalar(tag="{}/Train/{}".format( task_type, key), scalar_value=torch.mean(val).item(), global_step=train_iter) writer.add_scalar(tag="Optimize/lr", scalar_value=optimizer.param_groups[0]['lr'], global_step=train_iter) writer.add_scalar( tag='Optimize/trial', scalar_value=num_trial, global_step=train_iter, ) writer.add_scalar( tag='Optimize/patience', scalar_value=patience, global_step=train_iter, ) if train_iter % main_args.dev_every == 0 and train_iter > model_args.warm_up: eval_result_dict = evaluator() dev_acc = eval_result_dict[evaluator.score_item] if isinstance(dev_acc, torch.Tensor): dev_acc = dev_acc.sum().item() print('\r[Iter %d] %s %s=%.3f took %d s' % (train_iter, task_type, evaluator.score_item, dev_acc, eval_result_dict['EVAL TIME']), file=sys.stdout) is_better = (history_scores == []) or dev_acc > max(history_scores) history_scores.append(dev_acc) writer.add_scalar(tag='%s/Valid/Best %s' % (task_type, evaluator.score_item), scalar_value=max(history_scores), global_step=train_iter) for key, val in eval_result_dict.items(): writer.add_scalar(tag="{}/Valid/{}".format(task_type, key), scalar_value=val.sum().item() if isinstance(val, torch.Tensor) else val, global_step=train_iter) model, optimizer, num_trial, patience = get_lr_schedule( is_better=is_better, model_file=model_file, main_args=main_args, patience=patience, num_trial=num_trial, model=model, optimizer=optimizer, reload_model=False) model.train() epoch_time = time.time() - epoch_begin print('\r[Epoch %d] epoch elapsed %ds' % (epoch, epoch_time), file=sys.stdout)