Ejemplo n.º 1
0
        def encode_fn(x, params):
            """See the `encode` method of this class."""
            if not tensorspec.is_compatible_with(x):
                raise ValueError(
                    'The provided x is not compatible with the expected tensorspec.'
                )
            py_utils.assert_compatible(encode_params_spec, params)

            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(internal_structure['encode_params'],
                                         params),
                internal_py_values['encode_params'])
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, params)
            input_shapes_before_sum, _ = (
                core_encoder.split_shapes_by_commuting_structure(
                    input_shapes, commuting_structure))

            encoded_structure = {
                _TENSORS: encoded_x,
                _SHAPES: input_shapes_before_sum
            }
            encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
                encoded_structure)

            _add_to_structure('encoded_structure', encoded_structure_tf)
            _add_to_structure('state_update_tensors', state_update_tensors)
            _add_to_py_values('encoded_structure', encoded_structure_py)

            return (dict(
                py_utils.flatten_with_joined_string_paths(
                    encoded_structure_tf)),
                    tuple(tf.nest.flatten(state_update_tensors)))
        def encode_fn(x, flat_state):
            state = nest.pack_sequence_as(state_py_structure['state'],
                                          flat_state)
            encode_params, decode_params = encoder.get_params(state)
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, encode_params)
            updated_flat_state = tuple(
                nest.flatten(encoder.update_state(state,
                                                  state_update_tensors)))

            # The following code converts the nested structres necessary for the
            # underlying encoder, to a single flat dictionary, which is simpler to
            # manipulate by the users of SimpleEncoder.
            full_encoded_structure = {
                _TENSORS: encoded_x,
                _PARAMS: decode_params,
                _SHAPES: input_shapes
            }
            flat_encoded_structure = dict(
                nest.flatten_with_joined_string_paths(full_encoded_structure,
                                                      separator='/'))
            flat_encoded_py_structure, flat_encoded_tf_structure = (
                py_utils.split_dict_py_tf(flat_encoded_structure))

            if not encoded_py_structure:
                encoded_py_structure['full'] = nest.map_structure(
                    lambda _: None, full_encoded_structure)
                encoded_py_structure['flat_py'] = flat_encoded_py_structure
            return flat_encoded_tf_structure, updated_flat_state
Ejemplo n.º 3
0
def _make_decode_fn(encoder, encoded_x, decode_params, input_shapes):
    """Utility for creating a decoding function and its arguments.

  The inputs are potentially complex, nested structures of dictionaries. See
  documentation of the `Encoder` class for more details on the structure. In
  order to expose only a simple structure to the users, this method does
  the following:

  It creates a single dictionary out of the three input arguments, `encoded_x`,
  `decode_params`, `input_shapes`. Then, using the `nest` utility, flattens the
  dictionary. We split the flat dictionary to two parts, based on whether the
  keys map to TensorFlow objects or not. Only the part with TensorFlow objects
  is exposed to users, and is the expected input to the constructed decoding
  function.

  The decoding function merges the TensorFlow values with the non-TensorFlow
  values (never exposed to users), reconstructs the complex, nested dictionary,
  before providing the values back to the `decode` method of `encoder`.

  Args:
    encoder: An `Encoder` object that was used to generate the other arguments.
    encoded_x: The `encoded_x` value returned by `encoder.encode`.
    decode_params: The `decode_params` value returned by `encoder.get_params`.
    input_shapes: The `input_shapes` value returned by `encoder.encode`.

  Returns:
    A tuple expected as the return structure of the `encode` method.
  """
    full_encoded_structure = {
        _TENSORS: encoded_x,
        _PARAMS: decode_params,
        _SHAPES: input_shapes,
    }
    flat_encoded_structure = dict(
        nest.flatten_with_joined_string_paths(full_encoded_structure,
                                              separator='/'))
    flat_encoded_structure_py, flat_encoded_structure_tf = (
        py_utils.split_dict_py_tf(flat_encoded_structure))

    def decode_fn(encoded_structure):
        """Decoding function corresponding to the input arguments."""
        with tf.name_scope(None, 'simple_encoder_decode',
                           nest.flatten(encoded_structure)):
            if set(encoded_structure.keys()) != set(
                    flat_encoded_structure_tf.keys()):
                raise ValueError(
                    'The provided encoded_structure has unexpected structure. Please '
                    'make sure the structure of the dictionary was not changed.'
                )
            encoded_structure = py_utils.merge_dicts(
                encoded_structure, flat_encoded_structure_py)
            encoded_structure = nest.pack_sequence_as(
                full_encoded_structure, nest.flatten(encoded_structure))
            return encoder.decode(encoded_structure[_TENSORS],
                                  encoded_structure[_PARAMS],
                                  encoded_structure[_SHAPES])

    return flat_encoded_structure_tf, decode_fn
