Пример #1
0
    def train_and_validate(self, train_data, valid_data):
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data:
        :param valid_data:
        :return:
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)
            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            count = 0

            for batch in iter(train_iter):
                # reactivate training
                self.model.train()
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                count = self.batch_multiplier if update else count
                count -= 1

                # log learning progress
                if self.model.training and self.steps % self.logging_freq == 0 \
                        and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %d Step: %d Loss: %f Tokens per Sec: %f",
                        epoch_no + 1, self.steps, batch_loss,
                        elapsed_tokens / elapsed)
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            batch_size=self.batch_size, data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            criterion=self.criterion)

                    if self.ckpt_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.ckpt_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.ckpt_metric)
                        new_best = True
                        self.save_checkpoint()

                    # pass validation score or loss or ppl to scheduler
                    if self.schedule_metric == "loss":
                        # schedule based on loss
                        schedule_score = valid_loss
                    elif self.schedule_metric in ["ppl", "perplexity"]:
                        # schedule based on perplexity
                        schedule_score = valid_ppl
                    else:
                        # schedule based on evaluation score
                        schedule_score = valid_score
                    if self.scheduler is not None:
                        self.scheduler.step(schedule_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    # always print first x sentences
                    for p in range(self.print_valid_sents):
                        self.logger.debug("Example #%d", p)
                        self.logger.debug("\tRaw source: %s",
                                          valid_sources_raw[p])
                        self.logger.debug("\tSource: %s", valid_sources[p])
                        self.logger.debug("\tReference: %s",
                                          valid_references[p])
                        self.logger.debug("\tRaw hypothesis: %s",
                                          valid_hypotheses_raw[p])
                        self.logger.debug("\tHypothesis: %s",
                                          valid_hypotheses[p])
                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result at epoch %d, step %d: %s: %f, '
                        'loss: %f, ppl: %f, duration: %.4fs', epoch_no + 1,
                        self.steps, self.eval_metric, valid_score, valid_loss,
                        valid_ppl, valid_duration)

                    # store validation set outputs
                    self.store_outputs(valid_hypotheses)

                    # store attention plots for first three sentences of
                    # valid data and one randomly chosen example
                    store_attention_plots(attentions=valid_attention_scores,
                                          targets=valid_hypotheses_raw,
                                          sources=[s for s in valid_data.src],
                                          idx=[
                                              0, 1, 2,
                                              np.random.randint(
                                                  0, len(valid_hypotheses))
                                          ],
                                          output_prefix="{}/att.{}".format(
                                              self.model_dir, self.steps))

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break
        else:
            self.logger.info('Training ended after %d epochs.', epoch_no + 1)
        self.logger.info('Best validation result at step %d: %f %s.',
                         self.best_ckpt_iteration, self.best_ckpt_score,
                         self.ckpt_metric)
Пример #2
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    batch_type=self.batch_type,
                                    train=True,
                                    shuffle=self.shuffle)

        # For last batch in epoch batch_multiplier needs to be adjusted
        # to fit the number of leftover training examples
        leftover_batch_size = len(train_data) % (self.batch_multiplier *
                                                 self.batch_size)

        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            self.current_batch_multiplier = self.batch_multiplier
            self.optimizer.zero_grad()
            count = self.current_batch_multiplier - 1
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter)):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672

                # Set current_batch_mutliplier to fit
                # number of leftover examples for last batch in epoch
                # Only works if batch_type == sentence
                if self.batch_type == "sentence":
                    if self.batch_multiplier > 1 and i == len(train_iter) - \
                            math.ceil(leftover_batch_size / self.batch_size):
                        self.current_batch_multiplier = math.ceil(
                            leftover_batch_size / self.batch_size)
                        count = self.current_batch_multiplier - 1

                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch,
                                               update=update,
                                               count=count)

                # Only save finaly computed batch_loss of full batch
                if update:
                    self.tb_writer.add_scalar("train/train_batch_loss",
                                              batch_loss, self.steps)

                count = self.batch_multiplier if update else count
                count -= 1

                # Only add complete batch_loss of full mini-batch to epoch_loss
                if update:
                    epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                        self.steps, batch_loss, elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    start_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type,
                            postprocess=True # always remove BPE for validation
                        )

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(
                        sources_raw=[v for v in valid_sources_raw],
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
                        self.eval_metric, valid_score, valid_loss, valid_ppl,
                        valid_duration)

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores:
                        store_attention_plots(
                            attentions=valid_attention_scores,
                            targets=valid_hypotheses_raw,
                            sources=[s for s in valid_data.src],
                            indices=self.log_valid_sents,
                            output_prefix="{}/att.{}".format(
                                self.model_dir, self.steps),
                            tb_writer=self.tb_writer,
                            steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info(
            'Best validation result (greedy) at step '
            '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score,
            self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #3
0
    def _validate(self, valid_data, epoch_no):
        valid_start_time = time.time()

        valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        valid_hypotheses_raw, valid_attention_scores = \
            validate_on_data(
                batch_size=self.eval_batch_size,
                batch_class=self.batch_class,
                data=valid_data,
                eval_metric=self.eval_metric,
                level=self.level, model=self.model,
                use_cuda=self.use_cuda,
                max_output_length=self.max_output_length,
                compute_loss=True,
                beam_size=1,                # greedy validations
                batch_type=self.eval_batch_type,
                postprocess=True,           # always remove BPE for validation
                bpe_type=self.bpe_type,     # "subword-nmt" or "sentencepiece"
                sacrebleu=self.sacrebleu,   # sacrebleu options
                n_gpu=self.n_gpu
            )

        self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                  self.stats.steps)
        self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                  self.stats.steps)
        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                  self.stats.steps)

        if self.early_stopping_metric == "loss":
            ckpt_score = valid_loss
        elif self.early_stopping_metric in ["ppl", "perplexity"]:
            ckpt_score = valid_ppl
        else:
            ckpt_score = valid_score

        if self.scheduler is not None \
                and self.scheduler_step_at == "validation":
            self.scheduler.step(ckpt_score)

        new_best = False
        if self.stats.is_best(ckpt_score):
            self.stats.best_ckpt_score = ckpt_score
            self.stats.best_ckpt_iter = self.stats.steps
            logger.info('Hooray! New best validation result [%s]!',
                        self.early_stopping_metric)
            if self.ckpt_queue.maxlen > 0:
                logger.info("Saving new checkpoint.")
                new_best = True
                self._save_checkpoint(new_best)
        elif self.save_latest_checkpoint:
            self._save_checkpoint(new_best)

        # append to validation report
        self._add_report(valid_score=valid_score,
                         valid_loss=valid_loss,
                         valid_ppl=valid_ppl,
                         eval_metric=self.eval_metric,
                         new_best=new_best)

        self._log_examples(sources_raw=[v for v in valid_sources_raw],
                           sources=valid_sources,
                           hypotheses_raw=valid_hypotheses_raw,
                           hypotheses=valid_hypotheses,
                           references=valid_references)

        valid_duration = time.time() - valid_start_time
        logger.info(
            'Validation result (greedy) at epoch %3d, '
            'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
            'duration: %.4fs', epoch_no + 1, self.stats.steps,
            self.eval_metric, valid_score, valid_loss, valid_ppl,
            valid_duration)

        # store validation set outputs
        self._store_outputs(valid_hypotheses)

        # store attention plots for selected valid sentences
        if valid_attention_scores:
            store_attention_plots(attentions=valid_attention_scores,
                                  targets=valid_hypotheses_raw,
                                  sources=[s for s in valid_data.src],
                                  indices=self.log_valid_sents,
                                  output_prefix="{}/att.{}".format(
                                      self.model_dir, self.stats.steps),
                                  tb_writer=self.tb_writer,
                                  steps=self.stats.steps)

        return valid_duration
