def test_decode_split_commutes_with_sum_true_false_true(self): """Tests that splitting decode works as expected with commutes_with_sum. This test chains three encoding stages, first *does* commute with sum, the second *does not*, and third *does*, again. Together, only the first one should commute with sum, and the rest should not. """ encoder = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(test_utils.TimesTwoEncodingStage(), T2_VALS).make() self.assertFalse(encoder.fully_commutes_with_sum) data = self._data_for_test_decode_split(encoder, tf.constant(3.0)) # Test the encoding is as expected. self.assertEqual(data.x, data.decoded_x_after_sum) self.assertAllEqual( { T2_VALS: { P1_VALS: { T2_VALS: (data.x * 2.0 + 1.0) * 2.0 } }, }, data.encoded_x) # Only first part commutes with sum. self.assertAllEqual({T2_VALS: data.x * 2.0}, data.decoded_x_before_sum)
def test_tree_encoder(self): """Tests that the encoder works as a proper tree, not only a chain.""" encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_FLOATS).add_child( test_utils.TimesTwoEncodingStage(), P1_VALS) encoder = encoder.make() x = tf.constant([0.0, 0.1, -0.1, 0.9, -0.9, 1.6, -2.2]) encode_params, decode_params = encoder.get_params( encoder.initial_state()) encoded_x, _, input_shapes = encoder.encode(x, encode_params) decoded_x = encoder.decode(encoded_x, decode_params, input_shapes) x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x]) self.assertAllClose(x, decoded_x) expected_encoded_x = { SIF_SIGNS: { T2_VALS: np.array([0.0, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0]) }, SIF_INTS: { P1_VALS: np.array([1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0]) }, SIF_FLOATS: { P1_VALS: { T2_VALS: np.array([2.0, 2.2, 2.2, 3.8, 3.8, 3.2, 2.4]) } } } self.assertAllClose(expected_encoded_x, encoded_x)
def test_decode_split_commutes_with_sum_true_true(self): """Tests that splitting decode works as expected with commutes_with_sum. This test chains two encoding stages, first *does* commute with sum, the second *does*, too. Together, everything should commute with sum. """ encoder = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.TimesTwoEncodingStage(), T2_VALS).make() self.assertTrue(encoder.fully_commutes_with_sum) data = self._data_for_test_decode_split(encoder, tf.constant(3.0)) # Test the encoding is as expected. self.assertEqual(data.x, data.decoded_x_after_sum) self.assertAllEqual({ T2_VALS: { T2_VALS: data.x * 2.0 * 2.0 }, }, data.encoded_x) # Everyting commutes with sum - decoded_x_before_sum should be intact. self.assertAllEqual(data.encoded_x, data.decoded_x_before_sum) self.assertEqual( { COMMUTE: True, CHILDREN: { T2_VALS: { COMMUTE: True, CHILDREN: {} } } }, encoder.commuting_structure)
def test_composite_encoder(self): """Tests functionality with a general, composite `Encoder`.""" x = tf.constant(1.2) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = simple_encoder.SimpleEncoder(encoder.make(), tf.TensorSpec.from_tensor(x)) state = encoder.initial_state() iteration = _make_iteration_function(encoder) for i in range(1, 10): x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state)) self.assertAllClose(x, decoded_x) self.assertAllClose( 2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_SIGNS, T2_VALS])) self.assertAllClose( 2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_INTS, P1_VALS])) self.assertAllClose( 0.4 + 1 / i, _encoded_x_field(encoded_x, [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
def test_full_commutativity_with_sum(self): """Tests that fully commutes with sum property works.""" spec = tf.TensorSpec((2, ), tf.float32) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = gather_encoder.GatherEncoder.from_encoder( encoder.make(), spec) self.assertFalse(encoder.fully_commutes_with_sum)
def test_composite_encoder(self): """Tests functionality with a general, composite `Encoder`.""" encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = simple_encoder.StatefulSimpleEncoder(encoder.make()) x = tf.constant(1.2) encoder.initialize() encoded_x, decode_fn = encoder.encode(x) decoded_x = decode_fn(encoded_x) self.evaluate(tf.global_variables_initializer()) for i in range(1, 10): x_py, encoded_x_py, decoded_x_py = self.evaluate( [x, encoded_x, decoded_x]) self.assertAllClose(x_py, decoded_x_py) self.assertAllClose( 2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_SIGNS, T2_VALS])) self.assertAllClose( 2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_INTS, P1_VALS])) self.assertAllClose( 0.4 + 1 / i, _encoded_x_field(encoded_x_py, [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
def test_add_child_semantics(self): composer = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()) composer.add_child(test_utils.PlusOneEncodingStage(), T2_VALS) encoder_1 = composer.make() encoder_2 = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_child( test_utils.PlusOneEncodingStage(), T2_VALS).make() # Assert that these produce different trees. The add_child method returns # the newly created node, and thus the make creates only the child node. self.assertNotEqual(encoder_1.children.keys(), encoder_2.children.keys())
def test_add_child_parent_bad_key_raises(self): encoder = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()) with self.assertRaises(KeyError): encoder.add_child(test_utils.PlusOneEncodingStage(), '___bad_key') with self.assertRaises(KeyError): encoder.add_parent(test_utils.PlusOneEncodingStage(), '___bad_key')
def test_composite_encoder(self): """Tests functionality with a general, composite `Encoder`.""" x_fn = lambda: tf.constant(1.2) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS) encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS) encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child( test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = gather_encoder.GatherEncoder.from_encoder( encoder.make(), tf.TensorSpec.from_tensor(x_fn())) num_summands = 3 iteration = _make_iteration_function(encoder, x_fn, num_summands) state = encoder.initial_state() for i in range(1, 5): data = self.evaluate(iteration(state)) for j in range(num_summands): self.assertAllClose( 2.0, _encoded_x_field(data.encoded_x[j], [TENSORS, SIF_SIGNS, T2_VALS])) self.assertAllClose( 2.0, _encoded_x_field(data.encoded_x[j], [TENSORS, SIF_INTS, P1_VALS])) self.assertAllClose( 0.4 + 1 / i, _encoded_x_field(data.encoded_x[j], [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS])) self.assertAllClose(data.x[j], data.part_decoded_x[j]) self.assertAllClose(data.x[j] * num_summands, data.summed_part_decoded_x) self.assertAllClose(data.x[j] * num_summands, data.decoded_x) self.assertEqual((i, ), data.initial_state) self.assertEqual((i + 1, ), data.updated_state) state = data.updated_state
def test_encoder_is_reusable(self): """Tests that the same encoder can be used to encode multiple objects.""" encoder = core_encoder.EncoderComposer( test_utils.PlusOneEncodingStage()).add_parent( test_utils.TimesTwoEncodingStage(), T2_VALS).make() x_vals = [ tf.random.normal(shape) for shape in [(3, ), (3, 4), (3, 4, 5)] ] for x in x_vals: encode_params, decode_params = encoder.get_params( encoder.initial_state()) encoded_x, _, input_shapes = encoder.encode(x, encode_params) decoded_x = encoder.decode(encoded_x, decode_params, input_shapes) x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x]) self.assertAllClose(x, decoded_x) self.assertAllClose({T2_VALS: {P1_VALS: x * 2.0 + 1.0}}, encoded_x)
def test_add_parent(self): encoder = core_encoder.EncoderComposer( test_utils.ReduceMeanEncodingStage()).add_parent( test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(test_utils.TimesTwoEncodingStage(), T2_VALS).make() self.assertIsInstance(encoder, core_encoder.Encoder) self.assertIsInstance(encoder.stage._wrapped_stage, test_utils.TimesTwoEncodingStage) self.assertIsInstance(encoder.children[T2_VALS], core_encoder.Encoder) self.assertIsInstance(encoder.children[T2_VALS].stage._wrapped_stage, test_utils.PlusOneEncodingStage) self.assertIsInstance(encoder.children[T2_VALS].children[P1_VALS], core_encoder.Encoder) self.assertIsInstance( encoder.children[T2_VALS].children[P1_VALS].stage._wrapped_stage, test_utils.ReduceMeanEncodingStage)
def test_decode_split_commutes_with_sum_false_true(self): """Tests that splitting decode works as expected with commutes_with_sum. This test chains two encoding stages, first *does not* commute with sum, the second *does*. Together, nothing should commute with sum. """ encoder = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.PlusOneEncodingStage(), P1_VALS).make() self.assertFalse(encoder.fully_commutes_with_sum) data = self._data_for_test_decode_split(encoder, tf.constant(3.0)) # Test the encoding is as expected. self.assertEqual(data.x, data.decoded_x_after_sum) self.assertAllEqual({ P1_VALS: { T2_VALS: (data.x + 1.0) * 2.0 }, }, data.encoded_x) # Nothing commutes with sum - decoded_x_before_sum should be fully decoded. self.assertEqual(data.x, data.decoded_x_before_sum)
def test_commutativity_with_sum(self): """Tests that encoder that commutes with sum works.""" x_fn = lambda: tf.constant([1.0, 3.0]) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).make(), tf.TensorSpec.from_tensor(x_fn())) for num_summands in [1, 3, 7]: iteration = _make_iteration_function(encoder, x_fn, num_summands) data = self.evaluate(iteration(encoder.initial_state())) for i in range(num_summands): self.assertAllClose([1.0, 3.0], data.x[i]) self.assertAllClose([2.0, 6.0], _encoded_x_field(data.encoded_x[i], [TENSORS, T2_VALS])) self.assertAllClose( list(data.part_decoded_x[i].values())[0], list(data.encoded_x[i].values())[0]) self.assertAllClose( np.array([2.0, 6.0]) * num_summands, list(data.summed_part_decoded_x.values())[0]) self.assertAllClose( np.array([1.0, 3.0]) * num_summands, data.decoded_x)
def test_add_child_repeat_key_raises(self): encoder = core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()) encoder.add_child(test_utils.PlusOneEncodingStage(), T2_VALS) with self.assertRaises(KeyError): encoder.add_child(test_utils.PlusOneEncodingStage(), T2_VALS)
def test_correct_structure(self): """Tests that structured objects look like what they should. This test creates the following encoding tree: SignIntFloatEncodingStage [SIF_SIGNS] -> TimesTwoEncodingStage [SIF_INTS] -> PlusOneEncodingStage [SIF_FLOATS] -> PlusOneOverNEncodingStage [PN_VALS] -> AdaptiveNormalizeEncodingStage And verifies that the structured objects created by the methods of `Encoder` are of the expected structure. """ sif_stage = test_utils.SignIntFloatEncodingStage() times_two_stage = test_utils.TimesTwoEncodingStage() plus_one_stage = test_utils.PlusOneEncodingStage() plus_n_squared_stage = test_utils.PlusOneOverNEncodingStage() adaptive_normalize_stage = test_utils.AdaptiveNormalizeEncodingStage() encoder = core_encoder.EncoderComposer(sif_stage) encoder.add_child(times_two_stage, SIF_SIGNS) encoder.add_child(plus_one_stage, SIF_INTS) encoder.add_child(plus_n_squared_stage, SIF_FLOATS).add_child(adaptive_normalize_stage, PN_VALS) encoder = encoder.make() # Create all intermediary objects. x = tf.constant(1.0) initial_state = encoder.initial_state() encode_params, decode_params = encoder.get_params(initial_state) encoded_x, state_update_tensors, input_shapes = encoder.encode( x, encode_params) decoded_x = encoder.decode(encoded_x, decode_params, input_shapes) updated_state = encoder.update_state(initial_state, state_update_tensors) commuting_structure = encoder.commuting_structure # Verify the structure and naming of those objects is as expected. for state in [initial_state, updated_state]: tf.nest.assert_same_structure( { STATE: {}, CHILDREN: { SIF_INTS: { STATE: {}, CHILDREN: {} }, SIF_SIGNS: { STATE: {}, CHILDREN: {} }, SIF_FLOATS: { STATE: { PN_ITER_STATE: None }, CHILDREN: { PN_VALS: { STATE: { AN_FACTOR_STATE: None }, CHILDREN: {} } } } } }, state) self.assertIn( 'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name) self.assertIn( 'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN] [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name) self.assertIn( 'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name) self.assertIn( 'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN] [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name) for params in [encode_params, decode_params]: tf.nest.assert_same_structure( { PARAMS: {}, CHILDREN: { SIF_INTS: { PARAMS: { P1_ADD_PARAM: None }, CHILDREN: {} }, SIF_SIGNS: { PARAMS: { T2_FACTOR_PARAM: None }, CHILDREN: {} }, SIF_FLOATS: { PARAMS: { PN_ADD_PARAM: None }, CHILDREN: { PN_VALS: { PARAMS: { AN_FACTOR_PARAM: None }, CHILDREN: {} } } } } }, params) self.assertIn( 'encoder_get_params/' + sif_stage.name + '/' + SIF_INTS + '/' + plus_one_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX, params[CHILDREN][SIF_INTS][PARAMS][P1_ADD_PARAM].name) self.assertIn( 'encoder_get_params/' + sif_stage.name + '/' + SIF_SIGNS + '/' + times_two_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX, params[CHILDREN][SIF_SIGNS][PARAMS][T2_FACTOR_PARAM].name) self.assertIn( 'encoder_get_params/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX, params[CHILDREN][SIF_FLOATS][PARAMS][PN_ADD_PARAM].name) # Note: we do not check the value of # params[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS][PARAMS][AN_FACTOR_PARAM] # because the get_params method of adaptive_normalize_stage does not # modify the graph, only passes through the provided state tensor. tf.nest.assert_same_structure( { SIF_INTS: { P1_VALS: None }, SIF_SIGNS: { T2_VALS: None }, SIF_FLOATS: { PN_VALS: { AN_VALS: None } } }, encoded_x) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_INTS + '/' + plus_one_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, encoded_x[SIF_INTS][P1_VALS].name) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_SIGNS + '/' + times_two_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, encoded_x[SIF_SIGNS][T2_VALS].name) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, encoded_x[SIF_FLOATS][PN_VALS][AN_VALS].name) tf.nest.assert_same_structure( { TENSORS: {}, CHILDREN: { SIF_INTS: { TENSORS: {}, CHILDREN: {} }, SIF_SIGNS: { TENSORS: {}, CHILDREN: {} }, SIF_FLOATS: { TENSORS: {}, CHILDREN: { PN_VALS: { TENSORS: { AN_NORM_UPDATE: None }, CHILDREN: {} } } } } }, state_update_tensors) tf.nest.assert_same_structure(state_update_tensors, encoder.state_update_aggregation_modes) self.assertIn( 'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' + adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX, state_update_tensors[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS] [TENSORS][AN_NORM_UPDATE].name) tf.nest.assert_same_structure( { SHAPE: None, CHILDREN: { SIF_INTS: { SHAPE: None, CHILDREN: {} }, SIF_SIGNS: { SHAPE: None, CHILDREN: {} }, SIF_FLOATS: { SHAPE: None, CHILDREN: { PN_VALS: { SHAPE: None, CHILDREN: {} } } } } }, input_shapes) self.assertTrue(tf.is_tensor(decoded_x)) self.assertIn('encoder_decode/', decoded_x.name) tf.nest.assert_same_structure( { COMMUTE: None, CHILDREN: { SIF_INTS: { COMMUTE: None, CHILDREN: {} }, SIF_SIGNS: { COMMUTE: None, CHILDREN: {} }, SIF_FLOATS: { COMMUTE: None, CHILDREN: { PN_VALS: { COMMUTE: None, CHILDREN: {} } } } } }, commuting_structure) for item in tf.nest.flatten(commuting_structure): self.assertEqual(False, item)
def default_encoding_stage(self): """See base class.""" return test_utils.TimesTwoEncodingStage()