Exemple #1
0
    def testRNNLayersWithRNNCellParams(self):
        model = keras.Sequential()
        model.add(
            prune.prune_low_magnitude(keras.layers.RNN([
                layers.LSTMCell(10),
                layers.GRUCell(10),
                layers.PeepholeLSTMCell(10),
                layers.SimpleRNNCell(10)
            ]),
                                      input_shape=(3, 4),
                                      **self.params))

        model.compile(loss='categorical_crossentropy',
                      optimizer='sgd',
                      metrics=['accuracy'])
        test_utils.assert_model_sparsity(self, 0.0, model)
        model.fit(np.random.randn(
            *self._batch(model.input.get_shape().as_list(), 32)),
                  np.random.randn(
                      *self._batch(model.output.get_shape().as_list(), 32)),
                  callbacks=[pruning_callbacks.UpdatePruningStep()])

        test_utils.assert_model_sparsity(self, 0.5, model)

        self._check_strip_pruning_matches_original(model, 0.5)
  def testDropoutWrapperSerialization(self):
    wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
    cell = layers.LSTMCell(10)
    wrapper = wrapper_cls(cell)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertDictEqual(config, reconstructed_wrapper.get_config())
    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)

    wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertTrue(reconstructed_wrapper._dropout_state_filter(None))

    def dropout_state_filter_visitor(unused_state):
      return False

    wrapper = wrapper_cls(
        cell, dropout_state_filter_visitor=dropout_state_filter_visitor)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
    def __init__(self, state_size, action_size):
        super(ActorCriticModel, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        # self.action_size = 14
        w_init = keras.initializers.normal(0.5, 0.1)
        w_uniform_init = keras.initializers.RandomUniform(minval=-0.99, maxval=0.99, seed=None)
        w_pos_init = keras.initializers.RandomUniform(minval=0, maxval=0.99, seed=None)
        w_he = keras.initializers.he_normal(seed=None)

        # Conv model
        # self.conv1 = layers.Conv1D(32, 3, strides=1, padding="same")
        # self.conv2 = layers.Conv1D(32, 3, strides=1, padding="same")
        # self.conv3 = layers.Conv1D(64, 2, strides=1, padding="same")
        # self.conv4 = layers.Conv1D(64, 1, strides=1)

        # MLP model
        self.fc1 = layers.Dense(256, kernel_initializer=w_he)
        self.fc2 = layers.Dense(256, kernel_initializer=w_he)
        self.fc3 = layers.Dense(128, kernel_initializer=w_he)
        self.fc4 = layers.Dense(128, kernel_initializer=w_he)

        self.lstm = layers.LSTMCell(128)
        self.state = self.lstm.get_initial_state(batch_size=1, dtype=tf.float32)

        # self.dense1 = layers.Dense(200, activation=tf.nn.relu6)
        self.actions_mean = layers.Dense(action_size, activation=tf.nn.softsign, kernel_initializer=w_uniform_init)
        self.actions_sigma = layers.Dense(action_size, activation=tf.nn.softplus, kernel_initializer=w_he)

        # self.dense2 = layers.Dense(100, activation=tf.nn.relu6)
        self.values = layers.Dense(1, kernel_initializer=w_he)
  def testResidualWrapperSerialization(self):
    wrapper_cls = rnn_cell_wrapper_v2.ResidualWrapper
    cell = layers.LSTMCell(10)
    wrapper = wrapper_cls(cell)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertDictEqual(config, reconstructed_wrapper.get_config())
    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)

    wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    # Assert the reconstructed function will perform the math correctly.
    self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4)

    def residual_fn(inputs, outputs):
      return inputs * 3 + outputs

    wrapper = wrapper_cls(cell, residual_fn=residual_fn)
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    # Assert the reconstructed function will perform the math correctly.
    self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5)
    def testReturnsProvider_KerasRNNLayerWithKerasRNNCells(self):
        lstm_cell = l.LSTMCell(3)
        gru_cell = l.GRUCell(2)
        model = keras.Sequential(
            [l.RNN([lstm_cell, gru_cell], input_shape=(3, 2))])
        layer = model.layers[0]

        quantize_provider = self.quantize_registry.get_quantize_provider(layer)

        (weights, weight_quantizers) = self._convert_list(
            quantize_provider.get_weights_and_quantizers(layer))
        (activations, activation_quantizers) = self._convert_list(
            quantize_provider.get_activations_and_quantizers(layer))

        self._assert_weight_quantizers(weight_quantizers)
        self.assertEqual([
            lstm_cell.kernel, lstm_cell.recurrent_kernel, gru_cell.kernel,
            gru_cell.recurrent_kernel
        ], weights)

        self._assert_activation_quantizers(activation_quantizers)
        self.assertEqual([
            lstm_cell.activation, lstm_cell.recurrent_activation,
            gru_cell.activation, gru_cell.recurrent_activation
        ], activations)
  def testDeviceWrapperSerialization(self):
    wrapper_cls = rnn_cell_wrapper_v2.DeviceWrapper
    cell = layers.LSTMCell(10)
    wrapper = wrapper_cls(cell, "/cpu:0")
    config = wrapper.get_config()

    reconstructed_wrapper = wrapper_cls.from_config(config)
    self.assertDictEqual(config, reconstructed_wrapper.get_config())
    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