Пример #4
0
def test(cfg_file,
         ckpt: str,
         batch_class: Batch = Batch,
         output_path: str = None,
         save_attention: bool = False,
         datasets: dict = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param batch_class: class type of batch
    :param output_path: path to output
    :param datasets: datasets to predict
    :param save_attention: whether to save the computed attention weights
    """

    cfg = load_config(cfg_file)
    model_dir = cfg["training"]["model_dir"]

    if len(logger.handlers) == 0:
        _ = make_logger(model_dir, mode="test")  # version string returned

    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        ckpt = get_latest_checkpoint(model_dir)
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    # load the data
    if datasets is None:
        _, dev_data, test_data, src_vocab, trg_vocab = load_data(
            data_cfg=cfg["data"], datasets=["dev", "test"])
        data_to_predict = {"dev": dev_data, "test": test_data}
    else:  # avoid to load data again
        data_to_predict = {"dev": datasets["dev"], "test": datasets["test"]}
        src_vocab = datasets["src_vocab"]
        trg_vocab = datasets["trg_vocab"]

    # parse test args
    batch_size, batch_type, use_cuda, device, n_gpu, level, eval_metric, \
        max_output_length, beam_size, beam_alpha, postprocess, \
        bpe_type, sacrebleu, decoding_description, tokenizer_info \
        = parse_test_args(cfg, mode="test")

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.to(device)

    # multi-gpu eval
    if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
        model = _DataParallel(model)

    for data_set_name, data_set in data_to_predict.items():
        if data_set is None:
            continue

        dataset_file = cfg["data"][data_set_name] + "." + cfg["data"]["trg"]
        logger.info("Decoding on %s set (%s)...", data_set_name, dataset_file)

        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
            model, data=data_set, batch_size=batch_size,
            batch_class=batch_class, batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, compute_loss=False, beam_size=beam_size,
            beam_alpha=beam_alpha, postprocess=postprocess,
            bpe_type=bpe_type, sacrebleu=sacrebleu, n_gpu=n_gpu)
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
            logger.info("%4s %s%s: %6.2f [%s]", data_set_name, eval_metric,
                        tokenizer_info, score, decoding_description)
        else:
            logger.info("No references given for %s -> no evaluation.",
                        data_set_name)

        if save_attention:
            if attention_scores:
                attention_name = "{}.{}.att".format(data_set_name, step)
                attention_path = os.path.join(model_dir, attention_name)
                logger.info(
                    "Saving attention plots. This might take a while..")
                store_attention_plots(attentions=attention_scores,
                                      targets=hypotheses_raw,
                                      sources=data_set.src,
                                      indices=range(len(hypotheses)),
                                      output_prefix=attention_path)
                logger.info("Attention plots saved to: %s", attention_path)
            else:
                logger.warning("Attention scores could not be saved. "
                               "Note that attention scores are not available "
                               "when using beam search. "
                               "Set beam_size to 1 for greedy decoding.")

        if output_path is not None:
            output_path_set = "{}.{}".format(output_path, data_set_name)
            with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                for hyp in hypotheses:
                    out_file.write(hyp + "\n")
            logger.info("Translations saved to: %s", output_path_set)
Пример #5
0
def test(cfg_file,
         ckpt,
         output_path: str = None,
         save_attention: bool = False,
         logger: logging.Logger = None,
         data_to_test: str = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = logging.getLogger(__name__)
        FORMAT = '%(asctime)-15s - %(message)s'
        logging.basicConfig(format=FORMAT)
        logger.setLevel(level=logging.DEBUG)

    cfg = load_config(cfg_file)
    train_cfg = cfg["training"]
    data_cfg = cfg["data"]
    test_cfg = cfg["testing"]

    if "test" not in data_cfg.keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take latest (best) from model dir
    model_dir = train_cfg["model_dir"]
    if ckpt is None:
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError("No checkpoint at {}.".format(model_dir))
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = train_cfg.get("eval_batch_size", train_cfg["batch_size"])
    batch_type = train_cfg.get("batch_type", "sentence")
    use_cuda = train_cfg.get("use_cuda", False)
    assert "level" in data_cfg or "trg_level" in data_cfg
    trg_level = data_cfg.get("level", data_cfg["trg_level"])

    eval_metric = train_cfg["eval_metric"]
    if isinstance(eval_metric, str):
        eval_metric = [eval_metric]
    max_output_length = test_cfg.get("max_output_length",
                                     train_cfg.get("max_output_length", None))

    # load the data
    data = load_data(data_cfg)
    dev_data = data["dev_data"]
    test_data = data["test_data"]
    vocabs = data["vocabs"]

    data_to_predict = {"dev": dev_data, "test": test_data}
    if data_to_test is not None:
        assert data_to_test in data_to_predict
        data_to_predict = {data_to_test: data_to_predict[data_to_test]}

    # load model state from disk
    if isinstance(ckpt, str):
        ckpt = [ckpt]
    models = []
    for c in ckpt:
        model_checkpoint = load_checkpoint(c, use_cuda=use_cuda)

        # build model and load parameters into it
        m = build_model(cfg["model"], vocabs=vocabs)
        m.load_state_dict(model_checkpoint["model_state"])
        models.append(m)
    model = models[0] if len(models) == 1 else EnsembleModel(*models)

    if use_cuda:
        model.cuda()  # should this exist?

    # whether to use beam search for decoding, 0: greedy decoding
    beam_sizes = beam_alpha = 0
    if "testing" in cfg.keys():
        beam_sizes = test_cfg.get("beam_size", 0)
        beam_alpha = test_cfg.get("alpha", 0)
    beam_sizes = [beam_sizes] if isinstance(beam_sizes, int) else beam_sizes
    assert beam_alpha >= 0, "Use alpha >= 0"

    method = test_cfg.get("method", None)
    max_hyps = test_cfg.get("max_hyps", 1)  # only for the enumerate thing

    validate_by_label = test_cfg.get("validate_by_label",
                                     train_cfg.get("validate_by_label", False))
    forced_sparsity = test_cfg.get("forced_sparsity",
                                   train_cfg.get("forced_sparsity", False))

    for beam_size in beam_sizes:
        for data_set_name, data_set in data_to_predict.items():
            valid_results = validate_on_data(
                model,
                data=data_set,
                batch_size=batch_size,
                batch_type=batch_type,
                trg_level=trg_level,
                max_output_length=max_output_length,
                eval_metrics=eval_metric,
                use_cuda=use_cuda,
                loss_function=None,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                save_attention=save_attention,
                validate_by_label=validate_by_label,
                forced_sparsity=forced_sparsity,
                method=method,
                max_hyps=max_hyps,
                break_at_p=test_cfg.get("break_at_p", 1.0),
                break_at_argmax=test_cfg.get("break_at_argmax", False),
                short_depth=test_cfg.get("short_depth", 0))
            scores = valid_results[0]
            hypotheses, hypotheses_raw = valid_results[2:4]
            scores_by_label = valid_results[5]

            if "trg" in data_set.fields:
                log_scores(logger, data_set_name, scores, scores_by_label,
                           beam_size, beam_alpha)
            else:
                logger.info("No references given for %s -> no evaluation.",
                            data_set_name)

            attention_scores = valid_results[4]
            if save_attention and not attention_scores:
                logger.warning("Attention scores could not be saved. "
                               "Note that attention scores are not "
                               "available when using beam search. "
                               "Set beam_size to 0 for greedy decoding.")
            if save_attention and attention_scores:
                # currently this will break for transformers
                logger.info("Saving attention plots. This might be slow.")
                store_attention_plots(attentions=attention_scores,
                                      targets=hypotheses_raw,
                                      sources=[s for s in data_set.src],
                                      indices=range(len(hypotheses)),
                                      model_dir=model_dir,
                                      steps=step,
                                      data_set_name=data_set_name)
                logger.info("Attention plots saved to: %s", model_dir)

            if output_path is not None:
                output_path_set = "{}.{}".format(output_path, data_set_name)
                with open(output_path_set, mode="w", encoding="utf-8") as outf:
                    for hyp in hypotheses:
                        outf.write(hyp + "\n")
                logger.info("Translations saved to: %s", output_path_set)
Пример #6
0
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False,
         logger: Logger = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = make_logger()

    cfg = load_config(cfg_file)

    # when checkpoint is not specified, take latest (best) from model dir
    step = "best"
    model_dir = cfg["training"]["model_dir"]
    if ckpt is None:
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir))
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    architecture = cfg["model"].get("architecture", "encoder-decoder")
    batch_size = cfg["training"].get("eval_batch_size",
                                     cfg["training"]["batch_size"])
    batch_type = cfg["training"].get(
        "eval_batch_type", cfg["training"].get("batch_type", "sentence"))
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # original encoder-decoder testing
    if architecture == "encoder-decoder":
        if "test" not in cfg["data"].keys():
            raise ValueError("Test data must be specified in config.")
        # load the data
        _, dev_data, test_data, src_vocab, trg_vocab = load_data(
            data_cfg=cfg["data"])
        data_to_predict = {"dev": dev_data, "test": test_data}

        # load model state from disk
        model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

        # build model and load parameters into it
        model = build_model(cfg["model"],
                            src_vocab=src_vocab,
                            trg_vocab=trg_vocab)
        model.load_state_dict(model_checkpoint["model_state"])

        if use_cuda:
            model.cuda()

        # whether to use beam search for decoding, 0: greedy decoding
        if "testing" in cfg.keys():
            beam_size = cfg["testing"].get("beam_size", 1)
            beam_alpha = cfg["testing"].get("alpha", -1)
            postprocess = cfg["testing"].get("postprocess", True)
        else:
            beam_size = 1
            beam_alpha = -1
            postprocess = True

        for data_set_name, data_set in data_to_predict.items():

            # pylint: disable=unused-variable
            score, loss, ppl, sources, sources_raw, references, hypotheses, \
            hypotheses_raw, attention_scores = validate_on_data(
                model, data=data_set, batch_size=batch_size,
                batch_type=batch_type, level=level,
                max_output_length=max_output_length, eval_metric=eval_metric,
                use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
                beam_alpha=beam_alpha, logger=logger, postprocess=postprocess)
            # pylint: enable=unused-variable

            if "trg" in data_set.fields:
                decoding_description = "Greedy decoding" if beam_size < 2 else \
                    "Beam search decoding with beam size = {} and alpha = {}". \
                        format(beam_size, beam_alpha)
                logger.info("%4s %s: %6.2f [%s]", data_set_name, eval_metric,
                            score, decoding_description)
            else:
                logger.info("No references given for %s -> no evaluation.",
                            data_set_name)

            if save_attention:
                if attention_scores:
                    attention_name = "{}.{}.att".format(data_set_name, step)
                    attention_path = os.path.join(model_dir, attention_name)
                    logger.info(
                        "Saving attention plots. This might take a while..")
                    store_attention_plots(attentions=attention_scores,
                                          targets=hypotheses_raw,
                                          sources=data_set.src,
                                          indices=range(len(hypotheses)),
                                          output_prefix=attention_path)
                    logger.info("Attention plots saved to: %s", attention_path)
                else:
                    logger.warning(
                        "Attention scores could not be saved. "
                        "Note that attention scores are not available "
                        "when using beam search. "
                        "Set beam_size to 1 for greedy decoding.")

            if output_path is not None:
                output_path_set = "{}.{}".format(output_path, data_set_name)
                with open(output_path_set, mode="w",
                          encoding="utf-8") as out_file:
                    for hyp in hypotheses:
                        out_file.write(hyp + "\n")
                logger.info("Translations saved to: %s", output_path_set)
    else:
        # unsupervised NMT testing
        if "src2trg_test" not in cfg["data"].keys(
        ) or "trg2src_test" not in cfg["data"].keys():
            raise ValueError("Test data must be specified in config.")
        # load the data
        _, _, _, _, dev_src2trg, dev_trg2src, test_src2trg, test_trg2src, src_vocab, trg_vocab, _ = \
            load_unsupervised_data(data_cfg=cfg["data"])
        data_to_predict = {
            "src2trg": {
                "dev_src2trg": dev_src2trg,
                "test_src2trg": test_src2trg
            },
            "trg2src": {
                "dev_trg2src": dev_trg2src,
                "test_trg2src": test_trg2src
            }
        }

        # load model state from disk
        model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

        # build model and load parameters into it
        model = build_model(cfg["model"],
                            src_vocab=src_vocab,
                            trg_vocab=trg_vocab)
        assert isinstance(model, UnsupervisedNMTModel)
        model.src2src_translator.load_state_dict(
            model_checkpoint["src2src_model_state"])
        model.trg2trg_translator.load_state_dict(
            model_checkpoint["trg2trg_model_state"])
        model.src2trg_translator.load_state_dict(
            model_checkpoint["src2trg_model_state"])
        model.trg2src_translator.load_state_dict(
            model_checkpoint["trg2src_model_state"])

        if use_cuda:
            model.src2trg_translator.cuda()
            model.trg2trg_translator.cuda()
            model.src2trg_translator.cuda()
            model.trg2src_translator.cuda()

        # whether to use beam search for decoding, 0: greedy decoding
        if "testing" in cfg.keys():
            beam_size = cfg["testing"].get("beam_size", 1)
            beam_alpha = cfg["testing"].get("alpha", -1)
            postprocess = cfg["testing"].get("postprocess", True)
        else:
            beam_size = 1
            beam_alpha = -1
            postprocess = True

        for translation_direction, dataset_dict in data_to_predict.items():
            # choose correct translator
            if translation_direction == "src2trg":
                model_to_use = model.src2trg_translator
            else:
                model_to_use = model.trg2src_translator

            for dataset_name, dataset in dataset_dict.items():
                score, loss, ppl, sources, sources_raw, references, hypotheses, \
                hypotheses_raw, attention_scores = validate_on_data(
                    model_to_use, data=dataset, batch_size=batch_size,
                    batch_type=batch_type, level=level,
                    max_output_length=max_output_length, eval_metric=eval_metric,
                    use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
                    beam_alpha=beam_alpha, logger=logger, postprocess=postprocess)

                if "trg" in dataset.fields:
                    decoding_description = "Greedy decoding" if beam_size < 2 else \
                        "Beam search decoding with beam size = {} and alpha = {}". \
                            format(beam_size, beam_alpha)
                    logger.info("%4s %s: %6.2f [%s]", dataset_name,
                                eval_metric, score, decoding_description)
                else:
                    logger.info("No references given for %s -> no evaluation.",
                                dataset_name)

                if save_attention:
                    if attention_scores:
                        attention_name = "{}.{}.att".format(dataset_name, step)
                        attention_path = os.path.join(model_dir,
                                                      attention_name)
                        logger.info(
                            "Saving attention plots. This might take a while.."
                        )
                        store_attention_plots(attentions=attention_scores,
                                              targets=hypotheses_raw,
                                              sources=dataset.src,
                                              indices=list(
                                                  range(len(hypotheses))),
                                              output_prefix=attention_path)
                        logger.info("Attention plots saved to: %s",
                                    attention_path)
                    else:
                        logger.warning(
                            "Attention scores could not be saved. "
                            "Note that attention scores are not available "
                            "when using beam search. "
                            "Set beam_size to 1 for greedy decoding.")

                if output_path is not None:
                    output_path_set = "{}.{}".format(output_path, dataset_name)
                    with open(output_path_set, mode="w",
                              encoding="utf-8") as out_file:
                        for hyp in hypotheses:
                            out_file.write(hyp + "\n")
                    logger.info("Translations saved to: %s", output_path_set)
Пример #7
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            count = 0
            epoch_loss = 0

            for batch in iter(train_iter):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                self.tb_writer.add_scalar("train/train_batch_loss", batch_loss,
                                          self.steps)
                count = self.batch_multiplier if update else count
                count -= 1
                epoch_loss += batch_loss.detach().cpu().numpy()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %d Step: %d Batch Loss: %f Tokens per Sec: %f",
                        epoch_no + 1, self.steps, batch_loss,
                        elapsed_tokens / elapsed)
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if valid_data is not None and \
                    self.steps % self.validation_freq == 0 and update:

                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores, \
                        valid_logps = validate_on_data(
                            batch_size=self.batch_size, data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            return_logp=self.return_logp)

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(sources_raw=valid_sources_raw,
                                       sources=valid_sources,
                                       hypotheses_raw=valid_hypotheses_raw,
                                       hypotheses=valid_hypotheses,
                                       references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result at epoch %d, step %d: %s: %f, '
                        'loss: %f, ppl: %f, duration: %.4fs', epoch_no + 1,
                        self.steps, self.eval_metric, valid_score, valid_loss,
                        valid_ppl, valid_duration)

                    # store validation set outputs
                    self._store_outputs(
                        valid_hypotheses if self.post_process else
                        [" ".join(v) for v in valid_hypotheses_raw],
                        valid_logps if self.return_logp else None)

                    # store attention plots for selected valid sentences
                    store_attention_plots(attentions=valid_attention_scores,
                                          targets=valid_hypotheses_raw,
                                          sources=[s for s in valid_data.src],
                                          indices=self.log_valid_sents,
                                          output_prefix="{}/att.{}".format(
                                              self.model_dir, self.steps),
                                          tb_writer=self.tb_writer,
                                          steps=self.steps)

                if self.save_freq > 0 and self.steps % self.save_freq == 0:
                    ## Drop checkpoint by number of batches
                    ## Take care of batch multipler in to description
                    self.logger.info("Saving new checkpoint!"
                                     "Batches passed:{}"
                                     "Number of updates:{}".format(
                                         self.batch_multiplier * self.steps,
                                         self.steps))
                    self._save_checkpoint()

                if self.stop:
                    break

            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %d epochs.', epoch_no + 1)

        if valid_data is not None:
            self.logger.info('Best validation result at step %d: %f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)
Пример #8
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset):
        """
        Train the model and validate it on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(
            train_data,
            batch_size=self.batch_size,
            batch_type=self.batch_type,
            train=True,
            shuffle=self.shuffle)
        for epoch_no in range(1, self.epochs + 1):
            self.logger.info("EPOCH %d", epoch_no)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no - 1)  # 0-based indexing

            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter), 1):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = i % self.batch_multiplier == 0
                batch_loss = self._train_batch(batch, update=update)

                self.log_tensorboard("train", batch_loss=batch_loss)

                epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f",
                        epoch_no, self.steps, batch_loss,
                        elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    processed_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    # it would be nice to include loss and ppl in valid_scores
                    valid_scores, valid_sources, valid_sources_raw, \
                        valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores, \
                        scores_by_lang, by_lang = validate_on_data(
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metrics=self.eval_metrics,
                            attn_metrics=self.attn_metrics,
                            src_level=self.src_level,
                            trg_level=self.trg_level,
                            model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=0,  # greedy validations
                            batch_type=self.eval_batch_type,
                            save_attention=self.plot_attention,
                            log_sparsity=self.log_sparsity,
                            apply_mask=self.valid_apply_mask
                        )

                    ckpt_score = valid_scores[self.early_stopping_metric]
                    self.log_tensorboard("valid", **valid_scores)

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(
                        valid_scores=valid_scores,
                        eval_metrics=self.eval_metrics,
                        new_best=new_best)

                    self._log_examples(
                        sources_raw=valid_sources_raw,
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references
                    )

                    labeled_scores = sorted(valid_scores.items())
                    eval_report = ", ".join("{}: {:.5f}".format(n, v)
                                            for n, v in labeled_scores)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration

                    self.logger.info(
                        'Validation result at epoch %3d, step %8d: %s, '
                        'duration: %.4fs',
                        epoch_no, self.steps, eval_report, valid_duration)

                    if scores_by_lang is not None:
                        for metric, scores in scores_by_lang.items():
                            # make a report
                            lang_report = [metric]
                            numbers = sorted(scores.items())
                            lang_report.extend(["{}: {:.5f}".format(k, v)
                                                for k, v in numbers])

                            self.logger.info("\n\t".join(lang_report))

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores and self.plot_attention:
                        store_attention_plots(
                                attentions=valid_attention_scores,
                                sources=[s for s in valid_data.src],
                                targets=valid_hypotheses_raw,
                                indices=self.log_valid_sents,
                                model_dir=self.model_dir,
                                tb_writer=self.tb_writer,
                                steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info(
                'Epoch %3d: total training loss %.2f', epoch_no, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no)
        self.logger.info('Best validation result at step %8d: %6.2f %s.',
                         self.best_ckpt_iteration, self.best_ckpt_score,
                         self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #9
0
def test(cfg_file,
         ckpt,  # str or list now
         output_path: str = None,
         save_attention: bool = False,
         logger: logging.Logger = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = logging.getLogger(__name__)
        FORMAT = '%(asctime)-15s - %(message)s'
        logging.basicConfig(format=FORMAT)
        logger.setLevel(level=logging.DEBUG)

    cfg = load_config(cfg_file)
    train_cfg = cfg["training"]
    data_cfg = cfg["data"]
    test_cfg = cfg["testing"]

    if "test" not in data_cfg.keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        model_dir = train_cfg["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError("No checkpoint found in directory {}."
                                    .format(model_dir))
        try:
            step = ckpt.split(model_dir+"/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = train_cfg.get("eval_batch_size", train_cfg["batch_size"])
    batch_type = train_cfg.get("eval_batch_type", train_cfg.get("batch_type", "sentence"))
    use_cuda = train_cfg.get("use_cuda", False)
    src_level = data_cfg.get("src_level", data_cfg.get("level", "word"))
    trg_level = data_cfg.get("trg_level", data_cfg.get("level", "word"))

    eval_metric = train_cfg["eval_metric"]
    if isinstance(eval_metric, str):
        eval_metric = [eval_metric]
    attn_metric = train_cfg.get("attn_metric", [])
    if isinstance(attn_metric, str):
        attn_metric = [attn_metric]
    max_output_length = train_cfg.get("max_output_length", None)

    # load the data
    data = load_data(data_cfg)
    dev_data = data["dev_data"]
    test_data = data["test_data"]
    vocabs = data["vocabs"]

    data_to_predict = {"dev": dev_data, "test": test_data}

    # load model state from disk
    if isinstance(ckpt, str):
        ckpt = [ckpt]
    individual_models = []
    for c in ckpt:
        model_checkpoint = load_checkpoint(c, use_cuda=use_cuda)

        # build model and load parameters into it
        m = build_model(cfg["model"], vocabs=vocabs)
        m.load_state_dict(model_checkpoint["model_state"])
        individual_models.append(m)
    if len(individual_models) == 1:
        model = individual_models[0]
    else:
        model = EnsembleModel(*individual_models)

    if use_cuda:
        model.cuda()

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_sizes = test_cfg.get("beam_size", 0)
        beam_alpha = test_cfg.get("alpha", 0)
    else:
        beam_sizes = 0
        beam_alpha = 0
    if isinstance(beam_sizes, int):
        beam_sizes = [beam_sizes]
    assert beam_alpha >= 0, "Use alpha >= 0"

    for beam_size in beam_sizes:
        for data_set_name, data_set in data_to_predict.items():

            #pylint: disable=unused-variable
            scores, sources, sources_raw, references, hypotheses, \
            hypotheses_raw, attention_scores, scores_by_lang, by_lang = validate_on_data(
                model, data=data_set, batch_size=batch_size,
                batch_type=batch_type,
                src_level=src_level, trg_level=trg_level,
                max_output_length=max_output_length, eval_metrics=eval_metric,
                attn_metrics=attn_metric,
                use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
                beam_alpha=beam_alpha, save_attention=save_attention)
            #pylint: enable=unused-variable

            if "trg" in data_set.fields:
                labeled_scores = sorted(scores.items())
                eval_report = ", ".join("{}: {:.5f}".format(n, v)
                                        for n, v in labeled_scores)
                decoding_description = "Greedy decoding" if beam_size == 0 else \
                    "Beam search decoding with beam size = {} and alpha = {}".\
                        format(beam_size, beam_alpha)
                logger.info("%4s %s: [%s]",
                            data_set_name, eval_report, decoding_description)
                if scores_by_lang is not None:
                    for metric, scores in scores_by_lang.items():
                        # make a report
                        lang_report = [metric]
                        numbers = sorted(scores.items())
                        lang_report.extend(["{}: {:.5f}".format(k, v)
                                            for k, v in numbers])

                        logger.info("\n\t".join(lang_report))
            else:
                logger.info("No references given for %s -> no evaluation.",
                            data_set_name)

            if save_attention:
                # currently this will break for transformers
                if attention_scores:
                    #attention_name = "{}.{}.att".format(data_set_name, step)
                    #attention_path = os.path.join(model_dir, attention_name)
                    logger.info("Saving attention plots. This might take a while..")
                    store_attention_plots(attentions=attention_scores,
                                          targets=hypotheses_raw,
                                          sources=[s for s in data_set.src],
                                          indices=range(len(hypotheses)),
                                          model_dir=model_dir,
                                          steps=step,
                                          data_set_name=data_set_name)
                    logger.info("Attention plots saved to: %s", model_dir)
                else:
                    logger.warning("Attention scores could not be saved. "
                                   "Note that attention scores are not available "
                                   "when using beam search. "
                                   "Set beam_size to 0 for greedy decoding.")

            if output_path is not None:
                for lang, ref_and_hyp in by_lang.items():
                    if lang is None:
                        # monolingual case
                        output_path_set = "{}.{}".format(output_path, data_set_name)
                    else:
                        output_path_set = "{}.{}.{}".format(output_path, lang, data_set_name)
                    if isinstance(ref_and_hyp[0], str):
                        hyps = ref_and_hyp
                    else:
                        hyps = [hyp for (ref, hyp) in ref_and_hyp]
                    with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                        for hyp in hyps:
                            out_file.write(hyp + "\n")
                    logger.info("Translations saved to: %s", output_path_set)
Пример #10
0
def test(cfg_file,
         ckpt: str = None,
         output_path: str = None,
         save_attention: bool = False):
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file:
    :param ckpt:
    :param output_path:
    :param save_attention:
    :return:
    """

    cfg = load_config(cfg_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take oldest from model dir
    if ckpt is None:
        dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(dir)
        try:
            step = ckpt.split(dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = cfg["training"]["batch_size"]
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    # TODO load only test data
    train_data, dev_data, test_data, src_vocab, trg_vocab = \
        load_data(cfg=cfg)

    # TODO specify this differently
    data_to_predict = {"dev": dev_data, "test": test_data}

    # load model state from disk
    model_checkpoint = load_model_from_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.cuda()

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_size = cfg["testing"].get("beam_size", 0)
        beam_alpha = cfg["testing"].get("alpha", -1)
    else:
        beam_size = 0
        beam_alpha = -1

    for data_set_name, data_set in data_to_predict.items():

        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
            model, data=data_set, batch_size=batch_size, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, criterion=None, beam_size=beam_size,
            beam_alpha=beam_alpha)

        if "trg" in data_set.fields:
            decoding_description = "Greedy decoding" if beam_size == 0 else \
                "Beam search decoding with beam size = {} and alpha = {}".format(
                    beam_size, beam_alpha)
            print("{:4s} {}: {} [{}]".format(data_set_name, eval_metric, score,
                                             decoding_description))
        else:
            print("No references given for {} -> no evaluation.".format(
                data_set_name))

        if attention_scores is not None and save_attention:
            attention_path = "{}/{}.{}.att".format(dir, data_set_name, step)
            print("Attention plots saved to: {}.xx".format(attention_path))
            store_attention_plots(attentions=attention_scores,
                                  targets=hypotheses_raw,
                                  sources=[s for s in data_set.src],
                                  idx=range(len(hypotheses)),
                                  output_prefix=attention_path)

        if output_path is not None:
            output_path_set = "{}.{}".format(output_path, data_set_name)
            with open(output_path_set, mode="w", encoding="utf-8") as f:
                for h in hypotheses:
                    f.write(h + "\n")
            print("Translations saved to: {}".format(output_path_set))
Пример #11
0
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False,
         logger: Logger = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = make_logger()

    cfg = load_config(cfg_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir))
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = cfg["training"].get("eval_batch_size",
                                     cfg["training"]["batch_size"])
    batch_type = cfg["training"].get(
        "eval_batch_type", cfg["training"].get("batch_type", "sentence"))
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    _, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

    data_to_predict = {"dev": dev_data, "test": test_data}

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.cuda()

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_size = cfg["testing"].get("beam_size", 1)
        beam_alpha = cfg["testing"].get("alpha", -1)
    else:
        beam_size = 1
        beam_alpha = -1

    for data_set_name, data_set in data_to_predict.items():

        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores,valid_hypotheses_full_n_best,scores = validate_on_data(
            model, data=data_set, batch_size=batch_size,
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha, logger=logger)
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
            decoding_description = "Greedy decoding" if beam_size < 2 else \
                "Beam search decoding with beam size = {} and alpha = {}".\
                    format(beam_size, beam_alpha)
            logger.info("%4s %s: %6.2f [%s]", data_set_name, eval_metric,
                        score, decoding_description)
        else:
            logger.info("No references given for %s -> no evaluation.",
                        data_set_name)

        if save_attention:
            if attention_scores:
                attention_name = "{}.{}.att".format(data_set_name, step)
                attention_path = os.path.join(model_dir, attention_name)
                logger.info(
                    "Saving attention plots. This might take a while..")
                store_attention_plots(attentions=attention_scores,
                                      targets=hypotheses_raw,
                                      sources=data_set.src,
                                      indices=range(len(hypotheses)),
                                      output_prefix=attention_path)
                logger.info("Attention plots saved to: %s", attention_path)
            else:
                logger.warning("Attention scores could not be saved. "
                               "Note that attention scores are not available "
                               "when using beam search. "
                               "Set beam_size to 1 for greedy decoding.")

        if output_path is not None:
            '''
            output_path_set = "{}.{}".format(output_path, data_set_name)
            with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                for hyp in hypotheses:
                    out_file.write(hyp + "\n")


            #sy_debug
            alt_output = "{}.n_best.{}".format(output_path, data_set_name)
            with open(alt_output, mode="w", encoding="utf-8") as out_file:
                for n in valid_hypotheses_full_n_best:
                    out_file.write(n + "\n")
'''

            #@Shiya: exporting hypothesis and associated score to .csv file
            #TODO: write_to_csv(hyps,scores)
            def write_to_csv(hyps: list, scores: list):
                import csv

                output_file = "{}.n_csv.{}".format(output_path, data_set_name)
                with open(output_file, mode="w", newline='',
                          encoding="utf-8") as out_file:
                    fieldnames = ['Predictions', 'Scores']
                    writer = csv.DictWriter(out_file, fieldnames=fieldnames)
                    writer.writeheader()

                    for prediction, score in zip(hyps, scores):
                        writer.writerow({
                            fieldnames[0]: prediction,
                            fieldnames[1]: score
                        })

            write_to_csv(valid_hypotheses_full_n_best, scores)
Пример #12
0
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False,
         logger: logging.Logger = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = logging.getLogger(__name__)
        FORMAT = '%(asctime)-15s - %(message)s'
        logging.basicConfig(format=FORMAT)
        logger.setLevel(level=logging.DEBUG)

    cfg = load_config(cfg_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir))
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = cfg["training"]["batch_size"]
    batch_type = cfg["training"].get("batch_type", "sentence")
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    _, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

    data_to_predict = {"dev": dev_data, "test": test_data}

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.cuda()

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_size = cfg["testing"].get("beam_size", 0)
        beam_alpha = cfg["testing"].get("alpha", -1)
    else:
        beam_size = 0
        beam_alpha = -1

    for data_set_name, data_set in data_to_predict.items():

        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
            model, data=data_set, batch_size=batch_size,
            batch_type=batch_type, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha)
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
            decoding_description = "Greedy decoding" if beam_size == 0 else \
                "Beam search decoding with beam size = {} and alpha = {}".\
                    format(beam_size, beam_alpha)
            logger.info("%4s %s: %6.2f [%s]", data_set_name, eval_metric,
                        score, decoding_description)
        else:
            logger.info("No references given for %s -> no evaluation.",
                        data_set_name)

        if save_attention:
            if attention_scores:
                attention_name = "{}.{}.att".format(data_set_name, step)
                attention_path = os.path.join(model_dir, attention_name)
                logger.info(
                    "Saving attention plots. This might take a while..")
                store_attention_plots(attentions=attention_scores,
                                      targets=hypotheses_raw,
                                      sources=[s for s in data_set.src],
                                      indices=range(len(hypotheses)),
                                      output_prefix=attention_path)
                logger.info("Attention plots saved to: %s", attention_path)
            else:
                logger.warning("Attention scores could not be saved. "
                               "Note that attention scores are not available "
                               "when using beam search. "
                               "Set beam_size to 0 for greedy decoding.")

        if output_path is not None:
            output_path_set = "{}.{}".format(output_path, data_set_name)
            with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                for hyp in hypotheses:
                    out_file.write(hyp + "\n")
            logger.info("Translations saved to: %s", output_path_set)
Пример #13
0
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    """

    cfg = load_config(cfg_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take oldest from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir))
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = cfg["training"]["batch_size"]
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    _, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

    data_to_predict = {"dev": dev_data, "test": test_data}

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)
    model.load_state_dict(model_checkpoint["model_state"])

    if use_cuda:
        model.cuda()

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_size = cfg["testing"].get("beam_size", 0)
        beam_alpha = cfg["testing"].get("alpha", -1)
    else:
        beam_size = 0
        beam_alpha = -1

    for data_set_name, data_set in data_to_predict.items():
        if data_set is None:
            # e.g. no valid_data
            continue

        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores, logprobs = validate_on_data(
            model, data=data_set, batch_size=batch_size, level=level,
            max_output_length=max_output_length, eval_metric=eval_metric,
            use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
            beam_alpha=beam_alpha)
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
            decoding_description = "Greedy decoding" if beam_size == 0 else \
                "Beam search decoding with beam size = {} and alpha = {}".\
                    format(beam_size, beam_alpha)
            print("{:4s} {}: {} [{}]".format(data_set_name, eval_metric, score,
                                             decoding_description))
        else:
            print("No references given for {} -> no evaluation.".format(
                data_set_name))

        if attention_scores is not None and save_attention:
            attention_path = "{}/{}.{}.att".format(model_dir, data_set_name,
                                                   step)
            print("Attention plots saved to: {}.xx".format(attention_path))
            store_attention_plots(attentions=attention_scores,
                                  targets=hypotheses_raw,
                                  sources=[s for s in data_set.src],
                                  indices=range(len(hypotheses)),
                                  output_prefix=attention_path)

        if output_path is not None:
            output_path_set = "{}.{}".format(output_path, data_set_name)
            with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                if cfg["data"].get("post_process", True):
                    for hyp in hypotheses:
                        out_file.write(hyp + "\n")
                else:
                    for hyp in hypotheses_raw:
                        out_file.write(" ".join(hyp) + "\n")
            print("Translations saved to: {}".format(output_path_set))
Пример #14
0
def test(cfg_file,
         ckpt: str,
         output_path: str = None,
         save_attention: bool = False,
         logger: logging.Logger = None) -> None:
    """
    Main test function. Handles loading a model from checkpoint, generating
    translations and storing them and attention plots.

    :param cfg_file: path to configuration file
    :param ckpt: path to checkpoint to load
    :param output_path: path to output
    :param save_attention: whether to save the computed attention weights
    :param logger: log output to this logger (creates new logger if not set)
    """

    if logger is None:
        logger = logging.getLogger(__name__)
        FORMAT = '%(asctime)-15s - %(message)s'
        logging.basicConfig(format=FORMAT)
        logger.setLevel(level=logging.DEBUG)

    cfg = load_config(cfg_file)

    if "test" not in cfg["data"].keys():
        raise ValueError("Test data must be specified in config.")

    # when checkpoint is not specified, take latest (best) from model dir
    if ckpt is None:
        model_dir = cfg["training"]["model_dir"]
        ckpt = get_latest_checkpoint(model_dir)
        if ckpt is None:
            raise FileNotFoundError(
                "No checkpoint found in directory {}.".format(model_dir))
        try:
            step = ckpt.split(model_dir + "/")[1].split(".ckpt")[0]
        except IndexError:
            step = "best"

    batch_size = cfg["training"].get("eval_batch_size",
                                     cfg["training"]["batch_size"])
    batch_type = cfg["training"].get(
        "eval_batch_type", cfg["training"].get("batch_type", "sentence"))
    use_cuda = cfg["training"].get("use_cuda", False)
    level = cfg["data"]["level"]
    eval_metric = cfg["training"]["eval_metric"]
    max_output_length = cfg["training"].get("max_output_length", None)

    # load the data
    _, dev_data, test_data,\
    src_vocab, trg_vocab,\
    _, dev_kb, test_kb,\
    _, dev_kb_lookup, test_kb_lookup, \
    _, dev_kb_lengths, test_kb_lengths,\
    _, dev_kb_truvals, test_kb_truvals, \
    trv_vocab, canon_fun,\
         dev_data_canon, test_data_canon \
        = load_data(
        data_cfg=cfg["data"]
    )

    report_entf1_on_canonicals = cfg["training"].get(
        "report_entf1_on_canonicals", False)

    kb_task = (test_kb != None)

    data_to_predict = {"dev": dev_data, "test": test_data}

    # load model state from disk
    model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda)

    # build model and load parameters into it
    model = build_model(cfg["model"],
                        src_vocab=src_vocab,
                        trg_vocab=trg_vocab,
                        trv_vocab=trv_vocab,
                        canonizer=canon_fun)
    model.load_state_dict(model_checkpoint["model_state"])

    # FIXME for the moment, for testing, try overriding model.canonize with canon_fun from test functions loaded data
    # should hopefully not be an issue with gridsearch results...

    if use_cuda:
        model.cuda()  # move to GPU

    # whether to use beam search for decoding, 0: greedy decoding
    if "testing" in cfg.keys():
        beam_size = cfg["testing"].get("beam_size", 0)
        beam_alpha = cfg["testing"].get("alpha", -1)
    else:
        beam_size = 0
        beam_alpha = -1

    for data_set_name, data_set in data_to_predict.items():

        if data_set_name == "dev":
            kb_info = [
                dev_kb, dev_kb_lookup, dev_kb_lengths, dev_kb_truvals,
                dev_data_canon
            ]
        elif data_set_name == "test":
            kb_info = [
                test_kb, test_kb_lookup, test_kb_lengths, test_kb_truvals,
                test_data_canon
            ]
        else:
            raise ValueError((data_set_name, data_set))

        #pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores, kb_att_scores, ent_f1, ent_mcc = validate_on_data(
            model,
            data=data_set,
            batch_size=batch_size,
            batch_type=batch_type,
            level=level,
            max_output_length=max_output_length,
            eval_metric=eval_metric,
            use_cuda=use_cuda,
            loss_function=None,
            beam_size=beam_size,
            beam_alpha=beam_alpha,
            kb_task = kb_task,
            valid_kb=kb_info[0],
            valid_kb_lkp=kb_info[1],
            valid_kb_lens=kb_info[2],
            valid_kb_truvals=kb_info[3],
            valid_data_canon=kb_info[4],
            report_on_canonicals=report_entf1_on_canonicals
            )
        """
                batch_size=self.eval_batch_size,
                data=valid_data,
                eval_metric=self.eval_metric,
                level=self.level, 
                model=self.model,
                use_cuda=self.use_cuda,
                max_output_length=self.max_output_length,
                loss_function=self.loss,
                beam_size=0,  
                batch_type=self.eval_batch_type,
                kb_task=kb_task,
                valid_kb=valid_kb,
                valid_kb_lkp=valid_kb_lkp,
                valid_kb_lens=valid_kb_lens,
                valid_kb_truvals=valid_kb_truvals
        """
        #pylint: enable=unused-variable

        if "trg" in data_set.fields:
            decoding_description = "Greedy decoding" if beam_size == 0 else \
                "Beam search decoding with beam size = {} and alpha = {}".\
                    format(beam_size, beam_alpha)

            logger.info("%4s %s: %6.2f f1: %6.2f mcc: %6.2f [%s]",
                        data_set_name, eval_metric, score, ent_f1, ent_mcc,
                        decoding_description)
        else:
            logger.info("No references given for %s -> no evaluation.",
                        data_set_name)

        if save_attention:
            if attention_scores:
                attention_name = "{}.{}.att".format(data_set_name, step)
                attention_path = os.path.join(model_dir, attention_name)

                logger.info(
                    "Saving attention plots. This might take a while..")
                store_attention_plots(attentions=attention_scores,
                                      targets=hypotheses_raw,
                                      sources=data_set.src,
                                      indices=range(len(hypotheses)),
                                      output_prefix=attention_path)
                logger.info("Attention plots saved to: %s", attention_path)
            if kb_att_scores:
                kb_att_name = "{}.{}.kbatt".format(data_set_name, step)
                kb_att_path = os.path.join(model_dir, kb_att_name)
                store_attention_plots(
                    attentions=kb_att_scores,
                    targets=hypotheses_raw,
                    sources=list(data_set.kbsrc),  #TODO
                    indices=range(len(hypotheses)),
                    output_prefix=kb_att_path,
                    kb_info=(dev_kb_lookup, dev_kb_lengths,
                             list(data_set.kbtrg)))
                logger.info("KB Attention plots saved to: %s", attention_path)

            else:
                logger.warning("Attention scores could not be saved. "
                               "Note that attention scores are not available "
                               "when using beam search. "
                               "Set beam_size to 0 for greedy decoding.")

        if output_path is not None:
            output_path_set = "{}.{}".format(output_path, data_set_name)
            with open(output_path_set, mode="w", encoding="utf-8") as out_file:
                for hyp in hypotheses:
                    out_file.write(hyp + "\n")
            logger.info("Translations saved to: %s", output_path_set)
