Exemple #1
0
    def test_adaptive_stage(self):
        """Tests composition of two adaptive encoding stages."""
        encoder = core_encoder.EncoderComposer(
            test_utils.PlusOneOverNEncodingStage()).add_parent(
                test_utils.PlusOneOverNEncodingStage(), PN_VALS).make()
        x = tf.constant(1.0)
        state = encoder.initial_state()

        for i in range(1, 5):
            initial_state = state
            encode_params, decode_params = encoder.get_params(state)
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, encode_params)
            decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
            state = encoder.update_state(initial_state, state_update_tensors)
            data = self.evaluate(
                test_utils.TestData(x, encoded_x, decoded_x, initial_state,
                                    state_update_tensors, state))

            expected_initial_state = {
                STATE: {
                    PN_ITER_STATE: i
                },
                CHILDREN: {
                    PN_VALS: {
                        STATE: {
                            PN_ITER_STATE: i
                        },
                        CHILDREN: {}
                    }
                }
            }
            expected_state_update_tensors = {
                TENSORS: {},
                CHILDREN: {
                    PN_VALS: {
                        TENSORS: {},
                        CHILDREN: {}
                    }
                }
            }

            self.assertAllClose(data.x, data.decoded_x)
            self.assertAllEqual(expected_initial_state, data.initial_state)
            self.assertDictEqual(expected_state_update_tensors,
                                 data.state_update_tensors)
            self.assertAllClose(data.x + 2 * 1 / i,
                                data.encoded_x[PN_VALS][PN_VALS])
Exemple #2
0
    def test_composite_encoder(self):
        """Tests functionality with a general, composite `Encoder`."""
        x = tf.constant(1.2)
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = simple_encoder.SimpleEncoder(encoder.make(),
                                               tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        for i in range(1, 10):
            x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state))
            self.assertAllClose(x, decoded_x)
            self.assertAllClose(
                2.0, _encoded_x_field(encoded_x,
                                      [TENSORS, SIF_SIGNS, T2_VALS]))
            self.assertAllClose(
                2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_INTS, P1_VALS]))
            self.assertAllClose(
                0.4 + 1 / i,
                _encoded_x_field(encoded_x,
                                 [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
  def test_composite_encoder(self):
    """Tests functionality with a general, composite `Encoder`."""
    encoder = core_encoder.EncoderComposer(
        test_utils.SignIntFloatEncodingStage())
    encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
    encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
    encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child(
        test_utils.PlusOneOverNEncodingStage(), T2_VALS)
    encoder = simple_encoder.StatefulSimpleEncoder(encoder.make())

    x = tf.constant(1.2)
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    for i in range(1, 10):
      x_py, encoded_x_py, decoded_x_py = self.evaluate(
          [x, encoded_x, decoded_x])
      self.assertAllClose(x_py, decoded_x_py)
      self.assertAllClose(
          2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_SIGNS, T2_VALS]))
      self.assertAllClose(
          2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_INTS, P1_VALS]))
      self.assertAllClose(
          0.4 + 1 / i,
          _encoded_x_field(encoded_x_py,
                           [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
 def test_not_fully_defined_shape_raises(self):
     """Tests tensorspec without fully defined shape."""
     encoder = core_encoder.EncoderComposer(
         test_utils.PlusOneOverNEncodingStage()).make()
     with self.assertRaisesRegex(TypeError, 'fully defined'):
         gather_encoder.GatherEncoder.from_encoder(
             encoder, tf.TensorSpec((None, ), tf.float32))
Exemple #5
0
 def test_input_tensorspec(self):
     x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
     encoder = simple_encoder.SimpleEncoder(
         core_encoder.EncoderComposer(
             test_utils.PlusOneOverNEncodingStage()).make(),
         tf.TensorSpec.from_tensor(x))
     self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
    def test_full_commutativity_with_sum(self):
        """Tests that fully commutes with sum property works."""
        spec = tf.TensorSpec((2, ), tf.float32)

        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.TimesTwoEncodingStage()).make(), spec)
        self.assertTrue(encoder.fully_commutes_with_sum)

        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.TimesTwoEncodingStage()).add_parent(
                    test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec)
        self.assertTrue(encoder.fully_commutes_with_sum)

        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = gather_encoder.GatherEncoder.from_encoder(
            encoder.make(), spec)
        self.assertFalse(encoder.fully_commutes_with_sum)
 def test_not_a_tensorspec_raises(self, not_a_tensorspec):
     """Tests invalid type of tensorspec argument."""
     encoder = core_encoder.EncoderComposer(
         test_utils.PlusOneOverNEncodingStage()).make()
     with self.assertRaisesRegex(TypeError, 'TensorSpec'):
         gather_encoder.GatherEncoder.from_encoder(encoder,
                                                   not_a_tensorspec)
 def test_uninitialized_encode_raises(self):
   """Tests uninitialized stateful encoder cannot perform encode."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   x = tf.constant(1.0)
   with self.assertRaisesRegex(RuntimeError, 'not been initialized'):
     encoder.encode(x)
 def test_input_tensorspec(self):
     """Tests input_tensorspec property."""
     x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
     encoder = gather_encoder.GatherEncoder.from_encoder(
         core_encoder.EncoderComposer(
             test_utils.PlusOneOverNEncodingStage()).make(),
         tf.TensorSpec.from_tensor(x))
     self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
 def test_multiple_initialize_raises(self):
   """Tests encoder can be initialized only once."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   encoder.initialize()
   with self.assertRaisesRegex(RuntimeError, 'already initialized'):
     encoder.initialize()
 def test_multiple_encode_raises(self):
   """Tests the encode method of stateful encoder can only be called once."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   encoder.initialize()
   x = tf.constant(1.0)
   encoder.encode(x)
   with self.assertRaisesRegex(RuntimeError, 'only once'):
     encoder.encode(x)
  def test_modifying_encoded_x_raises(self):
    """Tests decode_fn raises if the encoded_x dictionary is modified."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.PlusOneOverNEncodingStage()).make())

    x = tf.constant(1.0)
    encoded_x, decode_fn = encoder.encode(x)
    encoded_x['__NOT_EXPECTED_KEY__'] = None
    with self.assertRaises(ValueError):
      decode_fn(encoded_x)
    with self.assertRaises(ValueError):
      decode_fn({})
