def test_build_encode_decode_tf_computations_for_broadcast( self, encoder_constructor): value_spec = tf.TensorSpec((20,), tf.float32) encoder = te.encoders.as_simple_encoder(encoder_constructor(), value_spec) _, state_type = encoding_utils._build_initial_state_tf_computation(encoder) value_type = computation_types.to_type(value_spec) encode_fn, decode_fn = ( encoding_utils._build_encode_decode_tf_computations_for_broadcast( state_type, value_type, encoder)) self.assertEqual(state_type, encode_fn.type_signature.parameter[0]) self.assertEqual(state_type, encode_fn.type_signature.result[0]) # Output of encode should be the input to decode. self.assertEqual(encode_fn.type_signature.result[1], decode_fn.type_signature.parameter) # Decode should return the same type as input to encode - value_type. self.assertEqual(value_type, encode_fn.type_signature.parameter[1]) self.assertEqual(value_type, decode_fn.type_signature.result)
def test_build_tf_computations_for_sum(self, encoder_constructor): # Tests that the partial computations have matching relevant input-output # signatures. value_spec = tf.TensorSpec((20,), tf.float32) encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec) _, state_type = encoding_utils._build_initial_state_tf_computation(encoder) value_type = computation_types.to_type(value_spec) nest_encoder = encoding_utils._build_tf_computations_for_gather( state_type, value_type, encoder) self.assertEqual(state_type, nest_encoder.get_params_fn.type_signature.parameter) encode_params_type = nest_encoder.get_params_fn.type_signature.result[0] decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[ 1] decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[ 2] self.assertEqual(value_type, nest_encoder.encode_fn.type_signature.parameter[0]) self.assertEqual(encode_params_type, nest_encoder.encode_fn.type_signature.parameter[1]) self.assertEqual(decode_before_sum_params_type, nest_encoder.encode_fn.type_signature.parameter[2]) state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[2] accumulator_type = nest_encoder.zero_fn.type_signature.result self.assertEqual(state_update_tensors_type, accumulator_type.state_update_tensors) self.assertEqual(accumulator_type, nest_encoder.accumulate_fn.type_signature.parameter[0]) self.assertEqual(nest_encoder.encode_fn.type_signature.result, nest_encoder.accumulate_fn.type_signature.parameter[1]) self.assertEqual(accumulator_type, nest_encoder.accumulate_fn.type_signature.result) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.parameter[0]) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.parameter[1]) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.result) self.assertEqual(accumulator_type, nest_encoder.report_fn.type_signature.parameter) self.assertEqual(accumulator_type, nest_encoder.report_fn.type_signature.result) self.assertEqual( accumulator_type.values, nest_encoder.decode_after_sum_fn.type_signature.parameter[0]) self.assertEqual( decode_after_sum_params_type, nest_encoder.decode_after_sum_fn.type_signature.parameter[1]) self.assertEqual(value_type, nest_encoder.decode_after_sum_fn.type_signature.result) self.assertEqual(state_type, nest_encoder.update_state_fn.type_signature.parameter[0]) self.assertEqual(state_update_tensors_type, nest_encoder.update_state_fn.type_signature.parameter[1]) self.assertEqual(state_type, nest_encoder.update_state_fn.type_signature.result)