def __init__(self, model_feature_specification_fn=None, model_label_specification_fn=None, is_model_device_tpu=False): """Initialize an instance. The provided specifications are used both for the in and out specification. The _preprocess_fn will not alter the provided tensors. Args: model_feature_specification_fn: (Optional) A function which takes mode as an argument and returns a valid spec structure for the features, preferablely a (hierarchical) namedtuple of TensorSpecs and OptionalTensorSpecs. model_label_specification_fn: (Optional) A function which takes mode as an argument and returns a valid spec structure for the labels, preferably a (hierarchical) namedtupel of TensorSpecs and OptionalTensorSpecs. is_model_device_tpu: True if the model is operating on TPU and otherwise False. This information is useful to do type conversions and strip unnecessary information from preprocessing since no summaries are generated on TPUs. """ for spec_generator in [ model_feature_specification_fn, model_label_specification_fn ]: for estimator_mode in [ ModeKeys.TRAIN, ModeKeys.PREDICT, ModeKeys.EVAL ]: if spec_generator: tensorspec_utils.assert_valid_spec_structure( spec_generator(estimator_mode)) self._model_feature_specification_fn = model_feature_specification_fn self._model_label_specification_fn = model_label_specification_fn self._is_model_device_tpu = is_model_device_tpu
def set_specification_from_model(self, t2r_model): """See base class documentation.""" super(MockMetaExportGenerator, self).set_specification_from_model(t2r_model) self._base_feature_spec = (t2r_model.preprocessor.base_preprocessor. get_in_feature_specification( tf.estimator.ModeKeys.PREDICT)) tensorspec_utils.assert_valid_spec_structure(self._base_feature_spec) self._base_label_spec = (t2r_model.preprocessor.base_preprocessor. get_in_label_specification( tf.estimator.ModeKeys.PREDICT)) tensorspec_utils.assert_valid_spec_structure(self._base_label_spec)
def set_specification_from_model(self, t2r_model): """Set the feature specifications and preprocess function from the model. Args: t2r_model: A T2R model instance. """ preprocessor = t2r_model.preprocessor self._feature_spec = preprocessor.get_in_feature_specification(MODE) tensorspec_utils.assert_valid_spec_structure(self._feature_spec) self._out_feature_spec = (preprocessor.get_out_feature_specification(MODE)) tensorspec_utils.assert_valid_spec_structure(self._out_feature_spec) self._preprocess_fn = functools.partial(preprocessor.preprocess, mode=MODE) self._model_name = type(t2r_model).__name__
def set_specification_from_model(self, t2r_model, mode): """Get all specifications to create and verify an input pipeline. Args: t2r_model: A T2RModel from which we get all necessary feature and label specifications. mode: A tf.estimator.ModelKeys object that specifies the mode for specification. """ preprocessor = t2r_model.preprocessor self._feature_spec = preprocessor.get_in_feature_specification(mode) tensorspec_utils.assert_valid_spec_structure(self._feature_spec) self._label_spec = preprocessor.get_in_label_specification(mode) tensorspec_utils.assert_valid_spec_structure(self._label_spec) # It is necessary to verify that the output of the dataset inputs fulfill # our specification. self._out_feature_spec = ( preprocessor.get_out_feature_specification(mode)) tensorspec_utils.assert_valid_spec_structure(self._out_feature_spec) self._out_label_spec = (preprocessor.get_out_label_specification(mode)) tensorspec_utils.assert_valid_spec_structure(self._out_label_spec) self._preprocess_fn = functools.partial(preprocessor.preprocess, mode=mode)
def test_assert_valid_spec_structure_invalid(self, spec_or_tensors): with self.assertRaises(ValueError): utils.assert_valid_spec_structure(spec_or_tensors)
def test_assert_valid_spec_structure_is_valid(self, collection_type): spec = self._make_tensorspec_collection(collection_type) utils.assert_valid_spec_structure(spec)
def set_label_specifications(self, label_spec, out_label_spec): tensorspec_utils.assert_valid_spec_structure(label_spec) tensorspec_utils.assert_valid_spec_structure(out_label_spec) self._label_spec = label_spec self._out_label_spec = out_label_spec
def set_feature_specifications(self, feature_spec, out_feature_spec): tensorspec_utils.assert_valid_spec_structure(feature_spec) tensorspec_utils.assert_valid_spec_structure(out_feature_spec) self._feature_spec = feature_spec self._out_feature_spec = out_feature_spec
def test_assert_valid_spec_structure_invalid(self, spec_or_tensors): with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises utils.assert_valid_spec_structure(spec_or_tensors)