Ejemplo n.º 4
0
 def test_split_dict_py_tf_basic(self):
   """Tests that `split_dict_py_tf` works with flat dictionary."""
   const = tf.constant(2.0)
   test_dict = {'py': 1.0, 'tf': const}
   expected_d_py = {'py': 1.0}
   expected_d_tf = {'tf': const}
   d_py, d_tf = py_utils.split_dict_py_tf(test_dict)
   self.assertDictEqual(expected_d_py, d_py)
   self.assertDictEqual(expected_d_tf, d_tf)
Ejemplo n.º 5
0
        def get_params_fn(flat_state):
            """See the `get_params` method of this class."""
            py_utils.assert_compatible(flat_state_spec, flat_state)
            state = tf.nest.pack_sequence_as(internal_structure['state'],
                                             flat_state)

            encode_params, decode_params = encoder.get_params(state)
            decode_before_sum_params, decode_after_sum_params = (
                core_encoder.split_params_by_commuting_structure(
                    decode_params, commuting_structure))

            # Get the portion of input_shapes that will be relevant in the
            # decode_after_sum method and fold it into the params exposed to user.
            _, _, input_shapes = encoder.encode(
                tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params)
            _, input_shapes_after_sum = (
                core_encoder.split_shapes_by_commuting_structure(
                    input_shapes, commuting_structure))
            decode_after_sum_params = {
                _PARAMS: decode_after_sum_params,
                _SHAPES: input_shapes_after_sum
            }

            encode_params_py, encode_params_tf = py_utils.split_dict_py_tf(
                encode_params)
            decode_before_sum_params_py, decode_before_sum_params_tf = (
                py_utils.split_dict_py_tf(decode_before_sum_params))
            decode_after_sum_params_py, decode_after_sum_params_tf = (
                py_utils.split_dict_py_tf(decode_after_sum_params))

            _add_to_structure('encode_params', encode_params_tf)
            _add_to_structure('decode_before_sum_params',
                              decode_before_sum_params_tf)
            _add_to_structure('decode_after_sum_params',
                              decode_after_sum_params_tf)
            _add_to_py_values('encode_params', encode_params_py)
            _add_to_py_values('decode_before_sum_params',
                              decode_before_sum_params_py)
            _add_to_py_values('decode_after_sum_params',
                              decode_after_sum_params_py)

            return (tuple(tf.nest.flatten(encode_params_tf)),
                    tuple(tf.nest.flatten(decode_before_sum_params_tf)),
                    tuple(tf.nest.flatten(decode_after_sum_params_tf)))
Ejemplo n.º 6
0
  def test_split_merge_identity(self, **test_dict):
    """Tests that spliting and merging amounts to identity.

    This test method tests that using the `split_dict_py_tf` and `merge_dicts`
    methods together amounts to an identity.

    Args:
      **test_dict: A dictionary to be used for the test.
    """
    new_dict = py_utils.merge_dicts(*py_utils.split_dict_py_tf(test_dict))
    self.assertDictEqual(new_dict, test_dict)
Ejemplo n.º 7
0
 def test_split_dict_py_tf_nested(self):
   """Tests that `split_dict_py_tf` works with nested dictionary."""
   const_1, const_2 = tf.constant(1.0), tf.constant(2.0)
   test_dict = {
       'nested': {
           'a': 1.0,
           'b': const_1
       },
       'py': 'string',
       'tf': const_2
   }
   expected_d_py = {
       'nested': {
           'a': 1.0,
       },
       'py': 'string',
   }
   expected_d_tf = {'nested': {'b': const_1}, 'tf': const_2}
   d_py, d_tf = py_utils.split_dict_py_tf(test_dict)
   self.assertDictEqual(expected_d_py, d_py)
   self.assertDictEqual(expected_d_tf, d_tf)
Ejemplo n.º 8
0
 def test_split_dict_py_tf_empty(self):
   """Tests that `split_dict_py_tf` works with empty dictionary."""
   d_py, d_tf = py_utils.split_dict_py_tf({})
   self.assertDictEqual({}, d_py)
   self.assertDictEqual({}, d_tf)
Ejemplo n.º 9
0
 def test_split_dict_py_tf_raises(self, bad_input):
   """Tests that `split_dict_py_tf` raises `TypeError`."""
   with self.assertRaises(TypeError):
     py_utils.split_dict_py_tf(bad_input)