def testDoesNotSupport_RNNLayerWithCustomRNNCell(self):
        class MinimalRNNCell(l.Layer):
            def __init__(self, units, **kwargs):
                self.units = units
                self.state_size = units
                super(MinimalRNNCell, self).__init__(**kwargs)

        self.assertFalse(
            self.quantize_registry.supports(l.RNN(cell=MinimalRNNCell(10))))
        self.assertFalse(
            self.quantize_registry.supports(
                l.RNN(cell=[l.LSTMCell(10), MinimalRNNCell(10)])))
Beispiel #2
0
def _make_rnn_layer(rnn_cell_fn, units, cell_type, return_sequences):
  """Assert arguments are valid and return rnn_layer_fn.

  Args:
    rnn_cell_fn: A function that returns a RNN cell instance that will be used
      to construct the RNN.
    units: Iterable of integer number of hidden units per RNN layer.
    cell_type: A class producing a RNN cell or a string specifying the cell
      type.
    return_sequences: A boolean indicating whether to return the last output
      in the output sequence, or the full sequence.:

  Returns:
    A tf.keras.layers.RNN layer.
  """
  _verify_rnn_cell_input(rnn_cell_fn, units, cell_type)
  if cell_type in _CELL_TYPE_TO_LAYER_MAPPING and isinstance(units, int):
    return _CELL_TYPE_TO_LAYER_MAPPING[cell_type](
        units=units, return_sequences=return_sequences)
  if not rnn_cell_fn:
    if cell_type == USE_DEFAULT:
      cell_type = _SIMPLE_RNN_KEY
    rnn_cell_fn = _make_rnn_cell_fn(units, cell_type)

  return keras_layers.RNN(cell=rnn_cell_fn(), return_sequences=return_sequences)
    def testReturnsProvider_KerasRNNLayerWithKerasRNNCells(self):
        lstm_cell = l.LSTMCell(3)
        gru_cell = l.GRUCell(2)
        model = keras.Sequential(
            [l.RNN([lstm_cell, gru_cell], input_shape=(3, 2))])
        layer = model.layers[0]

        quantize_provider = self.quantize_registry.get_quantize_provider(layer)

        (weights, weight_quantizers) = self._convert_list(
            quantize_provider.get_weights_and_quantizers(layer))
        (activations, activation_quantizers) = self._convert_list(
            quantize_provider.get_activations_and_quantizers(layer))

        self._assert_weight_quantizers(weight_quantizers)
        self.assertEqual([
            lstm_cell.kernel, lstm_cell.recurrent_kernel, gru_cell.kernel,
            gru_cell.recurrent_kernel
        ], weights)

        self._assert_activation_quantizers(activation_quantizers)
        self.assertEqual([
            lstm_cell.activation, lstm_cell.recurrent_activation,
            gru_cell.activation, gru_cell.recurrent_activation
        ], activations)
    def setUp(self):
        super(TFLiteQuantizeProviderRNNTest, self).setUp()

        self.cell1 = l.LSTMCell(3)
        self.cell2 = l.GRUCell(2)
        self.layer = l.RNN([self.cell1, self.cell2])
        self.layer.build(input_shape=(3, 2))

        self.quantize_provider = tflite_quantize_registry.TFLiteQuantizeProviderRNN(
            [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']],
            [['activation', 'recurrent_activation'],
             ['activation', 'recurrent_activation']], False)
 def _rnn_input(apply_wrapper):
   """Creates a RNN layer with/without wrapper and returns built rnn cell."""
   with base_layer.keras_style_scope():
     base_cell = rnn_cell_impl.MultiRNNCell(
         [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
   if apply_wrapper:
     rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell)
   else:
     rnn_cell = base_cell
   rnn_layer = keras_layers.RNN(rnn_cell)
   inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
   _ = rnn_layer(inputs)
   return base_cell._cells[0]
  def testWrapperWeights(self, wrapper):
    """Tests that wrapper weights contain wrapped cells weights."""
    base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell")
    rnn_cell = wrapper(base_cell)
    rnn_layer = layers.RNN(rnn_cell)
    inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
    rnn_layer(inputs)

    expected_weights = ["rnn/" + var for var in
                        ("kernel:0", "recurrent_kernel:0", "bias:0")]
    self.assertLen(rnn_cell.weights, 3)
    self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights)
    self.assertCountEqual([v.name for v in rnn_cell.trainable_variables],
                          expected_weights)
    self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables],
                          [])
    self.assertCountEqual([v.name for v in rnn_cell.cell.weights],
                          expected_weights)
Beispiel #7
0
 def testSupports_KerasRNNLayerWithKerasRNNCells(self):
   self.assertTrue(self.quantize_registry.supports(l.RNN(cell=l.LSTMCell(10))))
   self.assertTrue(
       self.quantize_registry.supports(
           l.RNN(cell=[l.LSTMCell(10), l.GRUCell(10)])))
Beispiel #8
0
 def _rnn_layer_fn():  # pylint: disable=function-redefined
     return keras_layers.RNN(cell=rnn_cell_fn(),
                             return_sequences=return_sequences)