Example #1
0
    def __call__(self, eval_desc="reorder"):
        """
        Args:
            eval_desc:

        Returns: eval the target bleu and origin bleu for paraphrase model

        """
        training = self.model.training
        self.model.eval()
        dev_log_dict = {}
        inp_examples, inp_ids = batchify_examples(examples=self.eval_set, sort_key=self.sort_key,
                                                  batch_size=self.batch_size)
        eval_start = time.time()
        for batch_examples in inp_examples:
            batch_ret = self.model.get_loss(examples=batch_examples, return_enc_state=False, train_iter=-1)
            dev_log_dict = update_tracker(batch_ret, dev_log_dict)
        use_time = time.time() - eval_start
        self.model.training = training
        return {
            'ACC': dev_log_dict['Acc'].mean().item(),
            'Relax_ACC': dev_log_dict['relax_correct'] * 100.0 / dev_log_dict['count'],
            'Relative_ACC': dev_log_dict['relative_correct'] * 100.0 / dev_log_dict['relative_count'],
            'Binary_ACC': dev_log_dict['binary_correct'] * 100.0 / dev_log_dict['relative_count'],
            'EVAL TIME': use_time,
            "EVAL SPEED": len(self.eval_set) / use_time
        }
Example #2
0
    def eval_elbo(self, eval_desc='vae-elbo', eval_step=None):
        model = self.model
        args = self.model.args
        training = model.training
        model.eval()
        step = eval_step if eval_step is not None else 2 * args.x0
        ret_track = {}

        batch_examples, _ = batchify_examples(examples=self.eval_set,
                                              batch_size=self.batch_size,
                                              sort=False)
        for batch in batch_examples:
            ret_loss = model.get_loss(batch, step)
            ret_track = update_tracker(ret_loss, ret_track)

        if self.write_down:
            write_result(ret_track,
                         fname=os.path.join(self.out_dir,
                                            eval_desc + ".score"))
        model.training = training
        return ret_track
Example #3
0
    def evaluate_reconstruction(self,
                                examples,
                                eval_desc,
                                eval_step=None,
                                write_down=True):
        training = self.vae.training
        self.vae.eval()
        step = eval_step if eval_step is not None else 2 * self.vae.args.x0
        eval_results = evaluate(examples,
                                self.vae,
                                sort_key='src',
                                eval_tgt='src',
                                batch_size=self.eval_batch_size)
        pred = predict_to_plain(eval_results['predict'])
        gold = reference_to_plain(eval_results['reference'])

        dev_set = Dataset(examples)
        dev_track = {}
        for dev_examples in dev_set.batch_iter(
                batch_size=self.train_batch_size, shuffle=False):
            ret_loss = self.vae.get_loss(dev_examples, step)
            dev_track = update_tracker(ret_loss, dev_track)
        if write_down:
            write_result(pred,
                         fname=os.path.join(self.out_dir, eval_desc + ".pred"))
            write_result(gold,
                         fname=os.path.join(self.out_dir, eval_desc + ".gold"))
            write_result(dev_track,
                         fname=os.path.join(self.out_dir,
                                            eval_desc + ".score"))
            with open(os.path.join(
                    self.out_dir,
                    eval_desc + ".score",
            ), "a") as f:
                f.write("{}:{}".format(self.vae.args.eval_mode,
                                       eval_results['accuracy']))
        self.vae.training = training
        return dev_track, eval_results
