def testWrapperKerasStyle(self, wrapper, wrapper_v2): """Tests if wrapper cell is instantiated in keras style scope.""" wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1)) self.assertIsNone(getattr(wrapped_cell_v2, "_keras_style", None)) wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1)) self.assertFalse(wrapped_cell._keras_style)
def testWrapperV2Caller(self, wrapper): """Tests that wrapper V2 is using the LayerRNNCell's caller.""" with legacy_base_layer.keras_style_scope(): base_cell = rnn_cell_impl.MultiRNNCell( [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) rnn_cell = wrapper(base_cell) inputs = tf.convert_to_tensor([[1]], dtype=tf.float32) state = tf.convert_to_tensor([[1]], dtype=tf.float32) _ = rnn_cell(inputs, [state, state]) weights = base_cell._cells[0].weights self.assertLen(weights, expected_len=2) self.assertTrue(all("_wrapper" in v.name for v in weights))