def testWrapperKerasStyle(self, wrapper, wrapper_v2): """Tests if wrapper cell is instantiated in keras style scope.""" wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1)) self.assertIsNone(getattr(wrapped_cell_v2, "_keras_style", None)) wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1)) self.assertFalse(wrapped_cell._keras_style)
def testDropoutWrapperKerasStyle(self): """Tests if DropoutWrapperV2 cell is instantiated in keras style scope.""" wrapped_cell_v2 = rnn_cell_impl.DropoutWrapperV2( rnn_cell_impl.BasicRNNCell(1)) self.assertTrue(wrapped_cell_v2._keras_style) wrapped_cell = rnn_cell_impl.DropoutWrapper(rnn_cell_impl.BasicRNNCell(1)) self.assertFalse(wrapped_cell._keras_style)
def testCellGetInitialState(self): cell = rnn_cell_impl.BasicRNNCell(5) with self.assertRaisesRegexp(ValueError, "batch_size and dtype cannot be None"): cell.get_initial_state(None, None, None) inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1)) with self.assertRaisesRegexp( ValueError, "batch size from input tensor is different from"): cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None) with self.assertRaisesRegexp( ValueError, "batch size from input tensor is different from"): cell.get_initial_state(inputs=inputs, batch_size=constant_op.constant(50), dtype=None) with self.assertRaisesRegexp( ValueError, "dtype from input tensor is different from"): cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16) initial_state = cell.get_initial_state(inputs=inputs, batch_size=None, dtype=None) self.assertEqual(initial_state.shape.as_list(), [None, 5]) self.assertEqual(initial_state.dtype, inputs.dtype) batch = array_ops.shape(inputs)[0] dtype = inputs.dtype initial_state = cell.get_initial_state(None, batch, dtype) self.assertEqual(initial_state.shape.as_list(), [None, 5]) self.assertEqual(initial_state.dtype, inputs.dtype)
def testBasicRNNCellNotTrainable(self): with self.test_session() as sess: def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False return getter(*args, **kwargs) with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5), custom_getter=not_trainable_getter): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) cell = rnn_cell_impl.BasicRNNCell(2) g, _ = cell(x, m) self.assertFalse(cell.trainable_variables) self.assertEqual([ "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME ], [v.name for v in cell.non_trainable_variables]) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g], { x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1]]) }) self.assertEqual(res[0].shape, (1, 2))
def testRNNCellSerialization(self): for cell in [ rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True), rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32), rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32), rnn_cell_impl.GRUCell(32, dtype=dtypes.float32) ]: with self.cached_session(): x = keras.Input((None, 5)) layer = keras.layers.RNN(cell) y = layer(x) model = keras.models.Model(x, y) model.compile(optimizer="rmsprop", loss="mse") # Test basic case serialization. x_np = np.random.random((6, 5, 5)) y_np = model.predict(x_np) weights = model.get_weights() config = layer.get_config() # The custom_objects is important here since rnn_cell_impl is # not visible as a Keras layer, and also has a name conflict with # keras.LSTMCell and GRUCell. layer = keras.layers.RNN.from_config( config, custom_objects={ "BasicRNNCell": rnn_cell_impl.BasicRNNCell, "GRUCell": rnn_cell_impl.GRUCell, "LSTMCell": rnn_cell_impl.LSTMCell, "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell }) y = layer(x) model = keras.models.Model(x, y) model.set_weights(weights) y_np_2 = model.predict(x_np) self.assertAllClose(y_np, y_np_2, atol=1e-4)
def testDropoutWrapperProperties(self, wrapper_type): cell = rnn_cell_impl.BasicRNNCell(10) wrapper = wrapper_type(cell) # Github issue 15810 self.assertEqual(wrapper.wrapped_cell, cell) self.assertEqual(wrapper.state_size, 10) self.assertEqual(wrapper.output_size, 10)
def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): mode = rnn.rnn_mode num_units = rnn.num_units num_layers = rnn.num_layers # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. if mode == CUDNN_LSTM: single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) elif mode == CUDNN_GRU: single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) elif mode == CUDNN_RNN_TANH: single_cell = ( lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh)) elif mode == CUDNN_RNN_RELU: single_cell = ( lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu)) else: raise ValueError("%s is not supported!" % mode) if not is_bidi: cell = rnn_cell_impl.MultiRNNCell( [single_cell() for _ in range(num_layers)]) return rnn_lib.dynamic_rnn(cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) else: cells_fw = [single_cell() for _ in range(num_layers)] cells_bw = [single_cell() for _ in range(num_layers)] (outputs, output_state_fw, output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, inputs, dtype=dtypes.float32, time_major=True, scope=scope) return outputs, (output_state_fw, output_state_bw)
def testWrapperV2Caller(self, wrapper): """Tests that wrapper V2 is using the LayerRNNCell's caller.""" with base_layer.keras_style_scope(): base_cell = rnn_cell_impl.MultiRNNCell( [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) rnn_cell = wrapper(base_cell) inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32) state = ops.convert_to_tensor([[1]], dtype=dtypes.float32) _ = rnn_cell(inputs, [state, state]) weights = base_cell._cells[0].weights self.assertLen(weights, expected_len=2) self.assertTrue(all(["_wrapper" in v.name for v in weights]))
def _rnn_input(apply_wrapper): """Creates a RNN layer with/without wrapper and returns built rnn cell.""" with base_layer.keras_style_scope(): base_cell = rnn_cell_impl.MultiRNNCell( [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) if apply_wrapper: rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell) else: rnn_cell = base_cell rnn_layer = keras_layers.RNN(rnn_cell) inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) _ = rnn_layer(inputs) return base_cell._cells[0]
def testWrapperCheckpointing(self): for wrapper_type in [ rnn_cell_impl.DropoutWrapper, rnn_cell_impl.ResidualWrapper, lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: cell = rnn_cell_impl.BasicRNNCell(1) wrapper = wrapper_type(cell) wrapper(array_ops.ones([1, 1]), state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) self.evaluate([v.initializer for v in cell.variables]) checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(cell._bias.assign([40.])) save_path = checkpoint.save(prefix) self.evaluate(cell._bias.assign([0.])) checkpoint.restore(save_path).assert_consumed().run_restore_ops() self.assertAllEqual([40.], self.evaluate(cell._bias))
def testBasicRNNCell(self): with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) cell = rnn_cell_impl.BasicRNNCell(2) g, _ = cell(x, m) self.assertEqual([ "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME ], [v.name for v in cell.trainable_variables]) self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g], { x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1]]) }) self.assertEqual(res[0].shape, (1, 2))
def testSimpleRNNCellAndBasicRNNCellComparison(self): input_shape = 10 output_shape = 5 timestep = 4 batch = 20 (x_train, _), _ = testing_utils.get_test_data( train_samples=batch, test_samples=0, input_shape=(timestep, input_shape), num_classes=output_shape) fix_weights_generator = keras.layers.SimpleRNNCell(output_shape) fix_weights_generator.build((None, input_shape)) # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is # zipped [kernel, recurrent_kernel] in SimpleRNNCell. keras_weights = fix_weights_generator.get_weights() kernel, recurrent_kernel, bias = keras_weights tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias] with self.session(graph=ops_lib.Graph()) as sess: inputs = array_ops.placeholder( dtypes.float32, shape=(None, timestep, input_shape)) cell = keras.layers.SimpleRNNCell(output_shape) k_out, k_state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32) cell.set_weights(keras_weights) [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train}) with self.session(graph=ops_lib.Graph()) as sess: inputs = array_ops.placeholder( dtypes.float32, shape=(None, timestep, input_shape)) cell = rnn_cell_impl.BasicRNNCell(output_shape) tf_out, tf_state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32) cell.set_weights(tf_weights) [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train}) self.assertAllClose(tf_out, k_out, atol=1e-5) self.assertAllClose(tf_state, k_state, atol=1e-5)
def testBasicRNNCellMatch(self): batch_size = 32 input_size = 100 num_units = 10 with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inputs = random_ops.random_uniform((batch_size, input_size)) _, initial_state = basic_rnn_cell(inputs, None, num_units) rnn_cell = rnn_cell_impl.BasicRNNCell(num_units) outputs, state = rnn_cell(inputs, initial_state) variable_scope.get_variable_scope().reuse_variables() my_cell = functools.partial(basic_rnn_cell, num_units=num_units) # pylint: disable=protected-access slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) self.assertEqual(slim_state.get_shape(), state.get_shape()) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([slim_outputs, slim_state, outputs, state]) self.assertAllClose(res[0], res[2]) self.assertAllClose(res[1], res[3])
def testWrappedCellProperty(self): cell = rnn_cell_impl.BasicRNNCell(10) wrapper = rnn_cell_impl.DropoutWrapper(cell) # Github issue 15810 self.assertEqual(wrapper.wrapped_cell, cell)