Exemple #7
0
  def testDropoutWrapperWithKerasLSTMCell(self):
    wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
    cell = layers.LSTMCell(10)

    with self.assertRaisesRegexp(ValueError, "does not work with "):
      wrapper_cls(cell)

    cell = layers.LSTMCellV2(10)
    with self.assertRaisesRegexp(ValueError, "does not work with "):
      wrapper_cls(cell)
    def setUp(self):
        super(TFLiteQuantizeProviderRNNTest, self).setUp()

        self.cell1 = l.LSTMCell(3)
        self.cell2 = l.GRUCell(2)
        self.layer = l.RNN([self.cell1, self.cell2])
        self.layer.build(input_shape=(3, 2))

        self.quantize_provider = tflite_quantize_registry.TFLiteQuantizeProviderRNN(
            [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']],
            [['activation', 'recurrent_activation'],
             ['activation', 'recurrent_activation']], False)
    def testDoesNotSupport_RNNLayerWithCustomRNNCell(self):
        class MinimalRNNCell(l.Layer):
            def __init__(self, units, **kwargs):
                self.units = units
                self.state_size = units
                super(MinimalRNNCell, self).__init__(**kwargs)

        self.assertFalse(
            self.quantize_registry.supports(l.RNN(cell=MinimalRNNCell(10))))
        self.assertFalse(
            self.quantize_registry.supports(
                l.RNN(cell=[l.LSTMCell(10), MinimalRNNCell(10)])))
Exemple #10
0
 def testSupports_KerasRNNLayerWithKerasRNNCells(self):
   self.assertTrue(self.quantize_registry.supports(l.RNN(cell=l.LSTMCell(10))))
   self.assertTrue(
       self.quantize_registry.supports(
           l.RNN(cell=[l.LSTMCell(10), l.GRUCell(10)])))
Exemple #11
0
    def _testDynamicDecodeRNN(self,
                              time_major,
                              has_attention,
                              with_alignment_history=False):
        encoder_sequence_length = np.array([3, 2, 3, 1, 1])
        decoder_sequence_length = np.array([2, 0, 1, 2, 3])
        batch_size = 5
        decoder_max_time = 4
        input_depth = 7
        cell_depth = 9
        attention_depth = 6
        vocab_size = 20
        end_token = vocab_size - 1
        start_token = 0
        embedding_dim = 50
        max_out = max(decoder_sequence_length)
        output_layer = layers.Dense(vocab_size, use_bias=True, activation=None)
        beam_width = 3

        with self.cached_session():
            batch_size_tensor = tf.constant(batch_size)
            embedding = np.random.randn(vocab_size,
                                        embedding_dim).astype(np.float32)
            cell = layers.LSTMCell(cell_depth)
            initial_state = cell.get_initial_state(batch_size=batch_size,
                                                   dtype=tf.float32)
            coverage_penalty_weight = 0.0
            if has_attention:
                coverage_penalty_weight = 0.2
                inputs = tf.compat.v1.placeholder_with_default(
                    np.random.randn(batch_size, decoder_max_time,
                                    input_depth).astype(np.float32),
                    shape=(None, None, input_depth))
                tiled_inputs = beam_search_decoder.tile_batch(
                    inputs, multiplier=beam_width)
                tiled_sequence_length = beam_search_decoder.tile_batch(
                    encoder_sequence_length, multiplier=beam_width)
                attention_mechanism = attention_wrapper.BahdanauAttention(
                    units=attention_depth,
                    memory=tiled_inputs,
                    memory_sequence_length=tiled_sequence_length)
                initial_state = beam_search_decoder.tile_batch(
                    initial_state, multiplier=beam_width)
                cell = attention_wrapper.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=attention_depth,
                    alignment_history=with_alignment_history)
            cell_state = cell.get_initial_state(batch_size=batch_size_tensor *
                                                beam_width,
                                                dtype=tf.float32)
            if has_attention:
                cell_state = cell_state.clone(cell_state=initial_state)
            bsd = beam_search_decoder.BeamSearchDecoder(
                cell=cell,
                beam_width=beam_width,
                output_layer=output_layer,
                length_penalty_weight=0.0,
                coverage_penalty_weight=coverage_penalty_weight,
                output_time_major=time_major,
                maximum_iterations=max_out)

            final_outputs, final_state, final_sequence_lengths = bsd(
                embedding,
                start_tokens=tf.fill([batch_size_tensor], start_token),
                end_token=end_token,
                initial_state=cell_state)

            def _t(shape):
                if time_major:
                    return (shape[1], shape[0]) + shape[2:]
                return shape

            self.assertIsInstance(
                final_outputs,
                beam_search_decoder.FinalBeamSearchDecoderOutput)
            self.assertIsInstance(final_state,
                                  beam_search_decoder.BeamSearchDecoderState)

            beam_search_decoder_output = \
                final_outputs.beam_search_decoder_output
            expected_seq_length = 3 if context.executing_eagerly() else None
            self.assertEqual(
                _t((batch_size, expected_seq_length, beam_width)),
                tuple(beam_search_decoder_output.scores.get_shape().as_list()))
            self.assertEqual(
                _t((batch_size, expected_seq_length, beam_width)),
                tuple(final_outputs.predicted_ids.get_shape().as_list()))

            self.evaluate(tf.compat.v1.global_variables_initializer())
            eval_results = self.evaluate({
                'final_outputs':
                final_outputs,
                'final_sequence_lengths':
                final_sequence_lengths
            })

            max_sequence_length = np.max(
                eval_results['final_sequence_lengths'])

            # A smoke test
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                eval_results['final_outputs'].beam_search_decoder_output.
                scores.shape)
            self.assertEqual(
                _t((batch_size, max_sequence_length, beam_width)),
                eval_results['final_outputs'].beam_search_decoder_output.
                predicted_ids.shape)
Exemple #12
0
 def __init__(self, lstm_cell_num, keep_prob, initializer=keras.initializers.orthogonal):
     super(BasicLSTMCell, self).__init__()
     self.lstm_cell = layers.LSTMCell(lstm_cell_num, kernel_initializer=initializer, forget_bias=0.0)
     self.lstm_wrapper = layers.DropoutWrapper(output_keep_prob=keep_prob)