Ejemplo n.º 1
0
    def testRNNLayersWithRNNCellParams(self):
        model = keras.Sequential()
        model.add(
            prune.prune_low_magnitude(keras.layers.RNN([
                layers.LSTMCell(10),
                layers.GRUCell(10),
                layers.PeepholeLSTMCell(10),
                layers.SimpleRNNCell(10)
            ]),
                                      input_shape=(3, 4),
                                      **self.params))

        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        test_utils.assert_model_sparsity(self, 0.0, model)
        model.fit(np.random.randn(
            *self._batch(model.input.get_shape().as_list(), 32)),
                  np.random.randn(
                      *self._batch(model.output.get_shape().as_list(), 32)),
                  callbacks=[pruning_callbacks.UpdatePruningStep()])

        test_utils.assert_model_sparsity(self, 0.5, model)

        self._check_strip_pruning_matches_original(model, 0.5)
  def testWrapperWeights(self, wrapper):
    """Tests that wrapper weights contain wrapped cells weights."""
    base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell")
    rnn_cell = wrapper(base_cell)
    rnn_layer = layers.RNN(rnn_cell)
    inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
    rnn_layer(inputs)

    expected_weights = ["rnn/" + var for var in
                        ("kernel:0", "recurrent_kernel:0", "bias:0")]
    self.assertLen(rnn_cell.weights, 3)
    self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights)
    self.assertCountEqual([v.name for v in rnn_cell.trainable_variables],
                          expected_weights)
    self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables],
                          [])
    self.assertCountEqual([v.name for v in rnn_cell.cell.weights],
                          expected_weights)