Beispiel #1
0
    def joey_translate(self, message_text, model, src_vocab, trg_vocab, logger,
                       beam_size, beam_alpha, level, lowercase,
                       max_output_length, use_cuda, nbest, cuda_device):

        sentence = message_text.strip()
        if lowercase:
            sentence = sentence.lower()

        # load the data which consists only of this sentence
        test_data, src_vocab, trg_vocab = self.load_line_as_data(
            lowercase=lowercase,
            line=sentence,
            src_vocab=src_vocab,
            trg_vocab=trg_vocab,
            level=level)

        # generate outputs
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
        hypotheses_raw, attention_scores = validate_on_data(
            model=model, data=test_data, batch_size=1, level=level,
            max_output_length=max_output_length, eval_metric=None,
            use_cuda=use_cuda, beam_size=beam_size,
            beam_alpha=beam_alpha, n_best=nbest, cuda_device=cuda_device, n_gpu=1)

        return hypotheses[0] if nbest == 1 else hypotheses
Beispiel #2
0
def translate(message_text, model, src_vocab, trg_vocab, preprocess,
              postprocess, logger, beam_size, beam_alpha, level, lowercase,
              max_output_length, use_cuda):
    """
    Describes how to translate a text message.

    :param message_text: Slack command, could be text.
    :param model: The Joey NMT model.
    :param src_vocab: Source vocabulary.
    :param trg_vocab: Target vocabulary.
    :param preprocess: Preprocessing pipeline (a list).
    :param postprocess: Postprocessing pipeline (a list).
    :param beam_size: Beam size for decoding.
    :param beam_alpha: Beam alpha for decoding.
    :param level: Segmentation level.
    :param lowercase: Lowercasing.
    :param max_output_length: Maximum output length.
    :param use_cuda: Using CUDA or not.
    :return:
    """
    sentence = message_text.strip()
    # remove emojis
    emoji_pattern = re.compile("\:[a-zA-Z]+\:")
    sentence = re.sub(emoji_pattern, "", sentence)
    sentence = sentence.strip()
    if lowercase:
        sentence = sentence.lower()
    for p in preprocess:
        sentence = p(sentence)

    # load the data which consists only of this sentence
    test_data, src_vocab, trg_vocab = load_line_as_data(lowercase=lowercase,
                                                        line=sentence,
                                                        src_vocab=src_vocab,
                                                        trg_vocab=trg_vocab,
                                                        level=level)

    # generate outputs
    score, loss, ppl, sources, sources_raw, references, hypotheses, \
    hypotheses_raw, attention_scores = validate_on_data(
        model, data=test_data, batch_size=1, level=level,
        max_output_length=max_output_length, eval_metric=None,
        use_cuda=use_cuda, loss_function=None, beam_size=beam_size,
        beam_alpha=beam_alpha, logger=logger)

    # post-process
    if level == "char":
        response = "".join(hypotheses)
    else:
        response = " ".join(hypotheses)

    for p in postprocess:
        response = p(response)

    return response
    def _translate(self, n_best):
        (batch_size, batch_type, use_cuda, device, n_gpu, level, eval_metric,
         max_output_length, beam_size, beam_alpha, postprocess, bpe_type,
         sacrebleu, _, _) = self.parsed_cfg

        (score, loss, ppl, sources, sources_raw, references, hypotheses,
         hypotheses_raw, attention_scores) = validate_on_data(
            self.model, data=self.test_data, batch_size=batch_size,
            batch_type=batch_type, level=level, use_cuda=use_cuda,
            max_output_length=max_output_length, eval_metric=None,
            compute_loss=False, beam_size=beam_size, beam_alpha=beam_alpha,
            postprocess=postprocess, bpe_type=bpe_type, sacrebleu=sacrebleu,
            n_gpu=n_gpu, n_best=n_best)
        return sources, hypotheses
