def test_remove_optional(self):
     preprocessor = noop_preprocessor.NoOpPreprocessor(
         model_feature_specification_fn=lambda mode: _FEATURE_SPEC_NO_CAST,
         model_label_specification_fn=lambda mode: _LABEL_SPEC_NO_CAST)
     tpu_preprocessor = tpu_preprocessor_wrapper.TPUPreprocessorWrapper(
         preprocessor=preprocessor)
     self.assertDictEqual(
         tpu_preprocessor.get_in_feature_specification(_MODE_TRAIN),
         preprocessor.get_in_feature_specification(_MODE_TRAIN))
     self.assertDictEqual(
         tpu_preprocessor.get_in_label_specification(_MODE_TRAIN),
         preprocessor.get_in_label_specification(_MODE_TRAIN))
     out_feature_spec = tensorspec_utils.replace_dtype(
         preprocessor.get_out_feature_specification(_MODE_TRAIN),
         from_dtype=tf.float32,
         to_dtype=tf.bfloat16)
     del out_feature_spec['optional_value']
     self.assertDictEqual(
         tpu_preprocessor.get_out_feature_specification(_MODE_TRAIN),
         out_feature_spec)
     out_label_spec = tensorspec_utils.replace_dtype(
         preprocessor.get_out_label_specification(_MODE_TRAIN),
         from_dtype=tf.float32,
         to_dtype=tf.bfloat16)
     del out_label_spec['optional_value']
     self.assertDictEqual(
         tpu_preprocessor.get_out_label_specification(_MODE_TRAIN),
         out_label_spec)
Esempio n. 2
0
 def get_feature_specification(
     self, mode):
   """Returns the feature specification with bfloat16 replacing float32."""
   return tensorspec_utils.replace_dtype(
       self._t2r_model.get_feature_specification(mode),
       from_dtype=tf.float32,
       to_dtype=tf.bfloat16)
    def get_in_label_specification(self, mode):
        """The specification for the input labels for the preprocess_fn.

    Arguments:
      mode: mode key for this feature specification

    Returns:
      A TensorSpecStruct describing the required and optional tensors.
    """
        return tensorspec_utils.replace_dtype(
            self._preprocessor.get_in_label_specification(mode),
            from_dtype=tf.bfloat16,
            to_dtype=tf.float32)
    def get_out_label_specification(self, mode):
        """The specification for the output labels after executing preprocess_fn.

    Note, we strip all optional specs to further reduce communication and
    computation overhead for feeding to TPUs.

    Arguments:
      mode: mode key for this feature specification

    Returns:
      A TensorSpecStruct describing the required and optional tensors.
    """
        return tensorspec_utils.replace_dtype(
            tensorspec_utils.filter_required_flat_tensor_spec(
                self._preprocessor.get_out_label_specification(mode)),
            from_dtype=tf.float32,
            to_dtype=tf.bfloat16)