Ejemplo n.º 1
0
  def _make_estimator(self):
    params = self._config["params"]
    train_config = self._config["train"]
    summary_steps = train_config["save_summary_steps"]

    run_config = tf.estimator.RunConfig(
        model_dir=self._config["model_dir"],
        tf_random_seed=self._seed,
        save_summary_steps=summary_steps,
        session_config=self._session_config,
        log_step_count_steps=params.get("gradients_accum", 1) * summary_steps)
    if "save_checkpoints_steps" in train_config or "save_checkpoints_secs" in train_config:
      run_config = run_config.replace(
          save_checkpoints_secs=train_config.get("save_checkpoints_secs"),
          save_checkpoints_steps=train_config.get("save_checkpoints_steps"))
    if "keep_checkpoint_max" in train_config:
      run_config = run_config.replace(
          keep_checkpoint_max=train_config["keep_checkpoint_max"])

    devices = get_devices(num_devices=self._num_devices, session_config=self._session_config)
    return tf.estimator.Estimator(
        self._model.model_fn(
            eval_prediction_hooks_fn=self._make_eval_prediction_hooks_fn(),
            devices=devices),
        config=run_config,
        params=params)
Ejemplo n.º 2
0
  def _make_estimator(self):
    params = self._config["params"]
    train_config = self._config["train"]
    summary_steps = train_config["save_summary_steps"]

    run_config = tf.estimator.RunConfig(
        model_dir=self._config["model_dir"],
        tf_random_seed=self._seed,
        save_summary_steps=summary_steps,
        session_config=self._session_config,
        log_step_count_steps=params.get("gradients_accum", 1) * summary_steps)
    if "save_checkpoints_steps" in train_config or "save_checkpoints_secs" in train_config:
      run_config = run_config.replace(
          save_checkpoints_secs=train_config.get("save_checkpoints_secs"),
          save_checkpoints_steps=train_config.get("save_checkpoints_steps"))
    if not self.is_chief():
      run_config = run_config.replace(
          save_checkpoints_secs=None,
          save_checkpoints_steps=None)
    if "keep_checkpoint_max" in train_config:
      run_config = run_config.replace(
          keep_checkpoint_max=train_config["keep_checkpoint_max"])

    params.setdefault("num_hypotheses", self._config["infer"].get("n_best", 1))

    devices = get_devices(num_devices=self._num_devices, session_config=self._session_config)
    return tf.estimator.Estimator(
        estimator_util.make_model_fn(
            self._model,
            eval_prediction_hooks_fn=self._make_eval_prediction_hooks_fn(),
            devices=devices,
            hvd=self._hvd),
        config=run_config,
        params=params)
Ejemplo n.º 3
0
  def __init__(self,
               model,
               config,
               seed=None,
               num_devices=1,
               gpu_allow_growth=False,
               session_config=None,
               auto_config=False):
    """Initializes the runner parameters.

    Args:
      model: A :class:`opennmt.models.model.Model` instance to run.
      config: The run configuration.
      seed: The random seed to set.
      num_devices: The number of devices (GPUs) to use for training.
      gpu_allow_growth: Allow GPU memory to grow dynamically.
      session_config: ``tf.ConfigProto`` overrides.
      auto_config: If ``True``, use automatic configuration values defined by
        :obj:`model`.

    Raises:
      NotImplementedError: If :obj:`auto_config` is ``True`` but :obj:`model`
        does not define any automatic configuration values.
    """
    self._model = model
    self._num_devices = num_devices

    # Configuration priority: user config > auto config > default config.
    self._config = copy.deepcopy(_CONFIG_FALLBACK)
    if auto_config:
      model_config = self._model.auto_config(num_devices=num_devices)
      if not model_config:
        raise NotImplementedError("This model does not define any automatic configuration values")
      misc.merge_dict(self._config, model_config)
    misc.merge_dict(self._config, config)
    tf.logging.info(
        "Using parameters: %s", json.dumps(self._config, indent=2, sort_keys=True))

    session_config_base = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        gpu_options=tf.GPUOptions(
            allow_growth=gpu_allow_growth))

    # Disable layout optimizer for better conv1d performance, see:
    # https://github.com/tensorflow/tensorflow/issues/20309
    # This field does not exist in TensorFlow 1.4, so guard against the
    # exception.
    try:
      rewrite_options = text_format.Parse("""
          graph_options {
            rewrite_options {
              layout_optimizer: OFF
            }
          }
          """, tf.ConfigProto())
      session_config_base.MergeFrom(rewrite_options)
    except text_format.ParseError:
      pass

    if session_config is not None:
      session_config_base.MergeFrom(session_config)
    session_config = session_config_base
    run_config = tf.estimator.RunConfig(
        model_dir=self._config["model_dir"],
        session_config=session_config,
        tf_random_seed=seed)

    np.random.seed(seed)
    random.seed(seed)

    if "train" in self._config:
      if "save_summary_steps" in self._config["train"]:
        accum = self._config["params"].get("gradients_accum", 1)
        summary_steps = self._config["train"]["save_summary_steps"]
        run_config = run_config.replace(
            save_summary_steps=summary_steps,
            log_step_count_steps=accum * summary_steps)
      if "save_checkpoints_steps" in self._config["train"]:
        run_config = run_config.replace(
            save_checkpoints_secs=None,
            save_checkpoints_steps=self._config["train"]["save_checkpoints_steps"])
      if "keep_checkpoint_max" in self._config["train"]:
        run_config = run_config.replace(
            keep_checkpoint_max=self._config["train"]["keep_checkpoint_max"])

    devices = get_devices(num_devices=num_devices, session_config=session_config)
    self._estimator = tf.estimator.Estimator(
        self._model.model_fn(
            eval_prediction_hooks_fn=self._make_eval_prediction_hooks_fn(),
            devices=devices),
        config=run_config,
        params=self._config["params"])