Esempio n. 1
0
    def test_decode_split_commutes_with_sum_false_false(self):
        """Tests that splitting decode works as expected with commutes_with_sum.

    This test chains two encoding stages, first *does not* commute with sum, the
    second *does not*, either. Together, nothing should commute with sum.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.PlusOneEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(), P1_VALS).make()
        self.assertFalse(encoder.fully_commutes_with_sum)
        data = self._data_for_test_decode_split(encoder, tf.constant(3.0))

        # Test the encoding is as expected.
        self.assertEqual(data.x, data.decoded_x_after_sum)
        self.assertAllEqual({
            P1_VALS: {
                P1_VALS: data.x + 1.0 + 1.0
            },
        }, data.encoded_x)
        # Nothing commutes with sum - decoded_x_before_sum should be fully decoded.
        self.assertEqual(data.x, data.decoded_x_before_sum)
        self.assertEqual(
            {
                COMMUTE: False,
                CHILDREN: {
                    P1_VALS: {
                        COMMUTE: False,
                        CHILDREN: {}
                    }
                }
            }, encoder.commuting_structure)
    def test_composite_encoder(self):
        """Tests functionality with a general, composite `Encoder`."""
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneEncodingStage(), T2_VALS)
        encoder = simple_encoder.SimpleEncoder(encoder.make())

        x = tf.constant(1.2)
        encoded_x, decode_fn = encoder.encode(x)
        decoded_x = decode_fn(encoded_x)

        x_py, encoded_x_py, decoded_x_py = self.evaluate(
            [x, encoded_x, decoded_x])
        self.assertAllClose(x_py, decoded_x_py)
        self.assertAllClose(
            2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_SIGNS, T2_VALS]))
        self.assertAllClose(
            2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_INTS, P1_VALS]))
        self.assertAllClose(
            1.4,
            _encoded_x_field(encoded_x_py,
                             [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
    def test_python_constants_not_exposed(self):
        """Tests that only TensorFlow values are exposed to users."""
        x_fn = lambda: tf.constant(1.0)
        tensorspec = tf.TensorSpec.from_tensor(x_fn())
        encoder_py = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(2.0, 3.0)).add_parent(
                    test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(
                        test_utils.SimpleLinearEncodingStage(2.0, 3.0),
                        SL_VALS).make(), tensorspec)
        a_var = tf.compat.v1.get_variable('a_var', initializer=2.0)
        b_var = tf.compat.v1.get_variable('b_var', initializer=3.0)
        encoder_tf = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(a_var, b_var)).add_parent(
                    test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(
                        test_utils.SimpleLinearEncodingStage(a_var, b_var),
                        SL_VALS).make(), tensorspec)

        (encode_params_py, decode_before_sum_params_py,
         decode_after_sum_params_py) = encoder_py.get_params()
        (encode_params_tf, decode_before_sum_params_tf,
         decode_after_sum_params_tf) = encoder_tf.get_params()

        # Params that are Python constants -- not tf.Tensors -- should be hidden
        # from the user, and made statically available at appropriate locations.
        self.assertLen(encode_params_py, 1)
        self.assertLen(encode_params_tf, 5)
        self.assertLen(decode_before_sum_params_py, 1)
        self.assertLen(decode_before_sum_params_tf, 3)
        self.assertEmpty(decode_after_sum_params_py)
        self.assertLen(decode_after_sum_params_tf, 2)
Esempio n. 4
0
 def test_add_child_parent_bad_key_raises(self):
     encoder = core_encoder.EncoderComposer(
         test_utils.TimesTwoEncodingStage())
     with self.assertRaises(KeyError):
         encoder.add_child(test_utils.PlusOneEncodingStage(), '___bad_key')
     with self.assertRaises(KeyError):
         encoder.add_parent(test_utils.PlusOneEncodingStage(), '___bad_key')
Esempio n. 5
0
    def test_tree_encoder(self):
        """Tests that the encoder works as a proper tree, not only a chain."""
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.PlusOneEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.TimesTwoEncodingStage(), P1_VALS)
        encoder = encoder.make()

        x = tf.constant([0.0, 0.1, -0.1, 0.9, -0.9, 1.6, -2.2])
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])

        self.assertAllClose(x, decoded_x)
        expected_encoded_x = {
            SIF_SIGNS: {
                T2_VALS: np.array([0.0, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0])
            },
            SIF_INTS: {
                P1_VALS: np.array([1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0])
            },
            SIF_FLOATS: {
                P1_VALS: {
                    T2_VALS: np.array([2.0, 2.2, 2.2, 3.8, 3.8, 3.2, 2.4])
                }
            }
        }
        self.assertAllClose(expected_encoded_x, encoded_x)
Esempio n. 6
0
    def test_add_child_semantics(self):
        composer = core_encoder.EncoderComposer(
            test_utils.TimesTwoEncodingStage())
        composer.add_child(test_utils.PlusOneEncodingStage(), T2_VALS)
        encoder_1 = composer.make()
        encoder_2 = core_encoder.EncoderComposer(
            test_utils.TimesTwoEncodingStage()).add_child(
                test_utils.PlusOneEncodingStage(), T2_VALS).make()

        # Assert that these produce different trees. The add_child method returns
        # the newly created node, and thus the make creates only the child node.
        self.assertNotEqual(encoder_1.children.keys(),
                            encoder_2.children.keys())
Esempio n. 7
0
    def test_decode_split_commutes_with_sum_true_false_true(self):
        """Tests that splitting decode works as expected with commutes_with_sum.

    This test chains three encoding stages, first *does* commute with sum, the
    second *does not*, and third *does*, again. Together, only the first one
    should commute with sum, and the rest should not.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.TimesTwoEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(),
                P1_VALS).add_parent(test_utils.TimesTwoEncodingStage(),
                                    T2_VALS).make()
        self.assertFalse(encoder.fully_commutes_with_sum)
        data = self._data_for_test_decode_split(encoder, tf.constant(3.0))

        # Test the encoding is as expected.
        self.assertEqual(data.x, data.decoded_x_after_sum)
        self.assertAllEqual(
            {
                T2_VALS: {
                    P1_VALS: {
                        T2_VALS: (data.x * 2.0 + 1.0) * 2.0
                    }
                },
            }, data.encoded_x)
        # Only first part commutes with sum.
        self.assertAllEqual({T2_VALS: data.x * 2.0}, data.decoded_x_before_sum)
