def testRNNLayersWithRNNCellParams(self): model = keras.Sequential() model.add( prune.prune_low_magnitude(keras.layers.RNN([ layers.LSTMCell(10), layers.GRUCell(10), layers.PeepholeLSTMCell(10), layers.SimpleRNNCell(10) ]), input_shape=(3, 4), **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testWrapperWeights(self, wrapper): """Tests that wrapper weights contain wrapped cells weights.""" base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell") rnn_cell = wrapper(base_cell) rnn_layer = layers.RNN(rnn_cell) inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) rnn_layer(inputs) expected_weights = ["rnn/" + var for var in ("kernel:0", "recurrent_kernel:0", "bias:0")] self.assertLen(rnn_cell.weights, 3) self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights) self.assertCountEqual([v.name for v in rnn_cell.trainable_variables], expected_weights) self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables], []) self.assertCountEqual([v.name for v in rnn_cell.cell.weights], expected_weights)