Пример #1
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        if "infer" not in self._config:
            self._config["infer"] = {}
        if checkpoint_path is not None and tf.gfile.IsDirectory(
                checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        input_fn = self._model.input_fn(
            tf.estimator.ModeKeys.PREDICT,
            self._config["infer"]["batch_size"],
            self._config["data"],
            features_file,
            bucket_width=self._config["infer"]["bucket_width"],
            num_threads=self._config["infer"].get("num_threads"),
            prefetch_buffer_size=self._config["infer"].get(
                "prefetch_buffer_size"))

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_hooks = []
        if log_time:
            infer_hooks.append(hooks.LogPredictionTimeHook())

        ordered_writer = None
        write_fn = lambda prediction: (self._model.print_prediction(
            prediction, params=self._config["infer"], stream=stream))

        for prediction in self._estimator.predict(
                input_fn=input_fn,
                checkpoint_path=checkpoint_path,
                hooks=infer_hooks):
            # If the index is part of the prediction, they may be out of order.
            if "index" in prediction:
                if ordered_writer is None:
                    ordered_writer = OrderRestorer(
                        index_fn=lambda prediction: prediction["index"],
                        callback_fn=write_fn)
                ordered_writer.push(prediction)
            else:
                write_fn(prediction)

        if predictions_file:
            stream.close()
Пример #2
0
  def get_alignment(self,
            features_file,
            checkpoint_path=None,
            log_time=False):
    """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
    if "infer" not in self._config:
      self._config["infer"] = {}
    if checkpoint_path is not None and os.path.isdir(checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

    batch_size = self._config["infer"].get("batch_size", 1)
    input_fn = self._model.input_fn(
        tf.estimator.ModeKeys.PREDICT,
        batch_size,
        self._config["data"],
        features_file,
        num_threads=self._config["infer"].get("num_threads"),
        prefetch_buffer_size=self._config["infer"].get("prefetch_buffer_size"))

    infer_hooks = []
    if log_time:
      infer_hooks.append(hooks.LogPredictionTimeHook())

    def _parse_prediction(prediction, params=None):
      n_best = params and params.get("n_best")
      n_best = n_best or 1

      if n_best > len(prediction["tokens"]):
        raise ValueError("n_best cannot be greater than beam_width")

      n_best_sentences = []
      n_alignments = []
      for i in range(n_best):
        tokens = prediction["tokens"][i][:prediction["length"][i] - 1] # Ignore </s>.
        sentence = self._model.target_inputter.tokenizer.detokenize(tokens)
        n_best_sentences.append(sentence)
        n_alignments.append(prediction["alignment"])
      return n_best_sentences, n_alignments

    sentences = []
    alignments = []
    for prediction in self._estimator.predict(
        input_fn=input_fn,
        checkpoint_path=checkpoint_path,
        hooks=infer_hooks):
      n_best_sentences, n_alignments = _parse_prediction(prediction, params=self._config["infer"])

      sentences.extend(n_best_sentences)
      alignments.extend(n_alignments)
    
    return sentences, alignments
Пример #3
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        if "infer" not in self._config:
            self._config["infer"] = {}
        if checkpoint_path is not None and os.path.isdir(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        batch_size = self._config["infer"].get("batch_size", 1)
        input_fn = self._model.input_fn(
            tf.estimator.ModeKeys.PREDICT,
            batch_size,
            self._config["data"],
            features_file,
            num_threads=self._config["infer"].get("num_threads"),
            prefetch_buffer_size=self._config["infer"].get(
                "prefetch_buffer_size"))

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_hooks = []
        if log_time:
            infer_hooks.append(hooks.LogPredictionTimeHook())

        for prediction in self._estimator.predict(
                input_fn=input_fn,
                checkpoint_path=checkpoint_path,
                hooks=infer_hooks):
            self._model.print_prediction(prediction,
                                         params=self._config["infer"],
                                         stream=stream)

        if predictions_file:
            stream.close()
Пример #4
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        need_ae = False
        if self._ae_model != None:
            need_ae = True
        if checkpoint_path is not None and tf.gfile.IsDirectory(
                checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        input_fn = estimator_util.make_input_fn(
            self._model,
            tf.estimator.ModeKeys.PREDICT,
            self._config["infer"]["batch_size"],
            features_file=features_file,
            bucket_width=self._config["infer"]["bucket_width"],
            num_threads=self._config["infer"].get("num_threads"),
            prefetch_buffer_size=self._config["infer"].get(
                "prefetch_buffer_size"),
            return_dataset=False,
            need_ae=need_ae)

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_hooks = []
        if log_time:
            infer_hooks.append(hooks.LogPredictionTimeHook())

        ordered_writer = None
        write_fn = lambda prediction: (self._model.print_prediction(
            prediction, params=self._config["infer"], stream=stream))

        estimator = self._make_estimator()
        sample_count = 0
        for prediction in estimator.predict(input_fn=input_fn,
                                            checkpoint_path=checkpoint_path,
                                            hooks=infer_hooks):
            if sample_count % 1000 == 0:
                now_time = time.strftime('%Y-%m-%d %H:%M:%S',
                                         time.localtime(time.time()))
                tf.logging.info("{}:{}".format(now_time, sample_count))
            sample_count += 1
            # If the index is part of the prediction, they may be out of order.
            if "index" in prediction:
                if ordered_writer is None:
                    ordered_writer = OrderRestorer(
                        index_fn=lambda prediction: prediction["index"],
                        callback_fn=write_fn)
                ordered_writer.push(prediction)
            else:
                write_fn(prediction)

        if predictions_file:
            stream.close()