def testWrapperKerasStyle(self, wrapper, wrapper_v2):
    """Tests if wrapper cell is instantiated in keras style scope."""
    wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1))
    self.assertIsNone(getattr(wrapped_cell_v2, "_keras_style", None))

    wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1))
    self.assertFalse(wrapped_cell._keras_style)
  def testDropoutWrapperKerasStyle(self):
    """Tests if DropoutWrapperV2 cell is instantiated in keras style scope."""
    wrapped_cell_v2 = rnn_cell_impl.DropoutWrapperV2(
        rnn_cell_impl.BasicRNNCell(1))
    self.assertTrue(wrapped_cell_v2._keras_style)

    wrapped_cell = rnn_cell_impl.DropoutWrapper(rnn_cell_impl.BasicRNNCell(1))
    self.assertFalse(wrapped_cell._keras_style)
Example #3
0
    def testCellGetInitialState(self):
        cell = rnn_cell_impl.BasicRNNCell(5)
        with self.assertRaisesRegexp(ValueError,
                                     "batch_size and dtype cannot be None"):
            cell.get_initial_state(None, None, None)

        inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1))
        with self.assertRaisesRegexp(
                ValueError, "batch size from input tensor is different from"):
            cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None)

        with self.assertRaisesRegexp(
                ValueError, "batch size from input tensor is different from"):
            cell.get_initial_state(inputs=inputs,
                                   batch_size=constant_op.constant(50),
                                   dtype=None)

        with self.assertRaisesRegexp(
                ValueError, "dtype from input tensor is different from"):
            cell.get_initial_state(inputs=inputs,
                                   batch_size=None,
                                   dtype=dtypes.int16)

        initial_state = cell.get_initial_state(inputs=inputs,
                                               batch_size=None,
                                               dtype=None)
        self.assertEqual(initial_state.shape.as_list(), [None, 5])
        self.assertEqual(initial_state.dtype, inputs.dtype)

        batch = array_ops.shape(inputs)[0]
        dtype = inputs.dtype
        initial_state = cell.get_initial_state(None, batch, dtype)
        self.assertEqual(initial_state.shape.as_list(), [None, 5])
        self.assertEqual(initial_state.dtype, inputs.dtype)
Example #4
0
    def testBasicRNNCellNotTrainable(self):
        with self.test_session() as sess:

            def not_trainable_getter(getter, *args, **kwargs):
                kwargs["trainable"] = False
                return getter(*args, **kwargs)

            with variable_scope.variable_scope(
                    "root",
                    initializer=init_ops.constant_initializer(0.5),
                    custom_getter=not_trainable_getter):
                x = array_ops.zeros([1, 2])
                m = array_ops.zeros([1, 2])
                cell = rnn_cell_impl.BasicRNNCell(2)
                g, _ = cell(x, m)
                self.assertFalse(cell.trainable_variables)
                self.assertEqual([
                    "root/basic_rnn_cell/%s:0" %
                    rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
                    "root/basic_rnn_cell/%s:0" %
                    rnn_cell_impl._BIAS_VARIABLE_NAME
                ], [v.name for v in cell.non_trainable_variables])
                sess.run([variables_lib.global_variables_initializer()])
                res = sess.run([g], {
                    x.name: np.array([[1., 1.]]),
                    m.name: np.array([[0.1, 0.1]])
                })
                self.assertEqual(res[0].shape, (1, 2))
Example #5
0
  def testRNNCellSerialization(self):
    for cell in [
        rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
        rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
        rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
        rnn_cell_impl.GRUCell(32, dtype=dtypes.float32)
    ]:
      with self.cached_session():
        x = keras.Input((None, 5))
        layer = keras.layers.RNN(cell)
        y = layer(x)
        model = keras.models.Model(x, y)
        model.compile(optimizer="rmsprop", loss="mse")

        # Test basic case serialization.
        x_np = np.random.random((6, 5, 5))
        y_np = model.predict(x_np)
        weights = model.get_weights()
        config = layer.get_config()
        # The custom_objects is important here since rnn_cell_impl is
        # not visible as a Keras layer, and also has a name conflict with
        # keras.LSTMCell and GRUCell.
        layer = keras.layers.RNN.from_config(
            config,
            custom_objects={
                "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
                "GRUCell": rnn_cell_impl.GRUCell,
                "LSTMCell": rnn_cell_impl.LSTMCell,
                "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
            })
        y = layer(x)
        model = keras.models.Model(x, y)
        model.set_weights(weights)
        y_np_2 = model.predict(x_np)
        self.assertAllClose(y_np, y_np_2, atol=1e-4)
 def testDropoutWrapperProperties(self, wrapper_type):
   cell = rnn_cell_impl.BasicRNNCell(10)
   wrapper = wrapper_type(cell)
   # Github issue 15810
   self.assertEqual(wrapper.wrapped_cell, cell)
   self.assertEqual(wrapper.state_size, 10)
   self.assertEqual(wrapper.output_size, 10)
