def testRNNLayersWithRNNCellParams(self): model = keras.Sequential() model.add( prune.prune_low_magnitude(keras.layers.RNN([ layers.LSTMCell(10), layers.GRUCell(10), layers.PeepholeLSTMCell(10), layers.SimpleRNNCell(10) ]), input_shape=(3, 4), **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testDropoutWrapperSerialization(self): wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper cell = layers.GRUCell(10) wrapper = wrapper_cls(cell) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls) wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertTrue(reconstructed_wrapper._dropout_state_filter(None)) def dropout_state_filter_visitor(unused_state): return False wrapper = wrapper_cls( cell, dropout_state_filter_visitor=dropout_state_filter_visitor) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
def testReturnsProvider_KerasRNNLayerWithKerasRNNCells(self): lstm_cell = l.LSTMCell(3) gru_cell = l.GRUCell(2) model = keras.Sequential( [l.RNN([lstm_cell, gru_cell], input_shape=(3, 2))]) layer = model.layers[0] quantize_provider = self.quantize_registry.get_quantize_provider(layer) (weights, weight_quantizers) = self._convert_list( quantize_provider.get_weights_and_quantizers(layer)) (activations, activation_quantizers) = self._convert_list( quantize_provider.get_activations_and_quantizers(layer)) self._assert_weight_quantizers(weight_quantizers) self.assertEqual([ lstm_cell.kernel, lstm_cell.recurrent_kernel, gru_cell.kernel, gru_cell.recurrent_kernel ], weights) self._assert_activation_quantizers(activation_quantizers) self.assertEqual([ lstm_cell.activation, lstm_cell.recurrent_activation, gru_cell.activation, gru_cell.recurrent_activation ], activations)
def setUp(self): super(TFLiteQuantizeProviderRNNTest, self).setUp() self.cell1 = l.LSTMCell(3) self.cell2 = l.GRUCell(2) self.layer = l.RNN([self.cell1, self.cell2]) self.layer.build(input_shape=(3, 2)) self.quantize_provider = tflite_quantize_registry.TFLiteQuantizeProviderRNN( [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']], [['activation', 'recurrent_activation'], ['activation', 'recurrent_activation']], False)
def testSupports_KerasRNNLayerWithKerasRNNCells(self): self.assertTrue(self.quantize_registry.supports(l.RNN(cell=l.LSTMCell(10)))) self.assertTrue( self.quantize_registry.supports( l.RNN(cell=[l.LSTMCell(10), l.GRUCell(10)])))