def test_python_constants_not_exposed(self): """Tests that only TensorFlow values are exposed to users.""" x_fn = lambda: tf.constant(1.0) tensorspec = tf.TensorSpec.from_tensor(x_fn()) encoder_py = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.SimpleLinearEncodingStage(2.0, 3.0)).add_parent( test_utils.PlusOneEncodingStage(), P1_VALS).add_parent( test_utils.SimpleLinearEncodingStage(2.0, 3.0), SL_VALS).make(), tensorspec) a_var = tf.compat.v1.get_variable('a_var', initializer=2.0) b_var = tf.compat.v1.get_variable('b_var', initializer=3.0) encoder_tf = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.SimpleLinearEncodingStage(a_var, b_var)).add_parent( test_utils.PlusOneEncodingStage(), P1_VALS).add_parent( test_utils.SimpleLinearEncodingStage(a_var, b_var), SL_VALS).make(), tensorspec) (encode_params_py, decode_before_sum_params_py, decode_after_sum_params_py) = encoder_py.get_params() (encode_params_tf, decode_before_sum_params_tf, decode_after_sum_params_tf) = encoder_tf.get_params() # Params that are Python constants -- not tf.Tensors -- should be hidden # from the user, and made statically available at appropriate locations. self.assertLen(encode_params_py, 1) self.assertLen(encode_params_tf, 5) self.assertLen(decode_before_sum_params_py, 1) self.assertLen(decode_before_sum_params_tf, 3) self.assertEmpty(decode_after_sum_params_py) self.assertLen(decode_after_sum_params_tf, 2)
def test_python_constants_not_exposed(self): """Tests that only TensorFlow values are exposed to users.""" x = tf.constant(1.0) tensorspec = tf.TensorSpec.from_tensor(x) encoder_py = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.SimpleLinearEncodingStage(2.0, 3.0)).make(), tensorspec) a_var = tf.compat.v1.get_variable('a_var', initializer=2.0) b_var = tf.compat.v1.get_variable('b_var', initializer=3.0) encoder_tf = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.SimpleLinearEncodingStage(a_var, b_var)).make(), tensorspec) state_py = encoder_py.initial_state() state_tf = encoder_tf.initial_state() iteration_py = _make_iteration_function(encoder_py) iteration_tf = _make_iteration_function(encoder_tf) self.evaluate(tf.compat.v1.global_variables_initializer()) _, encoded_x_py, decoded_x_py, _ = self.evaluate( iteration_py(x, state_py)) _, encoded_x_tf, decoded_x_tf, _ = self.evaluate( iteration_tf(x, state_tf)) # The encoded_x_tf should have two elements that encoded_x_py does not. # These correspond to the two variables created passed on to constructor of # encoder_tf, which are exposed as params. For encoder_py, these are python # integers, and should thus be hidden from users. self.assertLen(encoded_x_tf, len(encoded_x_py) + 2) # Make sure functionality is still the same. self.assertAllClose(x, decoded_x_tf) self.assertAllClose(x, decoded_x_py)
def test_commutes_with_sum(self): """Tests that commutativity works, provided appropriate num_summands.""" encoder = core_encoder.EncoderComposer( test_utils.PlusOneEncodingStage()).add_parent( test_utils.SimpleLinearEncodingStage(2.0, 3.0), SL_VALS).make() x = tf.constant(3.0) encode_params, decode_params = encoder.get_params( encoder.initial_state()) encoded_x, _, input_shapes = encoder.encode(x, encode_params) decoded_x_before_sum = encoder.decode_before_sum( encoded_x, decode_params, input_shapes) # Trivial summing of the encoded - and partially decoded - values. part_decoded_and_summed_x = tf.nest.map_structure( lambda x: x + x + x, decoded_x_before_sum) num_summands = 3 decoded_x_after_sum = encoder.decode_after_sum( part_decoded_and_summed_x, decode_params, num_summands, input_shapes) data = self.evaluate( self.commutation_test_data(x, encoded_x, decoded_x_before_sum, decoded_x_after_sum)) self.assertEqual(3.0, data.x) expected_encoded_x = {SL_VALS: {P1_VALS: (data.x * 2.0 + 3.0) + 1.0}} self.assertAllEqual(expected_encoded_x, data.encoded_x) expected_decoded_x_before_sum = {SL_VALS: data.x * 2.0 + 3.0} self.assertAllEqual(expected_decoded_x_before_sum, data.decoded_x_before_sum) self.assertEqual(9.0, data.decoded_x_after_sum)
def test_param_control_from_outside(self): """Tests that behavior can be controlled from outside, if needed.""" a_var = tf.get_variable('a_var', initializer=2.0) b_var = tf.get_variable('b_var', initializer=3.0) encoder = simple_encoder.SimpleEncoder( core_encoder.EncoderComposer( test_utils.SimpleLinearEncodingStage(a_var, b_var)).make()) x = tf.constant(1.0) encoded_x, decode_fn = encoder.encode(x) decoded_x = decode_fn(encoded_x) self.evaluate(tf.global_variables_initializer()) x_py, encoded_x_py, decoded_x_py = self.evaluate( [x, encoded_x, decoded_x]) self.assertAllClose(x_py, decoded_x_py) self.assertAllClose(5.0, _encoded_x_field(encoded_x_py, [TENSORS, SL_VALS])) # Change to variables should change the behavior of the encoder. self.evaluate([tf.assign(a_var, 5.0), tf.assign(b_var, -7.0)]) x_py, encoded_x_py, decoded_x_py = self.evaluate( [x, encoded_x, decoded_x]) self.assertAllClose(x_py, decoded_x_py) self.assertAllClose(-2.0, _encoded_x_field(encoded_x_py, [TENSORS, SL_VALS]))
def test_as_adaptive_encoding_stage(self): """Tests correctness of the wrapped encoding stage.""" a_var = tf.compat.v1.get_variable('a', initializer=2.0) b_var = tf.compat.v1.get_variable('b', initializer=3.0) stage = test_utils.SimpleLinearEncodingStage(a_var, b_var) wrapped_stage = encoding_stage.as_adaptive_encoding_stage(stage) self.assertIsInstance(wrapped_stage, encoding_stage.AdaptiveEncodingStageInterface) x = tf.constant(2.0) state = wrapped_stage.initial_state() encode_params, decode_params = wrapped_stage.get_params(state) encoded_x, state_update_tensors = wrapped_stage.encode( x, encode_params) updated_state = wrapped_stage.update_state(state, state_update_tensors) decoded_x = wrapped_stage.decode(encoded_x, decode_params) # Test that the added state functionality is empty. self.assertDictEqual({}, state) self.assertDictEqual({}, state_update_tensors) self.assertDictEqual({}, updated_state) self.assertDictEqual({}, wrapped_stage.state_update_aggregation_modes) # Test that __getattr__ retrieves attributes of the wrapped stage. self.assertIsInstance(wrapped_stage._a, tf.Variable) self.assertIs(wrapped_stage._a, a_var) self.assertIsInstance(wrapped_stage._b, tf.Variable) self.assertIs(wrapped_stage._b, b_var) # Test the functionality remain unchanged. self.assertEqual(stage.name, wrapped_stage.name) self.assertEqual(stage.compressible_tensors_keys, wrapped_stage.compressible_tensors_keys) self.assertEqual(stage.commutes_with_sum, wrapped_stage.commutes_with_sum) self.assertEqual(stage.decode_needs_input_shape, wrapped_stage.decode_needs_input_shape) self.evaluate(tf.compat.v1.global_variables_initializer()) test_data = test_utils.TestData( *self.evaluate([x, encoded_x, decoded_x])) self.assertEqual(2.0, test_data.x) self.assertEqual( 7.0, test_data.encoded_x[ test_utils.SimpleLinearEncodingStage.ENCODED_VALUES_KEY]) self.assertEqual(2.0, test_data.decoded_x)
def test_basic_encode_decode_tf_constructor_parameters(self): """Tests the core funcionality with `tf.Variable` constructor parameters.""" a_var = tf.get_variable('a_var', initializer=self._DEFAULT_A) b_var = tf.get_variable('b_var', initializer=self._DEFAULT_B) stage = test_utils.SimpleLinearEncodingStage(a_var, b_var) with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) x = self.default_input() encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = self.evaluate_test_data( test_utils.TestData(x, encoded_x, decoded_x)) self.common_asserts_for_test_data(test_data) # Change the variables and verify the behavior of stage changes. self.evaluate([tf.assign(a_var, 5.0), tf.assign(b_var, 6.0)]) test_data = self.evaluate_test_data( test_utils.TestData(x, encoded_x, decoded_x)) self.assertAllClose(test_data.x * 5.0 + 6.0, test_data.encoded_x[self._ENCODED_VALUES_KEY])
def default_encoding_stage(self): """See base class.""" return test_utils.SimpleLinearEncodingStage(self._DEFAULT_A, self._DEFAULT_B)