def test_adaptive_stage(self): """Tests composition of two adaptive encoding stages.""" encoder = core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).add_parent( test_utils.PlusOneOverNEncodingStage(), PN_VALS).make() x = tf.constant(1.0) state = encoder.initial_state() for i in range(1, 5): initial_state = state encode_params, decode_params = encoder.get_params(state) encoded_x, state_update_tensors, input_shapes = encoder.encode( x, encode_params) decoded_x = encoder.decode(encoded_x, decode_params, input_shapes) state = encoder.update_state(initial_state, state_update_tensors) data = self.evaluate( test_utils.TestData(x, encoded_x, decoded_x, initial_state, state_update_tensors, state)) expected_initial_state = { STATE: { PN_ITER_STATE: i }, CHILDREN: { PN_VALS: { STATE: { PN_ITER_STATE: i }, CHILDREN: {} } } } expected_state_update_tensors = { TENSORS: {}, CHILDREN: { PN_VALS: { TENSORS: {}, CHILDREN: {} } } } self.assertAllClose(data.x, data.decoded_x) self.assertAllEqual(expected_initial_state, data.initial_state) self.assertDictEqual(expected_state_update_tensors, data.state_update_tensors) self.assertAllClose(data.x + 2 * 1 / i, data.encoded_x[PN_VALS][PN_VALS])
def test_composite_encoder(self): """Tests functionality with a general, composite `Encoder`.""" x = tf.constant(1.2) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = simple_encoder.SimpleEncoder(encoder.make(), tf.TensorSpec.from_tensor(x)) state = encoder.initial_state() iteration = _make_iteration_function(encoder) for i in range(1, 10): x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state)) self.assertAllClose(x, decoded_x) self.assertAllClose( 2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_SIGNS, T2_VALS])) self.assertAllClose( 2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_INTS, P1_VALS])) self.assertAllClose( 0.4 + 1 / i, _encoded_x_field(encoded_x, [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
def test_composite_encoder(self): """Tests functionality with a general, composite `Encoder`.""" encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = simple_encoder.StatefulSimpleEncoder(encoder.make()) x = tf.constant(1.2) encoder.initialize() encoded_x, decode_fn = encoder.encode(x) decoded_x = decode_fn(encoded_x) self.evaluate(tf.global_variables_initializer()) for i in range(1, 10): x_py, encoded_x_py, decoded_x_py = self.evaluate( [x, encoded_x, decoded_x]) self.assertAllClose(x_py, decoded_x_py) self.assertAllClose( 2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_SIGNS, T2_VALS])) self.assertAllClose( 2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_INTS, P1_VALS])) self.assertAllClose( 0.4 + 1 / i, _encoded_x_field(encoded_x_py, [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
def test_not_fully_defined_shape_raises(self): """Tests tensorspec without fully defined shape.""" encoder = core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make() with self.assertRaisesRegex(TypeError, 'fully defined'): gather_encoder.GatherEncoder.from_encoder( encoder, tf.TensorSpec((None, ), tf.float32))
def test_input_tensorspec(self): x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) encoder = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x)) self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
def test_full_commutativity_with_sum(self): """Tests that fully commutes with sum property works.""" spec = tf.TensorSpec((2, ), tf.float32) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = gather_encoder.GatherEncoder.from_encoder( encoder.make(), spec) self.assertFalse(encoder.fully_commutes_with_sum)
def test_not_a_tensorspec_raises(self, not_a_tensorspec): """Tests invalid type of tensorspec argument.""" encoder = core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make() with self.assertRaisesRegex(TypeError, 'TensorSpec'): gather_encoder.GatherEncoder.from_encoder(encoder, not_a_tensorspec)
def test_uninitialized_encode_raises(self): """Tests uninitialized stateful encoder cannot perform encode.""" encoder = simple_encoder.StatefulSimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make()) x = tf.constant(1.0) with self.assertRaisesRegex(RuntimeError, 'not been initialized'): encoder.encode(x)
def test_input_tensorspec(self): """Tests input_tensorspec property.""" x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x)) self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
def test_multiple_initialize_raises(self): """Tests encoder can be initialized only once.""" encoder = simple_encoder.StatefulSimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make()) encoder.initialize() with self.assertRaisesRegex(RuntimeError, 'already initialized'): encoder.initialize()
def test_multiple_encode_raises(self): """Tests the encode method of stateful encoder can only be called once.""" encoder = simple_encoder.StatefulSimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make()) encoder.initialize() x = tf.constant(1.0) encoder.encode(x) with self.assertRaisesRegex(RuntimeError, 'only once'): encoder.encode(x)
def test_modifying_encoded_x_raises(self): """Tests decode_fn raises if the encoded_x dictionary is modified.""" encoder = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make()) x = tf.constant(1.0) encoded_x, decode_fn = encoder.encode(x) encoded_x['__NOT_EXPECTED_KEY__'] = None with self.assertRaises(ValueError): decode_fn(encoded_x) with self.assertRaises(ValueError): decode_fn({})
def test_basic_encode_decode(self): """Tests basic encoding and decoding works as expected.""" x = tf.constant(1.0, tf.float32) encoder = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x)) state = encoder.initial_state() iteration = _make_iteration_function(encoder) for i in range(1, 10): x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state)) self.assertAllClose(x, decoded_x) self.assertAllClose( 1.0 + 1 / i, _encoded_x_field(encoded_x, [TENSORS, PN_VALS]))
def test_basic_encode_decode(self): """Tests basic encoding and decoding works as expected.""" encoder = simple_encoder.StatefulSimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make()) x = tf.constant(1.0) encoder.initialize() encoded_x, decode_fn = encoder.encode(x) decoded_x = decode_fn(encoded_x) self.evaluate(tf.global_variables_initializer()) for i in range(1, 10): x_py, encoded_x_py, decoded_x_py = self.evaluate( [x, encoded_x, decoded_x]) self.assertAllClose(x_py, decoded_x_py) self.assertAllClose(1.0 + 1 / i, _encoded_x_field(encoded_x_py, [TENSORS, PN_VALS]))
def test_input_signature_enforced(self): """Tests that encode/decode input signature is enforced.""" x = tf.constant(1.0) encoder = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x)) state = encoder.initial_state() with self.assertRaises(ValueError): bad_x = tf.stack([x, x]) encoder.encode(bad_x, state) with self.assertRaises(ValueError): bad_state = state + (x, ) encoder.encode(x, bad_state) encoded_x = encoder.encode(x, state) with self.assertRaises(ValueError): bad_encoded_x = dict(encoded_x) bad_encoded_x.update({'x': x}) encoder.decode(bad_encoded_x)
def test_none_state_equal_to_initial_state(self): """Tests that not providing state is the same as initial_state.""" x_fn = lambda: tf.constant(1.0) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x_fn())) num_summands = 3 stateful_iteration = _make_iteration_function(encoder, x_fn, num_summands) state = encoder.initial_state() stateless_iteration = _make_stateless_iteration_function( encoder, x_fn, num_summands) stateful_data = self.evaluate(stateful_iteration(state)) stateless_data = self.evaluate(stateless_iteration()) self.assertAllClose(stateful_data.encoded_x, stateless_data.encoded_x) self.assertAllClose(stateful_data.decoded_x, stateless_data.decoded_x)
def test_basic_encode_decode(self): """Tests basic encoding and decoding works as expected.""" x_fn = lambda: tf.random.uniform((12, )) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x_fn())) num_summands = 3 iteration = _make_iteration_function(encoder, x_fn, num_summands) state = encoder.initial_state() for i in range(1, 5): data = self.evaluate(iteration(state)) for j in range(num_summands): self.assertAllClose( data.x[j] + 1 / i, _encoded_x_field(data.encoded_x[j], [TENSORS, PN_VALS])) self.assertEqual((i, ), data.initial_state) self.assertEqual((i + 1, ), data.updated_state) state = data.updated_state
def test_composite_encoder(self): """Tests functionality with a general, composite `Encoder`.""" x_fn = lambda: tf.constant(1.2) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = gather_encoder.GatherEncoder.from_encoder( encoder.make(), tf.TensorSpec.from_tensor(x_fn())) num_summands = 3 iteration = _make_iteration_function(encoder, x_fn, num_summands) state = encoder.initial_state() for i in range(1, 5): data = self.evaluate(iteration(state)) for j in range(num_summands): self.assertAllClose( 2.0, _encoded_x_field(data.encoded_x[j], [TENSORS, SIF_SIGNS, T2_VALS])) self.assertAllClose( 2.0, _encoded_x_field(data.encoded_x[j], [TENSORS, SIF_INTS, P1_VALS])) self.assertAllClose( 0.4 + 1 / i, _encoded_x_field(data.encoded_x[j], [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS])) self.assertAllClose(data.x[j], data.part_decoded_x[j]) self.assertAllClose(data.x[j] * num_summands, data.summed_part_decoded_x) self.assertAllClose(data.x[j] * num_summands, data.decoded_x) self.assertEqual((i, ), data.initial_state) self.assertEqual((i + 1, ), data.updated_state) state = data.updated_state
def test_none_state_equal_to_initial_state(self): """Tests that not providing state is the same as initial_state.""" x = tf.constant(1.0) encoder = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make(), tf.TensorSpec.from_tensor(x)) state = encoder.initial_state() stateful_iteration = _make_iteration_function(encoder) @tf.function def stateless_iteration(x): encoded_x, _ = encoder.encode(x) decoded_x = encoder.decode(encoded_x) return encoded_x, decoded_x _, encoded_x_stateful, decoded_x_stateful, _ = self.evaluate( stateful_iteration(x, state)) encoded_x_stateless, decoded_x_stateless = self.evaluate( stateless_iteration(x)) self.assertAllClose(encoded_x_stateful, encoded_x_stateless) self.assertAllClose(decoded_x_stateful, decoded_x_stateless)
def test_correct_structure(self): """Tests that structured objects look like what they should. This test creates the following encoding tree: SignIntFloatEncodingStage [SIF_SIGNS] -> TimesTwoEncodingStage [SIF_INTS] -> PlusOneEncodingStage [SIF_FLOATS] -> PlusOneOverNEncodingStage [PN_VALS] -> AdaptiveNormalizeEncodingStage And verifies that the structured objects created by the methods of `Encoder` are of the expected structure. """ sif_stage = test_utils.SignIntFloatEncodingStage() times_two_stage = test_utils.TimesTwoEncodingStage() plus_one_stage = test_utils.PlusOneEncodingStage() plus_n_squared_stage = test_utils.PlusOneOverNEncodingStage() adaptive_normalize_stage = test_utils.AdaptiveNormalizeEncodingStage() encoder = core_encoder.EncoderComposer(sif_stage) encoder.add_child(times_two_stage, SIF_SIGNS) encoder.add_child(plus_one_stage, SIF_INTS) encoder.add_child(plus_n_squared_stage, SIF_FLOATS).add_child(adaptive_normalize_stage, PN_VALS) encoder = encoder.make() # Create all intermediary objects. x = tf.constant(1.0) initial_state = encoder.initial_state() encode_params, decode_params = encoder.get_params(initial_state) encoded_x, state_update_tensors, input_shapes = encoder.encode( x, encode_params) decoded_x = encoder.decode(encoded_x, decode_params, input_shapes) updated_state = encoder.update_state(initial_state, state_update_tensors) commuting_structure = encoder.commuting_structure # Verify the structure and naming of those objects is as expected. for state in [initial_state, updated_state]: tf.nest.assert_same_structure( { STATE: {}, CHILDREN: { SIF_INTS: { STATE: {}, CHILDREN: {} }, SIF_SIGNS: { STATE: {}, CHILDREN: {} }, SIF_FLOATS: { STATE: { PN_ITER_STATE: None }, CHILDREN: { PN_VALS: { STATE: { AN_FACTOR_STATE: None }, CHILDREN: {} } } } } }, state) self.assertIn( 'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name) self.assertIn( 'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN] [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name) self.assertIn( 'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name) self.assertIn( 'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN] [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name) for params in [encode_params, decode_params]: tf.nest.assert_same_structure( { PARAMS: {}, CHILDREN: { SIF_INTS: { PARAMS: { P1_ADD_PARAM: None }, CHILDREN: {} }, SIF_SIGNS: { PARAMS: { T2_FACTOR_PARAM: None }, CHILDREN: {} }, SIF_FLOATS: { PARAMS: { PN_ADD_PARAM: None }, CHILDREN: { PN_VALS: { PARAMS: { AN_FACTOR_PARAM: None }, CHILDREN: {} } } } } }, params) self.assertIn( 'encoder_get_params/' + sif_stage.name + '/' + SIF_INTS + '/' + plus_one_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX, params[CHILDREN][SIF_INTS][PARAMS][P1_ADD_PARAM].name) self.assertIn( 'encoder_get_params/' + sif_stage.name + '/' + SIF_SIGNS + '/' + times_two_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX, params[CHILDREN][SIF_SIGNS][PARAMS][T2_FACTOR_PARAM].name) self.assertIn( 'encoder_get_params/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX, params[CHILDREN][SIF_FLOATS][PARAMS][PN_ADD_PARAM].name) # Note: we do not check the value of # params[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS][PARAMS][AN_FACTOR_PARAM] # because the get_params method of adaptive_normalize_stage does not # modify the graph, only passes through the provided state tensor. tf.nest.assert_same_structure( { SIF_INTS: { P1_VALS: None }, SIF_SIGNS: { T2_VALS: None }, SIF_FLOATS: { PN_VALS: { AN_VALS: None } } }, encoded_x) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_INTS + '/' + plus_one_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, encoded_x[SIF_INTS][P1_VALS].name) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_SIGNS + '/' + times_two_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, encoded_x[SIF_SIGNS][T2_VALS].name) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, encoded_x[SIF_FLOATS][PN_VALS][AN_VALS].name) tf.nest.assert_same_structure( { TENSORS: {}, CHILDREN: { SIF_INTS: { TENSORS: {}, CHILDREN: {} }, SIF_SIGNS: { TENSORS: {}, CHILDREN: {} }, SIF_FLOATS: { TENSORS: {}, CHILDREN: { PN_VALS: { TENSORS: { AN_NORM_UPDATE: None }, CHILDREN: {} } } } } }, state_update_tensors) tf.nest.assert_same_structure(state_update_tensors, encoder.state_update_aggregation_modes) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, state_update_tensors[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS] [TENSORS][AN_NORM_UPDATE].name) tf.nest.assert_same_structure( { SHAPE: None, CHILDREN: { SIF_INTS: { SHAPE: None, CHILDREN: {} }, SIF_SIGNS: { SHAPE: None, CHILDREN: {} }, SIF_FLOATS: { SHAPE: None, CHILDREN: { PN_VALS: { SHAPE: None, CHILDREN: {} } } } } }, input_shapes) self.assertTrue(tf.is_tensor(decoded_x)) self.assertIn('encoder_decode/', decoded_x.name) tf.nest.assert_same_structure( { COMMUTE: None, CHILDREN: { SIF_INTS: { COMMUTE: None, CHILDREN: {} }, SIF_SIGNS: { COMMUTE: None, CHILDREN: {} }, SIF_FLOATS: { COMMUTE: None, CHILDREN: { PN_VALS: { COMMUTE: None, CHILDREN: {} } } } } }, commuting_structure) for item in tf.nest.flatten(commuting_structure): self.assertEqual(False, item)
def test_is_adaptive_stage(self): self.assertFalse( test_utils.is_adaptive_stage(test_utils.PlusOneEncodingStage())) self.assertTrue( test_utils.is_adaptive_stage(test_utils.PlusOneOverNEncodingStage()))
def default_encoding_stage(self): """See base class.""" return test_utils.PlusOneOverNEncodingStage()
def test_not_a_tensorspec_raises(self, bad_tensorspec): """Tests invalid type of tensorspec argument.""" encoder = core_encoder.EncoderComposer( test_utils.PlusOneOverNEncodingStage()).make() with self.assertRaisesRegex(TypeError, 'TensorSpec'): simple_encoder.SimpleEncoderV2(encoder, bad_tensorspec)