def testRNNLayersWithRNNCellParams(self): model = keras.Sequential() model.add( prune.prune_low_magnitude(keras.layers.RNN([ layers.LSTMCell(10), layers.GRUCell(10), layers.PeepholeLSTMCell(10), layers.SimpleRNNCell(10) ]), input_shape=(3, 4), **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testDropoutWrapperSerialization(self): wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper cell = layers.LSTMCell(10) wrapper = wrapper_cls(cell) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls) wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertTrue(reconstructed_wrapper._dropout_state_filter(None)) def dropout_state_filter_visitor(unused_state): return False wrapper = wrapper_cls( cell, dropout_state_filter_visitor=dropout_state_filter_visitor) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
def __init__(self, state_size, action_size): super(ActorCriticModel, self).__init__() self.state_size = state_size self.action_size = action_size # self.action_size = 14 w_init = keras.initializers.normal(0.5, 0.1) w_uniform_init = keras.initializers.RandomUniform(minval=-0.99, maxval=0.99, seed=None) w_pos_init = keras.initializers.RandomUniform(minval=0, maxval=0.99, seed=None) w_he = keras.initializers.he_normal(seed=None) # Conv model # self.conv1 = layers.Conv1D(32, 3, strides=1, padding="same") # self.conv2 = layers.Conv1D(32, 3, strides=1, padding="same") # self.conv3 = layers.Conv1D(64, 2, strides=1, padding="same") # self.conv4 = layers.Conv1D(64, 1, strides=1) # MLP model self.fc1 = layers.Dense(256, kernel_initializer=w_he) self.fc2 = layers.Dense(256, kernel_initializer=w_he) self.fc3 = layers.Dense(128, kernel_initializer=w_he) self.fc4 = layers.Dense(128, kernel_initializer=w_he) self.lstm = layers.LSTMCell(128) self.state = self.lstm.get_initial_state(batch_size=1, dtype=tf.float32) # self.dense1 = layers.Dense(200, activation=tf.nn.relu6) self.actions_mean = layers.Dense(action_size, activation=tf.nn.softsign, kernel_initializer=w_uniform_init) self.actions_sigma = layers.Dense(action_size, activation=tf.nn.softplus, kernel_initializer=w_he) # self.dense2 = layers.Dense(100, activation=tf.nn.relu6) self.values = layers.Dense(1, kernel_initializer=w_he)
def testResidualWrapperSerialization(self): wrapper_cls = rnn_cell_wrapper_v2.ResidualWrapper cell = layers.LSTMCell(10) wrapper = wrapper_cls(cell) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls) wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) # Assert the reconstructed function will perform the math correctly. self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4) def residual_fn(inputs, outputs): return inputs * 3 + outputs wrapper = wrapper_cls(cell, residual_fn=residual_fn) config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) # Assert the reconstructed function will perform the math correctly. self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5)
def testReturnsProvider_KerasRNNLayerWithKerasRNNCells(self): lstm_cell = l.LSTMCell(3) gru_cell = l.GRUCell(2) model = keras.Sequential( [l.RNN([lstm_cell, gru_cell], input_shape=(3, 2))]) layer = model.layers[0] quantize_provider = self.quantize_registry.get_quantize_provider(layer) (weights, weight_quantizers) = self._convert_list( quantize_provider.get_weights_and_quantizers(layer)) (activations, activation_quantizers) = self._convert_list( quantize_provider.get_activations_and_quantizers(layer)) self._assert_weight_quantizers(weight_quantizers) self.assertEqual([ lstm_cell.kernel, lstm_cell.recurrent_kernel, gru_cell.kernel, gru_cell.recurrent_kernel ], weights) self._assert_activation_quantizers(activation_quantizers) self.assertEqual([ lstm_cell.activation, lstm_cell.recurrent_activation, gru_cell.activation, gru_cell.recurrent_activation ], activations)
def testDeviceWrapperSerialization(self): wrapper_cls = rnn_cell_wrapper_v2.DeviceWrapper cell = layers.LSTMCell(10) wrapper = wrapper_cls(cell, "/cpu:0") config = wrapper.get_config() reconstructed_wrapper = wrapper_cls.from_config(config) self.assertDictEqual(config, reconstructed_wrapper.get_config()) self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
def testDropoutWrapperWithKerasLSTMCell(self): wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper cell = layers.LSTMCell(10) with self.assertRaisesRegexp(ValueError, "does not work with "): wrapper_cls(cell) cell = layers.LSTMCellV2(10) with self.assertRaisesRegexp(ValueError, "does not work with "): wrapper_cls(cell)
def setUp(self): super(TFLiteQuantizeProviderRNNTest, self).setUp() self.cell1 = l.LSTMCell(3) self.cell2 = l.GRUCell(2) self.layer = l.RNN([self.cell1, self.cell2]) self.layer.build(input_shape=(3, 2)) self.quantize_provider = tflite_quantize_registry.TFLiteQuantizeProviderRNN( [['kernel', 'recurrent_kernel'], ['kernel', 'recurrent_kernel']], [['activation', 'recurrent_activation'], ['activation', 'recurrent_activation']], False)
def testDoesNotSupport_RNNLayerWithCustomRNNCell(self): class MinimalRNNCell(l.Layer): def __init__(self, units, **kwargs): self.units = units self.state_size = units super(MinimalRNNCell, self).__init__(**kwargs) self.assertFalse( self.quantize_registry.supports(l.RNN(cell=MinimalRNNCell(10)))) self.assertFalse( self.quantize_registry.supports( l.RNN(cell=[l.LSTMCell(10), MinimalRNNCell(10)])))
def testSupports_KerasRNNLayerWithKerasRNNCells(self): self.assertTrue(self.quantize_registry.supports(l.RNN(cell=l.LSTMCell(10)))) self.assertTrue( self.quantize_registry.supports( l.RNN(cell=[l.LSTMCell(10), l.GRUCell(10)])))
def _testDynamicDecodeRNN(self, time_major, has_attention, with_alignment_history=False): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.cached_session(): batch_size_tensor = tf.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = layers.LSTMCell(cell_depth) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) coverage_penalty_weight = 0.0 if has_attention: coverage_penalty_weight = 0.2 inputs = tf.compat.v1.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history) cell_state = cell.get_initial_state(batch_size=batch_size_tensor * beam_width, dtype=tf.float32) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight, output_time_major=time_major, maximum_iterations=max_out) final_outputs, final_state, final_sequence_lengths = bsd( embedding, start_tokens=tf.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertIsInstance( final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) self.assertIsInstance(final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = \ final_outputs.beam_search_decoder_output expected_seq_length = 3 if context.executing_eagerly() else None self.assertEqual( _t((batch_size, expected_seq_length, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, expected_seq_length, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) self.evaluate(tf.compat.v1.global_variables_initializer()) eval_results = self.evaluate({ 'final_outputs': final_outputs, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max( eval_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), eval_results['final_outputs'].beam_search_decoder_output. scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), eval_results['final_outputs'].beam_search_decoder_output. predicted_ids.shape)
def __init__(self, lstm_cell_num, keep_prob, initializer=keras.initializers.orthogonal): super(BasicLSTMCell, self).__init__() self.lstm_cell = layers.LSTMCell(lstm_cell_num, kernel_initializer=initializer, forget_bias=0.0) self.lstm_wrapper = layers.DropoutWrapper(output_keep_prob=keep_prob)