def __init__(self, estimator: tf.estimator.Estimator, feature_columns: List[fc2.FeatureColumn], serving_input_fn: Callable[..., Any], output_key: Optional[Text] = None, **kwargs): """Initialize an EstimatorMetadataBuilder. Args: estimator: Estimator instance to observe and save. feature_columns: A group of feature columns to export metadata for. These feature columns need to be basic feature columns and not derived columns such as embedding, indicator, bucketized. serving_input_fn: Serving input function to be used when exporting the model. output_key: Output key to find the model's relevant output tensors. Some valid values are logits, probabilities. If not provided, will default to logits and regression outputs. **kwargs: Any keyword arguments to be passed to export_saved_model. add_meta_graph() function. """ if not isinstance(estimator, tf.estimator.Estimator): raise ValueError('A valid estimator needs to be provided.') self._estimator = estimator if not feature_columns: raise ValueError('feature_columns cannot be empty.') self._feature_columns = feature_columns self._output_key = output_key self._monkey_patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() self._serving_input_fn = serving_input_fn self._save_args = kwargs
def save_model_with_metadata(self, file_path: str) -> str: """Saves the model and the generated metadata to the given file path. New metadata will not be generated for each call to this function since an Estimator is static. Calling this function with different paths will save the model and the same metadata to all paths. Args: file_path: Path to save the model and the metadata. It can be a GCS bucket or a local folder. The folder needs to be empty. Returns: Full file path where the model and the metadata are written. """ monkey_patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with monkey_patcher.exporting_context(self._output_key): model_path = self._estimator.export_saved_model( file_path, self._serving_input_fn, **self._save_args) if not self._metadata: self._metadata = self._create_metadata_from_tensors( monkey_patcher.feature_tensors_dict, monkey_patcher.crossed_columns, [fc.name for fc in self._feature_columns], monkey_patcher.output_tensors_dict) utils.write_metadata_to_file(self._metadata.to_dict(), model_path) return model_path
def test_exporting_context_no_output_key_classifier(self): patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with patcher.exporting_context(): self.classifier_dnn.export_saved_model( tf.test.get_temp_dir(), self._get_json_serving_input_fn) self.assertIn(prediction_keys.PredictionKeys.LOGITS, patcher.output_tensors_dict)
def test_exporting_context_weighted_column(self): patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with patcher.exporting_context('logits'): self.weighted_linear.export_saved_model( tf.test.get_temp_dir(), self._get_weighted_json_serving_input_fn) feature_tensors_dict = patcher.feature_tensors_dict self.assertLen(feature_tensors_dict, 3) self.assertLen(feature_tensors_dict['language'][0].input_tensor, 2)
def test_exporting_context_export_works_after_exception(self): try: patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with patcher.exporting_context('none'): self.classifier_dnn.export_saved_model( tf.test.get_temp_dir(), self._get_json_serving_input_fn) except ValueError: path = self.classifier_dnn.export_saved_model( tf.test.get_temp_dir(), self._get_json_serving_input_fn) # Folder should contain 'saved_model.pb' and 'variables' files. self.assertLen(os.listdir(path), 2)
def test_observe_with_boosted_trees(self): patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with patcher.exporting_context('logits'): self.boosted_tree.export_saved_model( tf.test.get_temp_dir(), self._get_json_serving_input_fn) feature_tensors_dict = patcher.feature_tensors_dict self.assertLen(feature_tensors_dict, 3) for fc in self.base_columns: self.assertIn(fc.name, feature_tensors_dict) self.assertLen(feature_tensors_dict[fc.name], 1) self.assertEmpty(feature_tensors_dict[fc.name][0].encoded_tensors) self.assertEmpty(patcher.crossed_columns) self.assertLen(patcher.output_tensors_dict, 1)
def test_exporting_context_wide_deep(self): patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with patcher.exporting_context('logits'): self.wide_deep_classifier.export_saved_model( tf.test.get_temp_dir(), self._get_json_serving_input_fn) feature_tensors_dict = patcher.feature_tensors_dict self.assertLen(feature_tensors_dict, 3) for fc in self.base_columns: self.assertIn(fc.name, feature_tensors_dict) self.assertLen(feature_tensors_dict[fc.name], 2) self.assertEmpty(feature_tensors_dict['age'][0].encoded_tensors) self.assertLen(feature_tensors_dict['language'][0].encoded_tensors, 3) self.assertLen( feature_tensors_dict['class_identity'][0].encoded_tensors, 2) self.assertLen(patcher.crossed_columns, 2) self.assertLen(patcher.output_tensors_dict, 1)
def test_exporting_context_incorrect_output_key(self): patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper() with self.assertRaisesRegex(ValueError, 'Output key .* is not found'): with patcher.exporting_context('none'): self.regressor.export_saved_model( tf.test.get_temp_dir(), self._get_json_serving_input_fn)