def __init__(self,
                 model_feature_specification_fn=None,
                 model_label_specification_fn=None,
                 is_model_device_tpu=False):
        """Initialize an instance.

    The provided specifications are used both for the in and out specification.
    The _preprocess_fn will not alter the provided tensors.

    Args:
      model_feature_specification_fn: (Optional) A function which takes mode as
        an argument and returns a valid spec structure for the features,
        preferablely a (hierarchical) namedtuple of TensorSpecs and
        OptionalTensorSpecs.
      model_label_specification_fn: (Optional) A function which takes mode as an
        argument and returns a valid spec structure for the labels, preferably a
        (hierarchical) namedtupel of TensorSpecs and OptionalTensorSpecs.
      is_model_device_tpu: True if the model is operating on TPU and otherwise
        False. This information is useful to do type conversions and strip
        unnecessary information from preprocessing since no summaries are
        generated on TPUs.
    """
        for spec_generator in [
                model_feature_specification_fn, model_label_specification_fn
        ]:
            for estimator_mode in [
                    ModeKeys.TRAIN, ModeKeys.PREDICT, ModeKeys.EVAL
            ]:
                if spec_generator:
                    tensorspec_utils.assert_valid_spec_structure(
                        spec_generator(estimator_mode))

        self._model_feature_specification_fn = model_feature_specification_fn
        self._model_label_specification_fn = model_label_specification_fn
        self._is_model_device_tpu = is_model_device_tpu
Beispiel #2
0
 def set_specification_from_model(self, t2r_model):
     """See base class documentation."""
     super(MockMetaExportGenerator,
           self).set_specification_from_model(t2r_model)
     self._base_feature_spec = (t2r_model.preprocessor.base_preprocessor.
                                get_in_feature_specification(
                                    tf.estimator.ModeKeys.PREDICT))
     tensorspec_utils.assert_valid_spec_structure(self._base_feature_spec)
     self._base_label_spec = (t2r_model.preprocessor.base_preprocessor.
                              get_in_label_specification(
                                  tf.estimator.ModeKeys.PREDICT))
     tensorspec_utils.assert_valid_spec_structure(self._base_label_spec)
Beispiel #3
0
  def set_specification_from_model(self,
                                   t2r_model):
    """Set the feature specifications and preprocess function from the model.

    Args:
      t2r_model: A T2R model instance.
    """
    preprocessor = t2r_model.preprocessor
    self._feature_spec = preprocessor.get_in_feature_specification(MODE)
    tensorspec_utils.assert_valid_spec_structure(self._feature_spec)
    self._out_feature_spec = (preprocessor.get_out_feature_specification(MODE))
    tensorspec_utils.assert_valid_spec_structure(self._out_feature_spec)
    self._preprocess_fn = functools.partial(preprocessor.preprocess, mode=MODE)
    self._model_name = type(t2r_model).__name__
    def set_specification_from_model(self, t2r_model, mode):
        """Get all specifications to create and verify an input pipeline.

    Args:
      t2r_model: A T2RModel from which we get all necessary feature
        and label specifications.
      mode: A tf.estimator.ModelKeys object that specifies the mode for
        specification.
    """
        preprocessor = t2r_model.preprocessor
        self._feature_spec = preprocessor.get_in_feature_specification(mode)
        tensorspec_utils.assert_valid_spec_structure(self._feature_spec)
        self._label_spec = preprocessor.get_in_label_specification(mode)
        tensorspec_utils.assert_valid_spec_structure(self._label_spec)
        # It is necessary to verify that the output of the dataset inputs fulfill
        # our specification.
        self._out_feature_spec = (
            preprocessor.get_out_feature_specification(mode))
        tensorspec_utils.assert_valid_spec_structure(self._out_feature_spec)
        self._out_label_spec = (preprocessor.get_out_label_specification(mode))
        tensorspec_utils.assert_valid_spec_structure(self._out_label_spec)
        self._preprocess_fn = functools.partial(preprocessor.preprocess,
                                                mode=mode)
Beispiel #5
0
 def test_assert_valid_spec_structure_invalid(self, spec_or_tensors):
     with self.assertRaises(ValueError):
         utils.assert_valid_spec_structure(spec_or_tensors)
Beispiel #6
0
 def test_assert_valid_spec_structure_is_valid(self, collection_type):
     spec = self._make_tensorspec_collection(collection_type)
     utils.assert_valid_spec_structure(spec)
 def set_label_specifications(self, label_spec, out_label_spec):
     tensorspec_utils.assert_valid_spec_structure(label_spec)
     tensorspec_utils.assert_valid_spec_structure(out_label_spec)
     self._label_spec = label_spec
     self._out_label_spec = out_label_spec
 def set_feature_specifications(self, feature_spec, out_feature_spec):
     tensorspec_utils.assert_valid_spec_structure(feature_spec)
     tensorspec_utils.assert_valid_spec_structure(out_feature_spec)
     self._feature_spec = feature_spec
     self._out_feature_spec = out_feature_spec
Beispiel #9
0
 def test_assert_valid_spec_structure_invalid(self, spec_or_tensors):
     with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
         utils.assert_valid_spec_structure(spec_or_tensors)