def testPrunesPreviouslyUnprunedModel(self): model = keras_test_utils.build_simple_dense_model() model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Simple unpruned model. No sparsity. model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), epochs=2, batch_size=20) test_utils.assert_model_sparsity(self, 0.0, model) # Apply pruning to model. model = prune.prune_low_magnitude(model, **self.params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Since newly compiled, iterations starts from 0. model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testMbyNSparsityPruning_SupportedLayers(self, layer_type, layer_arg, input_shape, m_by_n=(2, 4), sparsity_ratio=0.50): """Check that we prune supported layers with m by n sparsity.""" self.params.update({'sparsity_m_by_n': m_by_n}) model = keras.Sequential() model.add( prune.prune_low_magnitude(layer_type(*layer_arg), input_shape=input_shape, **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n) self._check_strip_pruning_matches_original(model, sparsity_ratio)
def testPrunesMnist_ReachesTargetSparsity(self, model_type): model = test_utils.build_mnist_model(model_type, self.params) if model_type == 'layer_list': model = keras.Sequential( prune.prune_low_magnitude(model, **self.params)) elif model_type in ['sequential', 'functional']: model = prune.prune_low_magnitude(model, **self.params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model, rtol=1e-4, atol=1e-4) model.fit(np.random.rand(32, 28, 28, 1), keras.utils.to_categorical( np.random.randint(10, size=(32, 1)), 10), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model, rtol=1e-4, atol=1e-4) self._check_strip_pruning_matches_original(model, 0.5)
def testPruneCheckpoints_CheckpointsNotSparse(self): is_model_sparsity_not_list = [] # Run multiple times since problem doesn't always happen. for _ in range(3): model = keras_test_utils.build_simple_dense_model() pruned_model = prune.prune_low_magnitude(model, **self.params) checkpoint_dir = tempfile.mkdtemp() checkpoint_path = checkpoint_dir + '/weights.{epoch:02d}.tf' callbacks = [ pruning_callbacks.UpdatePruningStep(), tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, save_weights_only=True, save_freq=1) ] # Train one step. Sparsity reaches final sparsity. self._train_model(pruned_model, epochs=1, callbacks=callbacks) test_utils.assert_model_sparsity(self, 0.5, pruned_model) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) same_architecture_model = keras_test_utils.build_simple_dense_model() pruned_model = prune.prune_low_magnitude(same_architecture_model, **self.params) # Sanity check. test_utils.assert_model_sparsity(self, 0, pruned_model) pruned_model.load_weights(latest_checkpoint) is_model_sparsity_not_list.append( test_utils.is_model_sparsity_not(0.5, pruned_model)) self.assertTrue(any(is_model_sparsity_not_list))
def testRNNLayersWithRNNCellParams(self): model = keras.Sequential() model.add( prune.prune_low_magnitude(keras.layers.RNN([ layers.LSTMCell(10), layers.GRUCell(10), layers.PeepholeLSTMCell(10), layers.SimpleRNNCell(10) ]), input_shape=(3, 4), **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def test_prune_model_recursively(self): """Checks that models are recursively pruned.""" # Setup a model with one layer being a keras.Model. internal_model = keras.Sequential([ keras.layers.Dense(10, input_shape=(10, )), ]) model = keras.Sequential([ internal_model, keras.layers.Dense(20), ]) pruned_model = sparsity_tooling.prune_for_benchmark(model, target_sparsity=.8, block_size=(1, 1)) test_utils.assert_model_sparsity(self, 0.8, pruned_model) # Check the block size of the prunned layers prunned_dense_layers = [ layer for layer in pruned_model.submodules if isinstance(layer, pruning_wrapper.PruneLowMagnitude) ] self.assertEqual(2, len(prunned_dense_layers)) for layer in prunned_dense_layers: self.assertEqual((1, 1), layer.block_size)
def testPrunesModel_CustomTrainingLoop_ReachesTargetSparsity(self): pruned_model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model()) batch_size = 20 x_train = np.random.rand(20, 10) y_train = keras.utils.to_categorical( np.random.randint(5, size=(batch_size, 1)), 5) loss = keras.losses.categorical_crossentropy optimizer = keras.optimizers.SGD() unused_arg = -1 step_callback = pruning_callbacks.UpdatePruningStep() step_callback.set_model(pruned_model) pruned_model.optimizer = optimizer step_callback.on_train_begin() # 2 epochs for _ in range(2): step_callback.on_train_batch_begin(batch=unused_arg) inp = np.reshape(x_train, [batch_size, 10]) # original shape: from [10]. with tf.GradientTape() as tape: logits = pruned_model(inp, training=True) loss_value = loss(y_train, logits) grads = tape.gradient(loss_value, pruned_model.trainable_variables) optimizer.apply_gradients( zip(grads, pruned_model.trainable_variables)) step_callback.on_epoch_end(batch=unused_arg) test_utils.assert_model_sparsity(self, 0.5, pruned_model)
def _check_strip_pruning_matches_original(self, model, sparsity): stripped_model = prune.strip_pruning(model) test_utils.assert_model_sparsity(self, sparsity, stripped_model) input_data = np.random.randn( *self._batch(model.input.get_shape().as_list(), 1)) model_result = model.predict(input_data) stripped_model_result = stripped_model.predict(input_data) np.testing.assert_almost_equal(model_result, stripped_model_result)
def test_prune_model(self): model = keras.Sequential([ keras.layers.Dense(10, input_shape=(10, )), keras.layers.Dense(2), ]) pruned_model = sparsity_tooling.prune_for_benchmark(model, target_sparsity=.8, block_size=(1, 1)) for layer in pruned_model.layers: self.assertEqual((1, 1), layer.block_size) test_utils.assert_model_sparsity(self, 0.8, pruned_model)
def testPruneWithPolynomialDecayPastEndStep_PreservesSparsity( self, save_restore_fn): begin_step, end_step = 0, 2 params = self.params params['pruning_schedule'] = pruning_schedule.PolynomialDecay( 0.2, 0.6, begin_step, end_step, 3, 1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Train 3 steps, past end_step. Sparsity 0.6 (final_sparsity) self._train_model(model, epochs=3) test_utils.assert_model_sparsity(self, 0.6, model) model = save_restore_fn(model) # Ensure sparsity is preserved. test_utils.assert_model_sparsity(self, 0.6, model) # Train one more step to ensure nothing happens that brings sparsity # back below 0.6. self._train_model(model, epochs=1) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testRNNLayersSingleCell_ReachesTargetSparsity(self, layer_type): model = keras.Sequential() model.add( prune.prune_low_magnitude( layer_type(10), input_shape=(3, 4), **self.params)) model.compile( loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit( np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)), np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type): model = keras.Sequential() args, input_shape = self._get_params_for_layer(layer_type) if args is None: return # Test for layer not supported yet. model.add(prune.prune_low_magnitude( layer_type(*args), input_shape=input_shape, **self.params)) model.compile( loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit( np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)), np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testSparsityPruningMbyN_SupportedSubclassLayers(self): """Check subclass layer that is supported for m by n sparsity.""" m_by_n = (2, 4) self.params.update({'sparsity_m_by_n': m_by_n}) class SubclassLayer(tf.keras.layers.Layer): def __init__(self): super(SubclassLayer, self).__init__() self.conv1 = tf.keras.layers.Conv2D(2, 3, activation='relu', padding='same', input_shape=[7, 7, 3]) self.conv2 = tf.keras.layers.DepthwiseConv2D(3) self.flatten = keras.layers.Flatten() self.dense = layers.Dense(10, activation='sigmoid') def call(self, inputs): x = self.conv1(inputs) x = self.conv2(x) x = self.flatten(x) x = self.dense(x) return x inputs = keras.Input(shape=(7, 7, 3)) outputs = SubclassLayer()(inputs) model = keras.Model(inputs, outputs) with self.assertRaises(ValueError): model = prune.prune_low_magnitude(model, **self.params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n) self._check_strip_pruning_matches_original(model, 0.5)
def testPruneRecursivelyReachesTargetSparsity(self): internal_model = keras.Sequential( [keras.layers.Dense(10, input_shape=(10, ))]) model = keras.Sequential([ internal_model, layers.Flatten(), layers.Dense(1), ]) model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randint(10, size=(32, 10)), np.random.randint(2, size=(32, 1)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) input_data = np.random.randint(10, size=(32, 10)) self._check_strip_pruning_matches_original(model, 0.5, input_data)
def testPrunesEmbedding(self): model = keras.Sequential() model.add( prune.prune_low_magnitude(keras.layers.Embedding(input_dim=10, output_dim=3), input_shape=(5, ), **self.params)) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(1, activation='sigmoid')) model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randint(10, size=(32, 5)), np.random.randint(2, size=(32, 1)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn): begin_step, end_step = 0, 4 params = self.params params['pruning_schedule'] = pruning_schedule.PolynomialDecay( 0.2, 0.6, begin_step, end_step, 3, 1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Train only 1 step. Sparsity 0.2 (initial_sparsity) self._train_model(model, epochs=1) test_utils.assert_model_sparsity(self, 0.2, model) model = save_restore_fn(model) # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) self._train_model(model, epochs=3) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testPruneStopAndRestartOnModel(self, save_restore_fn): params = { 'pruning_schedule': pruning_schedule.PolynomialDecay(0.2, 0.6, 0, 4, 3, 1), 'block_size': (1, 1), 'block_pooling_type': 'AVG' } model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, callbacks=[pruning_callbacks.UpdatePruningStep()]) # Training has run only 1 step. Sparsity 0.2 (initial_sparsity) test_utils.assert_model_sparsity(self, 0.2, model) model = save_restore_fn(model) model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, epochs=3, callbacks=[pruning_callbacks.UpdatePruningStep()]) # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testPrunesSimpleDenseModel(self, distribution): with distribution.scope(): model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **self.params) model.compile( loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Simple unpruned model. No sparsity. model.fit( np.random.rand(20, 10), keras.utils.np_utils.to_categorical( np.random.randint(5, size=(20, 1)), 5), epochs=2, callbacks=[pruning_callbacks.UpdatePruningStep()], batch_size=20) model.predict(np.random.rand(20, 10)) test_utils.assert_model_sparsity(self, 0.5, model) _, keras_file = tempfile.mkstemp('.h5') keras.models.save_model(model, keras_file) with prune.prune_scope(): loaded_model = keras.models.load_model(keras_file) test_utils.assert_model_sparsity(self, 0.5, loaded_model)
def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn): # TODO(tfmot): renable once SavedModel preserves step again. # This existed in TF 2.0 and 2.1 and should be reenabled in # TF 2.3. b/151755698 if save_restore_fn.__name__ == '_save_restore_tf_model': return begin_step, end_step = 1, 4 params = self.params params['pruning_schedule'] = pruning_schedule.PolynomialDecay( 0.2, 0.6, begin_step, end_step, 3, 1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Train only 1 step. Sparsity 0.2 (initial_sparsity) self._train_model(model, epochs=1) test_utils.assert_model_sparsity(self, 0.2, model) model = save_restore_fn(model) # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) self._train_model(model, epochs=3) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)