Beispiel #4
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    batch_type=self.batch_type,
                                    train=True,
                                    shuffle=self.shuffle)

        # For last batch in epoch batch_multiplier needs to be adjusted
        # to fit the number of leftover training examples
        leftover_batch_size = len(train_data) % (self.batch_multiplier *
                                                 self.batch_size)

        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            self.current_batch_multiplier = self.batch_multiplier
            self.optimizer.zero_grad()
            count = self.current_batch_multiplier - 1
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter)):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672

                # Set current_batch_mutliplier to fit
                # number of leftover examples for last batch in epoch
                # Only works if batch_type == sentence
                if self.batch_type == "sentence":
                    if self.batch_multiplier > 1 and i == len(train_iter) - \
                            math.ceil(leftover_batch_size / self.batch_size):
                        self.current_batch_multiplier = math.ceil(
                            leftover_batch_size / self.batch_size)
                        count = self.current_batch_multiplier - 1

                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch,
                                               update=update,
                                               count=count)

                # Only save finaly computed batch_loss of full batch
                if update:
                    self.tb_writer.add_scalar("train/train_batch_loss",
                                              batch_loss, self.steps)

                count = self.batch_multiplier if update else count
                count -= 1

                # Only add complete batch_loss of full mini-batch to epoch_loss
                if update:
                    epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                        self.steps, batch_loss, elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    start_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type,
                            postprocess=True # always remove BPE for validation
                        )

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(
                        sources_raw=[v for v in valid_sources_raw],
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
                        self.eval_metric, valid_score, valid_loss, valid_ppl,
                        valid_duration)

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores:
                        store_attention_plots(
                            attentions=valid_attention_scores,
                            targets=valid_hypotheses_raw,
                            sources=[s for s in valid_data.src],
                            indices=self.log_valid_sents,
                            output_prefix="{}/att.{}".format(
                                self.model_dir, self.steps),
                            tb_writer=self.tb_writer,
                            steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info(
            'Best validation result (greedy) at step '
            '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score,
            self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Beispiel #5
0
    def translate(self, message_text, model, src_vocab, trg_vocab, preprocess, postprocess,
              logger, beam_size, beam_alpha, level, lowercase,
              max_output_length, use_cuda):
      """
      Describes how to translate a text message.

      :param message_text: Slack command, could be text.
      :param model: The Joey NMT model.
      :param src_vocab: Source vocabulary.
      :param trg_vocab: Target vocabulary.
      :param preprocess: Preprocessing pipeline (a list).
      :param postprocess: Postprocessing pipeline (a list).
      :param beam_size: Beam size for decoding.
      :param beam_alpha: Beam alpha for decoding.
      :param level: Segmentation level.
      :param lowercase: Lowercasing.
      :param max_output_length: Maximum output length.
      :param use_cuda: Using CUDA or not.
      :return:
      """
      # ipdb.set_trace()
      sentence = message_text.strip()
      # remove emojis
      emoji_pattern = re.compile("\:[a-zA-Z]+\:")
      sentence = re.sub(emoji_pattern, "", sentence)
      sentence = sentence.strip()
      if lowercase:
          sentence = sentence.lower()
      for p in preprocess:
          sentence = p(sentence)

      # load the data which consists only of this sentence
      test_data, src_vocab, trg_vocab = load_line_as_data(lowercase=lowercase,
          line=sentence, src_vocab=src_vocab, trg_vocab=trg_vocab, level=level)

      # generate outputs
      score, loss, ppl, sources, sources_raw, references, hypotheses, \
      hypotheses_raw, attention_scores = validate_on_data(
          model, data=test_data, batch_size=1, level=level,
          max_output_length=max_output_length, eval_metric=None,
          use_cuda=use_cuda, beam_size=beam_size,
          beam_alpha=beam_alpha, n_gpu=0)

      #  validate_on_data(model: Model, data: Dataset,
      #                batch_size: int,
      #                use_cuda: bool, max_output_length: int,
      #                level: str, eval_metric: Optional[str],
      #                n_gpu: int,
      #                batch_class: Batch = Batch,
      #                compute_loss: bool = False,
      #                beam_size: int = 1, beam_alpha: int = -1,
      #                batch_type: str = "sentence",
      #                postprocess: bool = True,
      #                bpe_type: str = "subword-nmt",
      #                sacrebleu: dict = None) \

      # post-process
      if level == "char":
          response = "".join(hypotheses)
      else:
          response = " ".join(hypotheses)

      for p in postprocess:
          response = p(response)

      return response
Beispiel #6
0
def train(cfg_file):
    """
    Main training function. After training, also test on test data if given.

    :param cfg_file:
    :return:
    """
    cfg = load_config(cfg_file)
    # set the random seed
    # torch.backends.cudnn.deterministic = True
    seed = cfg["training"].get("random_seed", 42)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # load the data
    train_data, dev_data, test_data, src_vocab, trg_vocab = \
        load_data(cfg=cfg)

    # build an encoder-decoder model
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # store copy of original training config in model dir
    shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml")

    # print config
    log_cfg(cfg, trainer.logger)

    log_data_info(train_data=train_data,
                  valid_data=dev_data,
                  test_data=test_data,
                  src_vocab=src_vocab,
                  trg_vocab=trg_vocab,
                  logging_function=trainer.logger.info)
    model.log_parameters_list(logging_function=trainer.logger.info)

    logging.info(model)

    # store the vocabs
    src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"])
    src_vocab.to_file(src_vocab_file)
    trg_vocab_file = "{}/trg_vocab.txt".format(cfg["training"]["model_dir"])
    trg_vocab.to_file(trg_vocab_file)

    # train the model
    trainer.train_and_validate(train_data=train_data, valid_data=dev_data)

    if test_data is not None:
        trainer.load_checkpoint("{}/{}.ckpt".format(
            trainer.model_dir, trainer.best_ckpt_iteration))
        # test model
        if "testing" in cfg.keys():
            beam_size = cfg["testing"].get("beam_size", 0)
            beam_alpha = cfg["testing"].get("alpha", -1)
        else:
            beam_size = 0
            beam_alpha = -1

        # pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
            hypotheses_raw, attention_scores = validate_on_data(
                data=test_data, batch_size=trainer.batch_size,
                eval_metric=trainer.eval_metric, level=trainer.level,
                max_output_length=trainer.max_output_length,
                model=model, use_cuda=trainer.use_cuda, criterion=None,
                beam_size=beam_size, beam_alpha=beam_alpha)

        if "trg" in test_data.fields:
            decoding_description = "Greedy decoding" if beam_size == 0 else \
                "Beam search decoding with beam size = {} and alpha = {}"\
                    .format(beam_size, beam_alpha)
            trainer.logger.info("Test data result: %f %s [%s]", score,
                                trainer.eval_metric, decoding_description)
        else:
            trainer.logger.info(
                "No references given for %s.%s -> no evaluation.",
                cfg["data"]["test"], cfg["data"]["src"])

        output_path_set = "{}/{}.{}".format(trainer.model_dir, "test",
                                            cfg["data"]["trg"])
        with open(output_path_set, mode="w", encoding="utf-8") as f:
            for h in hypotheses:
                f.write(h + "\n")
        trainer.logger.info("Test translations saved to: %s", output_path_set)
Beispiel #7
0
    def train_and_validate(self, train_data, valid_data):
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data:
        :param valid_data:
        :return:
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)
            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            count = 0

            for batch in iter(train_iter):
                # reactivate training
                self.model.train()
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                count = self.batch_multiplier if update else count
                count -= 1

                # log learning progress
                if self.model.training and self.steps % self.logging_freq == 0 \
                        and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %d Step: %d Loss: %f Tokens per Sec: %f",
                        epoch_no + 1, self.steps, batch_loss,
                        elapsed_tokens / elapsed)
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            batch_size=self.batch_size, data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            criterion=self.criterion)

                    if self.ckpt_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.ckpt_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.ckpt_metric)
                        new_best = True
                        self.save_checkpoint()

                    # pass validation score or loss or ppl to scheduler
                    if self.schedule_metric == "loss":
                        # schedule based on loss
                        schedule_score = valid_loss
                    elif self.schedule_metric in ["ppl", "perplexity"]:
                        # schedule based on perplexity
                        schedule_score = valid_ppl
                    else:
                        # schedule based on evaluation score
                        schedule_score = valid_score
                    if self.scheduler is not None:
                        self.scheduler.step(schedule_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    # always print first x sentences
                    for p in range(self.print_valid_sents):
                        self.logger.debug("Example #%d", p)
                        self.logger.debug("\tRaw source: %s",
                                          valid_sources_raw[p])
                        self.logger.debug("\tSource: %s", valid_sources[p])
                        self.logger.debug("\tReference: %s",
                                          valid_references[p])
                        self.logger.debug("\tRaw hypothesis: %s",
                                          valid_hypotheses_raw[p])
                        self.logger.debug("\tHypothesis: %s",
                                          valid_hypotheses[p])
                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result at epoch %d, step %d: %s: %f, '
                        'loss: %f, ppl: %f, duration: %.4fs', epoch_no + 1,
                        self.steps, self.eval_metric, valid_score, valid_loss,
                        valid_ppl, valid_duration)

                    # store validation set outputs
                    self.store_outputs(valid_hypotheses)

                    # store attention plots for first three sentences of
                    # valid data and one randomly chosen example
                    store_attention_plots(attentions=valid_attention_scores,
                                          targets=valid_hypotheses_raw,
                                          sources=[s for s in valid_data.src],
                                          idx=[
                                              0, 1, 2,
                                              np.random.randint(
                                                  0, len(valid_hypotheses))
                                          ],
                                          output_prefix="{}/att.{}".format(
                                              self.model_dir, self.steps))

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break
        else:
            self.logger.info('Training ended after %d epochs.', epoch_no + 1)
        self.logger.info('Best validation result at step %d: %f %s.',
                         self.best_ckpt_iteration, self.best_ckpt_score,
                         self.ckpt_metric)
