Esempio n. 1
0
    def __init__(self,
                 estimator: tf.estimator.Estimator,
                 feature_columns: List[fc2.FeatureColumn],
                 serving_input_fn: Callable[..., Any],
                 output_key: Optional[Text] = None,
                 **kwargs):
        """Initialize an EstimatorMetadataBuilder.

    Args:
      estimator: Estimator instance to observe and save.
      feature_columns: A group of feature columns to export metadata for. These
        feature columns need to be basic feature columns and not derived
        columns such as embedding, indicator, bucketized.
      serving_input_fn: Serving input function to be used when exporting the
        model.
      output_key: Output key to find the model's relevant output tensors. Some
        valid values are logits, probabilities. If not provided, will default to
        logits and regression outputs.
      **kwargs: Any keyword arguments to be passed to export_saved_model.
        add_meta_graph() function.
    """
        if not isinstance(estimator, tf.estimator.Estimator):
            raise ValueError('A valid estimator needs to be provided.')
        self._estimator = estimator
        if not feature_columns:
            raise ValueError('feature_columns cannot be empty.')
        self._feature_columns = feature_columns
        self._output_key = output_key
        self._monkey_patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
        self._serving_input_fn = serving_input_fn
        self._save_args = kwargs
    def save_model_with_metadata(self, file_path: str) -> str:
        """Saves the model and the generated metadata to the given file path.

    New metadata will not be generated for each call to this function since an
    Estimator is static. Calling this function with different paths will save
    the model and the same metadata to all paths.

    Args:
      file_path: Path to save the model and the metadata. It can be a GCS bucket
        or a local folder. The folder needs to be empty.

    Returns:
      Full file path where the model and the metadata are written.
    """
        monkey_patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
        with monkey_patcher.exporting_context(self._output_key):
            model_path = self._estimator.export_saved_model(
                file_path, self._serving_input_fn, **self._save_args)

        if not self._metadata:
            self._metadata = self._create_metadata_from_tensors(
                monkey_patcher.feature_tensors_dict,
                monkey_patcher.crossed_columns,
                [fc.name for fc in self._feature_columns],
                monkey_patcher.output_tensors_dict)
        utils.write_metadata_to_file(self._metadata.to_dict(), model_path)
        return model_path
 def test_exporting_context_no_output_key_classifier(self):
     patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
     with patcher.exporting_context():
         self.classifier_dnn.export_saved_model(
             tf.test.get_temp_dir(), self._get_json_serving_input_fn)
     self.assertIn(prediction_keys.PredictionKeys.LOGITS,
                   patcher.output_tensors_dict)
 def test_exporting_context_weighted_column(self):
     patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
     with patcher.exporting_context('logits'):
         self.weighted_linear.export_saved_model(
             tf.test.get_temp_dir(),
             self._get_weighted_json_serving_input_fn)
     feature_tensors_dict = patcher.feature_tensors_dict
     self.assertLen(feature_tensors_dict, 3)
     self.assertLen(feature_tensors_dict['language'][0].input_tensor, 2)
 def test_exporting_context_export_works_after_exception(self):
     try:
         patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
         with patcher.exporting_context('none'):
             self.classifier_dnn.export_saved_model(
                 tf.test.get_temp_dir(), self._get_json_serving_input_fn)
     except ValueError:
         path = self.classifier_dnn.export_saved_model(
             tf.test.get_temp_dir(), self._get_json_serving_input_fn)
     # Folder should contain 'saved_model.pb' and 'variables' files.
     self.assertLen(os.listdir(path), 2)
 def test_observe_with_boosted_trees(self):
     patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
     with patcher.exporting_context('logits'):
         self.boosted_tree.export_saved_model(
             tf.test.get_temp_dir(), self._get_json_serving_input_fn)
     feature_tensors_dict = patcher.feature_tensors_dict
     self.assertLen(feature_tensors_dict, 3)
     for fc in self.base_columns:
         self.assertIn(fc.name, feature_tensors_dict)
         self.assertLen(feature_tensors_dict[fc.name], 1)
         self.assertEmpty(feature_tensors_dict[fc.name][0].encoded_tensors)
     self.assertEmpty(patcher.crossed_columns)
     self.assertLen(patcher.output_tensors_dict, 1)
 def test_exporting_context_wide_deep(self):
     patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
     with patcher.exporting_context('logits'):
         self.wide_deep_classifier.export_saved_model(
             tf.test.get_temp_dir(), self._get_json_serving_input_fn)
     feature_tensors_dict = patcher.feature_tensors_dict
     self.assertLen(feature_tensors_dict, 3)
     for fc in self.base_columns:
         self.assertIn(fc.name, feature_tensors_dict)
         self.assertLen(feature_tensors_dict[fc.name], 2)
     self.assertEmpty(feature_tensors_dict['age'][0].encoded_tensors)
     self.assertLen(feature_tensors_dict['language'][0].encoded_tensors, 3)
     self.assertLen(
         feature_tensors_dict['class_identity'][0].encoded_tensors, 2)
     self.assertLen(patcher.crossed_columns, 2)
     self.assertLen(patcher.output_tensors_dict, 1)
 def test_exporting_context_incorrect_output_key(self):
     patcher = monkey_patch_utils.EstimatorMonkeyPatchHelper()
     with self.assertRaisesRegex(ValueError, 'Output key .* is not found'):
         with patcher.exporting_context('none'):
             self.regressor.export_saved_model(
                 tf.test.get_temp_dir(), self._get_json_serving_input_fn)