Example #1
0
    def test_no_label_smoothing(self):
        pad_index = 0
        smoothing = 0.0
        criterion = XentLoss(pad_index=pad_index, smoothing=smoothing)

        # batch x seq_len x vocab_size: 3 x 2 x 5
        predict = torch.FloatTensor([[[0.1, 0.1, 0.6, 0.1, 0.1],
                                      [0.1, 0.1, 0.6, 0.1, 0.1]],
                                     [[0.1, 0.1, 0.6, 0.1, 0.1],
                                      [0.1, 0.1, 0.6, 0.1, 0.1]],
                                     [[0.1, 0.1, 0.6, 0.1, 0.1],
                                      [0.1, 0.1, 0.6, 0.1, 0.1]]])

        # batch x seq_len: 3 x 2
        targets = torch.LongTensor([[2, 1], [2, 0], [1, 0]])

        # test the smoothing function: should still be one-hot
        smoothed_targets = criterion._smooth_targets(
            targets=targets.view(-1), vocab_size=predict.size(-1))

        assert torch.max(smoothed_targets) == 1
        assert torch.min(smoothed_targets) == 0

        self.assertTensorAlmostEqual(
            smoothed_targets,
            torch.Tensor([[0., 0., 1., 0., 0.], [0., 1., 0., 0., 0.],
                          [0., 0., 1., 0., 0.], [0., 0., 0., 0., 0.],
                          [0., 1., 0., 0., 0.], [0., 0., 0., 0., 0.]]))

        v = criterion(predict.log(), targets)
        self.assertTensorAlmostEqual(v, 5.6268)
Example #2
0
    def test_label_smoothing(self):
        pad_index = 0
        smoothing = 0.4
        criterion = XentLoss(pad_index=pad_index, smoothing=smoothing)

        # batch x seq_len x vocab_size: 3 x 2 x 5
        predict = torch.FloatTensor([[[0.1, 0.1, 0.6, 0.1, 0.1],
                                      [0.1, 0.1, 0.6, 0.1, 0.1]],
                                     [[0.1, 0.1, 0.6, 0.1, 0.1],
                                      [0.1, 0.1, 0.6, 0.1, 0.1]],
                                     [[0.1, 0.1, 0.6, 0.1, 0.1],
                                      [0.1, 0.1, 0.6, 0.1, 0.1]]])

        # batch x seq_len: 3 x 2
        targets = torch.LongTensor([[2, 1], [2, 0], [1, 0]])

        # test the smoothing function
        smoothed_targets = criterion._smooth_targets(
            targets=targets.view(-1), vocab_size=predict.size(-1))
        self.assertTensorAlmostEqual(
            smoothed_targets,
            torch.Tensor([[0.0000, 0.1333, 0.6000, 0.1333, 0.1333],
                          [0.0000, 0.6000, 0.1333, 0.1333, 0.1333],
                          [0.0000, 0.1333, 0.6000, 0.1333, 0.1333],
                          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                          [0.0000, 0.6000, 0.1333, 0.1333, 0.1333],
                          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]))
        assert torch.max(smoothed_targets) == 1 - smoothing

        # test the loss computation
        v = criterion(predict.log(), targets)
        self.assertTensorAlmostEqual(v, 2.1326)
Example #3
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)
        self.current_batch_multiplier = self.batch_multiplier

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize accumalted batch loss (needed for batch_multiplier)
        self.norm_batch_loss_accumulated = 0
        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
            self.init_from_checkpoint(model_load_path,
                                      reset_best_ckpt=reset_best_ckpt,
                                      reset_scheduler=reset_scheduler,
                                      reset_optimizer=reset_optimizer)
Example #4
0
class TrainManager:
    """ Manages training loop, validations, learning rate scheduling
    and early stopping."""
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)
        self.current_batch_multiplier = self.batch_multiplier

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize accumalted batch loss (needed for batch_multiplier)
        self.norm_batch_loss_accumulated = 0
        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
            self.init_from_checkpoint(model_load_path,
                                      reset_best_ckpt=reset_best_ckpt,
                                      reset_scheduler=reset_scheduler,
                                      reset_optimizer=reset_optimizer)

    def _save_checkpoint(self) -> None:
        """
        Save the model's current parameters and the training state to a
        checkpoint.

        The training state contains the total number of training steps,
        the total number of training tokens,
        the best checkpoint score and iteration so far,
        and optimizer and scheduler states.

        """
        model_path = "{}/{}.ckpt".format(self.model_dir, self.steps)
        state = {
            "steps":
            self.steps,
            "total_tokens":
            self.total_tokens,
            "best_ckpt_score":
            self.best_ckpt_score,
            "best_ckpt_iteration":
            self.best_ckpt_iteration,
            "model_state":
            self.model.state_dict(),
            "optimizer_state":
            self.optimizer.state_dict(),
            "scheduler_state":
            self.scheduler.state_dict()
            if self.scheduler is not None else None,
        }
        torch.save(state, model_path)
        if self.ckpt_queue.full():
            to_delete = self.ckpt_queue.get()  # delete oldest ckpt
            try:
                os.remove(to_delete)
            except FileNotFoundError:
                self.logger.warning(
                    "Wanted to delete old checkpoint %s but "
                    "file does not exist.", to_delete)

        self.ckpt_queue.put(model_path)

        best_path = "{}/best.ckpt".format(self.model_dir)
        try:
            # create/modify symbolic link for best checkpoint
            symlink_update("{}.ckpt".format(self.steps), best_path)
        except OSError:
            # overwrite best.ckpt
            torch.save(state, best_path)

    def init_from_checkpoint(self,
                             path: str,
                             reset_best_ckpt: bool = False,
                             reset_scheduler: bool = False,
                             reset_optimizer: bool = False) -> None:
        """
        Initialize the trainer from a given checkpoint file.

        This checkpoint file contains not only model parameters, but also
        scheduler and optimizer states, see `self._save_checkpoint`.

        :param path: path to checkpoint
        :param reset_best_ckpt: reset tracking of the best checkpoint,
                                use for domain adaptation with a new dev
                                set or when using a new metric for fine-tuning.
        :param reset_scheduler: reset the learning rate scheduler, and do not
                                use the one stored in the checkpoint.
        :param reset_optimizer: reset the optimizer, and do not use the one
                                stored in the checkpoint.
        """
        model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

        # restore model and optimizer parameters
        self.model.load_state_dict(model_checkpoint["model_state"])

        if not reset_optimizer:
            self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            self.logger.info("Reset optimizer.")

        if not reset_scheduler:
            if model_checkpoint["scheduler_state"] is not None and \
                    self.scheduler is not None:
                self.scheduler.load_state_dict(
                    model_checkpoint["scheduler_state"])
        else:
            self.logger.info("Reset scheduler.")

        # restore counts
        self.steps = model_checkpoint["steps"]
        self.total_tokens = model_checkpoint["total_tokens"]

        if not reset_best_ckpt:
            self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]
        else:
            self.logger.info("Reset tracking of the best checkpoint.")

        # move parameters to cuda
        if self.use_cuda:
            self.model.cuda()

    # pylint: disable=unnecessary-comprehension
    # pylint: disable=too-many-branches
    # pylint: disable=too-many-statements
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    batch_type=self.batch_type,
                                    train=True,
                                    shuffle=self.shuffle)

        # For last batch in epoch batch_multiplier needs to be adjusted
        # to fit the number of leftover training examples
        leftover_batch_size = len(train_data) % (self.batch_multiplier *
                                                 self.batch_size)

        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            self.current_batch_multiplier = self.batch_multiplier
            self.optimizer.zero_grad()
            count = self.current_batch_multiplier - 1
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter)):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672

                # Set current_batch_mutliplier to fit
                # number of leftover examples for last batch in epoch
                # Only works if batch_type == sentence
                if self.batch_type == "sentence":
                    if self.batch_multiplier > 1 and i == len(train_iter) - \
                            math.ceil(leftover_batch_size / self.batch_size):
                        self.current_batch_multiplier = math.ceil(
                            leftover_batch_size / self.batch_size)
                        count = self.current_batch_multiplier - 1

                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch,
                                               update=update,
                                               count=count)

                # Only save finaly computed batch_loss of full batch
                if update:
                    self.tb_writer.add_scalar("train/train_batch_loss",
                                              batch_loss, self.steps)

                count = self.batch_multiplier if update else count
                count -= 1

                # Only add complete batch_loss of full mini-batch to epoch_loss
                if update:
                    epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                        self.steps, batch_loss, elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    start_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type,
                            postprocess=True # always remove BPE for validation
                        )

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(
                        sources_raw=[v for v in valid_sources_raw],
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
                        self.eval_metric, valid_score, valid_loss, valid_ppl,
                        valid_duration)

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores:
                        store_attention_plots(
                            attentions=valid_attention_scores,
                            targets=valid_hypotheses_raw,
                            sources=[s for s in valid_data.src],
                            indices=self.log_valid_sents,
                            output_prefix="{}/att.{}".format(
                                self.model_dir, self.steps),
                            tb_writer=self.tb_writer,
                            steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info(
            'Best validation result (greedy) at step '
            '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score,
            self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer

    def _train_batch(self,
                     batch: Batch,
                     update: bool = True,
                     count: int = 1) -> Tensor:
        """
        Train the model on one batch: Compute the loss, make a gradient step.

        :param batch: training batch
        :param update: if False, only store gradient. if True also make update
        :param count: number of portions (batch_size) left before update
        :return: loss for batch (sum)
        """
        batch_loss = self.model.get_loss_for_batch(batch=batch,
                                                   loss_function=self.loss)

        # normalize batch loss
        if self.normalization == "batch":
            normalizer = batch.nseqs
        elif self.normalization == "tokens":
            normalizer = batch.ntokens
        elif self.normalization == "none":
            normalizer = 1
        else:
            raise NotImplementedError(
                "Only normalize by 'batch' or 'tokens' "
                "or summation of loss 'none' implemented")

        norm_batch_loss = batch_loss / normalizer

        if update:
            if self.current_batch_multiplier > 1:
                norm_batch_loss = self.norm_batch_loss_accumulated + \
                    norm_batch_loss
                norm_batch_loss = norm_batch_loss / \
                    self.current_batch_multiplier if \
                    self.normalization != "none" else \
                    norm_batch_loss

            norm_batch_loss.backward()

            if self.clip_grad_fun is not None:
                # clip gradients (in-place)
                self.clip_grad_fun(params=self.model.parameters())

            # make gradient step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # increment step counter
            self.steps += 1

        else:
            if count == self.current_batch_multiplier - 1:
                self.norm_batch_loss_accumulated = norm_batch_loss
            else:
                # accumulate loss of current batch_size * batch_multiplier loss
                self.norm_batch_loss_accumulated += norm_batch_loss
        # increment token counter
        self.total_tokens += batch.ntokens

        return norm_batch_loss

    def _add_report(self,
                    valid_score: float,
                    valid_ppl: float,
                    valid_loss: float,
                    eval_metric: str,
                    new_best: bool = False) -> None:
        """
        Append a one-line report to validation logging file.

        :param valid_score: validation evaluation score [eval_metric]
        :param valid_ppl: validation perplexity
        :param valid_loss: validation loss (sum over whole validation set)
        :param eval_metric: evaluation metric, e.g. "bleu"
        :param new_best: whether this is a new best model
        """
        current_lr = -1
        # ignores other param groups for now
        for param_group in self.optimizer.param_groups:
            current_lr = param_group['lr']

        if current_lr < self.learning_rate_min:
            self.stop = True

        with open(self.valid_report_file, 'a') as opened_file:
            opened_file.write(
                "Steps: {}\tLoss: {:.5f}\tPPL: {:.5f}\t{}: {:.5f}\t"
                "LR: {:.8f}\t{}\n".format(self.steps, valid_loss, valid_ppl,
                                          eval_metric, valid_score, current_lr,
                                          "*" if new_best else ""))

    def _log_parameters_list(self) -> None:
        """
        Write all model parameters (name, shape) to the log.
        """
        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        n_params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info("Total params: %d", n_params)
        trainable_params = [
            n for (n, p) in self.model.named_parameters() if p.requires_grad
        ]
        self.logger.info("Trainable parameters: %s", sorted(trainable_params))
        assert trainable_params

    def _log_examples(self,
                      sources: List[str],
                      hypotheses: List[str],
                      references: List[str],
                      sources_raw: List[List[str]] = None,
                      hypotheses_raw: List[List[str]] = None,
                      references_raw: List[List[str]] = None) -> None:
        """
        Log a the first `self.log_valid_sents` sentences from given examples.

        :param sources: decoded sources (list of strings)
        :param hypotheses: decoded hypotheses (list of strings)
        :param references: decoded references (list of strings)
        :param sources_raw: raw sources (list of list of tokens)
        :param hypotheses_raw: raw hypotheses (list of list of tokens)
        :param references_raw: raw references (list of list of tokens)
        """
        for p in self.log_valid_sents:

            if p >= len(sources):
                continue

            self.logger.info("Example #%d", p)

            if sources_raw is not None:
                self.logger.debug("\tRaw source:     %s", sources_raw[p])
            if references_raw is not None:
                self.logger.debug("\tRaw reference:  %s", references_raw[p])
            if hypotheses_raw is not None:
                self.logger.debug("\tRaw hypothesis: %s", hypotheses_raw[p])

            self.logger.info("\tSource:     %s", sources[p])
            self.logger.info("\tReference:  %s", references[p])
            self.logger.info("\tHypothesis: %s", hypotheses[p])

    def _store_outputs(self, hypotheses: List[str]) -> None:
        """
        Write current validation outputs to file in `self.model_dir.`

        :param hypotheses: list of strings
        """
        current_valid_output_file = "{}/{}.hyps".format(
            self.model_dir, self.steps)
        with open(current_valid_output_file, 'w') as opened_file:
            for hyp in hypotheses:
                opened_file.write("{}\n".format(hyp))
Example #5
0
    def __init__(self,
                 model: Model,
                 config: dict,
                 batch_class: Batch = Batch) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        :param batch_class: batch class to encapsulate the torch class
        """
        train_config = config["training"]
        self.batch_class = batch_class

        # files for logging and storing
        self.model_dir = train_config["model_dir"]
        assert os.path.exists(self.model_dir)

        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        self.save_latest_checkpoint = train_config.get("save_latest_ckpt",
                                                       True)

        # model
        self.model = model
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.model.loss_function = XentLoss(pad_index=self.model.pad_index,
                                            smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = collections.deque(
            maxlen=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # eval options
        test_config = config["testing"]
        self.bpe_type = test_config.get("bpe_type", "subword-nmt")
        self.sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
        if "sacrebleu" in config["testing"].keys():
            self.sacrebleu["remove_whitespace"] = test_config["sacrebleu"] \
                .get("remove_whitespace", True)
            self.sacrebleu["tokenize"] = test_config["sacrebleu"] \
                .get("tokenize", "13a")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        # Placeholder so that we can use the train_iter in other functions.
        self.train_iter = None
        self.train_iter_state = None
        # per-device batch_size = self.batch_size // self.n_gpu
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        # per-device eval_batch_size = self.eval_batch_size // self.n_gpu
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"] and torch.cuda.is_available()
        self.n_gpu = torch.cuda.device_count() if self.use_cuda else 0
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        if self.use_cuda:
            self.model.to(self.device)

        # fp16
        self.fp16 = train_config.get("fp16", False)
        if self.fp16:
            if 'apex' not in sys.modules:
                raise ImportError("Please install apex from "
                                  "https://www.github.com/nvidia/apex "
                                  "to use fp16 training.") from no_apex
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')
            # opt level: one of {"O0", "O1", "O2", "O3"}
            # see https://nvidia.github.io/apex/amp.html#opt-levels

        # initialize training statistics
        self.stats = self.TrainStatistics(
            steps=0,
            stop=False,
            total_tokens=0,
            best_ckpt_iter=0,
            best_ckpt_score=np.inf if self.minimize_metric else -np.inf,
            minimize_metric=self.minimize_metric)

        # model parameters
        if "load_model" in train_config.keys():
            self.init_from_checkpoint(
                train_config["load_model"],
                reset_best_ckpt=train_config.get("reset_best_ckpt", False),
                reset_scheduler=train_config.get("reset_scheduler", False),
                reset_optimizer=train_config.get("reset_optimizer", False),
                reset_iter_state=train_config.get("reset_iter_state", False))

        # multi-gpu training (should be after apex fp16 initialization)
        if self.n_gpu > 1:
            self.model = _DataParallel(self.model)
Example #6
0
def Q_learning(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.
    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)  # config is a dict
    # make logger
    model_dir = make_model_dir(cfg["training"]["model_dir"],
                               overwrite=cfg["training"].get(
                                   "overwrite", False))
    _ = make_logger(model_dir, mode="train")  # version string returned
    # TODO: save version number in model checkpoints

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # load the data
    print("loadding data here")
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])
    # The training data is filtered to include sentences up to `max_sent_length`
    #     on source and target side.

    # training config:
    train_config = cfg["training"]
    shuffle = train_config.get("shuffle", True)
    batch_size = train_config["batch_size"]
    mini_BATCH_SIZE = train_config["mini_batch_size"]
    batch_type = train_config.get("batch_type", "sentence")
    outer_epochs = train_config.get("outer_epochs", 10)
    inner_epochs = train_config.get("inner_epochs", 10)
    TARGET_UPDATE = train_config.get("target_update", 10)
    Gamma = train_config.get("Gamma", 0.999)
    use_cuda = train_config["use_cuda"] and torch.cuda.is_available()

    # validation part config
    # validation
    validation_freq = train_config.get("validation_freq", 1000)
    ckpt_queue = queue.Queue(maxsize=train_config.get("keep_last_ckpts", 5))
    eval_batch_size = train_config.get("eval_batch_size", batch_size)
    level = cfg["data"]["level"]

    eval_metric = train_config.get("eval_metric", "bleu")
    n_gpu = torch.cuda.device_count() if use_cuda else 0
    eval_batch_type = train_config.get("eval_batch_type", batch_type)
    # eval options
    test_config = cfg["testing"]
    bpe_type = test_config.get("bpe_type", "subword-nmt")
    sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    max_output_length = train_config.get("max_output_length", None)
    minimize_metric = True
    # initialize training statistics
    stats = TrainStatistics(
        steps=0,
        stop=False,
        total_tokens=0,
        best_ckpt_iter=0,
        best_ckpt_score=np.inf if minimize_metric else -np.inf,
        minimize_metric=minimize_metric)

    early_stopping_metric = train_config.get("early_stopping_metric",
                                             "eval_metric")

    if early_stopping_metric in ["ppl", "loss"]:
        stats.minimize_metric = True
        stats.best_ckpt_score = np.inf
    elif early_stopping_metric == "eval_metric":
        if eval_metric in [
                "bleu", "chrf", "token_accuracy", "sequence_accuracy"
        ]:
            stats.minimize_metric = False
            stats.best_ckpt_score = -np.inf

        # eval metric that has to get minimized (not yet implemented)
        else:
            stats.minimize_metric = True

    # data loader(modified from train_and_validate function
    # Returns a torchtext iterator for a torchtext dataset.
    # param dataset: torchtext dataset containing src and optionally trg
    train_iter = make_data_iter(train_data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                train=True,
                                shuffle=shuffle)

    # initialize the Replay Memory D with capacity N
    memory = ReplayMemory(10000)
    steps_done = 0

    # initialize two DQN networks
    policy_net = build_model(cfg["model"],
                             src_vocab=src_vocab,
                             trg_vocab=trg_vocab)  # Q_network
    target_net = build_model(cfg["model"],
                             src_vocab=src_vocab,
                             trg_vocab=trg_vocab)  # Q_hat_network
    #logger.info(policy_net.src_vocab.stoi)
    #print("###############trg vocab: ", len(target_net.trg_vocab.stoi))
    #print("trg embed: ", target_net.trg_embed.vocab_size)
    if use_cuda:
        policy_net.cuda()
        target_net.cuda()

    target_net.load_state_dict(policy_net.state_dict())
    # Initialize target net Q_hat with weights equal to policy_net

    target_net.eval()  # target_net not update the parameters, test mode

    # Optimizer
    optimizer = build_optimizer(config=cfg["training"],
                                parameters=policy_net.parameters())
    # Loss function
    mse_loss = torch.nn.MSELoss()

    pad_index = policy_net.pad_index
    # print('!!!'*10, pad_index)

    cross_entropy_loss = XentLoss(pad_index=pad_index)
    policy_net.loss_function = cross_entropy_loss

    # learning rate scheduling
    scheduler, scheduler_step_at = build_scheduler(
        config=train_config,
        scheduler_mode="min" if minimize_metric else "max",
        optimizer=optimizer,
        hidden_size=cfg["model"]["encoder"]["hidden_size"])

    # model parameters
    if "load_model" in train_config.keys():
        load_model_path = train_config["load_model"]
        reset_best_ckpt = train_config.get("reset_best_ckpt", False)
        reset_scheduler = train_config.get("reset_scheduler", False)
        reset_optimizer = train_config.get("reset_optimizer", False)
        reset_iter_state = train_config.get("reset_iter_state", False)

        print('settings', reset_best_ckpt, reset_iter_state, reset_optimizer,
              reset_scheduler)

        logger.info("Loading model from %s", load_model_path)
        model_checkpoint = load_checkpoint(path=load_model_path,
                                           use_cuda=use_cuda)

        # restore model and optimizer parameters
        policy_net.load_state_dict(model_checkpoint["model_state"])

        if not reset_optimizer:
            optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            logger.info("Reset optimizer.")
        if not reset_scheduler:
            if model_checkpoint["scheduler_state"] is not None and \
                    scheduler is not None:
                scheduler.load_state_dict(model_checkpoint["scheduler_state"])
        else:
            logger.info("Reset scheduler.")

        if not reset_best_ckpt:
            stats.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            stats.best_ckpt_iter = model_checkpoint["best_ckpt_iteration"]
            print('stats.best_ckpt_score', stats.best_ckpt_score)
            print('stats.best_ckpt_iter', stats.best_ckpt_iter)
        else:
            logger.info("Reset tracking of the best checkpoint.")

        if (not reset_iter_state and model_checkpoint.get(
                'train_iter_state', None) is not None):
            train_iter_state = model_checkpoint["train_iter_state"]

        # move parameters to cuda

        target_net.load_state_dict(policy_net.state_dict())
        # Initialize target net Q_hat with weights equal to policy_net

        target_net.eval()

        if use_cuda:
            policy_net.cuda()
            target_net.cuda()

    for i_episode in range(outer_epochs):
        # Outer loop

        # get batch
        for i, batch in enumerate(iter(train_iter)):  # joeynmt training.py 377

            # create a Batch object from torchtext batch
            # ( use class Batch from batch.py)
            # return the sentences same length (with padding) in one batch
            batch = Batch(batch, policy_net.pad_index, use_cuda=use_cuda)
            # we want to get batch.src and batch.trg
            # the shape of batch.src: (batch_size * length of the sentence)

            # source here is represented by the word index not word embedding.

            encoder_output_batch, _, _, _ = policy_net(
                return_type="encode",
                src=batch.src,
                src_length=batch.src_length,
                src_mask=batch.src_mask,
            )

            trans_output_batch, _ = transformer_greedy(
                src_mask=batch.src_mask,
                max_output_length=max_output_length,
                model=policy_net,
                encoder_output=encoder_output_batch,
                steps_done=steps_done,
                use_cuda=use_cuda)
            #print('steps_done',steps_done)

            steps_done += 1

            #print('trans_output_batch.shape is:', trans_output_batch.shape)
            # batch_size * max_translation_sentence_length
            #print('batch.src', batch.src)
            #print('batch.trg', batch.trg)
            print('batch.trg.shape is:', batch.trg.shape)
            print('trans_output_batch', trans_output_batch)

            reward_batch = [
            ]  # Get the reward_batch (Get the bleu score of the sentences in a batch)

            for i in range(int(batch.src.shape[0])):
                all_outputs = [(trans_output_batch[i])[1:]]
                all_ref = [batch.trg[i]]
                sentence_score = calculate_bleu(model=policy_net,
                                                level=level,
                                                raw_hypo=all_outputs,
                                                raw_ref=all_ref)
                reward_batch.append(sentence_score)

            print('reward batch is', reward_batch)
            reward_batch = torch.tensor(reward_batch, dtype=torch.float)

            # reward_batch = bleu(hypotheses, references, tokenize="13a")
            # print('reward_batch.shape', reward_batch.shape)

            # make prefix and push tuples into memory
            push_sample_to_memory(model=policy_net,
                                  level=level,
                                  eos_index=policy_net.eos_index,
                                  memory=memory,
                                  src_batch=batch.src,
                                  trg_batch=batch.trg,
                                  trans_output_batch=trans_output_batch,
                                  reward_batch=reward_batch,
                                  max_output_length=max_output_length)
            print(memory.capacity, len(memory.memory))

            if len(memory.memory) == memory.capacity:
                # inner loop
                for t in range(inner_epochs):
                    # Sample mini-batch from the memory
                    transitions = memory.sample(mini_BATCH_SIZE)
                    # transition = [Transition(source=array([]), prefix=array([]), next_word= int, reward= int),
                    #               Transition(source=array([]), prefix=array([]), next_word= int, reward= int,...]
                    # Each Transition is what we push into memory for one sentence: memory.push(source, prefix, next_word, reward_batch[i])
                    mini_batch = Transition(*zip(*transitions))
                    # merge the same class in transition together
                    # mini_batch = Transition(source=(array([]), array([]),...), prefix=(array([],...),
                    #               next_word=array([...]), reward=array([...]))
                    # mini_batch.reward is tuple: length is mini_BATCH_SIZE.
                    #print('mini_batch', mini_batch)

                    #concatenate together into a tensor.
                    words = []
                    for word in mini_batch.next_word:
                        new_word = word.unsqueeze(0)
                        words.append(new_word)
                    mini_next_word = torch.cat(
                        words)  # shape (mini_BATCH_SIZE,)
                    mini_reward = torch.tensor(
                        mini_batch.reward)  # shape (mini_BATCH_SIZE,)

                    #print('mini_batch.finish', mini_batch.finish)

                    mini_is_eos = torch.Tensor(mini_batch.finish)
                    #print(mini_is_eos)

                    mini_src_length = [
                        len(item) for item in mini_batch.source_sentence
                    ]
                    mini_src_length = torch.Tensor(mini_src_length)

                    mini_src = pad_sequence(mini_batch.source_sentence,
                                            batch_first=True,
                                            padding_value=float(pad_index))
                    # shape (mini_BATCH_SIZE, max_length_src)

                    length_prefix = [len(item) for item in mini_batch.prefix]
                    mini_prefix_length = torch.Tensor(length_prefix)

                    prefix_list = []
                    for prefix_ in mini_batch.prefix:
                        prefix_ = torch.from_numpy(prefix_)
                        prefix_list.append(prefix_)

                    mini_prefix = pad_sequence(prefix_list,
                                               batch_first=True,
                                               padding_value=pad_index)
                    # shape (mini_BATCH_SIZE, max_length_prefix)

                    mini_src_mask = (mini_src != pad_index).unsqueeze(1)
                    mini_trg_mask = (mini_prefix != pad_index).unsqueeze(1)

                    #print('mini_src',  mini_src)
                    #print('mini_src_length', mini_src_length)
                    #print('mini_src_mask', mini_src_mask)
                    #print('mini_prefix', mini_prefix)
                    #print('mini_trg_mask', mini_trg_mask)

                    #print('mini_reward', mini_reward)

                    # max_length_src = torch.max(mini_src_length) #max([len(item) for item in mini_batch.source_sentence])

                    if use_cuda:
                        mini_src = mini_src.cuda()
                        mini_prefix = mini_prefix.cuda()
                        mini_src_mask = mini_src_mask.cuda()
                        mini_src_length = mini_src_length.cuda()
                        mini_trg_mask = mini_trg_mask.cuda()
                        mini_next_word = mini_next_word.cuda()

                    # print(next(policy_net.parameters()).is_cuda)
                    # print(mini_trg_mask.get_device())
                    # calculate the Q_value
                    logits_Q, _, _, _ = policy_net._encode_decode(
                        src=mini_src,
                        trg_input=mini_prefix,
                        src_mask=mini_src_mask,
                        src_length=mini_src_length,
                        trg_mask=
                        mini_trg_mask  # trg_mask = (self.trg_input != pad_index).unsqueeze(1)
                    )
                    #print('mini_prefix_length', mini_prefix_length)

                    #print('logits_Q.shape', logits_Q.shape) # torch.Size([64, 99, 31716])
                    #print('logits_Q', logits_Q)

                    # length_prefix = max([len(item) for item in mini_batch.prefix])
                    # logits_Q shape: batch_size * length of the sentence * total number of words in corpus.
                    logits_Q = logits_Q[range(mini_BATCH_SIZE),
                                        mini_prefix_length.long() - 1, :]
                    #print('logits_Q_.shape', logits_Q.shape) #shape(mini_batch_size, num_words)
                    # logits shape: mini_batch_size * total number of words in corpus
                    Q_value = logits_Q[range(mini_BATCH_SIZE), mini_next_word]
                    #print('mini_next_word', mini_next_word)
                    #print("Q_value", Q_value)

                    mini_prefix_add = torch.cat(
                        [mini_prefix, mini_next_word.unsqueeze(1)], dim=1)
                    #print('mini_prefix_add', mini_prefix_add)
                    mini_trg_mask_add = (mini_prefix_add !=
                                         pad_index).unsqueeze(1)
                    #print('mini_trg_mask_add', mini_trg_mask_add)

                    if use_cuda:
                        mini_prefix_add = mini_prefix_add.cuda()
                        mini_trg_mask_add = mini_trg_mask_add.cuda()

                    logits_Q_hat, _, _, _ = target_net._encode_decode(
                        src=mini_src,
                        trg_input=mini_prefix_add,
                        src_mask=mini_src_mask,
                        src_length=mini_src_length,
                        trg_mask=mini_trg_mask_add)
                    #print('mini_prefix_add.shape', mini_prefix_add.shape)
                    #print('logits_Q_hat.shape', logits_Q_hat.shape)
                    #print('mini_prefix_length.long()', mini_prefix_length.long())
                    logits_Q_hat = logits_Q_hat[range(mini_BATCH_SIZE),
                                                mini_prefix_length.long(), :]
                    Q_hat_value, _ = torch.max(logits_Q_hat, dim=1)
                    #print('Q_hat_value', Q_hat_value)

                    if use_cuda:

                        Q_hat_value = Q_hat_value.cuda()
                        mini_reward = mini_reward.cuda()
                        mini_is_eos = mini_is_eos.cuda()

                    yj = mini_reward.float() + Gamma * Q_hat_value
                    #print('yj', yj)
                    index = mini_is_eos.long()
                    #print('mini_is_eos', mini_is_eos)
                    yj[index] = mini_reward[index]
                    #print('yj', yj)
                    #print('Q_value1', Q_value)

                    yj.detach()
                    # Optimize the model
                    policy_net.zero_grad()

                    # Compute loss
                    loss = mse_loss(yj, Q_value)
                    print('loss', loss)
                    logger.info("step = {}, loss = {}".format(
                        stats.steps, loss.item()))
                    loss.backward()
                    #for param in policy_net.parameters():
                    #   param.grad.data.clamp_(-1, 1)
                    optimizer.step()

                    stats.steps += 1
                    #print('step', stats.steps)

                    if stats.steps % TARGET_UPDATE == 0:
                        #print('update the parameters in target_net.')
                        target_net.load_state_dict(policy_net.state_dict())

                    if stats.steps % validation_freq == 0:  # Validation
                        print('Start validation')

                        valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                            validate_on_data(
                                model=policy_net,
                                data=dev_data,
                                batch_size=eval_batch_size,
                                use_cuda=use_cuda,
                                level=level,
                                eval_metric=eval_metric,
                                n_gpu=n_gpu,
                                compute_loss=True,
                                beam_size=1,
                                beam_alpha=-1,
                                batch_type=eval_batch_type,
                                postprocess=True,
                                bpe_type=bpe_type,
                                sacrebleu=sacrebleu,
                                max_output_length=max_output_length
                            )
                        print(
                            'validation_loss: {}, validation_score: {}'.format(
                                valid_loss, valid_score))
                        logger.info(valid_loss)
                        print('average loss: total_loss/n_tokens:', valid_ppl)

                        if early_stopping_metric == "loss":
                            ckpt_score = valid_loss
                        elif early_stopping_metric in ["ppl", "perplexity"]:
                            ckpt_score = valid_ppl
                        else:
                            ckpt_score = valid_score
                        if stats.is_best(ckpt_score):
                            stats.best_ckpt_score = ckpt_score
                            stats.best_ckpt_iter = stats.steps
                            logger.info(
                                'Hooray! New best validation result [%s]!',
                                early_stopping_metric)
                            if ckpt_queue.maxsize > 0:
                                logger.info("Saving new checkpoint.")

                                # def _save_checkpoint(self) -> None:
                                """
                                Save the model's current parameters and the training state to a
                                checkpoint.
                                The training state contains the total number of training steps,
                                the total number of training tokens,
                                the best checkpoint score and iteration so far,
                                and optimizer and scheduler states.
                                """
                                model_path = "{}/{}.ckpt".format(
                                    model_dir, stats.steps)
                                model_state_dict = policy_net.module.state_dict() \
                                    if isinstance(policy_net, torch.nn.DataParallel) \
                                    else policy_net.state_dict()
                                state = {
                                    "steps": stats.steps,
                                    "total_tokens": stats.total_tokens,
                                    "best_ckpt_score": stats.best_ckpt_score,
                                    "best_ckpt_iteration":
                                    stats.best_ckpt_iter,
                                    "model_state": model_state_dict,
                                    "optimizer_state": optimizer.state_dict(),
                                    # "scheduler_state": scheduler.state_dict() if
                                    # self.scheduler is not None else None,
                                    # 'amp_state': amp.state_dict() if self.fp16 else None
                                }
                                torch.save(state, model_path)
                                if ckpt_queue.full():
                                    to_delete = ckpt_queue.get(
                                    )  # delete oldest ckpt
                                    try:
                                        os.remove(to_delete)
                                    except FileNotFoundError:
                                        logger.warning(
                                            "Wanted to delete old checkpoint %s but "
                                            "file does not exist.", to_delete)

                                ckpt_queue.put(model_path)

                                best_path = "{}/best.ckpt".format(model_dir)
                                try:
                                    # create/modify symbolic link for best checkpoint
                                    symlink_update(
                                        "{}.ckpt".format(stats.steps),
                                        best_path)
                                except OSError:
                                    # overwrite best.ckpt
                                    torch.save(state, best_path)
Example #7
0
class TrainManager:
    """ Manages training loop, validations, learning rate scheduling
    and early stopping."""
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'sequence_accuracy', 'token_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # if we schedule after BLEU/chrf, we want to maximize it, else minimize
        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "sequence_accuracy", "token_accuracy"
            ]:
                self.minimize_metric = False
            else:  # eval metric that has to get minimized (not yet implemented)
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # scheduled sampling
        self.scheduled_sampling = build_scheduled_sampling(config=train_config)
        self.minibatch_count = 0

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        self.logger.info("Training with or without cuda=%s",
                         str(self.use_cuda))
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        self.report_entf1_on_canonicals = train_config.get(
            "report_entf1_on_canonicals", False)

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            self.init_from_checkpoint(model_load_path)

        self.manage_decoder_timer = train_config.get("manage_decoder_timer",
                                                     True)
        if self.manage_decoder_timer:
            self.decoder_timer = self.model.decoder.timer

    def _save_checkpoint(self) -> None:
        """
        Save the model's current parameters and the training state to a
        checkpoint.

        The training state contains the total number of training steps,
        the total number of training tokens,
        the best checkpoint score and iteration so far,
        and optimizer and scheduler states.

        """
        model_path = "{}/{}.ckpt".format(self.model_dir, self.steps)
        state = {
            "steps": self.steps,
            "total_tokens": self.total_tokens,
            "best_ckpt_score": self.best_ckpt_score,
            "best_ckpt_iteration": self.best_ckpt_iteration,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict() if \
            self.scheduler is not None else None,
        }
        torch.save(state, model_path)
        if self.ckpt_queue.full():
            to_delete = self.ckpt_queue.get()  # delete oldest ckpt
            try:
                os.remove(to_delete)
            except FileNotFoundError:
                self.logger.warning(
                    "Wanted to delete old checkpoint %s but "
                    "file does not exist.", to_delete)

        self.ckpt_queue.put(model_path)

        # create/modify symbolic link for best checkpoint
        symlink_update("{}.ckpt".format(self.steps),
                       "{}/best.ckpt".format(self.model_dir))

    def init_from_checkpoint(self, path: str) -> None:
        """
        Initialize the trainer from a given checkpoint file.

        This checkpoint file contains not only model parameters, but also
        scheduler and optimizer states, see `self._save_checkpoint`.

        :param path: path to checkpoint
        """
        model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

        # restore model and optimizer parameters
        self.model.load_state_dict(model_checkpoint["model_state"])

        self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])

        if model_checkpoint["scheduler_state"] is not None and \
                self.scheduler is not None:
            self.scheduler.load_state_dict(model_checkpoint["scheduler_state"])

        # restore counts
        self.steps = model_checkpoint["steps"]
        self.total_tokens = model_checkpoint["total_tokens"]
        self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
        self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]

        # move parameters to cuda
        if self.use_cuda:
            self.model.cuda()

    # the typing here lies
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset, kb_task=None, train_kb: TranslationDataset =None,\
        train_kb_lkp: list = [], train_kb_lens: list = [], train_kb_truvals: TranslationDataset=None, valid_kb: Tuple=None, \
        valid_kb_lkp: list=[], valid_kb_lens: list = [], valid_kb_truvals:list=[],
        valid_data_canon: list=[]) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        :param kb_task: is not None if kb_task should be executed
        :param train_kb: TranslationDataset holding the loaded train kb data
        :param train_kb_lkp: List with train example index to corresponding kb indices
        :param train_kb_len: List with num of triples per kb 
        :param valid_kb: TranslationDataset holding the loaded valid kb data
        :param valid_kb_lkp: List with valid example index to corresponding kb indices
        :param valid_kb_len: List with num of triples per kb 
        :param valid_kb_truvals: FIXME TODO
        :param valid_data_canon: required to report loss 
        """

        if kb_task:
            train_iter = make_data_iter_kb(train_data,
                                           train_kb,
                                           train_kb_lkp,
                                           train_kb_lens,
                                           train_kb_truvals,
                                           batch_size=self.batch_size,
                                           batch_type=self.batch_type,
                                           train=True,
                                           shuffle=self.shuffle,
                                           canonize=self.model.canonize)
        else:
            train_iter = make_data_iter(train_data,
                                        batch_size=self.batch_size,
                                        batch_type=self.batch_type,
                                        train=True,
                                        shuffle=self.shuffle)

        with torch.autograd.set_detect_anomaly(True):
            for epoch_no in range(self.epochs):
                self.logger.info("EPOCH %d", epoch_no + 1)

                if self.scheduler is not None and self.scheduler_step_at == "epoch":
                    self.scheduler.step(epoch=epoch_no)

                self.model.train()

                start = time.time()
                total_valid_duration = 0
                processed_tokens = self.total_tokens
                count = self.batch_multiplier - 1
                epoch_loss = 0

                for batch in iter(train_iter):
                    # reactivate training
                    self.model.train()

                    # create a Batch object from torchtext batch
                    batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) if not kb_task else \
                        Batch_with_KB(batch, self.pad_index, use_cuda=self.use_cuda)

                    if kb_task:
                        assert hasattr(batch, "kbsrc"), dir(batch)
                        assert hasattr(batch, "kbtrg"), dir(batch)
                        assert hasattr(batch, "kbtrv"), dir(batch)

                    # only update every batch_multiplier batches
                    # see https://medium.com/@davidlmorton/
                    # increasing-mini-batch-size-without-increasing-
                    # memory-6794e10db672
                    update = count == 0

                    batch_loss = self._train_batch(batch, update=update)

                    if update:
                        self.tb_writer.add_scalar("train/train_batch_loss",
                                                  batch_loss, self.steps)

                    count = self.batch_multiplier if update else count
                    count -= 1
                    epoch_loss += batch_loss.detach().cpu().numpy()

                    if self.scheduler is not None and \
                            self.scheduler_step_at == "step" and update:
                        self.scheduler.step()

                    # log learning progress
                    if self.steps % self.logging_freq == 0 and update:
                        elapsed = time.time() - start - total_valid_duration
                        elapsed_tokens = self.total_tokens - processed_tokens
                        self.logger.info(
                            "Epoch %3d Step: %8d Batch Loss: %12.6f "
                            "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                            self.steps, batch_loss, elapsed_tokens / elapsed,
                            self.optimizer.param_groups[0]["lr"])
                        start = time.time()
                        total_valid_duration = 0

                    # validate on the entire dev set
                    if self.steps % self.validation_freq == 0 and update:

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("train")
                            self.decoder_timer.reset()

                        valid_start_time = time.time()


                        valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                            valid_hypotheses_raw, valid_attention_scores, valid_kb_att_scores, \
                            valid_ent_f1, valid_ent_mcc = \
                            validate_on_data(
                                batch_size=self.eval_batch_size,
                                data=valid_data,
                                eval_metric=self.eval_metric,
                                level=self.level,
                                model=self.model,
                                use_cuda=self.use_cuda,
                                max_output_length=self.max_output_length,
                                loss_function=self.loss,
                                beam_size=0,  # greedy validations #FIXME XXX NOTE TODO BUG set to 0 again!
                                batch_type=self.eval_batch_type,
                                kb_task=kb_task,
                                valid_kb=valid_kb,
                                valid_kb_lkp=valid_kb_lkp,
                                valid_kb_lens=valid_kb_lens,
                                valid_kb_truvals=valid_kb_truvals,
                                valid_data_canon=valid_data_canon,
                                report_on_canonicals=self.report_entf1_on_canonicals
                            )

                        if self.manage_decoder_timer:
                            self._log_decoder_timer_stats("valid")
                            self.decoder_timer.reset()

                        self.tb_writer.add_scalar("valid/valid_loss",
                                                  valid_loss, self.steps)
                        self.tb_writer.add_scalar("valid/valid_score",
                                                  valid_score, self.steps)
                        self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                                  self.steps)

                        if self.early_stopping_metric == "loss":
                            ckpt_score = valid_loss
                        elif self.early_stopping_metric in [
                                "ppl", "perplexity"
                        ]:
                            ckpt_score = valid_ppl
                        else:
                            ckpt_score = valid_score

                        new_best = False
                        if self.is_best(ckpt_score):
                            self.best_ckpt_score = ckpt_score
                            self.best_ckpt_iteration = self.steps
                            self.logger.info(
                                'Hooray! New best validation result [%s]!',
                                self.early_stopping_metric)
                            if self.ckpt_queue.maxsize > 0:
                                self.logger.info("Saving new checkpoint.")
                                new_best = True
                                self._save_checkpoint()

                        if self.scheduler is not None \
                                and self.scheduler_step_at == "validation":
                            self.scheduler.step(ckpt_score)

                        # append to validation report
                        self._add_report(valid_score=valid_score,
                                         valid_loss=valid_loss,
                                         valid_ppl=valid_ppl,
                                         eval_metric=self.eval_metric,
                                         valid_ent_f1=valid_ent_f1,
                                         valid_ent_mcc=valid_ent_mcc,
                                         new_best=new_best)

                        # pylint: disable=unnecessary-comprehension
                        self._log_examples(
                            sources_raw=[v for v in valid_sources_raw],
                            sources=valid_sources,
                            hypotheses_raw=valid_hypotheses_raw,
                            hypotheses=valid_hypotheses,
                            references=valid_references)

                        valid_duration = time.time() - valid_start_time
                        total_valid_duration += valid_duration
                        self.logger.info(
                            'Validation result at epoch %3d, step %8d: %s: %6.2f, '
                            'loss: %8.4f, ppl: %8.4f, duration: %.4fs',
                            epoch_no + 1, self.steps, self.eval_metric,
                            valid_score, valid_loss, valid_ppl, valid_duration)

                        # store validation set outputs
                        self._store_outputs(valid_hypotheses)

                        valid_src = list(valid_data.src)
                        # store attention plots for selected valid sentences
                        if valid_attention_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_attention_scores,
                                targets=valid_hypotheses_raw,
                                sources=valid_src,
                                indices=self.log_valid_sents,
                                output_prefix="{}/att.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps)
                            self.logger.info(
                                f"stored {plot_success_ratio} valid att scores!"
                            )
                        if valid_kb_att_scores:
                            plot_success_ratio = store_attention_plots(
                                attentions=valid_kb_att_scores,
                                targets=valid_hypotheses_raw,
                                sources=list(valid_kb.kbsrc),
                                indices=self.log_valid_sents,
                                output_prefix="{}/kbatt.{}".format(
                                    self.model_dir, self.steps),
                                tb_writer=self.tb_writer,
                                steps=self.steps,
                                kb_info=(valid_kb_lkp, valid_kb_lens,
                                         valid_kb_truvals),
                                on_the_fly_info=(valid_src, valid_kb,
                                                 self.model.canonize,
                                                 self.model.trg_vocab))
                            self.logger.info(
                                f"stored {plot_success_ratio} valid kb att scores!"
                            )
                        else:
                            self.logger.info(
                                "theres no valid kb att scores...")
                    if self.stop:
                        break
                if self.stop:
                    self.logger.info(
                        'Training ended since minimum lr %f was reached.',
                        self.learning_rate_min)
                    break

                self.logger.info('Epoch %3d: total training loss %.2f',
                                 epoch_no + 1, epoch_loss)
            else:
                self.logger.info('Training ended after %3d epochs.',
                                 epoch_no + 1)
            self.logger.info('Best validation result at step %8d: %6.2f %s.',
                             self.best_ckpt_iteration, self.best_ckpt_score,
                             self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer

    def _train_batch(self, batch: Batch, update: bool = True) -> Tensor:
        """
        Train the model on one batch: Compute the loss, make a gradient step.

        :param batch: training batch
        :param update: if False, only store gradient. if True also make update
        :return: loss for batch (sum)
        """

        batch_loss = self.model.get_loss_for_batch(batch=batch,
                                                   loss_function=self.loss,
                                                   e_i=self.scheduled_sampling(
                                                       self.minibatch_count))

        # normalize batch loss
        if self.normalization == "batch":
            normalizer = batch.nseqs
        elif self.normalization == "tokens":
            normalizer = batch.ntokens
        else:
            raise NotImplementedError("Only normalize by 'batch' or 'tokens'")

        norm_batch_loss = batch_loss / normalizer
        # division needed since loss.backward sums the gradients until updated
        norm_batch_multiply = norm_batch_loss / self.batch_multiplier

        # compute gradients
        norm_batch_multiply.backward()

        if self.clip_grad_fun is not None:
            # clip gradients (in-place)
            self.clip_grad_fun(params=self.model.parameters())

        if update:
            # make gradient step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # increment step counter
            self.steps += 1

        # increment token counter
        self.total_tokens += batch.ntokens

        # increment minibatch count for scheduled sampling
        self.minibatch_count += 1

        return norm_batch_loss

    def _add_report(self,
                    valid_score: float,
                    valid_ppl: float,
                    valid_loss: float,
                    eval_metric: str,
                    valid_ent_f1: float = None,
                    valid_ent_mcc: float = None,
                    new_best: bool = False) -> None:
        """
        Append a one-line report to validation logging file.

        :param valid_score: validation evaluation score [eval_metric]
        :param valid_ppl: validation perplexity
        :param valid_loss: validation loss (sum over whole validation set)
        :param eval_metric: evaluation metric, e.g. "bleu"
        :param valid_ent_f1: average validation entity f1
        :param new_best: whether this is a new best model
        """
        current_lr = -1
        # ignores other param groups for now
        for param_group in self.optimizer.param_groups:
            current_lr = param_group['lr']

        if current_lr < self.learning_rate_min:
            self.stop = True

        with open(self.valid_report_file, 'a') as opened_file:
            opened_file.write(
                "Steps: {}\tLoss: {:.5f}\tPPL: {:.5f}\t{}: {:.5f}\t"
                "LR: {:.8f}\tmbtch: {}\teps_i: {:.5f}\tentF1: {:.5f}\tentMCC: {:.5f}\t {}\n"
                .format(self.steps, valid_loss, valid_ppl, eval_metric,
                        valid_score, current_lr, self.minibatch_count,
                        self.scheduled_sampling(self.minibatch_count),
                        valid_ent_f1, valid_ent_mcc, "*" if new_best else ""))

    def _log_decoder_timer_stats(self, task):
        """
        Write decoder timer stats to log, for the given task.
        """
        assert hasattr(
            self, "decoder_timer"
        ), f"to log decoder timer stats, make sure we have a timer"

        stats = self.decoder_timer.logAllParams()
        stats_lines = "".join([
            f"{activity}: n={n}, avg={avg}\n"
            for activity, (n, avg) in stats.items()
        ])

        self.logger.info("Decoder Timer Stats for task %s:", task)
        self.logger.info("%s", stats_lines)

    def _log_parameters_list(self) -> None:
        """
        Write all model parameters (name, shape) to the log.
        """
        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        n_params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info("Total params: %d", n_params)
        trainable_params = [
            n for (n, p) in self.model.named_parameters() if p.requires_grad
        ]
        self.logger.info("Trainable parameters: %s", sorted(trainable_params))
        assert trainable_params

    def _log_examples(self,
                      sources: List[str],
                      hypotheses: List[str],
                      references: List[str],
                      sources_raw: List[List[str]] = None,
                      hypotheses_raw: List[List[str]] = None,
                      references_raw: List[List[str]] = None) -> None:
        """
        Log a the first `self.log_valid_sents` sentences from given examples.

        :param sources: decoded sources (list of strings)
        :param hypotheses: decoded hypotheses (list of strings)
        :param references: decoded references (list of strings)
        :param sources_raw: raw sources (list of list of tokens)
        :param hypotheses_raw: raw hypotheses (list of list of tokens)
        :param references_raw: raw references (list of list of tokens)
        """
        for p in self.log_valid_sents:

            if p >= len(sources):
                continue

            try:
                self.logger.info("Example #%d", p)
            except Exception as e:
                self.logger.info(
                    "Example #%d",
                    "Encountered an Error while logging this example: " +
                    str(e)) + "; going to next example"
                print(
                    f"encountered an Error ({e})while logging this example:  {p}"
                )
                continue

            if sources_raw is not None:
                self.logger.debug("\tRaw source:     %s", sources_raw[p])
            if references_raw is not None:
                self.logger.debug("\tRaw reference:  %s", references_raw[p])
            if hypotheses_raw is not None:
                self.logger.debug("\tRaw hypothesis: %s", hypotheses_raw[p])

            self.logger.info("\tSource:     %s", sources[p])
            self.logger.info("\tReference:  %s", references[p])
            self.logger.info("\tHypothesis: %s", hypotheses[p])

    def _store_outputs(self, hypotheses: List[str]) -> None:
        """
        Write current validation outputs to file in `self.model_dir.`

        :param hypotheses: list of strings
        """
        current_valid_output_file = "{}/{}.hyps".format(
            self.model_dir, self.steps)
        with open(current_valid_output_file, 'w') as opened_file:
            for hyp in hypotheses:
                opened_file.write("{}\n".format(hyp))