Beispiel #8
0
    def _validate(self, valid_data, epoch_no):
        valid_start_time = time.time()

        valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        valid_hypotheses_raw, valid_attention_scores = \
            validate_on_data(
                batch_size=self.eval_batch_size,
                batch_class=self.batch_class,
                data=valid_data,
                eval_metric=self.eval_metric,
                level=self.level, model=self.model,
                use_cuda=self.use_cuda,
                max_output_length=self.max_output_length,
                compute_loss=True,
                beam_size=1,                # greedy validations
                batch_type=self.eval_batch_type,
                postprocess=True,           # always remove BPE for validation
                bpe_type=self.bpe_type,     # "subword-nmt" or "sentencepiece"
                sacrebleu=self.sacrebleu,   # sacrebleu options
                n_gpu=self.n_gpu
            )

        self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                  self.stats.steps)
        self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                  self.stats.steps)
        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                  self.stats.steps)

        if self.early_stopping_metric == "loss":
            ckpt_score = valid_loss
        elif self.early_stopping_metric in ["ppl", "perplexity"]:
            ckpt_score = valid_ppl
        else:
            ckpt_score = valid_score

        if self.scheduler is not None \
                and self.scheduler_step_at == "validation":
            self.scheduler.step(ckpt_score)

        new_best = False
        if self.stats.is_best(ckpt_score):
            self.stats.best_ckpt_score = ckpt_score
            self.stats.best_ckpt_iter = self.stats.steps
            logger.info('Hooray! New best validation result [%s]!',
                        self.early_stopping_metric)
            if self.ckpt_queue.maxlen > 0:
                logger.info("Saving new checkpoint.")
                new_best = True
                self._save_checkpoint(new_best)
        elif self.save_latest_checkpoint:
            self._save_checkpoint(new_best)

        # append to validation report
        self._add_report(valid_score=valid_score,
                         valid_loss=valid_loss,
                         valid_ppl=valid_ppl,
                         eval_metric=self.eval_metric,
                         new_best=new_best)

        self._log_examples(sources_raw=[v for v in valid_sources_raw],
                           sources=valid_sources,
                           hypotheses_raw=valid_hypotheses_raw,
                           hypotheses=valid_hypotheses,
                           references=valid_references)

        valid_duration = time.time() - valid_start_time
        logger.info(
            'Validation result (greedy) at epoch %3d, '
            'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
            'duration: %.4fs', epoch_no + 1, self.stats.steps,
            self.eval_metric, valid_score, valid_loss, valid_ppl,
            valid_duration)

        # store validation set outputs
        self._store_outputs(valid_hypotheses)

        # store attention plots for selected valid sentences
        if valid_attention_scores:
            store_attention_plots(attentions=valid_attention_scores,
                                  targets=valid_hypotheses_raw,
                                  sources=[s for s in valid_data.src],
                                  indices=self.log_valid_sents,
                                  output_prefix="{}/att.{}".format(
                                      self.model_dir, self.stats.steps),
                                  tb_writer=self.tb_writer,
                                  steps=self.stats.steps)

        return valid_duration