Exemple #13
0
    def test_basic_encode_decode(self):
        """Tests basic encoding and decoding works as expected."""
        x = tf.constant(1.0, tf.float32)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.PlusOneOverNEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        for i in range(1, 10):
            x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state))
            self.assertAllClose(x, decoded_x)
            self.assertAllClose(
                1.0 + 1 / i, _encoded_x_field(encoded_x, [TENSORS, PN_VALS]))
  def test_basic_encode_decode(self):
    """Tests basic encoding and decoding works as expected."""
    encoder = simple_encoder.StatefulSimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.PlusOneOverNEncodingStage()).make())

    x = tf.constant(1.0)
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    for i in range(1, 10):
      x_py, encoded_x_py, decoded_x_py = self.evaluate(
          [x, encoded_x, decoded_x])
      self.assertAllClose(x_py, decoded_x_py)
      self.assertAllClose(1.0 + 1 / i,
                          _encoded_x_field(encoded_x_py, [TENSORS, PN_VALS]))
Exemple #15
0
    def test_input_signature_enforced(self):
        """Tests that encode/decode input signature is enforced."""
        x = tf.constant(1.0)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.PlusOneOverNEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        with self.assertRaises(ValueError):
            bad_x = tf.stack([x, x])
            encoder.encode(bad_x, state)
        with self.assertRaises(ValueError):
            bad_state = state + (x, )
            encoder.encode(x, bad_state)
        encoded_x = encoder.encode(x, state)
        with self.assertRaises(ValueError):
            bad_encoded_x = dict(encoded_x)
            bad_encoded_x.update({'x': x})
            encoder.decode(bad_encoded_x)
    def test_none_state_equal_to_initial_state(self):
        """Tests that not providing state is the same as initial_state."""
        x_fn = lambda: tf.constant(1.0)
        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.PlusOneOverNEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x_fn()))

        num_summands = 3
        stateful_iteration = _make_iteration_function(encoder, x_fn,
                                                      num_summands)
        state = encoder.initial_state()
        stateless_iteration = _make_stateless_iteration_function(
            encoder, x_fn, num_summands)

        stateful_data = self.evaluate(stateful_iteration(state))
        stateless_data = self.evaluate(stateless_iteration())

        self.assertAllClose(stateful_data.encoded_x, stateless_data.encoded_x)
        self.assertAllClose(stateful_data.decoded_x, stateless_data.decoded_x)
    def test_basic_encode_decode(self):
        """Tests basic encoding and decoding works as expected."""
        x_fn = lambda: tf.random.uniform((12, ))
        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.PlusOneOverNEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x_fn()))

        num_summands = 3
        iteration = _make_iteration_function(encoder, x_fn, num_summands)
        state = encoder.initial_state()

        for i in range(1, 5):
            data = self.evaluate(iteration(state))
            for j in range(num_summands):
                self.assertAllClose(
                    data.x[j] + 1 / i,
                    _encoded_x_field(data.encoded_x[j], [TENSORS, PN_VALS]))
            self.assertEqual((i, ), data.initial_state)
            self.assertEqual((i + 1, ), data.updated_state)
            state = data.updated_state
    def test_composite_encoder(self):
        """Tests functionality with a general, composite `Encoder`."""
        x_fn = lambda: tf.constant(1.2)
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = gather_encoder.GatherEncoder.from_encoder(
            encoder.make(), tf.TensorSpec.from_tensor(x_fn()))

        num_summands = 3
        iteration = _make_iteration_function(encoder, x_fn, num_summands)
        state = encoder.initial_state()

        for i in range(1, 5):
            data = self.evaluate(iteration(state))
            for j in range(num_summands):
                self.assertAllClose(
                    2.0,
                    _encoded_x_field(data.encoded_x[j],
                                     [TENSORS, SIF_SIGNS, T2_VALS]))
                self.assertAllClose(
                    2.0,
                    _encoded_x_field(data.encoded_x[j],
                                     [TENSORS, SIF_INTS, P1_VALS]))
                self.assertAllClose(
                    0.4 + 1 / i,
                    _encoded_x_field(data.encoded_x[j],
                                     [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
                self.assertAllClose(data.x[j], data.part_decoded_x[j])
                self.assertAllClose(data.x[j] * num_summands,
                                    data.summed_part_decoded_x)
                self.assertAllClose(data.x[j] * num_summands, data.decoded_x)

            self.assertEqual((i, ), data.initial_state)
            self.assertEqual((i + 1, ), data.updated_state)
            state = data.updated_state
Exemple #19
0
    def test_none_state_equal_to_initial_state(self):
        """Tests that not providing state is the same as initial_state."""
        x = tf.constant(1.0)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.PlusOneOverNEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        stateful_iteration = _make_iteration_function(encoder)

        @tf.function
        def stateless_iteration(x):
            encoded_x, _ = encoder.encode(x)
            decoded_x = encoder.decode(encoded_x)
            return encoded_x, decoded_x

        _, encoded_x_stateful, decoded_x_stateful, _ = self.evaluate(
            stateful_iteration(x, state))
        encoded_x_stateless, decoded_x_stateless = self.evaluate(
            stateless_iteration(x))

        self.assertAllClose(encoded_x_stateful, encoded_x_stateless)
        self.assertAllClose(decoded_x_stateful, decoded_x_stateless)
Exemple #20
0
    def test_correct_structure(self):
        """Tests that structured objects look like what they should.

    This test creates the following encoding tree:
    SignIntFloatEncodingStage
        [SIF_SIGNS] -> TimesTwoEncodingStage
        [SIF_INTS] -> PlusOneEncodingStage
        [SIF_FLOATS] -> PlusOneOverNEncodingStage
            [PN_VALS] -> AdaptiveNormalizeEncodingStage
    And verifies that the structured objects created by the methods of `Encoder`
    are of the expected structure.
    """
        sif_stage = test_utils.SignIntFloatEncodingStage()
        times_two_stage = test_utils.TimesTwoEncodingStage()
        plus_one_stage = test_utils.PlusOneEncodingStage()
        plus_n_squared_stage = test_utils.PlusOneOverNEncodingStage()
        adaptive_normalize_stage = test_utils.AdaptiveNormalizeEncodingStage()

        encoder = core_encoder.EncoderComposer(sif_stage)
        encoder.add_child(times_two_stage, SIF_SIGNS)
        encoder.add_child(plus_one_stage, SIF_INTS)
        encoder.add_child(plus_n_squared_stage,
                          SIF_FLOATS).add_child(adaptive_normalize_stage,
                                                PN_VALS)
        encoder = encoder.make()

        # Create all intermediary objects.
        x = tf.constant(1.0)
        initial_state = encoder.initial_state()
        encode_params, decode_params = encoder.get_params(initial_state)
        encoded_x, state_update_tensors, input_shapes = encoder.encode(
            x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        updated_state = encoder.update_state(initial_state,
                                             state_update_tensors)
        commuting_structure = encoder.commuting_structure

        # Verify the structure and naming of those objects is as expected.
        for state in [initial_state, updated_state]:
            tf.nest.assert_same_structure(
                {
                    STATE: {},
                    CHILDREN: {
                        SIF_INTS: {
                            STATE: {},
                            CHILDREN: {}
                        },
                        SIF_SIGNS: {
                            STATE: {},
                            CHILDREN: {}
                        },
                        SIF_FLOATS: {
                            STATE: {
                                PN_ITER_STATE: None
                            },
                            CHILDREN: {
                                PN_VALS: {
                                    STATE: {
                                        AN_FACTOR_STATE: None
                                    },
                                    CHILDREN: {}
                                }
                            }
                        }
                    }
                }, state)
        self.assertIn(
            'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS +
            '/' + plus_n_squared_stage.name +
            encoding_stage.INITIAL_STATE_SCOPE_SUFFIX,
            initial_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name)
        self.assertIn(
            'encoder_initial_state/' + sif_stage.name + '/' + SIF_FLOATS +
            '/' + plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name +
            encoding_stage.INITIAL_STATE_SCOPE_SUFFIX, initial_state[CHILDREN]
            [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name)
        self.assertIn(
            'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name +
            encoding_stage.UPDATE_STATE_SCOPE_SUFFIX,
            updated_state[CHILDREN][SIF_FLOATS][STATE][PN_ITER_STATE].name)
        self.assertIn(
            'encoder_update_state/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name +
            encoding_stage.UPDATE_STATE_SCOPE_SUFFIX, updated_state[CHILDREN]
            [SIF_FLOATS][CHILDREN][PN_VALS][STATE][AN_FACTOR_STATE].name)

        for params in [encode_params, decode_params]:
            tf.nest.assert_same_structure(
                {
                    PARAMS: {},
                    CHILDREN: {
                        SIF_INTS: {
                            PARAMS: {
                                P1_ADD_PARAM: None
                            },
                            CHILDREN: {}
                        },
                        SIF_SIGNS: {
                            PARAMS: {
                                T2_FACTOR_PARAM: None
                            },
                            CHILDREN: {}
                        },
                        SIF_FLOATS: {
                            PARAMS: {
                                PN_ADD_PARAM: None
                            },
                            CHILDREN: {
                                PN_VALS: {
                                    PARAMS: {
                                        AN_FACTOR_PARAM: None
                                    },
                                    CHILDREN: {}
                                }
                            }
                        }
                    }
                }, params)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_INTS + '/' +
                plus_one_stage.name + encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_INTS][PARAMS][P1_ADD_PARAM].name)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_SIGNS +
                '/' + times_two_stage.name +
                encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_SIGNS][PARAMS][T2_FACTOR_PARAM].name)
            self.assertIn(
                'encoder_get_params/' + sif_stage.name + '/' + SIF_FLOATS +
                '/' + plus_n_squared_stage.name +
                encoding_stage.GET_PARAMS_SCOPE_SUFFIX,
                params[CHILDREN][SIF_FLOATS][PARAMS][PN_ADD_PARAM].name)
            # Note: we do not check the value of
            # params[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS][PARAMS][AN_FACTOR_PARAM]
            # because the get_params method of adaptive_normalize_stage does not
            # modify the graph, only passes through the provided state tensor.

        tf.nest.assert_same_structure(
            {
                SIF_INTS: {
                    P1_VALS: None
                },
                SIF_SIGNS: {
                    T2_VALS: None
                },
                SIF_FLOATS: {
                    PN_VALS: {
                        AN_VALS: None
                    }
                }
            }, encoded_x)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_INTS + '/' +
            plus_one_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_INTS][P1_VALS].name)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_SIGNS + '/' +
            times_two_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_SIGNS][T2_VALS].name)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            encoded_x[SIF_FLOATS][PN_VALS][AN_VALS].name)

        tf.nest.assert_same_structure(
            {
                TENSORS: {},
                CHILDREN: {
                    SIF_INTS: {
                        TENSORS: {},
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        TENSORS: {},
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        TENSORS: {},
                        CHILDREN: {
                            PN_VALS: {
                                TENSORS: {
                                    AN_NORM_UPDATE: None
                                },
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, state_update_tensors)
        tf.nest.assert_same_structure(state_update_tensors,
                                      encoder.state_update_aggregation_modes)
        self.assertIn(
            'encoder_encode/' + sif_stage.name + '/' + SIF_FLOATS + '/' +
            plus_n_squared_stage.name + '/' + PN_VALS + '/' +
            adaptive_normalize_stage.name + encoding_stage.ENCODE_SCOPE_SUFFIX,
            state_update_tensors[CHILDREN][SIF_FLOATS][CHILDREN][PN_VALS]
            [TENSORS][AN_NORM_UPDATE].name)

        tf.nest.assert_same_structure(
            {
                SHAPE: None,
                CHILDREN: {
                    SIF_INTS: {
                        SHAPE: None,
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        SHAPE: None,
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        SHAPE: None,
                        CHILDREN: {
                            PN_VALS: {
                                SHAPE: None,
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, input_shapes)
        self.assertTrue(tf.is_tensor(decoded_x))
        self.assertIn('encoder_decode/', decoded_x.name)

        tf.nest.assert_same_structure(
            {
                COMMUTE: None,
                CHILDREN: {
                    SIF_INTS: {
                        COMMUTE: None,
                        CHILDREN: {}
                    },
                    SIF_SIGNS: {
                        COMMUTE: None,
                        CHILDREN: {}
                    },
                    SIF_FLOATS: {
                        COMMUTE: None,
                        CHILDREN: {
                            PN_VALS: {
                                COMMUTE: None,
                                CHILDREN: {}
                            }
                        }
                    }
                }
            }, commuting_structure)
        for item in tf.nest.flatten(commuting_structure):
            self.assertEqual(False, item)
 def test_is_adaptive_stage(self):
   self.assertFalse(
       test_utils.is_adaptive_stage(test_utils.PlusOneEncodingStage()))
   self.assertTrue(
       test_utils.is_adaptive_stage(test_utils.PlusOneOverNEncodingStage()))
 def default_encoding_stage(self):
   """See base class."""
   return test_utils.PlusOneOverNEncodingStage()
 def test_not_a_tensorspec_raises(self, bad_tensorspec):
   """Tests invalid type of tensorspec argument."""
   encoder = core_encoder.EncoderComposer(
       test_utils.PlusOneOverNEncodingStage()).make()
   with self.assertRaisesRegex(TypeError, 'TensorSpec'):
     simple_encoder.SimpleEncoderV2(encoder, bad_tensorspec)