Example #4
0
def train_vae(main_args, model_args, model=None):
    para_eval_dir = "/home/user_data/baoy/projects/seq2seq_parser/data/quora-mh/unsupervised"
    para_eval_list = ["dev.para.txt"]
    if "task_type" in model_args and model_args.task_type is not None:
        main_args.task_type = model_args.task_type
    dir_ret = get_model_info(main_args, model_args)
    model, optimizer, vocab = build_model(main_args, model_args, model)
    model_file = dir_ret['model_file']
    log_dir = dir_ret['log_dir']
    out_dir = dir_ret['out_dir']
    train_set, dev_set = load_data(main_args)

    model, optimizer, vocab = build_model(main_args, model_args, model)

    evaluator = SyntaxVaeEvaluator(
        model=model,
        out_dir=out_dir,
        train_batch_size=main_args.batch_size,
        batch_size=model_args.eval_bs,
    )

    writer = SummaryWriter(log_dir)
    writer.add_text("model", str(model))
    writer.add_text("args", str(main_args))

    train_iter = main_args.start_iter
    epoch = num_trial = patience = 0
    history_elbo = []
    history_bleu = []
    max_kl_item = -1
    max_kl_weight = None

    continue_anneal = model_args.peak_anneal

    if model_args.peak_anneal:
        model_args.warm_up = 0

    memory_temp_count = 0

    t_type = torch.Tensor

    adv_training = model_args.dis_train and model_args.adv_train
    if adv_training:
        print("has the adv training process")
    adv_syn = model_args.adv_syn > 0. or model_args.infer_weight * model_args.inf_sem
    adv_sem = model_args.adv_sem > 0. or model_args.infer_weight * model_args.inf_syn

    print(model_args.dev_item.lower())
    while True:
        epoch += 1
        train_track = {}
        for batch_examples in train_set.batch_iter(batch_size=main_args.batch_size, shuffle=True):
            train_iter += 1
            if adv_training:
                ret_loss = model.get_loss(batch_examples, train_iter, is_dis=True)
                if adv_syn:
                    dis_syn_loss = ret_loss['dis syn']
                    optimizer.zero_grad()
                    dis_syn_loss.backward()
                    if main_args.clip_grad > 0.:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad)
                        # optimizer.step()
                if adv_sem:
                    ret_loss = model.get_loss(batch_examples, train_iter, is_dis=True)
                    dis_sem_loss = ret_loss['dis sem']
                    optimizer.zero_grad()
                    dis_sem_loss.backward()
                    if main_args.clip_grad > 0.:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad)
                        # optimizer.step()

            ret_loss = model.get_loss(batch_examples, train_iter)
            loss = ret_loss['Loss']
            optimizer.zero_grad()
            loss.backward()
            if main_args.clip_grad > 0.:
                torch.nn.utils.clip_grad_norm_(model.parameters(), main_args.clip_grad)

            optimizer.step()
            train_iter += 1
            # tracker = update_track(loss, train_avg_kl, train_avg_nll, tracker)
            train_track = update_tracker(ret_loss, train_track)
            if train_iter % main_args.log_every == 0:
                train_avg_nll = ret_loss['NLL Loss']
                train_avg_kl = ret_loss['KL Loss']
                _kl_weight = ret_loss['KL Weight']
                for key, val in ret_loss.items():
                    writer.add_scalar(
                        'Train-Iter/VAE/{}'.format(key),
                        val.item() if isinstance(val, t_type) else val,
                        train_iter
                    )

                print("\rTrain-Iter %04d, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f, WD-Drop %6.3f"
                      % (train_iter, loss.item(), train_avg_nll, train_avg_kl, _kl_weight, model.step_unk_rate),
                      end=' ')
                writer.add_scalar(
                    tag='Optimize/lr',
                    scalar_value=optimizer.param_groups[0]['lr'],
                    global_step=train_iter,
                )

            if train_iter % main_args.dev_every == 0 and train_iter > model_args.warm_up:
                # dev_track, eval_results = _test_vae(model, dev_set, main_args, train_iter)
                dev_track, eval_results = evaluator.evaluate_reconstruction(examples=dev_set.examples,
                                                                            eval_desc="dev{}".format(train_iter),
                                                                            eval_step=train_iter, write_down=False)
                _weight = model.get_kl_weight(step=train_iter)
                _kl_item = torch.mean(dev_track['KL Item'])
                # writer.add_scalar("VAE/Valid-Iter/KL Item", _kl_item, train_iter)
                for key, val in dev_track.items():
                    writer.add_scalar(
                        'Valid-Iter/VAE/{}'.format(key),
                        torch.mean(val) if isinstance(val, t_type) else val,
                        train_iter
                    )
                if continue_anneal and model.step_kl_weight is None:
                    if _kl_item > max_kl_item:
                        max_kl_item = _kl_item
                        max_kl_weight = _weight
                    else:
                        if (max_kl_item - _kl_item) > model_args.stop_clip_kl:
                            model.step_kl_weight = max_kl_weight
                            writer.add_text(tag='peak_anneal',
                                            text_string="fixed the kl weight:{} with kl peak:{} at step:{}".format(
                                                max_kl_weight,
                                                max_kl_item,
                                                train_iter
                                            ), global_step=train_iter)
                            continue_anneal = False
                dev_elbo = torch.mean(dev_track['Model Score'])
                writer.add_scalar("Evaluation/VAE/Dev Score", dev_elbo, train_iter)

                # evaluate bleu
                dev_bleu = eval_results['accuracy']
                print()
                print("Valid-Iter %04d, NLL_Loss:%9.4f, KL_Loss: %9.4f, Sum Score:%9.4f BLEU:%9.4f" % (
                    train_iter,
                    torch.mean(dev_track['NLL Loss']),
                    torch.mean(dev_track['KL Loss']),
                    dev_elbo,
                    eval_results['accuracy']), file=sys.stderr
                      )
                writer.add_scalar(
                    tag='Evaluation/VAE/Iter %s' % model.args.eval_mode,
                    scalar_value=dev_bleu,
                    global_step=train_iter
                )
                if model_args.dev_item == "ELBO" or model_args.dev_item.lower() == "para-elbo" or model_args.dev_item.lower() == "gen-elbo":
                    is_better = history_elbo == [] or dev_elbo < min(history_elbo)
                elif model_args.dev_item == "BLEU" or model_args.dev_item.lower() == "para-bleu" or model_args.dev_item.lower() == "gen-bleu":
                    is_better = history_bleu == [] or dev_bleu > max(history_bleu)

                history_elbo.append(dev_elbo)
                writer.add_scalar("Evaluation/VAE/Best ELBO Score", min(history_elbo), train_iter)
                history_bleu.append(dev_bleu)
                writer.add_scalar("Evaluation/VAE/Best BLEU Score", max(history_bleu), train_iter)

                if is_better:
                    writer.add_scalar(
                        tag='Evaluation/VAE/Best %s' % model.args.eval_mode,
                        scalar_value=dev_bleu,
                        global_step=train_iter
                    )
                    writer.add_scalar(
                        tag='Evaluation/VAE/Best NLL-LOSS',
                        scalar_value=torch.mean(dev_track['NLL Loss']),
                        global_step=train_iter
                    )
                    writer.add_scalar(
                        tag='Evaluation/VAE/Best KL-LOSS',
                        scalar_value=torch.mean(dev_track['KL Loss']),
                        global_step=train_iter
                    )
                    if train_iter * 2 > model_args.x0:
                        memory_temp_count = 3

                if model_args.dev_item.lower().startswith("gen") and memory_temp_count > 0:
                    evaluator.evaluate_generation(
                        sample_size=len(dev_set.examples),
                        eval_desc="gen_iter{}".format(train_iter),
                    )
                    memory_temp_count -= 1

                if model_args.dev_item.lower().startswith("para") and memory_temp_count > 0:

                    para_score = evaluator.evaluate_para(
                        eval_dir=para_eval_dir,
                        eval_list=para_eval_list,
                        eval_desc="para_iter{}".format(train_iter)
                    )
                    if memory_temp_count == 3:
                        writer.add_scalar(
                            tag='Evaluation/VAE/Para Dev Ori-BLEU',
                            scalar_value=para_score[0][0],
                            global_step=train_iter
                        )
                        writer.add_scalar(
                            tag='Evaluation/VAE/Para Dev Tgt-BLEU',
                            scalar_value=para_score[0][1],
                            global_step=train_iter
                        )
                        if len(para_score) > 1:
                            writer.add_scalar(
                                tag='Evaluation/VAE/Para Test Ori-BLEU',
                                scalar_value=para_score[1][0],
                                global_step=train_iter
                            )
                            writer.add_scalar(
                                tag='Evaluation/VAE/Para Test Tgt-BLEU',
                                scalar_value=para_score[1][1],
                                global_step=train_iter
                            )
                    memory_temp_count -= 1

                model, optimizer, num_trial, patience = get_lr_schedule(
                    is_better=is_better,
                    model_file=model_file,
                    main_args=main_args,
                    patience=patience,
                    num_trial=num_trial,
                    model=model,
                    optimizer=optimizer,
                    reload_model=model_args.reload_model,
                )
                model.train()
        elbo = torch.mean(train_track['Model Score'])
        print()
        print("Train-Epoch %02d, Score %9.4f" % (epoch, elbo))
        for key, val in train_track.items():
            writer.add_scalar(
                'Train-Epoch/VAE/{}'.format(key),
                torch.mean(val) if isinstance(val, t_type) else val,
                epoch
            )