Esempio n. 8
0
    def test_composite_encoder(self):
        """Tests functionality with a general, composite `Encoder`."""
        x = tf.constant(1.2)
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = simple_encoder.SimpleEncoder(encoder.make(),
                                               tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        for i in range(1, 10):
            x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state))
            self.assertAllClose(x, decoded_x)
            self.assertAllClose(
                2.0, _encoded_x_field(encoded_x,
                                      [TENSORS, SIF_SIGNS, T2_VALS]))
            self.assertAllClose(
                2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_INTS, P1_VALS]))
            self.assertAllClose(
                0.4 + 1 / i,
                _encoded_x_field(encoded_x,
                                 [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
    def test_full_commutativity_with_sum(self):
        """Tests that fully commutes with sum property works."""
        spec = tf.TensorSpec((2, ), tf.float32)

        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.TimesTwoEncodingStage()).make(), spec)
        self.assertTrue(encoder.fully_commutes_with_sum)

        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.TimesTwoEncodingStage()).add_parent(
                    test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec)
        self.assertTrue(encoder.fully_commutes_with_sum)

        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = gather_encoder.GatherEncoder.from_encoder(
            encoder.make(), spec)
        self.assertFalse(encoder.fully_commutes_with_sum)
Esempio n. 10
0
    def test_commutes_with_sum(self):
        """Tests that commutativity works, provided appropriate num_summands."""
        encoder = core_encoder.EncoderComposer(
            test_utils.PlusOneEncodingStage()).add_parent(
                test_utils.SimpleLinearEncodingStage(2.0, 3.0),
                SL_VALS).make()

        x = tf.constant(3.0)
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x_before_sum = encoder.decode_before_sum(
            encoded_x, decode_params, input_shapes)
        # Trivial summing of the encoded - and partially decoded - values.
        part_decoded_and_summed_x = tf.nest.map_structure(
            lambda x: x + x + x, decoded_x_before_sum)
        num_summands = 3
        decoded_x_after_sum = encoder.decode_after_sum(
            part_decoded_and_summed_x, decode_params, num_summands,
            input_shapes)
        data = self.evaluate(
            self.commutation_test_data(x, encoded_x, decoded_x_before_sum,
                                       decoded_x_after_sum))
        self.assertEqual(3.0, data.x)
        expected_encoded_x = {SL_VALS: {P1_VALS: (data.x * 2.0 + 3.0) + 1.0}}
        self.assertAllEqual(expected_encoded_x, data.encoded_x)
        expected_decoded_x_before_sum = {SL_VALS: data.x * 2.0 + 3.0}
        self.assertAllEqual(expected_decoded_x_before_sum,
                            data.decoded_x_before_sum)
        self.assertEqual(9.0, data.decoded_x_after_sum)