Beispiel #9
0
def Q_learning(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.
    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)  # config is a dict
    # make logger
    model_dir = make_model_dir(cfg["training"]["model_dir"],
                               overwrite=cfg["training"].get(
                                   "overwrite", False))
    _ = make_logger(model_dir, mode="train")  # version string returned
    # TODO: save version number in model checkpoints

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # load the data
    print("loadding data here")
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])
    # The training data is filtered to include sentences up to `max_sent_length`
    #     on source and target side.

    # training config:
    train_config = cfg["training"]
    shuffle = train_config.get("shuffle", True)
    batch_size = train_config["batch_size"]
    mini_BATCH_SIZE = train_config["mini_batch_size"]
    batch_type = train_config.get("batch_type", "sentence")
    outer_epochs = train_config.get("outer_epochs", 10)
    inner_epochs = train_config.get("inner_epochs", 10)
    TARGET_UPDATE = train_config.get("target_update", 10)
    Gamma = train_config.get("Gamma", 0.999)
    use_cuda = train_config["use_cuda"] and torch.cuda.is_available()

    # validation part config
    # validation
    validation_freq = train_config.get("validation_freq", 1000)
    ckpt_queue = queue.Queue(maxsize=train_config.get("keep_last_ckpts", 5))
    eval_batch_size = train_config.get("eval_batch_size", batch_size)
    level = cfg["data"]["level"]

    eval_metric = train_config.get("eval_metric", "bleu")
    n_gpu = torch.cuda.device_count() if use_cuda else 0
    eval_batch_type = train_config.get("eval_batch_type", batch_type)
    # eval options
    test_config = cfg["testing"]
    bpe_type = test_config.get("bpe_type", "subword-nmt")
    sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    max_output_length = train_config.get("max_output_length", None)
    minimize_metric = True
    # initialize training statistics
    stats = TrainStatistics(
        steps=0,
        stop=False,
        total_tokens=0,
        best_ckpt_iter=0,
        best_ckpt_score=np.inf if minimize_metric else -np.inf,
        minimize_metric=minimize_metric)

    early_stopping_metric = train_config.get("early_stopping_metric",
                                             "eval_metric")

    if early_stopping_metric in ["ppl", "loss"]:
        stats.minimize_metric = True
        stats.best_ckpt_score = np.inf
    elif early_stopping_metric == "eval_metric":
        if eval_metric in [
                "bleu", "chrf", "token_accuracy", "sequence_accuracy"
        ]:
            stats.minimize_metric = False
            stats.best_ckpt_score = -np.inf

        # eval metric that has to get minimized (not yet implemented)
        else:
            stats.minimize_metric = True

    # data loader(modified from train_and_validate function
    # Returns a torchtext iterator for a torchtext dataset.
    # param dataset: torchtext dataset containing src and optionally trg
    train_iter = make_data_iter(train_data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                train=True,
                                shuffle=shuffle)

    # initialize the Replay Memory D with capacity N
    memory = ReplayMemory(10000)
    steps_done = 0

    # initialize two DQN networks
    policy_net = build_model(cfg["model"],
                             src_vocab=src_vocab,
                             trg_vocab=trg_vocab)  # Q_network
    target_net = build_model(cfg["model"],
                             src_vocab=src_vocab,
                             trg_vocab=trg_vocab)  # Q_hat_network
    #logger.info(policy_net.src_vocab.stoi)
    #print("###############trg vocab: ", len(target_net.trg_vocab.stoi))
    #print("trg embed: ", target_net.trg_embed.vocab_size)
    if use_cuda:
        policy_net.cuda()
        target_net.cuda()

    target_net.load_state_dict(policy_net.state_dict())
    # Initialize target net Q_hat with weights equal to policy_net

    target_net.eval()  # target_net not update the parameters, test mode

    # Optimizer
    optimizer = build_optimizer(config=cfg["training"],
                                parameters=policy_net.parameters())
    # Loss function
    mse_loss = torch.nn.MSELoss()

    pad_index = policy_net.pad_index
    # print('!!!'*10, pad_index)

    cross_entropy_loss = XentLoss(pad_index=pad_index)
    policy_net.loss_function = cross_entropy_loss

    # learning rate scheduling
    scheduler, scheduler_step_at = build_scheduler(
        config=train_config,
        scheduler_mode="min" if minimize_metric else "max",
        optimizer=optimizer,
        hidden_size=cfg["model"]["encoder"]["hidden_size"])

    # model parameters
    if "load_model" in train_config.keys():
        load_model_path = train_config["load_model"]
        reset_best_ckpt = train_config.get("reset_best_ckpt", False)
        reset_scheduler = train_config.get("reset_scheduler", False)
        reset_optimizer = train_config.get("reset_optimizer", False)
        reset_iter_state = train_config.get("reset_iter_state", False)

        print('settings', reset_best_ckpt, reset_iter_state, reset_optimizer,
              reset_scheduler)

        logger.info("Loading model from %s", load_model_path)
        model_checkpoint = load_checkpoint(path=load_model_path,
                                           use_cuda=use_cuda)

        # restore model and optimizer parameters
        policy_net.load_state_dict(model_checkpoint["model_state"])

        if not reset_optimizer:
            optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            logger.info("Reset optimizer.")
        if not reset_scheduler:
            if model_checkpoint["scheduler_state"] is not None and \
                    scheduler is not None:
                scheduler.load_state_dict(model_checkpoint["scheduler_state"])
        else:
            logger.info("Reset scheduler.")

        if not reset_best_ckpt:
            stats.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            stats.best_ckpt_iter = model_checkpoint["best_ckpt_iteration"]
            print('stats.best_ckpt_score', stats.best_ckpt_score)
            print('stats.best_ckpt_iter', stats.best_ckpt_iter)
        else:
            logger.info("Reset tracking of the best checkpoint.")

        if (not reset_iter_state and model_checkpoint.get(
                'train_iter_state', None) is not None):
            train_iter_state = model_checkpoint["train_iter_state"]

        # move parameters to cuda

        target_net.load_state_dict(policy_net.state_dict())
        # Initialize target net Q_hat with weights equal to policy_net

        target_net.eval()

        if use_cuda:
            policy_net.cuda()
            target_net.cuda()

    for i_episode in range(outer_epochs):
        # Outer loop

        # get batch
        for i, batch in enumerate(iter(train_iter)):  # joeynmt training.py 377

            # create a Batch object from torchtext batch
            # ( use class Batch from batch.py)
            # return the sentences same length (with padding) in one batch
            batch = Batch(batch, policy_net.pad_index, use_cuda=use_cuda)
            # we want to get batch.src and batch.trg
            # the shape of batch.src: (batch_size * length of the sentence)

            # source here is represented by the word index not word embedding.

            encoder_output_batch, _, _, _ = policy_net(
                return_type="encode",
                src=batch.src,
                src_length=batch.src_length,
                src_mask=batch.src_mask,
            )

            trans_output_batch, _ = transformer_greedy(
                src_mask=batch.src_mask,
                max_output_length=max_output_length,
                model=policy_net,
                encoder_output=encoder_output_batch,
                steps_done=steps_done,
                use_cuda=use_cuda)
            #print('steps_done',steps_done)

            steps_done += 1

            #print('trans_output_batch.shape is:', trans_output_batch.shape)
            # batch_size * max_translation_sentence_length
            #print('batch.src', batch.src)
            #print('batch.trg', batch.trg)
            print('batch.trg.shape is:', batch.trg.shape)
            print('trans_output_batch', trans_output_batch)

            reward_batch = [
            ]  # Get the reward_batch (Get the bleu score of the sentences in a batch)

            for i in range(int(batch.src.shape[0])):
                all_outputs = [(trans_output_batch[i])[1:]]
                all_ref = [batch.trg[i]]
                sentence_score = calculate_bleu(model=policy_net,
                                                level=level,
                                                raw_hypo=all_outputs,
                                                raw_ref=all_ref)
                reward_batch.append(sentence_score)

            print('reward batch is', reward_batch)
            reward_batch = torch.tensor(reward_batch, dtype=torch.float)

            # reward_batch = bleu(hypotheses, references, tokenize="13a")
            # print('reward_batch.shape', reward_batch.shape)

            # make prefix and push tuples into memory
            push_sample_to_memory(model=policy_net,
                                  level=level,
                                  eos_index=policy_net.eos_index,
                                  memory=memory,
                                  src_batch=batch.src,
                                  trg_batch=batch.trg,
                                  trans_output_batch=trans_output_batch,
                                  reward_batch=reward_batch,
                                  max_output_length=max_output_length)
            print(memory.capacity, len(memory.memory))

            if len(memory.memory) == memory.capacity:
                # inner loop
                for t in range(inner_epochs):
                    # Sample mini-batch from the memory
                    transitions = memory.sample(mini_BATCH_SIZE)
                    # transition = [Transition(source=array([]), prefix=array([]), next_word= int, reward= int),
                    #               Transition(source=array([]), prefix=array([]), next_word= int, reward= int,...]
                    # Each Transition is what we push into memory for one sentence: memory.push(source, prefix, next_word, reward_batch[i])
                    mini_batch = Transition(*zip(*transitions))
                    # merge the same class in transition together
                    # mini_batch = Transition(source=(array([]), array([]),...), prefix=(array([],...),
                    #               next_word=array([...]), reward=array([...]))
                    # mini_batch.reward is tuple: length is mini_BATCH_SIZE.
                    #print('mini_batch', mini_batch)

                    #concatenate together into a tensor.
                    words = []
                    for word in mini_batch.next_word:
                        new_word = word.unsqueeze(0)
                        words.append(new_word)
                    mini_next_word = torch.cat(
                        words)  # shape (mini_BATCH_SIZE,)
                    mini_reward = torch.tensor(
                        mini_batch.reward)  # shape (mini_BATCH_SIZE,)

                    #print('mini_batch.finish', mini_batch.finish)

                    mini_is_eos = torch.Tensor(mini_batch.finish)
                    #print(mini_is_eos)

                    mini_src_length = [
                        len(item) for item in mini_batch.source_sentence
                    ]
                    mini_src_length = torch.Tensor(mini_src_length)

                    mini_src = pad_sequence(mini_batch.source_sentence,
                                            batch_first=True,
                                            padding_value=float(pad_index))
                    # shape (mini_BATCH_SIZE, max_length_src)

                    length_prefix = [len(item) for item in mini_batch.prefix]
                    mini_prefix_length = torch.Tensor(length_prefix)

                    prefix_list = []
                    for prefix_ in mini_batch.prefix:
                        prefix_ = torch.from_numpy(prefix_)
                        prefix_list.append(prefix_)

                    mini_prefix = pad_sequence(prefix_list,
                                               batch_first=True,
                                               padding_value=pad_index)
                    # shape (mini_BATCH_SIZE, max_length_prefix)

                    mini_src_mask = (mini_src != pad_index).unsqueeze(1)
                    mini_trg_mask = (mini_prefix != pad_index).unsqueeze(1)

                    #print('mini_src',  mini_src)
                    #print('mini_src_length', mini_src_length)
                    #print('mini_src_mask', mini_src_mask)
                    #print('mini_prefix', mini_prefix)
                    #print('mini_trg_mask', mini_trg_mask)

                    #print('mini_reward', mini_reward)

                    # max_length_src = torch.max(mini_src_length) #max([len(item) for item in mini_batch.source_sentence])

                    if use_cuda:
                        mini_src = mini_src.cuda()
                        mini_prefix = mini_prefix.cuda()
                        mini_src_mask = mini_src_mask.cuda()
                        mini_src_length = mini_src_length.cuda()
                        mini_trg_mask = mini_trg_mask.cuda()
                        mini_next_word = mini_next_word.cuda()

                    # print(next(policy_net.parameters()).is_cuda)
                    # print(mini_trg_mask.get_device())
                    # calculate the Q_value
                    logits_Q, _, _, _ = policy_net._encode_decode(
                        src=mini_src,
                        trg_input=mini_prefix,
                        src_mask=mini_src_mask,
                        src_length=mini_src_length,
                        trg_mask=
                        mini_trg_mask  # trg_mask = (self.trg_input != pad_index).unsqueeze(1)
                    )
                    #print('mini_prefix_length', mini_prefix_length)

                    #print('logits_Q.shape', logits_Q.shape) # torch.Size([64, 99, 31716])
                    #print('logits_Q', logits_Q)

                    # length_prefix = max([len(item) for item in mini_batch.prefix])
                    # logits_Q shape: batch_size * length of the sentence * total number of words in corpus.
                    logits_Q = logits_Q[range(mini_BATCH_SIZE),
                                        mini_prefix_length.long() - 1, :]
                    #print('logits_Q_.shape', logits_Q.shape) #shape(mini_batch_size, num_words)
                    # logits shape: mini_batch_size * total number of words in corpus
                    Q_value = logits_Q[range(mini_BATCH_SIZE), mini_next_word]
                    #print('mini_next_word', mini_next_word)
                    #print("Q_value", Q_value)

                    mini_prefix_add = torch.cat(
                        [mini_prefix, mini_next_word.unsqueeze(1)], dim=1)
                    #print('mini_prefix_add', mini_prefix_add)
                    mini_trg_mask_add = (mini_prefix_add !=
                                         pad_index).unsqueeze(1)
                    #print('mini_trg_mask_add', mini_trg_mask_add)

                    if use_cuda:
                        mini_prefix_add = mini_prefix_add.cuda()
                        mini_trg_mask_add = mini_trg_mask_add.cuda()

                    logits_Q_hat, _, _, _ = target_net._encode_decode(
                        src=mini_src,
                        trg_input=mini_prefix_add,
                        src_mask=mini_src_mask,
                        src_length=mini_src_length,
                        trg_mask=mini_trg_mask_add)
                    #print('mini_prefix_add.shape', mini_prefix_add.shape)
                    #print('logits_Q_hat.shape', logits_Q_hat.shape)
                    #print('mini_prefix_length.long()', mini_prefix_length.long())
                    logits_Q_hat = logits_Q_hat[range(mini_BATCH_SIZE),
                                                mini_prefix_length.long(), :]
                    Q_hat_value, _ = torch.max(logits_Q_hat, dim=1)
                    #print('Q_hat_value', Q_hat_value)

                    if use_cuda:

                        Q_hat_value = Q_hat_value.cuda()
                        mini_reward = mini_reward.cuda()
                        mini_is_eos = mini_is_eos.cuda()

                    yj = mini_reward.float() + Gamma * Q_hat_value
                    #print('yj', yj)
                    index = mini_is_eos.long()
                    #print('mini_is_eos', mini_is_eos)
                    yj[index] = mini_reward[index]
                    #print('yj', yj)
                    #print('Q_value1', Q_value)

                    yj.detach()
                    # Optimize the model
                    policy_net.zero_grad()

                    # Compute loss
                    loss = mse_loss(yj, Q_value)
                    print('loss', loss)
                    logger.info("step = {}, loss = {}".format(
                        stats.steps, loss.item()))
                    loss.backward()
                    #for param in policy_net.parameters():
                    #   param.grad.data.clamp_(-1, 1)
                    optimizer.step()

                    stats.steps += 1
                    #print('step', stats.steps)

                    if stats.steps % TARGET_UPDATE == 0:
                        #print('update the parameters in target_net.')
                        target_net.load_state_dict(policy_net.state_dict())

                    if stats.steps % validation_freq == 0:  # Validation
                        print('Start validation')

                        valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                            validate_on_data(
                                model=policy_net,
                                data=dev_data,
                                batch_size=eval_batch_size,
                                use_cuda=use_cuda,
                                level=level,
                                eval_metric=eval_metric,
                                n_gpu=n_gpu,
                                compute_loss=True,
                                beam_size=1,
                                beam_alpha=-1,
                                batch_type=eval_batch_type,
                                postprocess=True,
                                bpe_type=bpe_type,
                                sacrebleu=sacrebleu,
                                max_output_length=max_output_length
                            )
                        print(
                            'validation_loss: {}, validation_score: {}'.format(
                                valid_loss, valid_score))
                        logger.info(valid_loss)
                        print('average loss: total_loss/n_tokens:', valid_ppl)

                        if early_stopping_metric == "loss":
                            ckpt_score = valid_loss
                        elif early_stopping_metric in ["ppl", "perplexity"]:
                            ckpt_score = valid_ppl
                        else:
                            ckpt_score = valid_score
                        if stats.is_best(ckpt_score):
                            stats.best_ckpt_score = ckpt_score
                            stats.best_ckpt_iter = stats.steps
                            logger.info(
                                'Hooray! New best validation result [%s]!',
                                early_stopping_metric)
                            if ckpt_queue.maxsize > 0:
                                logger.info("Saving new checkpoint.")

                                # def _save_checkpoint(self) -> None:
                                """
                                Save the model's current parameters and the training state to a
                                checkpoint.
                                The training state contains the total number of training steps,
                                the total number of training tokens,
                                the best checkpoint score and iteration so far,
                                and optimizer and scheduler states.
                                """
                                model_path = "{}/{}.ckpt".format(
                                    model_dir, stats.steps)
                                model_state_dict = policy_net.module.state_dict() \
                                    if isinstance(policy_net, torch.nn.DataParallel) \
                                    else policy_net.state_dict()
                                state = {
                                    "steps": stats.steps,
                                    "total_tokens": stats.total_tokens,
                                    "best_ckpt_score": stats.best_ckpt_score,
                                    "best_ckpt_iteration":
                                    stats.best_ckpt_iter,
                                    "model_state": model_state_dict,
                                    "optimizer_state": optimizer.state_dict(),
                                    # "scheduler_state": scheduler.state_dict() if
                                    # self.scheduler is not None else None,
                                    # 'amp_state': amp.state_dict() if self.fp16 else None
                                }
                                torch.save(state, model_path)
                                if ckpt_queue.full():
                                    to_delete = ckpt_queue.get(
                                    )  # delete oldest ckpt
                                    try:
                                        os.remove(to_delete)
                                    except FileNotFoundError:
                                        logger.warning(
                                            "Wanted to delete old checkpoint %s but "
                                            "file does not exist.", to_delete)

                                ckpt_queue.put(model_path)

                                best_path = "{}/best.ckpt".format(model_dir)
                                try:
                                    # create/modify symbolic link for best checkpoint
                                    symlink_update(
                                        "{}.ckpt".format(stats.steps),
                                        best_path)
                                except OSError:
                                    # overwrite best.ckpt
                                    torch.save(state, best_path)
