def _predict_without_distribute_strategy(self, model, input_fn): """Predicts the dataset without using distribute strategy.""" ds = input_fn() all_results = [] for features, _ in ds: outputs = model.predict_on_batch(features) for unique_id, start_logits, end_logits in zip(features['unique_ids'], outputs[0], outputs[1]): raw_result = run_squad_helper.RawResult( unique_id=unique_id.numpy(), start_logits=start_logits.tolist(), end_logits=end_logits.tolist()) all_results.append(raw_result) if len(all_results) % 100 == 0: tf.compat.v1.logging.info('Made predictions for %d records.', len(all_results)) return all_results
def predict_tflite(self, tflite_filepath, dataset): """Predicts the dataset for TFLite model in `tflite_filepath`.""" all_results = [] lite_runner = model_util.LiteRunner(tflite_filepath, self.reorder_input_details, self.reorder_output_details) for features, _ in dataset: outputs = lite_runner.run(features) for unique_id, start_logits, end_logits in zip( features['unique_ids'], outputs[0], outputs[1]): raw_result = run_squad_helper.RawResult( unique_id=unique_id.numpy(), start_logits=start_logits.tolist(), end_logits=end_logits.tolist()) all_results.append(raw_result) if len(all_results) % 100 == 0: tf.compat.v1.logging.info( 'Made predictions for %d records.', len(all_results)) return all_results