Esempio n. 11
0
 def test_as_adaptive_encoding_stage_identity(self):
     """Tests that this acts as identity for an adaptive encoding stage."""
     adaptive_stage = encoding_stage.NoneStateAdaptiveEncodingStage(
         test_utils.PlusOneEncodingStage())
     wrapped_stage = encoding_stage.as_adaptive_encoding_stage(
         adaptive_stage)
     self.assertIs(adaptive_stage, wrapped_stage)
Esempio n. 12
0
    def test_adaptive_stage_using_state_update_tensors(self):
        """Tests adaptive encoding stage with state update tensors."""
        encoder = core_encoder.EncoderComposer(
            test_utils.AdaptiveNormalizeEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(), P1_VALS).make()
        x = tf.constant(1.0)
        state = encoder.initial_state()

        for _ in range(1, 5):
            initial_state = state
            encode_params, decode_params = encoder.get_params(state)
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, encode_params)
            decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
            state = encoder.update_state(initial_state, state_update_tensors)
            data = self.evaluate(
                test_utils.TestData(x, encoded_x, decoded_x, initial_state,
                                    state_update_tensors, state))

            self.assertAllClose(data.x, data.decoded_x)
            self.assertLessEqual(
                data.initial_state[CHILDREN][P1_VALS][STATE][AN_FACTOR_STATE],
                1.0)
            self.assertEqual(
                data.state_update_tensors[CHILDREN][P1_VALS][TENSORS]
                [AN_NORM_UPDATE], 2.0)
            self.assertLessEqual(data.encoded_x[P1_VALS][AN_VALS], 2.0)
  def test_basic_encode_decode(self):
    """Tests basic encoding and decoding works as expected."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(test_utils.PlusOneEncodingStage()).make())

    x = tf.constant(1.0)
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])
    self.assertAllClose(x, decoded_x)
    self.assertAllClose(2.0, _encoded_x_field(encoded_x, [TENSORS, P1_VALS]))
  def test_modifying_encoded_x_raises(self):
    """Tests decode_fn raises if the encoded_x dictionary is modified."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(test_utils.PlusOneEncodingStage()).make())

    x = tf.constant(1.0)
    encoded_x, decode_fn = encoder.encode(x)
    encoded_x['__NOT_EXPECTED_KEY__'] = None
    with self.assertRaises(ValueError):
      decode_fn(encoded_x)
    with self.assertRaises(ValueError):
      decode_fn({})
  def test_encode_multiple_objects(self):
    """Tests the same object can encode multiple different objects."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(test_utils.PlusOneEncodingStage()).make())

    for shape in [(2,), (2, 3), (2, 3, 4)]:
      x = tf.constant(np.ones(shape, np.float32))
      encoded_x, decode_fn = encoder.encode(x)
      decoded_x = decode_fn(encoded_x)

      x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])
      self.assertAllClose(x, decoded_x)
      self.assertAllClose(
          2.0 * np.ones(shape, np.float32),
          _encoded_x_field(encoded_x, [TENSORS, P1_VALS]))
Esempio n. 16
0
    def test_encoder_is_reusable(self):
        """Tests that the same encoder can be used to encode multiple objects."""
        encoder = core_encoder.EncoderComposer(
            test_utils.PlusOneEncodingStage()).add_parent(
                test_utils.TimesTwoEncodingStage(), T2_VALS).make()
        x_vals = [
            tf.random.normal(shape) for shape in [(3, ), (3, 4), (3, 4, 5)]
        ]
        for x in x_vals:
            encode_params, decode_params = encoder.get_params(
                encoder.initial_state())
            encoded_x, _, input_shapes = encoder.encode(x, encode_params)
            decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
            x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])

            self.assertAllClose(x, decoded_x)
            self.assertAllClose({T2_VALS: {P1_VALS: x * 2.0 + 1.0}}, encoded_x)
Esempio n. 17
0
    def test_add_parent(self):
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(),
                P1_VALS).add_parent(test_utils.TimesTwoEncodingStage(),
                                    T2_VALS).make()

        self.assertIsInstance(encoder, core_encoder.Encoder)
        self.assertIsInstance(encoder.stage._wrapped_stage,
                              test_utils.TimesTwoEncodingStage)
        self.assertIsInstance(encoder.children[T2_VALS], core_encoder.Encoder)
        self.assertIsInstance(encoder.children[T2_VALS].stage._wrapped_stage,
                              test_utils.PlusOneEncodingStage)
        self.assertIsInstance(encoder.children[T2_VALS].children[P1_VALS],
                              core_encoder.Encoder)
        self.assertIsInstance(
            encoder.children[T2_VALS].children[P1_VALS].stage._wrapped_stage,
            test_utils.ReduceMeanEncodingStage)
    def test_composite_encoder(self):
        """Tests functionality with a general, composite `Encoder`."""
        x_fn = lambda: tf.constant(1.2)
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = gather_encoder.GatherEncoder.from_encoder(
            encoder.make(), tf.TensorSpec.from_tensor(x_fn()))

        num_summands = 3
        iteration = _make_iteration_function(encoder, x_fn, num_summands)
        state = encoder.initial_state()

        for i in range(1, 5):
            data = self.evaluate(iteration(state))
            for j in range(num_summands):
                self.assertAllClose(
                    2.0,
                    _encoded_x_field(data.encoded_x[j],
                                     [TENSORS, SIF_SIGNS, T2_VALS]))
                self.assertAllClose(
                    2.0,
                    _encoded_x_field(data.encoded_x[j],
                                     [TENSORS, SIF_INTS, P1_VALS]))
                self.assertAllClose(
                    0.4 + 1 / i,
                    _encoded_x_field(data.encoded_x[j],
                                     [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
                self.assertAllClose(data.x[j], data.part_decoded_x[j])
                self.assertAllClose(data.x[j] * num_summands,
                                    data.summed_part_decoded_x)
                self.assertAllClose(data.x[j] * num_summands, data.decoded_x)

            self.assertEqual((i, ), data.initial_state)
            self.assertEqual((i + 1, ), data.updated_state)
            state = data.updated_state
Esempio n. 19
0
 def test_add_child_repeat_key_raises(self):
     encoder = core_encoder.EncoderComposer(
         test_utils.TimesTwoEncodingStage())
     encoder.add_child(test_utils.PlusOneEncodingStage(), T2_VALS)
     with self.assertRaises(KeyError):
         encoder.add_child(test_utils.PlusOneEncodingStage(), T2_VALS)
 def default_encoding_stage(self):
   """See base class."""
   return test_utils.PlusOneEncodingStage()
 def test_is_adaptive_stage(self):
   self.assertFalse(
       test_utils.is_adaptive_stage(test_utils.PlusOneEncodingStage()))
   self.assertTrue(
       test_utils.is_adaptive_stage(test_utils.PlusOneOverNEncodingStage()))
Esempio n. 22
0
    def test_correct_structure(self):
        """Tests that structured objects look like what they should.

    This test creates the following encoding tree:
    SignIntFloatEncodingStage
        [SIF_SIGNS] -> TimesTwoEncodingStage
        [SIF_INTS] -> PlusOneEncodingStage
        [SIF_FLOATS] -> PlusOneOverNEncodingStage
            [PN_VALS] -> AdaptiveNormalizeEncodingStage
    And verifies that the structured objects created by the methods of `Encoder`
    are of the expected structure.
    """
        sif_stage = test_utils.SignIntFloatEncodingStage()
        times_two_stage = test_utils.TimesTwoEncodingStage()
        plus_one_stage = test_utils.PlusOneEncodingStage()
        plus_n_squared_stage = test_utils.PlusOneOverNEncodingStage()
        adaptive_normalize_stage = test_utils.AdaptiveNormalizeEncodingStage()

        encoder = core_encoder.EncoderComposer(sif_stage)
        encoder.add_child(times_two_stage, SIF_SIGNS)
        encoder.add_child(plus_one_stage, SIF_INTS)
        encoder.add_child(plus_n_squared_stage,
                          SIF_FLOATS).add_child(adaptive_normalize_stage,
                                                PN_VALS)
        encoder = encoder.make()

        # Create all intermediary objects.
        x = tf.constant(1.0)
        initial_state = encoder.initial_state()
        encode_params, decode_params = encoder.get_params(initial_state)
        encoded_x, state_update_tensors, input_shapes = encoder.encode(
            x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        updated_state = encoder.update_state(initial_state,
                                             state_update_tensors)
        commuting_structure = encoder.commuting_structure

        # Verify the structure and naming of those objects is as expected.
        for state in [initial_state, updated_state]:
            tf.nest.assert_same_structure(
                {
                    STATE: {},
                    CHILDREN: {
                        SIF_INTS: {
                            STATE: {},
                            CHILDREN: {}
                        },
                        SIF_SIGNS: {
                            STATE: {},
                            CHILDREN: {}
                        },
                        SIF_FLOATS: {
                            STATE: {
                                PN_ITER_STATE: None
                            },
                            CHILDREN: {
                                PN_VALS: {
                                    STATE: {
                                        AN_FACTOR_STATE: None
                                    },
                                    CHILDREN: {}
                                }
                            }
                        }
                    }
                }, state)
        self.assertIn(
            'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS +
            '/' + plus_n_squared_stage.name +
            encoding_stage.INITIAL_STATE_SCOPE_SUFFIX,
            initial_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name)
        self.assertIn(
            'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS +
            '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name +
            encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN]
            [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name)
        self.assertIn(
            'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name +
            encoding_stage.UPDATE_STATE_SCOPE_SUFFIX,
            updated_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name)
        self.assertIn(
            'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name +
            encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN]
            [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name)

        for params in [encode_params, decode_params]:
            tf.nest.assert_same_structure(
                {
                    PARAMS: {},
                    CHILDREN: {
                        SIF_INTS: {
                            PARAMS: {
                                P1_ADD_PARAM: None
                            },
                            CHILDREN: {}
                        },
                        SIF_SIGNS: {
                            PARAMS: {
                                T2_FACTOR_PARAM: None
                            },
                            CHILDREN: {}
                        },
                        SIF_FLOATS: {
                            PARAMS: {
                                PN_ADD_PARAM: None
                            },
                            CHILDREN: {
                                PN_VALS: {
                                    PARAMS: {
                                        AN_FACTOR_PARAM: None
                                    },
                                    CHILDREN: {}
                                }
                            }
                        }
                    }
                }, params)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_INTS + '/' +
                plus_one_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_INTS][PARAMS][P1_ADD_PARAM].name)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_SIGNS +
                '/' + times_two_stage.name +
                encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_SIGNS][PARAMS][T2_FACTOR_PARAM].name)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_FLOATS +
                '/' + plus_n_squared_stage.name +
                encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_FLOATS][PARAMS][PN_ADD_PARAM].name)
            # Note: we do not check the value of
            # params[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS][PARAMS][AN_FACTOR_PARAM]
            # because the get_params method of adaptive_normalize_stage does not
            # modify the graph, only passes through the provided state tensor.

        tf.nest.assert_same_structure(
            {
                SIF_INTS: {
                    P1_VALS: None
                },
                SIF_SIGNS: {
                    T2_VALS: None
                },
                SIF_FLOATS: {
                    PN_VALS: {
                        AN_VALS: None
                    }
                }
            }, encoded_x)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_INTS + '/' +
            plus_one_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_INTS][P1_VALS].name)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_SIGNS + '/' +
            times_two_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_SIGNS][T2_VALS].name)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_FLOATS][PN_VALS][AN_VALS].name)

        tf.nest.assert_same_structure(
            {
                TENSORS: {},
                CHILDREN: {
                    SIF_INTS: {
                        TENSORS: {},
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        TENSORS: {},
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        TENSORS: {},
                        CHILDREN: {
                            PN_VALS: {
                                TENSORS: {
                                    AN_NORM_UPDATE: None
                                },
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, state_update_tensors)
        tf.nest.assert_same_structure(state_update_tensors,
                                      encoder.state_update_aggregation_modes)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            state_update_tensors[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS]
            [TENSORS][AN_NORM_UPDATE].name)

        tf.nest.assert_same_structure(
            {
                SHAPE: None,
                CHILDREN: {
                    SIF_INTS: {
                        SHAPE: None,
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        SHAPE: None,
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        SHAPE: None,
                        CHILDREN: {
                            PN_VALS: {
                                SHAPE: None,
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, input_shapes)
        self.assertTrue(tf.is_tensor(decoded_x))
        self.assertIn('encoder_decode/', decoded_x.name)

        tf.nest.assert_same_structure(
            {
                COMMUTE: None,
                CHILDREN: {
                    SIF_INTS: {
                        COMMUTE: None,
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        COMMUTE: None,
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        COMMUTE: None,
                        CHILDREN: {
                            PN_VALS: {
                                COMMUTE: None,
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, commuting_structure)
        for item in tf.nest.flatten(commuting_structure):
            self.assertEqual(False, item)