Example #1
0
        def encode_fn(x, params):
            """See the `encode` method of this class."""
            if not tensorspec.is_compatible_with(x):
                raise ValueError(
                    'The provided x is not compatible with the expected tensorspec.'
                )
            py_utils.assert_compatible(encode_params_spec, params)

            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(internal_structure['encode_params'],
                                         params),
                internal_py_values['encode_params'])
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, params)
            input_shapes_before_sum, _ = (
                core_encoder.split_shapes_by_commuting_structure(
                    input_shapes, commuting_structure))

            encoded_structure = {
                _TENSORS: encoded_x,
                _SHAPES: input_shapes_before_sum
            }
            encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
                encoded_structure)

            _add_to_structure('encoded_structure', encoded_structure_tf)
            _add_to_structure('state_update_tensors', state_update_tensors)
            _add_to_py_values('encoded_structure', encoded_structure_py)

            return (dict(
                py_utils.flatten_with_joined_string_paths(
                    encoded_structure_tf)),
                    tuple(tf.nest.flatten(state_update_tensors)))
Example #2
0
        def decode_before_sum_fn(encoded_structure, params):
            """See the `decode_before_sum` method of this class."""
            py_utils.assert_compatible(encoded_structure_spec,
                                       encoded_structure)
            py_utils.assert_compatible(decode_before_sum_params_spec, params)

            encoded_structure = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(
                    internal_structure['encoded_structure'],
                    tf.nest.flatten(encoded_structure)),
                internal_py_values['encoded_structure'])
            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(
                    internal_structure['decode_before_sum_params'], params),
                internal_py_values['decode_before_sum_params'])

            encoded_tensors = encoded_structure[_TENSORS]
            input_shapes = encoded_structure[_SHAPES]
            part_decoded_structure = encoder.decode_before_sum(
                encoded_tensors, params, input_shapes)

            _add_to_structure('part_decoded_structure', part_decoded_structure)
            if isinstance(part_decoded_structure, dict):
                return dict(
                    py_utils.flatten_with_joined_string_paths(
                        part_decoded_structure))
            else:
                return part_decoded_structure
Example #3
0
 def update_state_fn(flat_state, state_update_tensors):
   """See the `update_state` method of this class."""
   py_utils.assert_compatible(flat_state_spec, flat_state)
   state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state)
   state_update_tensors = tf.nest.pack_sequence_as(
       internal_structure['state_update_tensors'], state_update_tensors)
   updated_state = encoder.update_state(state, state_update_tensors)
   return tuple(tf.nest.flatten(updated_state))
Example #4
0
 def test_assert_compatible(self):
   spec = [
       tf.TensorSpec((2,), tf.int32),
       (tf.TensorSpec((2, 3), tf.float64), tf.TensorSpec((), tf.float32))
   ]
   value = [
       tf.zeros((2,), tf.int32),
       (tf.zeros((2, 3), tf.float64), tf.zeros((), tf.float32))
   ]
   py_utils.assert_compatible(spec, value)
   py_utils.assert_compatible(
       tf.TensorSpec(None, tf.int32), tf.zeros((2,), tf.int32))
Example #5
0
    def decode_after_sum_fn(part_decoded_structure, params, num_summands):
      """See the `decode_after_sum` method of this class."""
      py_utils.assert_compatible(part_decoded_structure_spec,
                                 part_decoded_structure)
      py_utils.assert_compatible(decode_after_sum_params_spec, params)

      part_decoded_structure = tf.nest.pack_sequence_as(
          internal_structure['part_decoded_structure'],
          tf.nest.flatten(part_decoded_structure))
      params = py_utils.merge_dicts(
          tf.nest.pack_sequence_as(
              internal_structure['decode_after_sum_params'], params),
          internal_py_values['decode_after_sum_params'])
      actual_params = params[_PARAMS]
      shapes = params[_SHAPES]
      decoded_x = encoder.decode_after_sum(part_decoded_structure,
                                           actual_params, num_summands, shapes)
      return decoded_x
Example #6
0
        def get_params_fn(flat_state):
            """See the `get_params` method of this class."""
            py_utils.assert_compatible(flat_state_spec, flat_state)
            state = tf.nest.pack_sequence_as(internal_structure['state'],
                                             flat_state)

            encode_params, decode_params = encoder.get_params(state)
            decode_before_sum_params, decode_after_sum_params = (
                core_encoder.split_params_by_commuting_structure(
                    decode_params, commuting_structure))

            # Get the portion of input_shapes that will be relevant in the
            # decode_after_sum method and fold it into the params exposed to user.
            _, _, input_shapes = encoder.encode(
                tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params)
            _, input_shapes_after_sum = (
                core_encoder.split_shapes_by_commuting_structure(
                    input_shapes, commuting_structure))
            decode_after_sum_params = {
                _PARAMS: decode_after_sum_params,
                _SHAPES: input_shapes_after_sum
            }

            encode_params_py, encode_params_tf = py_utils.split_dict_py_tf(
                encode_params)
            decode_before_sum_params_py, decode_before_sum_params_tf = (
                py_utils.split_dict_py_tf(decode_before_sum_params))
            decode_after_sum_params_py, decode_after_sum_params_tf = (
                py_utils.split_dict_py_tf(decode_after_sum_params))

            _add_to_structure('encode_params', encode_params_tf)
            _add_to_structure('decode_before_sum_params',
                              decode_before_sum_params_tf)
            _add_to_structure('decode_after_sum_params',
                              decode_after_sum_params_tf)
            _add_to_py_values('encode_params', encode_params_py)
            _add_to_py_values('decode_before_sum_params',
                              decode_before_sum_params_py)
            _add_to_py_values('decode_after_sum_params',
                              decode_after_sum_params_py)

            return (tuple(tf.nest.flatten(encode_params_tf)),
                    tuple(tf.nest.flatten(decode_before_sum_params_tf)),
                    tuple(tf.nest.flatten(decode_after_sum_params_tf)))
Example #7
0
 def test_assert_compatible_raises_type_error(self, not_a_spec):
   with self.assertRaises(TypeError):  # pylint: disable=g-error-prone-assert-raises
     py_utils.assert_compatible(not_a_spec, None)
Example #8
0
 def test_assert_compatible_raises_incompatible_shapes(self, dtype):
   spec = tf.TensorSpec((2,), dtype)
   value = tf.zeros((3,), dtype)
   with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
     py_utils.assert_compatible(spec, value)
Example #9
0
 def test_assert_compatible_raises_incompatible_dtype(self, *shape):
   spec = tf.TensorSpec(shape, tf.float32)
   value = tf.zeros(shape, tf.float64)
   with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
     py_utils.assert_compatible(spec, value)