Exemple #1
0
  def _make_estimator(self):
    params = self._config["params"]
    train_config = self._config["train"]
    summary_steps = train_config["save_summary_steps"]

    run_config = tf.estimator.RunConfig(
        model_dir=self._config["model_dir"],
        tf_random_seed=self._seed,
        save_summary_steps=summary_steps,
        session_config=self._session_config,
        log_step_count_steps=params.get("gradients_accum", 1) * summary_steps)
    if "save_checkpoints_steps" in train_config or "save_checkpoints_secs" in train_config:
      run_config = run_config.replace(
          save_checkpoints_secs=train_config.get("save_checkpoints_secs"),
          save_checkpoints_steps=train_config.get("save_checkpoints_steps"))
    if not self.is_chief():
      run_config = run_config.replace(
          save_checkpoints_secs=None,
          save_checkpoints_steps=None)
    if "keep_checkpoint_max" in train_config:
      run_config = run_config.replace(
          keep_checkpoint_max=train_config["keep_checkpoint_max"])

    params.setdefault("num_hypotheses", self._config["infer"].get("n_best", 1))

    devices = get_devices(num_devices=self._num_devices, session_config=self._session_config)
    return tf.estimator.Estimator(
        estimator_util.make_model_fn(
            self._model,
            eval_prediction_hooks_fn=self._make_eval_prediction_hooks_fn(),
            devices=devices,
            hvd=self._hvd),
        config=run_config,
        params=params)
Exemple #2
0
  def model_fn(self, num_devices=1, eval_prediction_hooks_fn=None, devices=None, hvd=None):
    """Returns the model function.

    Args:
      num_devices: The number of devices used for training.
      eval_prediction_hooks_fn: A callable that takes the model predictions
        during evaluation and return an iterable of evaluation hooks (e.g. for
        saving predictions on disk, running external evaluators, etc.).
      devices: The list of devices used for training, if known.
      hvd: Optional Horovod object.

    See Also:
      ``tf.estimator.Estimator`` 's ``model_fn`` argument for more details about
      arguments and the returned value.
    """
    return estimator.make_model_fn(
        self,
        eval_prediction_hooks_fn=eval_prediction_hooks_fn,
        num_devices=num_devices,
        devices=devices,
        hvd=hvd)