예제 #1
0
    def testRNNLayersWithRNNCellParams(self):
        model = keras.Sequential()
        model.add(
            prune.prune_low_magnitude(keras.layers.RNN([
                layers.LSTMCell(10),
                layers.GRUCell(10),
                layers.PeepholeLSTMCell(10),
                layers.SimpleRNNCell(10)
            ]),
                                      input_shape=(3, 4),
                                      **self.params))

        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        test_utils.assert_model_sparsity(self, 0.0, model)
        model.fit(np.random.randn(
            *self._batch(model.input.get_shape().as_list(), 32)),
                  np.random.randn(
                      *self._batch(model.output.get_shape().as_list(), 32)),
                  callbacks=[pruning_callbacks.UpdatePruningStep()])

        test_utils.assert_model_sparsity(self, 0.5, model)

        self._check_strip_pruning_matches_original(model, 0.5)
예제 #2
0
  def testDropoutWrapperSerialization(self):
    wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
    cell = layers.GRUCell(10)
    wrapper = wrapper_cls(cell)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertDictEqual(config, reconstructed_wrapper.get_config())
    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)

    wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertTrue(reconstructed_wrapper._dropout_state_filter(None))

    def dropout_state_filter_visitor(unused_state):
      return False

    wrapper = wrapper_cls(
        cell, dropout_state_filter_visitor=dropout_state_filter_visitor)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
    def testReturnsProvider_KerasRNNLayerWithKerasRNNCells(self):
        lstm_cell = l.LSTMCell(3)
        gru_cell = l.GRUCell(2)
        model = keras.Sequential(
            [l.RNN([lstm_cell, gru_cell], input_shape=(3, 2))])
        layer = model.layers[0]

        quantize_provider = self.quantize_registry.get_quantize_provider(layer)

        (weights, weight_quantizers) = self._convert_list(
            quantize_provider.get_weights_and_quantizers(layer))
        (activations, activation_quantizers) = self._convert_list(
            quantize_provider.get_activations_and_quantizers(layer))

        self._assert_weight_quantizers(weight_quantizers)
        self.assertEqual([
            lstm_cell.kernel, lstm_cell.recurrent_kernel, gru_cell.kernel,
            gru_cell.recurrent_kernel
        ], weights)

        self._assert_activation_quantizers(activation_quantizers)
        self.assertEqual([
            lstm_cell.activation, lstm_cell.recurrent_activation,
            gru_cell.activation, gru_cell.recurrent_activation
        ], activations)
    def setUp(self):
        super(TFLiteQuantizeProviderRNNTest, self).setUp()

        self.cell1 = l.LSTMCell(3)
        self.cell2 = l.GRUCell(2)
        self.layer = l.RNN([self.cell1, self.cell2])
        self.layer.build(input_shape=(3, 2))

        self.quantize_provider = tflite_quantize_registry.TFLiteQuantizeProviderRNN(
            [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']],
            [['activation', 'recurrent_activation'],
             ['activation', 'recurrent_activation']], False)
예제 #5
0
 def testSupports_KerasRNNLayerWithKerasRNNCells(self):
   self.assertTrue(self.quantize_registry.supports(l.RNN(cell=l.LSTMCell(10))))
   self.assertTrue(
       self.quantize_registry.supports(
           l.RNN(cell=[l.LSTMCell(10), l.GRUCell(10)])))