Beispiel #1
0
    def __init__(
        self,
        config: Config,
        model: SILTransformer,
        src_spp: Optional[sp.SentencePieceProcessor],
        trg_spp: Optional[sp.SentencePieceProcessor],
        step: int,
        checkpoint_path: Path,
        type: str = None,
    ):
        self.types: List[str] = []
        if type is not None:
            self.types.append(type)
        self.step = step
        # Configuration priority: user config > auto config > default config.
        new_config = copy.deepcopy(_CONFIG_FALLBACK)
        merge_dict(new_config, model.auto_config())
        merge_dict(new_config, config.root)
        new_config["params"].setdefault("num_hypotheses",
                                        new_config["infer"].get("n_best", 1))
        new_config["params"].setdefault(
            "average_loss_in_time",
            new_config["train"]["batch_type"] == "tokens")
        new_config["infer"]["n_best"] = 1
        self.config = new_config

        self.src_spp = src_spp
        self.trg_spp = trg_spp
        self.model: SILTransformer = clone_layer(model)
        self.model.initialize(self.config["data"],
                              params=self.config["params"])
        self._analyze_fn: Optional[Function] = None

        self.checkpoint_path = checkpoint_path
        self.checkpoint: Checkpoint = None
Beispiel #2
0
 def _init_model(self, config):
     model = misc.clone_layer(self._model)
     model.initialize(config["data"], params=config["params"])
     if "optimizer" in config["params"]:
         optimizer = model.get_optimizer()
     else:
         optimizer = None
     checkpoint = checkpoint_util.Checkpoint(
         model,
         optimizer=optimizer,
         model_dir=config.get("model_dir"),
         keep_checkpoint_max=config["train"].get("keep_checkpoint_max", 8))
     return checkpoint
Beispiel #3
0
    def restore(self, checkpoint_path=None, weights_only=False):
        """Restores a checkpoint.

    Args:
      checkpoint_path: Path a checkpoint to restore. If not set, the latest
        checkpoint from :obj:`model_dir` will be restored.
      weights_only: Only restore model weights.

    Returns:
      Path to the restored checkpoint.
    """
        if weights_only:
            checkpoint = tf.train.Checkpoint(model=self._model)
        else:
            checkpoint = self._checkpoint
        if checkpoint_path is not None:
            if tf.io.gfile.isdir(checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        elif self._checkpoint_manager.latest_checkpoint is not None:
            checkpoint_path = self._checkpoint_manager.latest_checkpoint
        if checkpoint_path is None:
            tf.get_logger().warning("No checkpoint to restore in %s",
                                    self._model_dir)
            return None
        if is_v1_checkpoint(checkpoint_path):
            tf.get_logger().info("Upgrading V1 checkpoint...")
            # Work with copies of model and optimizer as the downstream task might
            # need to create the variable differently (e.g. under a distribution
            # strategy scope).
            tmp_model = misc.clone_layer(self._model)
            tmp_optimizer = copy.deepcopy(
                self._optimizer) if self._optimizer is not None else None
            tmp_model.create_variables(optimizer=tmp_optimizer)
            step = _restore_v1_checkpoint(checkpoint_path,
                                          tmp_model,
                                          optimizer=tmp_optimizer)
            # Save an updated checkpoint in the model directory and restore this one instead.
            tmp_checkpoint = Checkpoint(tmp_model,
                                        optimizer=tmp_optimizer,
                                        model_dir=self._model_dir)
            checkpoint_path = tmp_checkpoint.save(step)
            return self.restore(checkpoint_path=checkpoint_path,
                                weights_only=weights_only)
        load_status = checkpoint.restore(checkpoint_path)
        if weights_only:
            load_status.expect_partial()
        tf.get_logger().info("Restored checkpoint %s", checkpoint_path)
        return checkpoint_path
Beispiel #4
0
    def __init__(self,
                 model,
                 config,
                 auto_config=False,
                 mixed_precision=False,
                 seed=None):
        """Initializes the runner parameters.

        Args:
          model: A :class:`opennmt.models.Model` instance to run or a callable that
            returns such instance.
          config: The run configuration.
          auto_config: If ``True``, use automatic configuration values defined by
            :obj:`model`.
          mixed_precision: Enable mixed precision.
          seed: The random seed to set.

        Raises:
          TypeError: if :obj:`model` is not a :class:`opennmt.models.Model` instance
            or a callable.
        """
        if isinstance(model, models.Model):
            self._model = model
            self._model_fn = lambda: misc.clone_layer(model)
        elif callable(model):
            self._model = model()
            self._model_fn = model
        else:
            raise TypeError(
                "model should be a opennmt.models.Model instance or a callable"
            )
        tf.get_logger().info("Using OpenNMT-tf version %s", __version__)
        tf.get_logger().info("Using model:\n%s", self._model)
        self._optimizer = None
        self._config = copy.deepcopy(config)
        self._auto_config = auto_config
        self._mixed_precision = mixed_precision
        if mixed_precision:
            tf.config.optimizer.set_experimental_options(
                {"auto_mixed_precision": True})
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)
            tf.random.set_seed(seed)
Beispiel #5
0
 def _init_model(self, config):
     model = misc.clone_layer(self._model)
     model.initialize(config["data"], params=config["params"])
     return model