def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ if "infer" not in self._config: self._config["infer"] = {} if checkpoint_path is not None and tf.gfile.IsDirectory( checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) input_fn = self._model.input_fn( tf.estimator.ModeKeys.PREDICT, self._config["infer"]["batch_size"], self._config["data"], features_file, bucket_width=self._config["infer"]["bucket_width"], num_threads=self._config["infer"].get("num_threads"), prefetch_buffer_size=self._config["infer"].get( "prefetch_buffer_size")) if predictions_file: stream = io.open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_hooks = [] if log_time: infer_hooks.append(hooks.LogPredictionTimeHook()) ordered_writer = None write_fn = lambda prediction: (self._model.print_prediction( prediction, params=self._config["infer"], stream=stream)) for prediction in self._estimator.predict( input_fn=input_fn, checkpoint_path=checkpoint_path, hooks=infer_hooks): # If the index is part of the prediction, they may be out of order. if "index" in prediction: if ordered_writer is None: ordered_writer = OrderRestorer( index_fn=lambda prediction: prediction["index"], callback_fn=write_fn) ordered_writer.push(prediction) else: write_fn(prediction) if predictions_file: stream.close()
def get_alignment(self, features_file, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ if "infer" not in self._config: self._config["infer"] = {} if checkpoint_path is not None and os.path.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) batch_size = self._config["infer"].get("batch_size", 1) input_fn = self._model.input_fn( tf.estimator.ModeKeys.PREDICT, batch_size, self._config["data"], features_file, num_threads=self._config["infer"].get("num_threads"), prefetch_buffer_size=self._config["infer"].get("prefetch_buffer_size")) infer_hooks = [] if log_time: infer_hooks.append(hooks.LogPredictionTimeHook()) def _parse_prediction(prediction, params=None): n_best = params and params.get("n_best") n_best = n_best or 1 if n_best > len(prediction["tokens"]): raise ValueError("n_best cannot be greater than beam_width") n_best_sentences = [] n_alignments = [] for i in range(n_best): tokens = prediction["tokens"][i][:prediction["length"][i] - 1] # Ignore </s>. sentence = self._model.target_inputter.tokenizer.detokenize(tokens) n_best_sentences.append(sentence) n_alignments.append(prediction["alignment"]) return n_best_sentences, n_alignments sentences = [] alignments = [] for prediction in self._estimator.predict( input_fn=input_fn, checkpoint_path=checkpoint_path, hooks=infer_hooks): n_best_sentences, n_alignments = _parse_prediction(prediction, params=self._config["infer"]) sentences.extend(n_best_sentences) alignments.extend(n_alignments) return sentences, alignments
def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ if "infer" not in self._config: self._config["infer"] = {} if checkpoint_path is not None and os.path.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) batch_size = self._config["infer"].get("batch_size", 1) input_fn = self._model.input_fn( tf.estimator.ModeKeys.PREDICT, batch_size, self._config["data"], features_file, num_threads=self._config["infer"].get("num_threads"), prefetch_buffer_size=self._config["infer"].get( "prefetch_buffer_size")) if predictions_file: stream = io.open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_hooks = [] if log_time: infer_hooks.append(hooks.LogPredictionTimeHook()) for prediction in self._estimator.predict( input_fn=input_fn, checkpoint_path=checkpoint_path, hooks=infer_hooks): self._model.print_prediction(prediction, params=self._config["infer"], stream=stream) if predictions_file: stream.close()
def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ need_ae = False if self._ae_model != None: need_ae = True if checkpoint_path is not None and tf.gfile.IsDirectory( checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) input_fn = estimator_util.make_input_fn( self._model, tf.estimator.ModeKeys.PREDICT, self._config["infer"]["batch_size"], features_file=features_file, bucket_width=self._config["infer"]["bucket_width"], num_threads=self._config["infer"].get("num_threads"), prefetch_buffer_size=self._config["infer"].get( "prefetch_buffer_size"), return_dataset=False, need_ae=need_ae) if predictions_file: stream = io.open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_hooks = [] if log_time: infer_hooks.append(hooks.LogPredictionTimeHook()) ordered_writer = None write_fn = lambda prediction: (self._model.print_prediction( prediction, params=self._config["infer"], stream=stream)) estimator = self._make_estimator() sample_count = 0 for prediction in estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, hooks=infer_hooks): if sample_count % 1000 == 0: now_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) tf.logging.info("{}:{}".format(now_time, sample_count)) sample_count += 1 # If the index is part of the prediction, they may be out of order. if "index" in prediction: if ordered_writer is None: ordered_writer = OrderRestorer( index_fn=lambda prediction: prediction["index"], callback_fn=write_fn) ordered_writer.push(prediction) else: write_fn(prediction) if predictions_file: stream.close()