def decode_before_sum_fn(encoded_structure, params): """See the `decode_before_sum` method of this class.""" py_utils.assert_compatible(encoded_structure_spec, encoded_structure) py_utils.assert_compatible(decode_before_sum_params_spec, params) encoded_structure = py_utils.merge_dicts( tf.nest.pack_sequence_as( internal_structure['encoded_structure'], tf.nest.flatten(encoded_structure)), internal_py_values['encoded_structure']) params = py_utils.merge_dicts( tf.nest.pack_sequence_as( internal_structure['decode_before_sum_params'], params), internal_py_values['decode_before_sum_params']) encoded_tensors = encoded_structure[_TENSORS] input_shapes = encoded_structure[_SHAPES] part_decoded_structure = encoder.decode_before_sum( encoded_tensors, params, input_shapes) _add_to_structure('part_decoded_structure', part_decoded_structure) if isinstance(part_decoded_structure, dict): return dict( py_utils.flatten_with_joined_string_paths( part_decoded_structure)) else: return part_decoded_structure
def encode_fn(x, params): """See the `encode` method of this class.""" if not tensorspec.is_compatible_with(x): raise ValueError( 'The provided x is not compatible with the expected tensorspec.' ) py_utils.assert_compatible(encode_params_spec, params) params = py_utils.merge_dicts( tf.nest.pack_sequence_as(internal_structure['encode_params'], params), internal_py_values['encode_params']) encoded_x, state_update_tensors, input_shapes = encoder.encode( x, params) input_shapes_before_sum, _ = ( core_encoder.split_shapes_by_commuting_structure( input_shapes, commuting_structure)) encoded_structure = { _TENSORS: encoded_x, _SHAPES: input_shapes_before_sum } encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf( encoded_structure) _add_to_structure('encoded_structure', encoded_structure_tf) _add_to_structure('state_update_tensors', state_update_tensors) _add_to_py_values('encoded_structure', encoded_structure_py) return (dict( py_utils.flatten_with_joined_string_paths( encoded_structure_tf)), tuple(tf.nest.flatten(state_update_tensors)))
def decode_fn(encoded_structure): encoded_structure = py_utils.merge_dicts( encoded_structure, encoded_py_structure['flat_py']) encoded_structure = nest.pack_sequence_as( encoded_py_structure['full'], nest.flatten(encoded_structure)) return encoder.decode(encoded_structure[_TENSORS], encoded_structure[_PARAMS], encoded_structure[_SHAPES])
def test_split_merge_identity(self, **test_dict): """Tests that spliting and merging amounts to identity. This test method tests that using the `split_dict_py_tf` and `merge_dicts` methods together amounts to an identity. Args: **test_dict: A dictionary to be used for the test. """ new_dict = py_utils.merge_dicts(*py_utils.split_dict_py_tf(test_dict)) self.assertDictEqual(new_dict, test_dict)
def decode_fn(encoded_structure): """Decoding function corresponding to the input arguments.""" with tf.name_scope(None, 'simple_encoder_decode', nest.flatten(encoded_structure)): if set(encoded_structure.keys()) != set( flat_encoded_structure_tf.keys()): raise ValueError( 'The provided encoded_structure has unexpected structure. Please ' 'make sure the structure of the dictionary was not changed.' ) encoded_structure = py_utils.merge_dicts( encoded_structure, flat_encoded_structure_py) encoded_structure = nest.pack_sequence_as( full_encoded_structure, nest.flatten(encoded_structure)) return encoder.decode(encoded_structure[_TENSORS], encoded_structure[_PARAMS], encoded_structure[_SHAPES])
def decode_after_sum_fn(part_decoded_structure, params, num_summands): """See the `decode_after_sum` method of this class.""" py_utils.assert_compatible(part_decoded_structure_spec, part_decoded_structure) py_utils.assert_compatible(decode_after_sum_params_spec, params) part_decoded_structure = tf.nest.pack_sequence_as( internal_structure['part_decoded_structure'], tf.nest.flatten(part_decoded_structure)) params = py_utils.merge_dicts( tf.nest.pack_sequence_as( internal_structure['decode_after_sum_params'], params), internal_py_values['decode_after_sum_params']) actual_params = params[_PARAMS] shapes = params[_SHAPES] decoded_x = encoder.decode_after_sum(part_decoded_structure, actual_params, num_summands, shapes) return decoded_x
def test_merge_dicts_raises(self, bad_dict1, bad_dict2, error_type): """Tests that `merge_dicts` raises appropriate error.""" with self.assertRaises(error_type): py_utils.merge_dicts(bad_dict1, bad_dict2) with self.assertRaises(error_type): py_utils.merge_dicts(bad_dict2, bad_dict1)
def test_merge_dicts(self, dict1, dict2, expected_dict): """Tests that `merge_dicts` works as expected.""" self.assertDictEqual(expected_dict, py_utils.merge_dicts(dict1, dict2)) self.assertDictEqual(expected_dict, py_utils.merge_dicts(dict2, dict1))