Beispiel #10
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset):
        """
        Train the model and validate it on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(
            train_data,
            batch_size=self.batch_size,
            batch_type=self.batch_type,
            train=True,
            shuffle=self.shuffle)
        for epoch_no in range(1, self.epochs + 1):
            self.logger.info("EPOCH %d", epoch_no)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no - 1)  # 0-based indexing

            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter), 1):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = i % self.batch_multiplier == 0
                batch_loss = self._train_batch(batch, update=update)

                self.log_tensorboard("train", batch_loss=batch_loss)

                epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f",
                        epoch_no, self.steps, batch_loss,
                        elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    processed_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    # it would be nice to include loss and ppl in valid_scores
                    valid_scores, valid_sources, valid_sources_raw, \
                        valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores, \
                        scores_by_lang, by_lang = validate_on_data(
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metrics=self.eval_metrics,
                            attn_metrics=self.attn_metrics,
                            src_level=self.src_level,
                            trg_level=self.trg_level,
                            model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=0,  # greedy validations
                            batch_type=self.eval_batch_type,
                            save_attention=self.plot_attention,
                            log_sparsity=self.log_sparsity,
                            apply_mask=self.valid_apply_mask
                        )

                    ckpt_score = valid_scores[self.early_stopping_metric]
                    self.log_tensorboard("valid", **valid_scores)

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(
                        valid_scores=valid_scores,
                        eval_metrics=self.eval_metrics,
                        new_best=new_best)

                    self._log_examples(
                        sources_raw=valid_sources_raw,
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references
                    )

                    labeled_scores = sorted(valid_scores.items())
                    eval_report = ", ".join("{}: {:.5f}".format(n, v)
                                            for n, v in labeled_scores)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration

                    self.logger.info(
                        'Validation result at epoch %3d, step %8d: %s, '
                        'duration: %.4fs',
                        epoch_no, self.steps, eval_report, valid_duration)

                    if scores_by_lang is not None:
                        for metric, scores in scores_by_lang.items():
                            # make a report
                            lang_report = [metric]
                            numbers = sorted(scores.items())
                            lang_report.extend(["{}: {:.5f}".format(k, v)
                                                for k, v in numbers])

                            self.logger.info("\n\t".join(lang_report))

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores and self.plot_attention:
                        store_attention_plots(
                                attentions=valid_attention_scores,
                                sources=[s for s in valid_data.src],
                                targets=valid_hypotheses_raw,
                                indices=self.log_valid_sents,
                                model_dir=self.model_dir,
                                tb_writer=self.tb_writer,
                                steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info(
                'Epoch %3d: total training loss %.2f', epoch_no, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no)
        self.logger.info('Best validation result at step %8d: %6.2f %s.',
                         self.best_ckpt_iteration, self.best_ckpt_score,
                         self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Beispiel #11
0
def train(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.

    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # load the data
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])

    # build an encoder-decoder model
    model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

    # for training management, e.g. early stopping and model selection
    trainer = TrainManager(model=model, config=cfg)

    # store copy of original training config in model dir
    shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml")

    # log all entries of config
    log_cfg(cfg, trainer.logger)

    log_data_info(train_data=train_data,
                  valid_data=dev_data,
                  test_data=test_data,
                  src_vocab=src_vocab,
                  trg_vocab=trg_vocab,
                  logging_function=trainer.logger.info)

    # store the vocabs
    src_vocab_file = "{}/src_vocab.txt".format(cfg["training"]["model_dir"])
    src_vocab.to_file(src_vocab_file)
    trg_vocab_file = "{}/trg_vocab.txt".format(cfg["training"]["model_dir"])
    trg_vocab.to_file(trg_vocab_file)

    # train the model
    trainer.train_and_validate(train_data=train_data, valid_data=dev_data)

    # test the model with the best checkpoint
    if test_data is not None:

        # load checkpoint
        if trainer.best_ckpt_iteration > 0:
            checkpoint_path = "{}/{}.ckpt".format(trainer.model_dir,
                                                  trainer.best_ckpt_iteration)
        else:
            ## For save_checkpoint by save_freq
            checkpoint_path = get_latest_checkpoint(trainer.model_dir)
        try:
            trainer.init_from_checkpoint(checkpoint_path)
        except AssertionError:
            trainer.logger.warning(
                "Checkpoint %s does not exist. "
                "Skipping testing.", checkpoint_path)
            if trainer.best_ckpt_iteration == 0 \
                and trainer.best_ckpt_score in [np.inf, -np.inf]:
                trainer.logger.warning(
                    "It seems like no checkpoint was written, "
                    "since no improvement was obtained over the initial model."
                )
            return

        # generate hypotheses for test data
        if "testing" in cfg.keys():
            beam_size = cfg["testing"].get("beam_size", 0)
            beam_alpha = cfg["testing"].get("alpha", -1)
            return_logp = cfg["testing"].get("return_logp", False)
        else:
            beam_size = 0
            beam_alpha = -1
            return_logp = False

        # pylint: disable=unused-variable
        score, loss, ppl, sources, sources_raw, references, hypotheses, \
            hypotheses_raw, attention_scores, log_probs = validate_on_data(
                data=test_data, batch_size=trainer.batch_size,
                eval_metric=trainer.eval_metric, level=trainer.level,
                max_output_length=trainer.max_output_length,
                model=model, use_cuda=trainer.use_cuda, loss_function=None,
                beam_size=beam_size, beam_alpha=beam_alpha,
                return_logp=return_logp)

        if "trg" in test_data.fields:
            decoding_description = "Greedy decoding" if beam_size == 0 else \
                "Beam search decoding with beam size = {} and alpha = {}"\
                    .format(beam_size, beam_alpha)
            trainer.logger.info("Test data result: %f %s [%s]", score,
                                trainer.eval_metric, decoding_description)
        else:
            trainer.logger.info(
                "No references given for %s.%s -> no evaluation.",
                cfg["data"]["test"], cfg["data"]["src"])

        output_path_set = "{}/{}.{}".format(trainer.model_dir, "test",
                                            cfg["data"]["trg"])
        with open(output_path_set, mode="w", encoding="utf-8") as f:
            for h in hypotheses:
                f.write("{}\n".format(h))
        trainer.logger.info("Test translations saved to: %s", output_path_set)

        if return_logp:
            output_path_set_logp = output_path_set + ".logp"
            with open(output_path_set_logp, mode="w", encoding="utf-8") as f:
                for l in log_probs:
                    f.write("{}\n".format(l))
            trainer.logger.info("Test log probs saved to: %s",
                                output_path_set_logp)
Beispiel #12
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            count = 0
            epoch_loss = 0

            for batch in iter(train_iter):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                self.tb_writer.add_scalar("train/train_batch_loss", batch_loss,
                                          self.steps)
                count = self.batch_multiplier if update else count
                count -= 1
                epoch_loss += batch_loss.detach().cpu().numpy()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %d Step: %d Batch Loss: %f Tokens per Sec: %f",
                        epoch_no + 1, self.steps, batch_loss,
                        elapsed_tokens / elapsed)
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if valid_data is not None and \
                    self.steps % self.validation_freq == 0 and update:

                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores, \
                        valid_logps = validate_on_data(
                            batch_size=self.batch_size, data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            return_logp=self.return_logp)

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(sources_raw=valid_sources_raw,
                                       sources=valid_sources,
                                       hypotheses_raw=valid_hypotheses_raw,
                                       hypotheses=valid_hypotheses,
                                       references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result at epoch %d, step %d: %s: %f, '
                        'loss: %f, ppl: %f, duration: %.4fs', epoch_no + 1,
                        self.steps, self.eval_metric, valid_score, valid_loss,
                        valid_ppl, valid_duration)

                    # store validation set outputs
                    self._store_outputs(
                        valid_hypotheses if self.post_process else
                        [" ".join(v) for v in valid_hypotheses_raw],
                        valid_logps if self.return_logp else None)

                    # store attention plots for selected valid sentences
                    store_attention_plots(attentions=valid_attention_scores,
                                          targets=valid_hypotheses_raw,
                                          sources=[s for s in valid_data.src],
                                          indices=self.log_valid_sents,
                                          output_prefix="{}/att.{}".format(
                                              self.model_dir, self.steps),
                                          tb_writer=self.tb_writer,
                                          steps=self.steps)

                if self.save_freq > 0 and self.steps % self.save_freq == 0:
                    ## Drop checkpoint by number of batches
                    ## Take care of batch multipler in to description
                    self.logger.info("Saving new checkpoint!"
                                     "Batches passed:{}"
                                     "Number of updates:{}".format(
                                         self.batch_multiplier * self.steps,
                                         self.steps))
                    self._save_checkpoint()

                if self.stop:
                    break

            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %d epochs.', epoch_no + 1)

        if valid_data is not None:
            self.logger.info('Best validation result at step %d: %f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)
Beispiel #13
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset, kb_task=None, train_kb: TranslationDataset =None,\
        train_kb_lkp: list = [], train_kb_lens: list = [], train_kb_truvals: TranslationDataset=None, valid_kb: Tuple=None, \
        valid_kb_lkp: list=[], valid_kb_lens: list = [], valid_kb_truvals:list=[],
        valid_data_canon: list=[]) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        :param kb_task: is not None if kb_task should be executed
        :param train_kb: TranslationDataset holding the loaded train kb data
        :param train_kb_lkp: List with train example index to corresponding kb indices
        :param train_kb_len: List with num of triples per kb 
        :param valid_kb: TranslationDataset holding the loaded valid kb data
        :param valid_kb_lkp: List with valid example index to corresponding kb indices
        :param valid_kb_len: List with num of triples per kb 
        :param valid_kb_truvals: FIXME TODO
        :param valid_data_canon: required to report loss 
        """

        if kb_task:
            train_iter = make_data_iter_kb(train_data,
                                           train_kb,
                                           train_kb_lkp,
                                           train_kb_lens,
                                           train_kb_truvals,
                                           batch_size=self.batch_size,
                                           batch_type=self.batch_type,
                                           train=True,
                                           shuffle=self.shuffle,
                                           canonize=self.model.canonize)
        else:
            train_iter = make_data_iter(train_data,
                                        batch_size=self.batch_size,
                                        batch_type=self.batch_type,
                                        train=True,
                                        shuffle=self.shuffle)

        with torch.autograd.set_detect_anomaly(True):
            for epoch_no in range(self.epochs):
                self.logger.info("EPOCH %d", epoch_no + 1)

                if self.scheduler is not None and self.scheduler_step_at == "epoch":
                    self.scheduler.step(epoch=epoch_no)

                self.model.train()

                start = time.time()
                total_valid_duration = 0
                processed_tokens = self.total_tokens
                count = self.batch_multiplier - 1
                epoch_loss = 0

                for batch in iter(train_iter):
                    # reactivate training
                    self.model.train()

                    # create a Batch object from torchtext batch
                    batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) if not kb_task else \
                        Batch_with_KB(batch, self.pad_index, use_cuda=self.use_cuda)

                    if kb_task:
                        assert hasattr(batch, "kbsrc"), dir(batch)
                        assert hasattr(batch, "kbtrg"), dir(batch)
                        assert hasattr(batch, "kbtrv"), dir(batch)

                    # only update every batch_multiplier batches
                    # see https://medium.com/@davidlmorton/
                    # increasing-mini-batch-size-without-increasing-
                    # memory-6794e10db672
                    update = count == 0

                    batch_loss = self._train_batch(batch, update=update)

                    if update:
                        self.tb_writer.add_scalar("train/train_batch_loss",
                                                  batch_loss, self.steps)

                    count = self.batch_multiplier if update else count
                    count -= 1
                    epoch_loss += batch_loss.detach().cpu().numpy()

                    if self.scheduler is not None and \
                            self.scheduler_step_at == "step" and update:
                        self.scheduler.step()

                    # log learning progress
                    if self.steps % self.logging_freq == 0 and update:
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.total_tokens - processed_tokens
                        self.logger.info(
                            "Epoch %3d Step: %8d Batch Loss: %12.6f "
                            "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                            self.steps, batch_loss, elapsed_tokens / elapsed,
                            self.optimizer.param_groups[0]["lr"])
                        start = time.time()
                        total_valid_duration = 0

                    # validate on the entire dev set
                    if self.steps % self.validation_freq == 0 and update:

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("train")
                            self.decoder_timer.reset()

                        valid_start_time = time.time()


                        valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                            valid_hypotheses_raw, valid_attention_scores, valid_kb_att_scores, \
                            valid_ent_f1, valid_ent_mcc = \
                            validate_on_data(
                                batch_size=self.eval_batch_size,
                                data=valid_data,
                                eval_metric=self.eval_metric,
                                level=self.level,
                                model=self.model,
                                use_cuda=self.use_cuda,
                                max_output_length=self.max_output_length,
                                loss_function=self.loss,
                                beam_size=0,  # greedy validations #FIXME XXX NOTE TODO BUG set to 0 again!
                                batch_type=self.eval_batch_type,
                                kb_task=kb_task,
                                valid_kb=valid_kb,
                                valid_kb_lkp=valid_kb_lkp,
                                valid_kb_lens=valid_kb_lens,
                                valid_kb_truvals=valid_kb_truvals,
                                valid_data_canon=valid_data_canon,
                                report_on_canonicals=self.report_entf1_on_canonicals
                            )

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("valid")
                            self.decoder_timer.reset()

                        self.tb_writer.add_scalar("valid/valid_loss",
                                                  valid_loss, self.steps)
                        self.tb_writer.add_scalar("valid/valid_score",
                                                  valid_score, self.steps)
                        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                                  self.steps)

                        if self.early_stopping_metric == "loss":
                            ckpt_score = valid_loss
                        elif self.early_stopping_metric in [
                                "ppl", "perplexity"
                        ]:
                            ckpt_score = valid_ppl
                        else:
                            ckpt_score = valid_score

                        new_best = False
                        if self.is_best(ckpt_score):
                            self.best_ckpt_score = ckpt_score
                            self.best_ckpt_iteration = self.steps
                            self.logger.info(
                                'Hooray! New best validation result [%s]!',
                                self.early_stopping_metric)
                            if self.ckpt_queue.maxsize > 0:
                                self.logger.info("Saving new checkpoint.")
                                new_best = True
                                self._save_checkpoint()

                        if self.scheduler is not None \
                                and self.scheduler_step_at == "validation":
                            self.scheduler.step(ckpt_score)

                        # append to validation report
                        self._add_report(valid_score=valid_score,
                                         valid_loss=valid_loss,
                                         valid_ppl=valid_ppl,
                                         eval_metric=self.eval_metric,
                                         valid_ent_f1=valid_ent_f1,
                                         valid_ent_mcc=valid_ent_mcc,
                                         new_best=new_best)

                        # pylint: disable=unnecessary-comprehension
                        self._log_examples(
                            sources_raw=[v for v in valid_sources_raw],
                            sources=valid_sources,
                            hypotheses_raw=valid_hypotheses_raw,
                            hypotheses=valid_hypotheses,
                            references=valid_references)

                        valid_duration = time.time() - valid_start_time
                        total_valid_duration += valid_duration
                        self.logger.info(
                            'Validation result at epoch %3d, step %8d: %s: %6.2f, '
                            'loss: %8.4f, ppl: %8.4f, duration: %.4fs',
                            epoch_no + 1, self.steps, self.eval_metric,
                            valid_score, valid_loss, valid_ppl, valid_duration)

                        # store validation set outputs
                        self._store_outputs(valid_hypotheses)

                        valid_src = list(valid_data.src)
                        # store attention plots for selected valid sentences
                        if valid_attention_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_attention_scores,
                                targets=valid_hypotheses_raw,
                                sources=valid_src,
                                indices=self.log_valid_sents,
                                output_prefix="{}/att.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps)
                            self.logger.info(
                                f"stored {plot_success_ratio} valid att scores!"
                            )
                        if valid_kb_att_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_kb_att_scores,
                                targets=valid_hypotheses_raw,
                                sources=list(valid_kb.kbsrc),
                                indices=self.log_valid_sents,
                                output_prefix="{}/kbatt.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps,
                                kb_info=(valid_kb_lkp, valid_kb_lens,
                                         valid_kb_truvals),
                                on_the_fly_info=(valid_src, valid_kb,
                                                 self.model.canonize,
                                                 self.model.trg_vocab))
                            self.logger.info(
                                f"stored {plot_success_ratio} valid kb att scores!"
                            )
                        else:
                            self.logger.info(
                                "theres no valid kb att scores...")
                    if self.stop:
                        break
                if self.stop:
                    self.logger.info(
                        'Training ended since minimum lr %f was reached.',
                        self.learning_rate_min)
                    break

                self.logger.info('Epoch %3d: total training loss %.2f',
                                 epoch_no + 1, epoch_loss)
            else:
                self.logger.info('Training ended after %3d epochs.',
                                 epoch_no + 1)
            self.logger.info('Best validation result at step %8d: %6.2f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer