コード例 #1
0
  def _build_train_spec(self, checkpoint_path):
    train_hooks = [
        hooks.LogParametersCountHook()]

    if checkpoint_path is not None:
      train_hooks.append(hooks.LoadWeightsFromCheckpointHook(checkpoint_path))
    if self._hvd is not None:
      train_hooks.append(self._hvd.BroadcastGlobalVariablesHook(0))

    train_steps = self._config["train"].get("train_steps")
    if train_steps is not None and self._hvd is not None:
      train_steps //= self._hvd.size()
    train_spec = tf.estimator.TrainSpec(
        input_fn=estimator_util.make_input_fn(
            self._model,
            tf.estimator.ModeKeys.TRAIN,
            self._config["train"]["batch_size"],
            features_file=self._config["data"]["train_features_file"],
            labels_file=self._config["data"].get("train_labels_file"),
            batch_type=self._config["train"]["batch_type"],
            batch_multiplier=self._num_devices,
            bucket_width=self._config["train"]["bucket_width"],
            maximum_features_length=self._config["train"].get("maximum_features_length"),
            maximum_labels_length=self._config["train"].get("maximum_labels_length"),
            shuffle_buffer_size=self._config["train"]["sample_buffer_size"],
            single_pass=self._config["train"].get("single_pass", False),
            num_shards=self._hvd.size() if self._hvd is not None else 1,
            shard_index=self._hvd.rank() if self._hvd is not None else 0,
            num_threads=self._config["train"].get("num_threads"),
            prefetch_buffer_size=self._config["train"].get("prefetch_buffer_size"),
            return_dataset=False),
        max_steps=train_steps,
        hooks=train_hooks)
    return train_spec
コード例 #2
0
ファイル: runner.py プロジェクト: mfomicheva/OpenNMT-tf
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        if checkpoint_path is not None and tf.gfile.IsDirectory(
                checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        input_fn = estimator_util.make_input_fn(
            self._model,
            tf.estimator.ModeKeys.PREDICT,
            self._config["infer"]["batch_size"],
            features_file=features_file,
            bucket_width=self._config["infer"]["bucket_width"],
            num_threads=self._config["infer"].get("num_threads"),
            prefetch_buffer_size=self._config["infer"].get(
                "prefetch_buffer_size"),
            return_dataset=False)

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_hooks = []
        if log_time:
            infer_hooks.append(hooks.LogPredictionTimeHook())

        ordered_writer = None
        write_fn = lambda prediction: (self._model.print_prediction(
            prediction, params=self._config["infer"], stream=stream))

        estimator = self._make_estimator()
        for prediction in estimator.predict(input_fn=input_fn,
                                            checkpoint_path=checkpoint_path,
                                            hooks=infer_hooks):
            # If the index is part of the prediction, they may be out of order.
            if "index" in prediction:
                if ordered_writer is None:
                    ordered_writer = OrderRestorer(
                        index_fn=lambda prediction: prediction["index"],
                        callback_fn=write_fn)
                ordered_writer.push(prediction)
            else:
                write_fn(prediction)

        if predictions_file:
            stream.close()
コード例 #3
0
ファイル: runner.py プロジェクト: Desperado-Jia/OpenNMT-tf
 def _build_eval_spec(self):
   eval_spec = tf.estimator.EvalSpec(
       input_fn=estimator_util.make_input_fn(
           self._model,
           tf.estimator.ModeKeys.EVAL,
           self._config["eval"]["batch_size"],
           features_file=self._config["data"]["eval_features_file"],
           labels_file=self._config["data"].get("eval_labels_file"),
           num_threads=self._config["eval"].get("num_threads"),
           prefetch_buffer_size=self._config["eval"].get("prefetch_buffer_size")),
       steps=None,
       exporters=_make_exporters(
           self._config["eval"]["exporters"],
           estimator_util.make_serving_input_fn(self._model),
           assets_extra=self._get_model_assets()),
       throttle_secs=self._config["eval"]["eval_delay"])
   return eval_spec
コード例 #4
0
  def input_fn(self,
               mode,
               batch_size,
               metadata=None,
               features_file=None,
               labels_file=None,
               batch_type="examples",
               batch_multiplier=1,
               bucket_width=None,
               single_pass=False,
               num_threads=None,
               sample_buffer_size=None,
               prefetch_buffer_size=None,
               maximum_features_length=None,
               maximum_labels_length=None,
               num_shards=1,
               shard_index=0):
    """Returns an input function.

    Args:
      mode: A ``tf.estimator.ModeKeys`` mode.
      batch_size: The batch size to use.
      metadata: A dictionary containing additional metadata set
        by the user. Required if ``Model.initialize()`` has not been called.
      features_file: The file containing input features.
      labels_file: The file containing output labels.
      batch_type: The training batching stragety to use: can be "examples" or
        "tokens".
      batch_multiplier: The batch size multiplier to prepare splitting accross
         replicated graph parts.
      bucket_width: The width of the length buckets to select batch candidates
        from. ``None`` to not constrain batch formation.
      single_pass: If ``True``, makes a single pass over the training data.
      num_threads: The number of elements processed in parallel.
      sample_buffer_size: The number of elements from which to sample.
      prefetch_buffer_size: The number of batches to prefetch asynchronously. If
        ``None``, use an automatically tuned value on TensorFlow 1.8+ and 1 on
        older versions.
      maximum_features_length: The maximum length or list of maximum lengths of
        the features sequence(s). ``None`` to not constrain the length.
      maximum_labels_length: The maximum length of the labels sequence.
        ``None`` to not constrain the length.
      num_shards: The number of data shards (usually the number of workers in a
        distributed setting).
      shard_index: The shard index this input pipeline should read from.

    Returns:
      A callable that returns the next element.

    See Also:
      ``tf.estimator.Estimator``.
    """
    if metadata is not None:
      self.initialize(metadata)
    return estimator.make_input_fn(
        self,
        mode,
        batch_size,
        features_file,
        labels_file=labels_file,
        batch_type=batch_type,
        batch_multiplier=batch_multiplier,
        bucket_width=bucket_width,
        maximum_features_length=maximum_features_length,
        maximum_labels_length=maximum_labels_length,
        shuffle_buffer_size=sample_buffer_size,
        single_pass=single_pass,
        num_shards=num_shards,
        shard_index=shard_index,
        num_threads=num_threads,
        prefetch_buffer_size=prefetch_buffer_size,
        return_dataset=False)