def encode_fn(x, params): """See the `encode` method of this class.""" if not tensorspec.is_compatible_with(x): raise ValueError( 'The provided x is not compatible with the expected tensorspec.' ) py_utils.assert_compatible(encode_params_spec, params) params = py_utils.merge_dicts( tf.nest.pack_sequence_as(internal_structure['encode_params'], params), internal_py_values['encode_params']) encoded_x, state_update_tensors, input_shapes = encoder.encode( x, params) input_shapes_before_sum, _ = ( core_encoder.split_shapes_by_commuting_structure( input_shapes, commuting_structure)) encoded_structure = { _TENSORS: encoded_x, _SHAPES: input_shapes_before_sum } encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf( encoded_structure) _add_to_structure('encoded_structure', encoded_structure_tf) _add_to_structure('state_update_tensors', state_update_tensors) _add_to_py_values('encoded_structure', encoded_structure_py) return (dict( py_utils.flatten_with_joined_string_paths( encoded_structure_tf)), tuple(tf.nest.flatten(state_update_tensors)))
def decode_before_sum_fn(encoded_structure, params): """See the `decode_before_sum` method of this class.""" py_utils.assert_compatible(encoded_structure_spec, encoded_structure) py_utils.assert_compatible(decode_before_sum_params_spec, params) encoded_structure = py_utils.merge_dicts( tf.nest.pack_sequence_as( internal_structure['encoded_structure'], tf.nest.flatten(encoded_structure)), internal_py_values['encoded_structure']) params = py_utils.merge_dicts( tf.nest.pack_sequence_as( internal_structure['decode_before_sum_params'], params), internal_py_values['decode_before_sum_params']) encoded_tensors = encoded_structure[_TENSORS] input_shapes = encoded_structure[_SHAPES] part_decoded_structure = encoder.decode_before_sum( encoded_tensors, params, input_shapes) _add_to_structure('part_decoded_structure', part_decoded_structure) if isinstance(part_decoded_structure, dict): return dict( py_utils.flatten_with_joined_string_paths( part_decoded_structure)) else: return part_decoded_structure
def update_state_fn(flat_state, state_update_tensors): """See the `update_state` method of this class.""" py_utils.assert_compatible(flat_state_spec, flat_state) state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state) state_update_tensors = tf.nest.pack_sequence_as( internal_structure['state_update_tensors'], state_update_tensors) updated_state = encoder.update_state(state, state_update_tensors) return tuple(tf.nest.flatten(updated_state))
def test_assert_compatible(self): spec = [ tf.TensorSpec((2,), tf.int32), (tf.TensorSpec((2, 3), tf.float64), tf.TensorSpec((), tf.float32)) ] value = [ tf.zeros((2,), tf.int32), (tf.zeros((2, 3), tf.float64), tf.zeros((), tf.float32)) ] py_utils.assert_compatible(spec, value) py_utils.assert_compatible( tf.TensorSpec(None, tf.int32), tf.zeros((2,), tf.int32))
def decode_after_sum_fn(part_decoded_structure, params, num_summands): """See the `decode_after_sum` method of this class.""" py_utils.assert_compatible(part_decoded_structure_spec, part_decoded_structure) py_utils.assert_compatible(decode_after_sum_params_spec, params) part_decoded_structure = tf.nest.pack_sequence_as( internal_structure['part_decoded_structure'], tf.nest.flatten(part_decoded_structure)) params = py_utils.merge_dicts( tf.nest.pack_sequence_as( internal_structure['decode_after_sum_params'], params), internal_py_values['decode_after_sum_params']) actual_params = params[_PARAMS] shapes = params[_SHAPES] decoded_x = encoder.decode_after_sum(part_decoded_structure, actual_params, num_summands, shapes) return decoded_x
def get_params_fn(flat_state): """See the `get_params` method of this class.""" py_utils.assert_compatible(flat_state_spec, flat_state) state = tf.nest.pack_sequence_as(internal_structure['state'], flat_state) encode_params, decode_params = encoder.get_params(state) decode_before_sum_params, decode_after_sum_params = ( core_encoder.split_params_by_commuting_structure( decode_params, commuting_structure)) # Get the portion of input_shapes that will be relevant in the # decode_after_sum method and fold it into the params exposed to user. _, _, input_shapes = encoder.encode( tf.zeros(tensorspec.shape, tensorspec.dtype), encode_params) _, input_shapes_after_sum = ( core_encoder.split_shapes_by_commuting_structure( input_shapes, commuting_structure)) decode_after_sum_params = { _PARAMS: decode_after_sum_params, _SHAPES: input_shapes_after_sum } encode_params_py, encode_params_tf = py_utils.split_dict_py_tf( encode_params) decode_before_sum_params_py, decode_before_sum_params_tf = ( py_utils.split_dict_py_tf(decode_before_sum_params)) decode_after_sum_params_py, decode_after_sum_params_tf = ( py_utils.split_dict_py_tf(decode_after_sum_params)) _add_to_structure('encode_params', encode_params_tf) _add_to_structure('decode_before_sum_params', decode_before_sum_params_tf) _add_to_structure('decode_after_sum_params', decode_after_sum_params_tf) _add_to_py_values('encode_params', encode_params_py) _add_to_py_values('decode_before_sum_params', decode_before_sum_params_py) _add_to_py_values('decode_after_sum_params', decode_after_sum_params_py) return (tuple(tf.nest.flatten(encode_params_tf)), tuple(tf.nest.flatten(decode_before_sum_params_tf)), tuple(tf.nest.flatten(decode_after_sum_params_tf)))
def test_assert_compatible_raises_type_error(self, not_a_spec): with self.assertRaises(TypeError): # pylint: disable=g-error-prone-assert-raises py_utils.assert_compatible(not_a_spec, None)
def test_assert_compatible_raises_incompatible_shapes(self, dtype): spec = tf.TensorSpec((2,), dtype) value = tf.zeros((3,), dtype) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises py_utils.assert_compatible(spec, value)
def test_assert_compatible_raises_incompatible_dtype(self, *shape): spec = tf.TensorSpec(shape, tf.float32) value = tf.zeros(shape, tf.float64) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises py_utils.assert_compatible(spec, value)