Пример #1
0
  def testWrapperKerasStyle(self, wrapper, wrapper_v2):
    """Tests if wrapper cell is instantiated in keras style scope."""
    wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1))
    self.assertIsNone(getattr(wrapped_cell_v2, "_keras_style", None))

    wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1))
    self.assertFalse(wrapped_cell._keras_style)
Пример #2
0
  def testWrapperV2Caller(self, wrapper):
    """Tests that wrapper V2 is using the LayerRNNCell's caller."""

    with legacy_base_layer.keras_style_scope():
      base_cell = rnn_cell_impl.MultiRNNCell(
          [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
    rnn_cell = wrapper(base_cell)
    inputs = tf.convert_to_tensor([[1]], dtype=tf.float32)
    state = tf.convert_to_tensor([[1]], dtype=tf.float32)
    _ = rnn_cell(inputs, [state, state])
    weights = base_cell._cells[0].weights
    self.assertLen(weights, expected_len=2)
    self.assertTrue(all("_wrapper" in v.name for v in weights))