Example #5
0
def training(main_args, model_args, model=None):
    if "task_type" in model_args and model_args.task_type is not None:
        main_args.task_type = model_args.task_type
    dir_ret = get_model_info(main_args=main_args, model_args=model_args)
    model, optimizer, vocab = build_model(main_args=main_args,
                                          model_args=model_args,
                                          model=model)
    train_set, dev_set = load_data(main_args)
    model_file = dir_ret['model_file']
    log_dir = dir_ret['log_dir']
    out_dir = dir_ret['out_dir']

    writer = SummaryWriter(log_dir)
    GlobalOps.writer_ops = writer
    writer.add_text("main_args", str(main_args))
    writer.add_text("model_args", str(model_args))

    print("...... Start Training ......")

    train_iter = main_args.start_iter
    train_nums, train_loss = 0., 0.
    epoch, num_trial, patience, = 0, 0, 0

    history_scores = []
    task_type = main_args.task_type
    eval_select = eval_key_dict[task_type.lower()]
    sort_key = sort_key_dict[
        model_args.sort_key] if "sort_key" in model_args else sort_key_dict[
            model_args.enc_type]
    evaluator = get_evaluator(eval_choice=eval_select,
                              model=model,
                              eval_set=dev_set.examples,
                              eval_lists=main_args.eval_lists,
                              sort_key=sort_key,
                              eval_tgt="tgt",
                              batch_size=model_args.eval_bs,
                              out_dir=out_dir,
                              write_down=True)
    print("Dev ITEM: ", evaluator.score_item)
    adv_training = "adv_train" in model_args and model_args.adv_train
    adv_syn, adv_sem = False, False

    def hyper_init():
        adv_syn_ = (model_args.adv_syn +
                    model_args.infer_weight * model_args.inf_sem) > 0.
        adv_sem_ = (model_args.adv_sem +
                    model_args.infer_weight * model_args.inf_syn) > 0.
        return adv_syn_, adv_sem_

    if adv_training:
        adv_syn, adv_sem = hyper_init()

    def normal_training():
        optimizer.zero_grad()
        batch_ret_ = model.get_loss(examples=batch_examples,
                                    return_enc_state=False,
                                    train_iter=train_iter)
        batch_loss_ = batch_ret_['Loss']
        return batch_ret_

    def universal_training():
        if adv_training:
            ret_loss = model.get_loss(examples=batch_examples,
                                      train_iter=train_iter,
                                      is_dis=True)
            if adv_syn:
                dis_syn_loss = ret_loss['dis syn']
                optimizer.zero_grad()
                dis_syn_loss.backward()
                if main_args.clip_grad > 0.:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   main_args.clip_grad)
                optimizer.step()
            if adv_sem:
                ret_loss = model.get_loss(examples=batch_examples,
                                          train_iter=train_iter,
                                          is_dis=True)
                dis_sem_loss = ret_loss['dis sem']
                optimizer.zero_grad()
                dis_sem_loss.backward()
                if main_args.clip_grad > 0.:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   main_args.clip_grad)
                optimizer.step()
        return normal_training()

    while True:
        epoch += 1
        epoch_begin = time.time()
        train_log_dict = {}

        for batch_examples in train_set.batch_iter(
                batch_size=main_args.batch_size, shuffle=True):
            train_iter += 1
            train_nums += len(batch_examples)
            # batch_ret = model.get_loss(examples=batch_examples, return_enc_state=False, train_iter=train_iter)
            batch_ret = universal_training()
            batch_loss = batch_ret['Loss']
            train_loss += batch_loss.sum().item()
            torch.mean(batch_loss).backward()

            if main_args.clip_grad > 0.:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               main_args.clip_grad)
            optimizer.step()

            train_log_dict = update_tracker(batch_ret, train_log_dict)

            if train_iter % main_args.log_every == 0:
                print('\r[Iter %d] Train loss=%.5f' %
                      (train_iter, train_loss / train_nums),
                      file=sys.stdout,
                      end=" ")
                for key, val in train_log_dict.items():
                    if isinstance(val, torch.Tensor):
                        writer.add_scalar(tag="{}/Train/{}".format(
                            task_type, key),
                                          scalar_value=torch.mean(val).item(),
                                          global_step=train_iter)
                writer.add_scalar(tag="Optimize/lr",
                                  scalar_value=optimizer.param_groups[0]['lr'],
                                  global_step=train_iter)
                writer.add_scalar(
                    tag='Optimize/trial',
                    scalar_value=num_trial,
                    global_step=train_iter,
                )
                writer.add_scalar(
                    tag='Optimize/patience',
                    scalar_value=patience,
                    global_step=train_iter,
                )

            if train_iter % main_args.dev_every == 0 and train_iter > model_args.warm_up:
                eval_result_dict = evaluator()
                dev_acc = eval_result_dict[evaluator.score_item]
                if isinstance(dev_acc, torch.Tensor):
                    dev_acc = dev_acc.sum().item()
                print('\r[Iter %d] %s %s=%.3f took %d s' %
                      (train_iter, task_type, evaluator.score_item, dev_acc,
                       eval_result_dict['EVAL TIME']),
                      file=sys.stdout)
                is_better = (history_scores
                             == []) or dev_acc > max(history_scores)
                history_scores.append(dev_acc)

                writer.add_scalar(tag='%s/Valid/Best %s' %
                                  (task_type, evaluator.score_item),
                                  scalar_value=max(history_scores),
                                  global_step=train_iter)
                for key, val in eval_result_dict.items():
                    writer.add_scalar(tag="{}/Valid/{}".format(task_type, key),
                                      scalar_value=val.sum().item() if
                                      isinstance(val, torch.Tensor) else val,
                                      global_step=train_iter)

                model, optimizer, num_trial, patience = get_lr_schedule(
                    is_better=is_better,
                    model_file=model_file,
                    main_args=main_args,
                    patience=patience,
                    num_trial=num_trial,
                    model=model,
                    optimizer=optimizer,
                    reload_model=False)
                model.train()

        epoch_time = time.time() - epoch_begin
        print('\r[Epoch %d] epoch elapsed %ds' % (epoch, epoch_time),
              file=sys.stdout)