Ejemplo n.º 1
0
def score_dataset(model, dataset, print_params=None, output_file=None):
    """Outputs the model scores for the dataset.

    Args:
      model: A :class:`opennmt.models.Model` instance.
      dataset: A ``tf.data.Dataset`` instance outputting parallel features and
        labels.
      print_params: A dictionary of parameters passed to
        :meth:`opennmt.models.Model.print_score`.
      output_file: If set, outputs are saved in this file, otherwise they are
        printed on the standard output.
    """
    if output_file:
        stream = open(output_file, encoding="utf-8", mode="w")
    else:
        stream = sys.stdout

    write_fn = lambda batch: (model.print_score(
        batch, params=print_params, stream=stream))
    index_fn = lambda batch: batch.get("index")
    ordered_writer = misc.OrderRestorer(index_fn, write_fn)

    score_fn = tf.function(model.score, input_signature=dataset.element_spec)
    for features, labels in dataset:
        results = score_fn(features, labels)
        results = tf.nest.map_structure(lambda t: t.numpy(), results)
        for batch in misc.extract_batches(results):
            ordered_writer.push(batch)

    if output_file:
        stream.close()
Ejemplo n.º 2
0
 def testEventOrderRestorer(self):
     events = []
     restorer = misc.OrderRestorer(index_fn=lambda x: x[0],
                                   callback_fn=lambda x: events.append(x))
     self.assertFalse(restorer.push((2, "toto")))
     self.assertFalse(restorer.push((1, "tata")))
     self.assertFalse(restorer.push((3, "foo")))
     self.assertTrue(restorer.push((0, "bar")))
     self.assertTrue(restorer.push((4, "titi")))
     with self.assertRaises(ValueError):
         restorer.push((2, "invalid"))
     self.assertEqual(len(events), 5)
     self.assertTupleEqual(events[0], (0, "bar"))
     self.assertTupleEqual(events[1], (1, "tata"))
     self.assertTupleEqual(events[2], (2, "toto"))
     self.assertTupleEqual(events[3], (3, "foo"))
     self.assertTupleEqual(events[4], (4, "titi"))
Ejemplo n.º 3
0
    def __call__(self, step):
        """Runs the evaluator.

        Args:
          step: The current training step.

        Returns:
          A dictionary of evaluation metrics.
        """
        tf.get_logger().info("Running evaluation for step %d", step)
        output_file = None
        output_path = None
        if self._save_predictions:
            output_path = os.path.join(self._eval_dir,
                                       "predictions.txt.%d" % step)
            output_file = tf.io.gfile.GFile(output_path, "w")
            params = {"n_best": 1}
            write_fn = lambda prediction: (self._model.print_prediction(
                prediction, params=params, stream=output_file))
            index_fn = lambda prediction: prediction.get("index")
            ordered_writer = misc.OrderRestorer(index_fn, write_fn)

        loss_num = 0
        loss_den = 0
        metrics = self._model.get_metrics()
        for source, target in self._dataset:
            loss, predictions = self._eval_fn(source, target)
            if isinstance(loss, tuple):
                loss_num += loss[0]
                loss_den += loss[1]
            else:
                loss_num += loss
                loss_den += 1
            if metrics:
                self._model.update_metrics(metrics, predictions, target)
            if output_file is not None:
                predictions = {k: v.numpy() for k, v in predictions.items()}
                for prediction in misc.extract_batches(predictions):
                    ordered_writer.push(prediction)
        if loss_den == 0:
            raise RuntimeError("No examples were evaluated")
        loss = loss_num / loss_den

        results = dict(loss=loss, perplexity=tf.math.exp(loss))
        if metrics:
            for name, metric in metrics.items():
                results[name] = metric.result()
        if self._save_predictions:
            tf.get_logger().info("Evaluation predictions saved to %s",
                                 output_path)
            output_file.close()
            for scorer in self._scorers:
                score = scorer(self._labels_file, output_path)
                if isinstance(score, dict):
                    results.update(score)
                else:
                    results[scorer.name] = score

        for name, value in results.items():
            if isinstance(value, tf.Tensor):
                results[name] = value.numpy()

        self._record_results(step, results)
        self._maybe_export(step, results)
        self._maybe_garbage_collect_exports()
        return results
