Ejemplo n.º 1
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        if "infer" not in self._config:
            self._config["infer"] = {}
        if checkpoint_path is not None and tf.gfile.IsDirectory(
                checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        input_fn = self._model.input_fn(
            tf.estimator.ModeKeys.PREDICT,
            self._config["infer"]["batch_size"],
            self._config["data"],
            features_file,
            bucket_width=self._config["infer"]["bucket_width"],
            num_threads=self._config["infer"].get("num_threads"),
            prefetch_buffer_size=self._config["infer"].get(
                "prefetch_buffer_size"))

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_hooks = []
        if log_time:
            infer_hooks.append(hooks.LogPredictionTimeHook())

        ordered_writer = None
        write_fn = lambda prediction: (self._model.print_prediction(
            prediction, params=self._config["infer"], stream=stream))

        for prediction in self._estimator.predict(
                input_fn=input_fn,
                checkpoint_path=checkpoint_path,
                hooks=infer_hooks):
            # If the index is part of the prediction, they may be out of order.
            if "index" in prediction:
                if ordered_writer is None:
                    ordered_writer = OrderRestorer(
                        index_fn=lambda prediction: prediction["index"],
                        callback_fn=write_fn)
                ordered_writer.push(prediction)
            else:
                write_fn(prediction)

        if predictions_file:
            stream.close()
Ejemplo n.º 2
0
    def infer_multiple(
        self,
        features_paths: List[Union[str, List[str]]],
        predictions_paths: List[str],
        checkpoint_path: Optional[str] = None,
    ) -> None:
        config = self._finalize_config()
        model: Model = self._init_model(config)
        checkpoint = Checkpoint.from_config(config, model)
        checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True)
        infer_config = config["infer"]
        for features_path, predictions_path in zip(features_paths,
                                                   predictions_paths):
            dataset = model.examples_inputter.make_inference_dataset(
                features_path,
                infer_config["batch_size"],
                length_bucket_width=infer_config["length_bucket_width"],
                prefetch_buffer_size=infer_config.get("prefetch_buffer_size"),
            )

            with open(predictions_path, encoding="utf-8", mode="w") as stream:
                infer_fn = tf.function(
                    model.infer, input_signature=(dataset.element_spec, ))
                if not tf.config.functions_run_eagerly():
                    tf.get_logger().info(
                        "Tracing and optimizing the inference graph...")
                    infer_fn.get_concrete_function()  # Trace the function now.

                # Inference might return out-of-order predictions. The OrderRestorer utility is
                # used to write predictions in their original order.
                ordered_writer = OrderRestorer(
                    lambda pred: pred.get("index"),
                    lambda pred: (model.print_prediction(
                        pred, params=infer_config, stream=stream)),
                )

                for source in dataset:
                    predictions = infer_fn(source)
                    predictions = tf.nest.map_structure(
                        lambda t: t.numpy(), predictions)
                    for prediction in extract_batches(predictions):
                        ordered_writer.push(prediction)
Ejemplo n.º 3
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        need_ae = False
        if self._ae_model != None:
            need_ae = True
        if checkpoint_path is not None and tf.gfile.IsDirectory(
                checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)

        input_fn = estimator_util.make_input_fn(
            self._model,
            tf.estimator.ModeKeys.PREDICT,
            self._config["infer"]["batch_size"],
            features_file=features_file,
            bucket_width=self._config["infer"]["bucket_width"],
            num_threads=self._config["infer"].get("num_threads"),
            prefetch_buffer_size=self._config["infer"].get(
                "prefetch_buffer_size"),
            return_dataset=False,
            need_ae=need_ae)

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_hooks = []
        if log_time:
            infer_hooks.append(hooks.LogPredictionTimeHook())

        ordered_writer = None
        write_fn = lambda prediction: (self._model.print_prediction(
            prediction, params=self._config["infer"], stream=stream))

        estimator = self._make_estimator()
        sample_count = 0
        for prediction in estimator.predict(input_fn=input_fn,
                                            checkpoint_path=checkpoint_path,
                                            hooks=infer_hooks):
            if sample_count % 1000 == 0:
                now_time = time.strftime('%Y-%m-%d %H:%M:%S',
                                         time.localtime(time.time()))
                tf.logging.info("{}:{}".format(now_time, sample_count))
            sample_count += 1
            # If the index is part of the prediction, they may be out of order.
            if "index" in prediction:
                if ordered_writer is None:
                    ordered_writer = OrderRestorer(
                        index_fn=lambda prediction: prediction["index"],
                        callback_fn=write_fn)
                ordered_writer.push(prediction)
            else:
                write_fn(prediction)

        if predictions_file:
            stream.close()