def score_dataset(model, dataset, print_params=None, output_file=None): """Outputs the model scores for the dataset. Args: model: A :class:`opennmt.models.Model` instance. dataset: A ``tf.data.Dataset`` instance outputting parallel features and labels. print_params: A dictionary of parameters passed to :meth:`opennmt.models.Model.print_score`. output_file: If set, outputs are saved in this file, otherwise they are printed on the standard output. """ if output_file: stream = open(output_file, encoding="utf-8", mode="w") else: stream = sys.stdout write_fn = lambda batch: (model.print_score( batch, params=print_params, stream=stream)) index_fn = lambda batch: batch.get("index") ordered_writer = misc.OrderRestorer(index_fn, write_fn) score_fn = tf.function(model.score, input_signature=dataset.element_spec) for features, labels in dataset: results = score_fn(features, labels) results = tf.nest.map_structure(lambda t: t.numpy(), results) for batch in misc.extract_batches(results): ordered_writer.push(batch) if output_file: stream.close()
def testEventOrderRestorer(self): events = [] restorer = misc.OrderRestorer(index_fn=lambda x: x[0], callback_fn=lambda x: events.append(x)) self.assertFalse(restorer.push((2, "toto"))) self.assertFalse(restorer.push((1, "tata"))) self.assertFalse(restorer.push((3, "foo"))) self.assertTrue(restorer.push((0, "bar"))) self.assertTrue(restorer.push((4, "titi"))) with self.assertRaises(ValueError): restorer.push((2, "invalid")) self.assertEqual(len(events), 5) self.assertTupleEqual(events[0], (0, "bar")) self.assertTupleEqual(events[1], (1, "tata")) self.assertTupleEqual(events[2], (2, "toto")) self.assertTupleEqual(events[3], (3, "foo")) self.assertTupleEqual(events[4], (4, "titi"))
def __call__(self, step): """Runs the evaluator. Args: step: The current training step. Returns: A dictionary of evaluation metrics. """ tf.get_logger().info("Running evaluation for step %d", step) output_file = None output_path = None if self._save_predictions: output_path = os.path.join(self._eval_dir, "predictions.txt.%d" % step) output_file = tf.io.gfile.GFile(output_path, "w") params = {"n_best": 1} write_fn = lambda prediction: (self._model.print_prediction( prediction, params=params, stream=output_file)) index_fn = lambda prediction: prediction.get("index") ordered_writer = misc.OrderRestorer(index_fn, write_fn) loss_num = 0 loss_den = 0 metrics = self._model.get_metrics() for source, target in self._dataset: loss, predictions = self._eval_fn(source, target) if isinstance(loss, tuple): loss_num += loss[0] loss_den += loss[1] else: loss_num += loss loss_den += 1 if metrics: self._model.update_metrics(metrics, predictions, target) if output_file is not None: predictions = {k: v.numpy() for k, v in predictions.items()} for prediction in misc.extract_batches(predictions): ordered_writer.push(prediction) if loss_den == 0: raise RuntimeError("No examples were evaluated") loss = loss_num / loss_den results = dict(loss=loss, perplexity=tf.math.exp(loss)) if metrics: for name, metric in metrics.items(): results[name] = metric.result() if self._save_predictions: tf.get_logger().info("Evaluation predictions saved to %s", output_path) output_file.close() for scorer in self._scorers: score = scorer(self._labels_file, output_path) if isinstance(score, dict): results.update(score) else: results[scorer.name] = score for name, value in results.items(): if isinstance(value, tf.Tensor): results[name] = value.numpy() self._record_results(step, results) self._maybe_export(step, results) self._maybe_garbage_collect_exports() return results
def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ checkpoint, config = self._init_run() checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True) model = checkpoint.model infer_config = config["infer"] dataset = model.examples_inputter.make_inference_dataset( features_file, infer_config["batch_size"], length_bucket_width=infer_config["length_bucket_width"], prefetch_buffer_size=infer_config.get("prefetch_buffer_size")) @dataset_util.function_on_next(dataset, as_numpy=True) def _predict(next_fn): source = next_fn() return model.infer(source) if predictions_file: stream = io.open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout ordered_writer = None write_fn = lambda prediction: (model.print_prediction( prediction, params=infer_config, stream=stream)) total_time = 0 total_tokens = 0 total_examples = 0 start_time = time.time() for predictions in _predict(): # pylint: disable=no-value-for-parameter end_time = time.time() if log_time: total_time += end_time - start_time batch_size = next(six.itervalues(predictions)).shape[0] total_examples += batch_size length = predictions.get("length") if length is not None: if len(length.shape) == 2: length = length[:, 0] total_tokens += sum(length) for prediction in misc.extract_batches(predictions): if "index" in prediction: if ordered_writer is None: ordered_writer = misc.OrderRestorer( index_fn=lambda prediction: prediction["index"], callback_fn=write_fn) ordered_writer.push(prediction) else: write_fn(prediction) start_time = time.time() if log_time: tf.get_logger().info("Total prediction time (s): %f", total_time) tf.get_logger().info("Average prediction time (s): %f", total_time / total_examples) if total_tokens > 0: tf.get_logger().info("Tokens per second: %f", total_tokens / total_time) if predictions_file: stream.close()
def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ checkpoint, config = self._init_run() checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True) model = checkpoint.model infer_config = config["infer"] dataset = model.examples_inputter.make_inference_dataset( features_file, infer_config["batch_size"], length_bucket_width=infer_config["length_bucket_width"], prefetch_buffer_size=infer_config.get("prefetch_buffer_size")) if predictions_file: stream = open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_fn = tf.function(model.infer, input_signature=(dataset.element_spec, )) infer_fn.get_concrete_function() # Trace the function now. # Inference might return out-of-order predictions. The OrderRestorer utility is # used to write predictions in their original order. write_fn = lambda prediction: (model.print_prediction( prediction, params=infer_config, stream=stream)) index_fn = lambda prediction: prediction.get("index") ordered_writer = misc.OrderRestorer(index_fn, write_fn) total_time = 0 total_tokens = 0 total_examples = 0 start_time = time.time() for source in dataset: predictions = infer_fn(source) predictions = tf.nest.map_structure(lambda t: t.numpy(), predictions) for prediction in misc.extract_batches(predictions): ordered_writer.push(prediction) if log_time: batch_size = next(iter(predictions.values())).shape[0] total_examples += batch_size length = predictions.get("length") if length is not None: if len(length.shape) == 2: length = length[:, 0] total_tokens += sum(length) if log_time: end_time = time.time() total_time = end_time - start_time tf.get_logger().info("Total prediction time (s): %f", total_time) tf.get_logger().info("Average prediction time (s): %f", total_time / total_examples) if total_tokens > 0: tf.get_logger().info("Tokens per second: %f", total_tokens / total_time) if predictions_file: stream.close()
def predict_dataset(model, dataset, print_params=None, predictions_file=None, log_time=False): """Outputs the model predictions for the dataset. To run inference on strings directly, see :meth:`opennmt.models.Model.serve_function`. Args: model: A :class:`opennmt.models.Model` instance. dataset: A ``tf.data.Dataset`` instance outputting features. print_params: A dictionary of parameters passed to :meth:`opennmt.models.Model.print_prediction`. predictions_file: If set, predictions are saved in this file, otherwise they are printed on the standard output. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ if predictions_file: stream = open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_fn = tf.function(model.infer, input_signature=(dataset.element_spec, )) if not tf.config.functions_run_eagerly(): tf.get_logger().info("Tracing and optimizing the inference graph...") infer_fn.get_concrete_function() # Trace the function now. # Inference might return out-of-order predictions. The OrderRestorer utility is # used to write predictions in their original order. write_fn = lambda prediction: (model.print_prediction( prediction, params=print_params, stream=stream)) index_fn = lambda prediction: prediction.get("index") ordered_writer = misc.OrderRestorer(index_fn, write_fn) total_time = 0 total_tokens = 0 total_examples = 0 start_time = time.time() # When the inference dataset is bucketized, it can happen that no output is # written in a long time. To avoid confusion and give the impression that # the process is stuck, we ensure that something is logged regularly. max_time_without_output = 10 last_output_time = start_time for features in dataset: predictions = infer_fn(features) predictions = tf.nest.map_structure(lambda t: t.numpy(), predictions) batch_time = time.time() for prediction in misc.extract_batches(predictions): written = ordered_writer.push(prediction) if written: last_output_time = batch_time else: time_without_output = batch_time - last_output_time if time_without_output >= max_time_without_output: tf.get_logger().info( "%d predictions are buffered, but waiting for the prediction of " "line %d to advance the output...", ordered_writer.buffer_size, ordered_writer.next_index + 1, ) last_output_time = batch_time if log_time: batch_size = next(iter(predictions.values())).shape[0] total_examples += batch_size length = predictions.get("length") if length is not None: if len(length.shape) == 2: length = length[:, 0] total_tokens += sum(length) if log_time: end_time = time.time() total_time = end_time - start_time tf.get_logger().info("Total prediction time (s): %f", total_time) tf.get_logger().info("Average prediction time (s): %f", total_time / total_examples) if total_tokens > 0: tf.get_logger().info("Tokens per second: %f", total_tokens / total_time) if predictions_file: stream.close()
def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ config = self._finalize_config() model = self._init_model(config) checkpoint = checkpoint_util.Checkpoint.from_config(config, model) checkpoint.restore(checkpoint_path=checkpoint_path, weights_only=True) infer_config = config["infer"] dataset = model.examples_inputter.make_inference_dataset( features_file, infer_config["batch_size"], length_bucket_width=infer_config["length_bucket_width"], prefetch_buffer_size=infer_config.get("prefetch_buffer_size")) if predictions_file: stream = open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_fn = tf.function(model.infer, input_signature=(dataset.element_spec, )) tf.get_logger().info("Tracing and optimizing the inference graph...") infer_fn.get_concrete_function() # Trace the function now. # Inference might return out-of-order predictions. The OrderRestorer utility is # used to write predictions in their original order. write_fn = lambda prediction: (model.print_prediction( prediction, params=infer_config, stream=stream)) index_fn = lambda prediction: prediction.get("index") ordered_writer = misc.OrderRestorer(index_fn, write_fn) total_time = 0 total_tokens = 0 total_examples = 0 start_time = time.time() # When the inference dataset is bucketized, it can happen that no output is # written in a long time. To avoid confusion and give the impression that # the process is stuck, we ensure that something is logged regularly. max_time_without_output = 10 last_output_time = start_time for source in dataset: predictions = infer_fn(source) predictions = tf.nest.map_structure(lambda t: t.numpy(), predictions) batch_time = time.time() for prediction in misc.extract_batches(predictions): written = ordered_writer.push(prediction) if written: last_output_time = batch_time else: time_without_output = batch_time - last_output_time if time_without_output >= max_time_without_output: tf.get_logger().info( "%d predictions are buffered, but waiting for the prediction of " "line %d to advance the output...", ordered_writer.buffer_size, ordered_writer.next_index + 1) last_output_time = batch_time if log_time: batch_size = next(iter(predictions.values())).shape[0] total_examples += batch_size length = predictions.get("length") if length is not None: if len(length.shape) == 2: length = length[:, 0] total_tokens += sum(length) if log_time: end_time = time.time() total_time = end_time - start_time tf.get_logger().info("Total prediction time (s): %f", total_time) tf.get_logger().info("Average prediction time (s): %f", total_time / total_examples) if total_tokens > 0: tf.get_logger().info("Tokens per second: %f", total_tokens / total_time) if predictions_file: stream.close()