Ejemplo n.º 4
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        checkpoint, config = self._init_run()
        checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True)
        model = checkpoint.model
        infer_config = config["infer"]
        dataset = model.examples_inputter.make_inference_dataset(
            features_file,
            infer_config["batch_size"],
            length_bucket_width=infer_config["length_bucket_width"],
            prefetch_buffer_size=infer_config.get("prefetch_buffer_size"))

        @dataset_util.function_on_next(dataset, as_numpy=True)
        def _predict(next_fn):
            source = next_fn()
            return model.infer(source)

        if predictions_file:
            stream = io.open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        ordered_writer = None
        write_fn = lambda prediction: (model.print_prediction(
            prediction, params=infer_config, stream=stream))

        total_time = 0
        total_tokens = 0
        total_examples = 0
        start_time = time.time()

        for predictions in _predict():  # pylint: disable=no-value-for-parameter
            end_time = time.time()
            if log_time:
                total_time += end_time - start_time
                batch_size = next(six.itervalues(predictions)).shape[0]
                total_examples += batch_size
                length = predictions.get("length")
                if length is not None:
                    if len(length.shape) == 2:
                        length = length[:, 0]
                    total_tokens += sum(length)
            for prediction in misc.extract_batches(predictions):
                if "index" in prediction:
                    if ordered_writer is None:
                        ordered_writer = misc.OrderRestorer(
                            index_fn=lambda prediction: prediction["index"],
                            callback_fn=write_fn)
                    ordered_writer.push(prediction)
                else:
                    write_fn(prediction)
            start_time = time.time()

        if log_time:
            tf.get_logger().info("Total prediction time (s): %f", total_time)
            tf.get_logger().info("Average prediction time (s): %f",
                                 total_time / total_examples)
            if total_tokens > 0:
                tf.get_logger().info("Tokens per second: %f",
                                     total_tokens / total_time)
        if predictions_file:
            stream.close()
Ejemplo n.º 5
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        checkpoint, config = self._init_run()
        checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True)
        model = checkpoint.model
        infer_config = config["infer"]
        dataset = model.examples_inputter.make_inference_dataset(
            features_file,
            infer_config["batch_size"],
            length_bucket_width=infer_config["length_bucket_width"],
            prefetch_buffer_size=infer_config.get("prefetch_buffer_size"))

        if predictions_file:
            stream = open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_fn = tf.function(model.infer,
                               input_signature=(dataset.element_spec, ))
        infer_fn.get_concrete_function()  # Trace the function now.

        # Inference might return out-of-order predictions. The OrderRestorer utility is
        # used to write predictions in their original order.
        write_fn = lambda prediction: (model.print_prediction(
            prediction, params=infer_config, stream=stream))
        index_fn = lambda prediction: prediction.get("index")
        ordered_writer = misc.OrderRestorer(index_fn, write_fn)

        total_time = 0
        total_tokens = 0
        total_examples = 0
        start_time = time.time()

        for source in dataset:
            predictions = infer_fn(source)
            predictions = tf.nest.map_structure(lambda t: t.numpy(),
                                                predictions)
            for prediction in misc.extract_batches(predictions):
                ordered_writer.push(prediction)
            if log_time:
                batch_size = next(iter(predictions.values())).shape[0]
                total_examples += batch_size
                length = predictions.get("length")
                if length is not None:
                    if len(length.shape) == 2:
                        length = length[:, 0]
                    total_tokens += sum(length)

        if log_time:
            end_time = time.time()
            total_time = end_time - start_time
            tf.get_logger().info("Total prediction time (s): %f", total_time)
            tf.get_logger().info("Average prediction time (s): %f",
                                 total_time / total_examples)
            if total_tokens > 0:
                tf.get_logger().info("Tokens per second: %f",
                                     total_tokens / total_time)
        if predictions_file:
            stream.close()
Ejemplo n.º 6
0
def predict_dataset(model,
                    dataset,
                    print_params=None,
                    predictions_file=None,
                    log_time=False):
    """Outputs the model predictions for the dataset.

    To run inference on strings directly, see
    :meth:`opennmt.models.Model.serve_function`.

    Args:
      model: A :class:`opennmt.models.Model` instance.
      dataset: A ``tf.data.Dataset`` instance outputting features.
      print_params: A dictionary of parameters passed to
        :meth:`opennmt.models.Model.print_prediction`.
      predictions_file: If set, predictions are saved in this file, otherwise they
        are printed on the standard output.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
    if predictions_file:
        stream = open(predictions_file, encoding="utf-8", mode="w")
    else:
        stream = sys.stdout

    infer_fn = tf.function(model.infer,
                           input_signature=(dataset.element_spec, ))
    if not tf.config.functions_run_eagerly():
        tf.get_logger().info("Tracing and optimizing the inference graph...")
        infer_fn.get_concrete_function()  # Trace the function now.

    # Inference might return out-of-order predictions. The OrderRestorer utility is
    # used to write predictions in their original order.
    write_fn = lambda prediction: (model.print_prediction(
        prediction, params=print_params, stream=stream))
    index_fn = lambda prediction: prediction.get("index")
    ordered_writer = misc.OrderRestorer(index_fn, write_fn)

    total_time = 0
    total_tokens = 0
    total_examples = 0
    start_time = time.time()

    # When the inference dataset is bucketized, it can happen that no output is
    # written in a long time. To avoid confusion and give the impression that
    # the process is stuck, we ensure that something is logged regularly.
    max_time_without_output = 10
    last_output_time = start_time

    for features in dataset:
        predictions = infer_fn(features)
        predictions = tf.nest.map_structure(lambda t: t.numpy(), predictions)
        batch_time = time.time()

        for prediction in misc.extract_batches(predictions):
            written = ordered_writer.push(prediction)
            if written:
                last_output_time = batch_time
            else:
                time_without_output = batch_time - last_output_time
                if time_without_output >= max_time_without_output:
                    tf.get_logger().info(
                        "%d predictions are buffered, but waiting for the prediction of "
                        "line %d to advance the output...",
                        ordered_writer.buffer_size,
                        ordered_writer.next_index + 1,
                    )
                    last_output_time = batch_time

        if log_time:
            batch_size = next(iter(predictions.values())).shape[0]
            total_examples += batch_size
            length = predictions.get("length")
            if length is not None:
                if len(length.shape) == 2:
                    length = length[:, 0]
                total_tokens += sum(length)

    if log_time:
        end_time = time.time()
        total_time = end_time - start_time
        tf.get_logger().info("Total prediction time (s): %f", total_time)
        tf.get_logger().info("Average prediction time (s): %f",
                             total_time / total_examples)
        if total_tokens > 0:
            tf.get_logger().info("Tokens per second: %f",
                                 total_tokens / total_time)
    if predictions_file:
        stream.close()
Ejemplo n.º 7
0
    def infer(self,
              features_file,
              predictions_file=None,
              checkpoint_path=None,
              log_time=False):
        """Runs inference.

    Args:
      features_file: The file(s) to infer from.
      predictions_file: If set, predictions are saved in this file.
      checkpoint_path: Path of a specific checkpoint to predict. If ``None``,
        the latest is used.
      log_time: If ``True``, several time metrics will be printed in the logs at
        the end of the inference loop.
    """
        config = self._finalize_config()
        model = self._init_model(config)
        checkpoint = checkpoint_util.Checkpoint.from_config(config, model)
        checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True)
        infer_config = config["infer"]
        dataset = model.examples_inputter.make_inference_dataset(
            features_file,
            infer_config["batch_size"],
            length_bucket_width=infer_config["length_bucket_width"],
            prefetch_buffer_size=infer_config.get("prefetch_buffer_size"))

        if predictions_file:
            stream = open(predictions_file, encoding="utf-8", mode="w")
        else:
            stream = sys.stdout

        infer_fn = tf.function(model.infer,
                               input_signature=(dataset.element_spec, ))
        tf.get_logger().info("Tracing and optimizing the inference graph...")
        infer_fn.get_concrete_function()  # Trace the function now.

        # Inference might return out-of-order predictions. The OrderRestorer utility is
        # used to write predictions in their original order.
        write_fn = lambda prediction: (model.print_prediction(
            prediction, params=infer_config, stream=stream))
        index_fn = lambda prediction: prediction.get("index")
        ordered_writer = misc.OrderRestorer(index_fn, write_fn)

        total_time = 0
        total_tokens = 0
        total_examples = 0
        start_time = time.time()

        # When the inference dataset is bucketized, it can happen that no output is
        # written in a long time. To avoid confusion and give the impression that
        # the process is stuck, we ensure that something is logged regularly.
        max_time_without_output = 10
        last_output_time = start_time

        for source in dataset:
            predictions = infer_fn(source)
            predictions = tf.nest.map_structure(lambda t: t.numpy(),
                                                predictions)
            batch_time = time.time()

            for prediction in misc.extract_batches(predictions):
                written = ordered_writer.push(prediction)
                if written:
                    last_output_time = batch_time
                else:
                    time_without_output = batch_time - last_output_time
                    if time_without_output >= max_time_without_output:
                        tf.get_logger().info(
                            "%d predictions are buffered, but waiting for the prediction of "
                            "line %d to advance the output...",
                            ordered_writer.buffer_size,
                            ordered_writer.next_index + 1)
                        last_output_time = batch_time

            if log_time:
                batch_size = next(iter(predictions.values())).shape[0]
                total_examples += batch_size
                length = predictions.get("length")
                if length is not None:
                    if len(length.shape) == 2:
                        length = length[:, 0]
                    total_tokens += sum(length)

        if log_time:
            end_time = time.time()
            total_time = end_time - start_time
            tf.get_logger().info("Total prediction time (s): %f", total_time)
            tf.get_logger().info("Average prediction time (s): %f",
                                 total_time / total_examples)
            if total_tokens > 0:
                tf.get_logger().info("Tokens per second: %f",
                                     total_tokens / total_time)
        if predictions_file:
            stream.close()