Example #7
0
def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None):
    mode = rnn.rnn_mode
    num_units = rnn.num_units
    num_layers = rnn.num_layers

    # To reuse cuDNN-trained models, must use cudnn compatible rnn cells.
    if mode == CUDNN_LSTM:
        single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units)
    elif mode == CUDNN_GRU:
        single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units)
    elif mode == CUDNN_RNN_TANH:
        single_cell = (
            lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh))
    elif mode == CUDNN_RNN_RELU:
        single_cell = (
            lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu))
    else:
        raise ValueError("%s is not supported!" % mode)

    if not is_bidi:
        cell = rnn_cell_impl.MultiRNNCell(
            [single_cell() for _ in range(num_layers)])
        return rnn_lib.dynamic_rnn(cell,
                                   inputs,
                                   dtype=dtypes.float32,
                                   time_major=True,
                                   scope=scope)
    else:
        cells_fw = [single_cell() for _ in range(num_layers)]
        cells_bw = [single_cell() for _ in range(num_layers)]

        (outputs, output_state_fw,
         output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn(
             cells_fw,
             cells_bw,
             inputs,
             dtype=dtypes.float32,
             time_major=True,
             scope=scope)
        return outputs, (output_state_fw, output_state_bw)
  def testWrapperV2Caller(self, wrapper):
    """Tests that wrapper V2 is using the LayerRNNCell's caller."""

    with base_layer.keras_style_scope():
      base_cell = rnn_cell_impl.MultiRNNCell(
          [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
    rnn_cell = wrapper(base_cell)
    inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32)
    state = ops.convert_to_tensor([[1]], dtype=dtypes.float32)
    _ = rnn_cell(inputs, [state, state])
    weights = base_cell._cells[0].weights
    self.assertLen(weights, expected_len=2)
    self.assertTrue(all(["_wrapper" in v.name for v in weights]))
 def _rnn_input(apply_wrapper):
   """Creates a RNN layer with/without wrapper and returns built rnn cell."""
   with base_layer.keras_style_scope():
     base_cell = rnn_cell_impl.MultiRNNCell(
         [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
   if apply_wrapper:
     rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell)
   else:
     rnn_cell = base_cell
   rnn_layer = keras_layers.RNN(rnn_cell)
   inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
   _ = rnn_layer(inputs)
   return base_cell._cells[0]
 def testWrapperCheckpointing(self):
   for wrapper_type in [
       rnn_cell_impl.DropoutWrapper,
       rnn_cell_impl.ResidualWrapper,
       lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
     cell = rnn_cell_impl.BasicRNNCell(1)
     wrapper = wrapper_type(cell)
     wrapper(array_ops.ones([1, 1]),
             state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
     self.evaluate([v.initializer for v in cell.variables])
     checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     self.evaluate(cell._bias.assign([40.]))
     save_path = checkpoint.save(prefix)
     self.evaluate(cell._bias.assign([0.]))
     checkpoint.restore(save_path).assert_consumed().run_restore_ops()
     self.assertAllEqual([40.], self.evaluate(cell._bias))
 def testBasicRNNCell(self):
   with self.cached_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       x = array_ops.zeros([1, 2])
       m = array_ops.zeros([1, 2])
       cell = rnn_cell_impl.BasicRNNCell(2)
       g, _ = cell(x, m)
       self.assertEqual([
           "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
           "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
       ], [v.name for v in cell.trainable_variables])
       self.assertFalse(cell.non_trainable_variables)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([g], {
           x.name: np.array([[1., 1.]]),
           m.name: np.array([[0.1, 0.1]])
       })
       self.assertEqual(res[0].shape, (1, 2))
Example #12
0
  def testSimpleRNNCellAndBasicRNNCellComparison(self):
    input_shape = 10
    output_shape = 5
    timestep = 4
    batch = 20
    (x_train, _), _ = testing_utils.get_test_data(
        train_samples=batch,
        test_samples=0,
        input_shape=(timestep, input_shape),
        num_classes=output_shape)
    fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
    fix_weights_generator.build((None, input_shape))
    # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
    # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
    # zipped [kernel, recurrent_kernel] in SimpleRNNCell.
    keras_weights = fix_weights_generator.get_weights()
    kernel, recurrent_kernel, bias = keras_weights
    tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]

    with self.session(graph=ops_lib.Graph()) as sess:
      inputs = array_ops.placeholder(
          dtypes.float32, shape=(None, timestep, input_shape))
      cell = keras.layers.SimpleRNNCell(output_shape)
      k_out, k_state = rnn.dynamic_rnn(
          cell, inputs, dtype=dtypes.float32)
      cell.set_weights(keras_weights)
      [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
    with self.session(graph=ops_lib.Graph()) as sess:
      inputs = array_ops.placeholder(
          dtypes.float32, shape=(None, timestep, input_shape))
      cell = rnn_cell_impl.BasicRNNCell(output_shape)
      tf_out, tf_state = rnn.dynamic_rnn(
          cell, inputs, dtype=dtypes.float32)
      cell.set_weights(tf_weights)
      [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})

    self.assertAllClose(tf_out, k_out, atol=1e-5)
    self.assertAllClose(tf_state, k_state, atol=1e-5)
Example #13
0
 def testBasicRNNCellMatch(self):
   batch_size = 32
   input_size = 100
   num_units = 10
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       inputs = random_ops.random_uniform((batch_size, input_size))
       _, initial_state = basic_rnn_cell(inputs, None, num_units)
       rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
       outputs, state = rnn_cell(inputs, initial_state)
       variable_scope.get_variable_scope().reuse_variables()
       my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
       # pylint: disable=protected-access
       slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
       # pylint: enable=protected-access
       slim_outputs, slim_state = slim_cell(inputs, initial_state)
       self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
       self.assertEqual(slim_state.get_shape(), state.get_shape())
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([slim_outputs, slim_state, outputs, state])
       self.assertAllClose(res[0], res[2])
       self.assertAllClose(res[1], res[3])
Example #14
0
 def testWrappedCellProperty(self):
     cell = rnn_cell_impl.BasicRNNCell(10)
     wrapper = rnn_cell_impl.DropoutWrapper(cell)
     # Github issue 15810
     self.assertEqual(wrapper.wrapped_cell, cell)