コード例 #1
0
  def test_build_encode_decode_tf_computations_for_broadcast(
      self, encoder_constructor):
    value_spec = tf.TensorSpec((20,), tf.float32)
    encoder = te.encoders.as_simple_encoder(encoder_constructor(), value_spec)

    _, state_type = encoding_utils._build_initial_state_tf_computation(encoder)
    value_type = computation_types.to_type(value_spec)
    encode_fn, decode_fn = (
        encoding_utils._build_encode_decode_tf_computations_for_broadcast(
            state_type, value_type, encoder))

    self.assertEqual(state_type, encode_fn.type_signature.parameter[0])
    self.assertEqual(state_type, encode_fn.type_signature.result[0])
    # Output of encode should be the input to decode.
    self.assertEqual(encode_fn.type_signature.result[1],
                     decode_fn.type_signature.parameter)
    # Decode should return the same type as input to encode - value_type.
    self.assertEqual(value_type, encode_fn.type_signature.parameter[1])
    self.assertEqual(value_type, decode_fn.type_signature.result)
コード例 #2
0
  def test_build_tf_computations_for_sum(self, encoder_constructor):
    # Tests that the partial computations have matching relevant input-output
    # signatures.
    value_spec = tf.TensorSpec((20,), tf.float32)
    encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec)

    _, state_type = encoding_utils._build_initial_state_tf_computation(encoder)
    value_type = computation_types.to_type(value_spec)
    nest_encoder = encoding_utils._build_tf_computations_for_gather(
        state_type, value_type, encoder)

    self.assertEqual(state_type,
                     nest_encoder.get_params_fn.type_signature.parameter)
    encode_params_type = nest_encoder.get_params_fn.type_signature.result[0]
    decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
        1]
    decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
        2]

    self.assertEqual(value_type,
                     nest_encoder.encode_fn.type_signature.parameter[0])
    self.assertEqual(encode_params_type,
                     nest_encoder.encode_fn.type_signature.parameter[1])
    self.assertEqual(decode_before_sum_params_type,
                     nest_encoder.encode_fn.type_signature.parameter[2])
    state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[2]

    accumulator_type = nest_encoder.zero_fn.type_signature.result
    self.assertEqual(state_update_tensors_type,
                     accumulator_type.state_update_tensors)

    self.assertEqual(accumulator_type,
                     nest_encoder.accumulate_fn.type_signature.parameter[0])
    self.assertEqual(nest_encoder.encode_fn.type_signature.result,
                     nest_encoder.accumulate_fn.type_signature.parameter[1])
    self.assertEqual(accumulator_type,
                     nest_encoder.accumulate_fn.type_signature.result)
    self.assertEqual(accumulator_type,
                     nest_encoder.merge_fn.type_signature.parameter[0])
    self.assertEqual(accumulator_type,
                     nest_encoder.merge_fn.type_signature.parameter[1])
    self.assertEqual(accumulator_type,
                     nest_encoder.merge_fn.type_signature.result)
    self.assertEqual(accumulator_type,
                     nest_encoder.report_fn.type_signature.parameter)
    self.assertEqual(accumulator_type,
                     nest_encoder.report_fn.type_signature.result)

    self.assertEqual(
        accumulator_type.values,
        nest_encoder.decode_after_sum_fn.type_signature.parameter[0])
    self.assertEqual(
        decode_after_sum_params_type,
        nest_encoder.decode_after_sum_fn.type_signature.parameter[1])
    self.assertEqual(value_type,
                     nest_encoder.decode_after_sum_fn.type_signature.result)

    self.assertEqual(state_type,
                     nest_encoder.update_state_fn.type_signature.parameter[0])
    self.assertEqual(state_update_tensors_type,
                     nest_encoder.update_state_fn.type_signature.parameter[1])
    self.assertEqual(state_type,
                     nest_encoder.update_state_fn.type_signature.result)