Пример #15
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset, kb_task=None, train_kb: TranslationDataset =None,\
        train_kb_lkp: list = [], train_kb_lens: list = [], train_kb_truvals: TranslationDataset=None, valid_kb: Tuple=None, \
        valid_kb_lkp: list=[], valid_kb_lens: list = [], valid_kb_truvals:list=[],
        valid_data_canon: list=[]) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        :param kb_task: is not None if kb_task should be executed
        :param train_kb: TranslationDataset holding the loaded train kb data
        :param train_kb_lkp: List with train example index to corresponding kb indices
        :param train_kb_len: List with num of triples per kb 
        :param valid_kb: TranslationDataset holding the loaded valid kb data
        :param valid_kb_lkp: List with valid example index to corresponding kb indices
        :param valid_kb_len: List with num of triples per kb 
        :param valid_kb_truvals: FIXME TODO
        :param valid_data_canon: required to report loss 
        """

        if kb_task:
            train_iter = make_data_iter_kb(train_data,
                                           train_kb,
                                           train_kb_lkp,
                                           train_kb_lens,
                                           train_kb_truvals,
                                           batch_size=self.batch_size,
                                           batch_type=self.batch_type,
                                           train=True,
                                           shuffle=self.shuffle,
                                           canonize=self.model.canonize)
        else:
            train_iter = make_data_iter(train_data,
                                        batch_size=self.batch_size,
                                        batch_type=self.batch_type,
                                        train=True,
                                        shuffle=self.shuffle)

        with torch.autograd.set_detect_anomaly(True):
            for epoch_no in range(self.epochs):
                self.logger.info("EPOCH %d", epoch_no + 1)

                if self.scheduler is not None and self.scheduler_step_at == "epoch":
                    self.scheduler.step(epoch=epoch_no)

                self.model.train()

                start = time.time()
                total_valid_duration = 0
                processed_tokens = self.total_tokens
                count = self.batch_multiplier - 1
                epoch_loss = 0

                for batch in iter(train_iter):
                    # reactivate training
                    self.model.train()

                    # create a Batch object from torchtext batch
                    batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) if not kb_task else \
                        Batch_with_KB(batch, self.pad_index, use_cuda=self.use_cuda)

                    if kb_task:
                        assert hasattr(batch, "kbsrc"), dir(batch)
                        assert hasattr(batch, "kbtrg"), dir(batch)
                        assert hasattr(batch, "kbtrv"), dir(batch)

                    # only update every batch_multiplier batches
                    # see https://medium.com/@davidlmorton/
                    # increasing-mini-batch-size-without-increasing-
                    # memory-6794e10db672
                    update = count == 0

                    batch_loss = self._train_batch(batch, update=update)

                    if update:
                        self.tb_writer.add_scalar("train/train_batch_loss",
                                                  batch_loss, self.steps)

                    count = self.batch_multiplier if update else count
                    count -= 1
                    epoch_loss += batch_loss.detach().cpu().numpy()

                    if self.scheduler is not None and \
                            self.scheduler_step_at == "step" and update:
                        self.scheduler.step()

                    # log learning progress
                    if self.steps % self.logging_freq == 0 and update:
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.total_tokens - processed_tokens
                        self.logger.info(
                            "Epoch %3d Step: %8d Batch Loss: %12.6f "
                            "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                            self.steps, batch_loss, elapsed_tokens / elapsed,
                            self.optimizer.param_groups[0]["lr"])
                        start = time.time()
                        total_valid_duration = 0

                    # validate on the entire dev set
                    if self.steps % self.validation_freq == 0 and update:

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("train")
                            self.decoder_timer.reset()

                        valid_start_time = time.time()


                        valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                            valid_hypotheses_raw, valid_attention_scores, valid_kb_att_scores, \
                            valid_ent_f1, valid_ent_mcc = \
                            validate_on_data(
                                batch_size=self.eval_batch_size,
                                data=valid_data,
                                eval_metric=self.eval_metric,
                                level=self.level,
                                model=self.model,
                                use_cuda=self.use_cuda,
                                max_output_length=self.max_output_length,
                                loss_function=self.loss,
                                beam_size=0,  # greedy validations #FIXME XXX NOTE TODO BUG set to 0 again!
                                batch_type=self.eval_batch_type,
                                kb_task=kb_task,
                                valid_kb=valid_kb,
                                valid_kb_lkp=valid_kb_lkp,
                                valid_kb_lens=valid_kb_lens,
                                valid_kb_truvals=valid_kb_truvals,
                                valid_data_canon=valid_data_canon,
                                report_on_canonicals=self.report_entf1_on_canonicals
                            )

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("valid")
                            self.decoder_timer.reset()

                        self.tb_writer.add_scalar("valid/valid_loss",
                                                  valid_loss, self.steps)
                        self.tb_writer.add_scalar("valid/valid_score",
                                                  valid_score, self.steps)
                        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                                  self.steps)

                        if self.early_stopping_metric == "loss":
                            ckpt_score = valid_loss
                        elif self.early_stopping_metric in [
                                "ppl", "perplexity"
                        ]:
                            ckpt_score = valid_ppl
                        else:
                            ckpt_score = valid_score

                        new_best = False
                        if self.is_best(ckpt_score):
                            self.best_ckpt_score = ckpt_score
                            self.best_ckpt_iteration = self.steps
                            self.logger.info(
                                'Hooray! New best validation result [%s]!',
                                self.early_stopping_metric)
                            if self.ckpt_queue.maxsize > 0:
                                self.logger.info("Saving new checkpoint.")
                                new_best = True
                                self._save_checkpoint()

                        if self.scheduler is not None \
                                and self.scheduler_step_at == "validation":
                            self.scheduler.step(ckpt_score)

                        # append to validation report
                        self._add_report(valid_score=valid_score,
                                         valid_loss=valid_loss,
                                         valid_ppl=valid_ppl,
                                         eval_metric=self.eval_metric,
                                         valid_ent_f1=valid_ent_f1,
                                         valid_ent_mcc=valid_ent_mcc,
                                         new_best=new_best)

                        # pylint: disable=unnecessary-comprehension
                        self._log_examples(
                            sources_raw=[v for v in valid_sources_raw],
                            sources=valid_sources,
                            hypotheses_raw=valid_hypotheses_raw,
                            hypotheses=valid_hypotheses,
                            references=valid_references)

                        valid_duration = time.time() - valid_start_time
                        total_valid_duration += valid_duration
                        self.logger.info(
                            'Validation result at epoch %3d, step %8d: %s: %6.2f, '
                            'loss: %8.4f, ppl: %8.4f, duration: %.4fs',
                            epoch_no + 1, self.steps, self.eval_metric,
                            valid_score, valid_loss, valid_ppl, valid_duration)

                        # store validation set outputs
                        self._store_outputs(valid_hypotheses)

                        valid_src = list(valid_data.src)
                        # store attention plots for selected valid sentences
                        if valid_attention_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_attention_scores,
                                targets=valid_hypotheses_raw,
                                sources=valid_src,
                                indices=self.log_valid_sents,
                                output_prefix="{}/att.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps)
                            self.logger.info(
                                f"stored {plot_success_ratio} valid att scores!"
                            )
                        if valid_kb_att_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_kb_att_scores,
                                targets=valid_hypotheses_raw,
                                sources=list(valid_kb.kbsrc),
                                indices=self.log_valid_sents,
                                output_prefix="{}/kbatt.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps,
                                kb_info=(valid_kb_lkp, valid_kb_lens,
                                         valid_kb_truvals),
                                on_the_fly_info=(valid_src, valid_kb,
                                                 self.model.canonize,
                                                 self.model.trg_vocab))
                            self.logger.info(
                                f"stored {plot_success_ratio} valid kb att scores!"
                            )
                        else:
                            self.logger.info(
                                "theres no valid kb att scores...")
                    if self.stop:
                        break
                if self.stop:
                    self.logger.info(
                        'Training ended since minimum lr %f was reached.',
                        self.learning_rate_min)
                    break

                self.logger.info('Epoch %3d: total training loss %.2f',
                                 epoch_no + 1, epoch_loss)
            else:
                self.logger.info('Training ended after %3d epochs.',
                                 epoch_no + 1)
            self.logger.info('Best validation result at step %8d: %6.2f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer