def main():
    """
    Creates a temporary file for the given input which is
    used to create a dataset, that is then evaluated on the given model.
    The generated summary is printed to standard out.
    """
    args, unknown_args = prepare_arg_parser().parse_known_args()
    model_file = args.model_file

    with suppress_stdout_stderr():
        model, _optimizer, vocab, _stats, cfg = train.load_model(
            model_file, unknown_args
        )

    _, filename = tempfile.mkstemp()
    try:
        with open(filename, "a") as f:
            input_ = sys.stdin.read()
            article = preprocess.parse(input_)
            print(f"{article}\tSUMMARY_STUB", file=f)

        with suppress_stdout_stderr():
            dataset = Dataset(filename, vocab, cfg)

        batch = next(dataset.generator(1, cfg.pointer))

        # don't enforce any min lengths (useful for short cmdline summaries")
        setattr(cfg, "min_summary_length", 1)
        bs = BeamSearch(model, cfg=cfg)
        summary = evaluate.batch_to_text(bs, batch)[0]
        print(f"SUMMARY:\n{summary}")
    finally:
        os.remove(filename)
class Trainer:
    """Trainer class to simplify training and saving a model"""
    def __init__(self, model, optimizer, vocab, cfg, stats=defaultdict(int)):
        """
        Create a trainer instance for the given model.
        This includes creating datasets based on the `input_file` and
        `valid_file` in the given `cfg`.

        :param model: The Seq2Seq model to train
        :param optimizer: The optimizer for the `model`
        :param vocab: A `Vocabulary` instance to be used
        :param cfg: The current Config from which we get epochs/batch_size/files etc.
        :param stats:
            A dict with values such as epoch/running_avg_loss etc. when resuming training
        """

        self.cfg = cfg
        self.dataset = Dataset(cfg.train_file, vocab, cfg)
        self.validation_dataset = Dataset(cfg.valid_file,
                                          vocab,
                                          cfg,
                                          evaluation=True)

        self.model = model
        self.optimizer = optimizer
        self.validator = cfg.validator.lower()
        # if cfg.validate_every == 0, we validate once every epoch
        self.validate_every = (
            math.ceil(len(self.dataset) / cfg.batch_size)  # batches/epoch
            if cfg.validate_every == 0 else cfg.validate_every)
        self.rouge_valid = self.validator == "rouge"
        self.loss_valid = self.validator == "loss"
        self.scheduler = (
            None if cfg.learning_rate_decay >= 1.0 else ReduceLROnPlateau(
                optimizer,
                mode="min" if self.validator == "loss" else "max",
                factor=cfg.learning_rate_decay,
                patience=cfg.learning_rate_patience,
                verbose=True,
            ))
        self.coverage_loss_weight_decay = cfg.coverage_loss_weight_decay
        assert self.validator in ["loss", "rouge"]
        self.early_stopping = cfg.early_stopping
        self.patience = cfg.patience

        self.epoch = stats["epoch"]
        self.iteration = stats["iteration"]
        self.running_avg_loss = stats["running_avg_loss"]
        self.running_avg_cov_loss = stats["running_avg_cov_loss"]
        self.best_validation_score = stats["best_validation_score"]
        self.current_validation_score = stats.get("current_validation_score",
                                                  0)
        self.current_patience = stats.get("current_patience", 0)
        self.model_identifier = stats["model_identifier"]
        self.time_training = stats["time_training"]

        # Updated and managed from train function and context
        self.training_start_time = None
        self.writer = None
        self.pbar = None

    def __enter__(self):
        # Create summary writer for tensorboard
        log_dir = os.path.splitext(self.cfg.output_file)[0] + "_log"
        print(f"Tensorboard logging directory: {log_dir}")
        self.writer = SummaryWriter(log_dir)

        print(f"Training on {DEVICE.type.upper()}")

        # Create pbar instance for clean progress tracking
        # The total will be whatever delimits first, epochs or iterations
        epoch_total = (math.ceil(len(self.dataset) / self.cfg.batch_size) *
                       self.cfg.epochs)
        total = min(epoch_total, self.cfg.iterations)
        self.pbar = tqdm(
            total=total,
            initial=self.iteration,
            desc="Training",
            postfix={
                "loss": self.running_avg_loss,
                "cov": self.running_avg_cov_loss
            },
            bar_format=
            "{desc}: {n_fmt}/{total_fmt}{postfix} [{elapsed},{rate_fmt}]",
            leave=True,
        )
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._update_time_training()
        self.save_model()
        self.writer.close()
        self.pbar.close()
        if DEVICE.type.upper() == "CUDA":
            print("Emptying cuda cache...")
            torch.cuda.empty_cache()

    def _update_progress(self, loss, cov_loss):
        """Updates progress, both in tensorboardX, and in the progress bar"""
        self.writer.add_scalar("train/loss", loss, self.iteration)
        self.writer.add_scalar("train/cov_loss", cov_loss, self.iteration)
        self.writer.add_scalar("train/running_avg_loss", self.running_avg_loss,
                               self.iteration)
        self.writer.add_scalar("train/running_avg_cov_loss",
                               self.running_avg_cov_loss, self.iteration)
        self.pbar.update()
        postfix = {
            "loss": round(self.running_avg_loss, 2),
            "cov": round(self.running_avg_cov_loss, 2),
        }
        self.pbar.set_postfix(postfix)

    def _update_running_avg_loss(self, loss, cov_loss=0, decay=0.99):
        """
        Updates the running avg losses
        :param loss: The new loss to update the running loss based on
        :param cov_loss: The new optional cov_loss to update the running loss based on
        :param decay: Optional decay value for calculating the running avg (default 0.99)
        """
        self.running_avg_loss = (loss if self.running_avg_loss == 0 else
                                 self.running_avg_loss * decay +
                                 (1 - decay) * loss)
        self.running_avg_cov_loss = (cov_loss if self.running_avg_cov_loss == 0
                                     else self.running_avg_cov_loss * decay +
                                     (1 - decay) * cov_loss)

    def _update_time_training(self):
        """
        Update `self.time_training` based on `self.training_start_time`
        and on the current time. Resets `self.training_start_time` to the
        current time, to avoid including time intervals more than once in
        the total training time.
        """
        if self.training_start_time is not None:
            elapsed = time.time() - self.training_start_time
            self.time_training += elapsed
            # ensure we reset training_start_time since
            # this session has been added to time_training
            self.training_start_time = time.time()

    def _validation_improved(self, new, ref):
        """
        Check whether a new validation score was better than a reference,
        according to the currently configured validator

        :param new: The score we check whether or not it was an improvement
        :param ref: The reference score we compare against
        :returns: True if the `new` score was an improvement over `ref`
        """
        return new > ref if self.rouge_valid else new < ref or ref == 0

    def _validate(self):
        """
        Run validation of the model using the method defined in `self.validator`.
        If the model evaluates to a better score than `self.best_validation_score`,
        it is saved to `self.cfg.output_file` with a '_best' suffix.

        :returns: `True` if we should early stop, otherwise `False`.
        """

        self.model.eval()
        new = (valid.get_validation_score(self.model, self.validation_dataset,
                                          self.cfg)
               if self.validator == "rouge" else valid.get_validation_loss(
                   self.model, self.validation_dataset, self.cfg))
        self.model.train()

        self.writer.add_scalar("validation/score", new, self.iteration)
        old = self.current_validation_score
        best = self.best_validation_score
        self.current_validation_score = new
        if self._validation_improved(new, best):
            self.current_patience = 0
            self.pbar.write(
                f"Validation improved: {best:5.2f} -> {new:5.2f} (new best)")
            self.best_validation_score = new
            self.save_model(suffix="_best")
        else:
            self.current_patience += 1
            s = "improved" if self._validation_improved(new,
                                                        old) else "declined"
            self.pbar.write(
                "Validation {}: {:5.2f} -> {:5.2f} (P: {}/{})".format(
                    s, old, new, self.current_patience, self.patience))

            # decay coverage loss weight
            if self.coverage_loss_weight_decay < 1:
                old_w = self.model.coverage_loss_weight
                new_w = old_w * self.coverage_loss_weight_decay
                self.pbar.write(
                    f"Decaying coverage loss weight: {old_w:.2f} -> {new_w:.2f}"
                )
                self.model.coverage_loss_weight = new_w

        if self.scheduler is not None:
            self.pbar.clear()
            self.scheduler.step(
                new)  # inform scheduler of new validation score

    def save_model(self, suffix=""):
        """
        Saves current model to `self.cfg.output_file`.

        :param suffix:
            Optional suffix to append to default save location.
            Used to save checkpoints for best model so far (suffix="best").
        """
        self._update_time_training()
        destination = suffix.join(os.path.splitext(self.cfg.output_file))

        model_state_dict = (self.model.module.state_dict() if isinstance(
            self.model, nn.DataParallel) else self.model.state_dict())

        torch.save(
            {
                "model_state_dict": model_state_dict,
                "optimizer_state_dict": self.optimizer.state_dict(),
                "vocab": self.dataset.vocab,
                "config": self.cfg,
                "stats": {
                    "epoch": self.epoch,
                    "iteration": self.iteration,
                    "running_avg_loss": self.running_avg_loss,
                    "running_avg_cov_loss": self.running_avg_cov_loss,
                    "best_validation_score": self.best_validation_score,
                    "current_validation_score": self.current_validation_score,
                    "current_patience": self.current_patience,
                    "model_identifier": self.model_identifier,
                    "time_training": self.time_training,
                },
            },
            destination,
        )

    def train(self):
        """Start training"""
        self.training_start_time = time.time()
        for _epoch in range(self.epoch, self.epoch + self.cfg.epochs):
            generator = self.dataset.generator(self.cfg.batch_size,
                                               self.cfg.pointer)
            for _batch_idx, batch in enumerate(generator):
                self.iteration += 1
                loss, cov_loss = self.train_batch(batch)
                self._update_running_avg_loss(loss, cov_loss)
                self._update_progress(loss, cov_loss)

                if self.iteration % self.validate_every == 0:
                    self._validate()
                    if self.early_stopping and self.current_patience > self.patience:
                        self.pbar.write("Early stopping...")
                        return

                if self.iteration >= self.cfg.iterations:
                    return

                if self.iteration % self.cfg.save_every == 0:
                    self.save_model()

            self.epoch += 1

        self._update_time_training()

    def train_batch(self, batch):
        """
        Run a single training iteration.
        :param batch: the current `data.Batch` instance to process

        :returns: A tuple of the loss and part thereof that is coverage loss
        """
        self.optimizer.zero_grad()
        loss, cov_loss, _output = self.model(batch)
        loss.mean().backward()
        clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
        self.optimizer.step()

        return loss.mean().item(), cov_loss.mean().item()