示例#1
0
        def decode_before_sum_fn(encoded_structure, params):
            """See the `decode_before_sum` method of this class."""
            py_utils.assert_compatible(encoded_structure_spec,
                                       encoded_structure)
            py_utils.assert_compatible(decode_before_sum_params_spec, params)

            encoded_structure = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(
                    internal_structure['encoded_structure'],
                    tf.nest.flatten(encoded_structure)),
                internal_py_values['encoded_structure'])
            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(
                    internal_structure['decode_before_sum_params'], params),
                internal_py_values['decode_before_sum_params'])

            encoded_tensors = encoded_structure[_TENSORS]
            input_shapes = encoded_structure[_SHAPES]
            part_decoded_structure = encoder.decode_before_sum(
                encoded_tensors, params, input_shapes)

            _add_to_structure('part_decoded_structure', part_decoded_structure)
            if isinstance(part_decoded_structure, dict):
                return dict(
                    py_utils.flatten_with_joined_string_paths(
                        part_decoded_structure))
            else:
                return part_decoded_structure
示例#2
0
        def encode_fn(x, params):
            """See the `encode` method of this class."""
            if not tensorspec.is_compatible_with(x):
                raise ValueError(
                    'The provided x is not compatible with the expected tensorspec.'
                )
            py_utils.assert_compatible(encode_params_spec, params)

            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(internal_structure['encode_params'],
                                         params),
                internal_py_values['encode_params'])
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, params)
            input_shapes_before_sum, _ = (
                core_encoder.split_shapes_by_commuting_structure(
                    input_shapes, commuting_structure))

            encoded_structure = {
                _TENSORS: encoded_x,
                _SHAPES: input_shapes_before_sum
            }
            encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
                encoded_structure)

            _add_to_structure('encoded_structure', encoded_structure_tf)
            _add_to_structure('state_update_tensors', state_update_tensors)
            _add_to_py_values('encoded_structure', encoded_structure_py)

            return (dict(
                py_utils.flatten_with_joined_string_paths(
                    encoded_structure_tf)),
                    tuple(tf.nest.flatten(state_update_tensors)))
 def decode_fn(encoded_structure):
     encoded_structure = py_utils.merge_dicts(
         encoded_structure, encoded_py_structure['flat_py'])
     encoded_structure = nest.pack_sequence_as(
         encoded_py_structure['full'], nest.flatten(encoded_structure))
     return encoder.decode(encoded_structure[_TENSORS],
                           encoded_structure[_PARAMS],
                           encoded_structure[_SHAPES])
示例#4
0
  def test_split_merge_identity(self, **test_dict):
    """Tests that spliting and merging amounts to identity.

    This test method tests that using the `split_dict_py_tf` and `merge_dicts`
    methods together amounts to an identity.

    Args:
      **test_dict: A dictionary to be used for the test.
    """
    new_dict = py_utils.merge_dicts(*py_utils.split_dict_py_tf(test_dict))
    self.assertDictEqual(new_dict, test_dict)
示例#5
0
 def decode_fn(encoded_structure):
     """Decoding function corresponding to the input arguments."""
     with tf.name_scope(None, 'simple_encoder_decode',
                        nest.flatten(encoded_structure)):
         if set(encoded_structure.keys()) != set(
                 flat_encoded_structure_tf.keys()):
             raise ValueError(
                 'The provided encoded_structure has unexpected structure. Please '
                 'make sure the structure of the dictionary was not changed.'
             )
         encoded_structure = py_utils.merge_dicts(
             encoded_structure, flat_encoded_structure_py)
         encoded_structure = nest.pack_sequence_as(
             full_encoded_structure, nest.flatten(encoded_structure))
         return encoder.decode(encoded_structure[_TENSORS],
                               encoded_structure[_PARAMS],
                               encoded_structure[_SHAPES])
示例#6
0
    def decode_after_sum_fn(part_decoded_structure, params, num_summands):
      """See the `decode_after_sum` method of this class."""
      py_utils.assert_compatible(part_decoded_structure_spec,
                                 part_decoded_structure)
      py_utils.assert_compatible(decode_after_sum_params_spec, params)

      part_decoded_structure = tf.nest.pack_sequence_as(
          internal_structure['part_decoded_structure'],
          tf.nest.flatten(part_decoded_structure))
      params = py_utils.merge_dicts(
          tf.nest.pack_sequence_as(
              internal_structure['decode_after_sum_params'], params),
          internal_py_values['decode_after_sum_params'])
      actual_params = params[_PARAMS]
      shapes = params[_SHAPES]
      decoded_x = encoder.decode_after_sum(part_decoded_structure,
                                           actual_params, num_summands, shapes)
      return decoded_x
示例#7
0
 def test_merge_dicts_raises(self, bad_dict1, bad_dict2, error_type):
   """Tests that `merge_dicts` raises appropriate error."""
   with self.assertRaises(error_type):
     py_utils.merge_dicts(bad_dict1, bad_dict2)
   with self.assertRaises(error_type):
     py_utils.merge_dicts(bad_dict2, bad_dict1)
示例#8
0
 def test_merge_dicts(self, dict1, dict2, expected_dict):
   """Tests that `merge_dicts` works as expected."""
   self.assertDictEqual(expected_dict, py_utils.merge_dicts(dict1, dict2))
   self.assertDictEqual(expected_dict, py_utils.merge_dicts(dict2, dict1))