def testWrapperV2Caller(self, wrapper):
        """Tests that wrapper V2 is using the LayerRNNCell's caller."""

        with legacy_base_layer.keras_style_scope():
            base_cell = legacy_cells.MultiRNNCell(
                [legacy_cells.BasicRNNCell(1) for _ in range(2)])
        rnn_cell = wrapper(base_cell)
        inputs = tf.convert_to_tensor([[1]], dtype=tf.float32)
        state = tf.convert_to_tensor([[1]], dtype=tf.float32)
        _ = rnn_cell(inputs, [state, state])
        weights = base_cell._cells[0].weights
        self.assertLen(weights, expected_len=2)
        self.assertTrue(all("_wrapper" in v.name for v in weights))
Beispiel #2
0
 def testWrapperKerasStyle(self, wrapper):
     """Tests if wrapper cell is instantiated in keras style scope."""
     wrapped_cell = wrapper(legacy_cells.BasicRNNCell(1))
     self.assertIsNone(getattr(wrapped_cell, "_keras_style", None))
 def testWrapperKerasStyle(self, wrapper):
     """Tests if wrapper cell is instantiated in keras style scope."""
     wrapped_cell = wrapper(legacy_cells.BasicRNNCell(1))
     self.assertFalse(wrapped_cell._keras_style)