def testPruneSequentialModelPreservesBuiltState(self): # No InputLayer model = keras.Sequential([ layers.Dense(10), layers.Dense(10), ]) self.assertEqual(model.built, False) pruned_model = prune.prune_low_magnitude(model, **self.params) self.assertEqual(model.built, False) # Test built state preserves across serialization with prune.prune_scope(): loaded_model = keras.models.model_from_config( json.loads(pruned_model.to_json())) self.assertEqual(loaded_model.built, False) # With InputLayer model = keras.Sequential([ layers.Dense(10, input_shape=(10, )), layers.Dense(10), ]) self.assertEqual(model.built, True) pruned_model = prune.prune_low_magnitude(model, **self.params) self.assertEqual(model.built, True) # Test built state preserves across serialization with prune.prune_scope(): loaded_model = keras.models.model_from_config( json.loads(pruned_model.to_json())) self.assertEqual(loaded_model.built, True)
def testFunctionalModelForLatencyOnDoubleXNNPackPolicy(self): i = keras.Input(shape=(8, 8, 3)) x = layers.ZeroPadding2D(padding=1)(i) x = layers.Conv2D( filters=16, kernel_size=(3, 3), strides=(2, 2), padding='valid', )(x) x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x) o = layers.GlobalAveragePooling2D()(x) model = keras.Model(inputs=[i], outputs=[o]) pruned_model = prune.prune_low_magnitude( model, pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(), **self.params, ) self.assertEqual(self._count_pruned_layers(pruned_model), 1) double_pruned_model = prune.prune_low_magnitude( pruned_model, pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(), **self.params, ) self.assertEqual(self._count_pruned_layers(double_pruned_model), 1)
def testPruneModelCustomNonPrunableLayerRaisesError(self): with self.assertRaises(ValueError): prune.prune_low_magnitude( keras.Sequential([ self.keras_prunable_layer, self.keras_non_prunable_layer, self.custom_prunable_layer, self.custom_non_prunable_layer ]), **self.params)
def testPrunesMnist_ReachesTargetSparsity(self, model_type): model = test_utils.build_mnist_model(model_type, self.params) if model_type == 'layer_list': model = keras.Sequential( prune.prune_low_magnitude(model, **self.params)) elif model_type in ['sequential', 'functional']: model = prune.prune_low_magnitude(model, **self.params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model, rtol=1e-4, atol=1e-4) model.fit(np.random.rand(32, 28, 28, 1), keras.utils.to_categorical( np.random.randint(10, size=(32, 1)), 10), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model, rtol=1e-4, atol=1e-4) self._check_strip_pruning_matches_original(model, 0.5)
def testPruneCheckpoints_CheckpointsNotSparse(self): is_model_sparsity_not_list = [] # Run multiple times since problem doesn't always happen. for _ in range(3): model = keras_test_utils.build_simple_dense_model() pruned_model = prune.prune_low_magnitude(model, **self.params) checkpoint_dir = tempfile.mkdtemp() checkpoint_path = checkpoint_dir + '/weights.{epoch:02d}.tf' callbacks = [ pruning_callbacks.UpdatePruningStep(), tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, save_weights_only=True, save_freq=1) ] # Train one step. Sparsity reaches final sparsity. self._train_model(pruned_model, epochs=1, callbacks=callbacks) test_utils.assert_model_sparsity(self, 0.5, pruned_model) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) same_architecture_model = keras_test_utils.build_simple_dense_model() pruned_model = prune.prune_low_magnitude(same_architecture_model, **self.params) # Sanity check. test_utils.assert_model_sparsity(self, 0, pruned_model) pruned_model.load_weights(latest_checkpoint) is_model_sparsity_not_list.append( test_utils.is_model_sparsity_not(0.5, pruned_model)) self.assertTrue(any(is_model_sparsity_not_list))
def testPruneModelDoesNotWrapAlreadyWrappedLayer(self): model = keras.Sequential( [layers.Dense(10), prune.prune_low_magnitude(layers.Dense(10), **self.params)]) pruned_model = prune.prune_low_magnitude(model, **self.params) pruned_model.build(input_shape=(10, 1)) self.assertEqual(len(model.layers), len(pruned_model.layers)) self._validate_pruned_layer(model.layers[0], pruned_model.layers[0]) # Second layer is used as-is since it's already a pruned layer. self.assertEqual(model.layers[1], pruned_model.layers[1])
def testPruneScope_NotNeededForTFCheckpoint(self): model = keras_test_utils.build_simple_dense_model() pruned_model = prune.prune_low_magnitude(model) _, tf_weights = tempfile.mkstemp('.tf') pruned_model.save_weights(tf_weights) same_architecture_model = keras_test_utils.build_simple_dense_model() same_architecture_model = prune.prune_low_magnitude(same_architecture_model) # would error if `prune_scope` was needed. same_architecture_model.load_weights(tf_weights)
def main(unused_argv): # input image dimensions img_rows, img_cols = 28, 28 # the data, shuffled and split between train and test sets (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() if tf.keras.backend.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = tf.keras.utils.to_categorical(y_train, num_classes) y_test = tf.keras.utils.to_categorical(y_test, num_classes) pruning_params = { 'pruning_schedule': PolynomialDecay( initial_sparsity=0.1, final_sparsity=0.75, begin_step=1000, end_step=5000, frequency=100) } layerwise_model = build_layerwise_model(input_shape, **pruning_params) sequential_model = build_sequential_model(input_shape) sequential_model = prune.prune_low_magnitude( sequential_model, **pruning_params) functional_model = build_functional_model(input_shape) functional_model = prune.prune_low_magnitude( functional_model, **pruning_params) models = [layerwise_model, sequential_model, functional_model] train_and_save(models, x_train, y_train, x_test, y_test)
def testPruneSequentialModel(self): # No InputLayer model = keras.Sequential([ layers.Dense(10), layers.Dense(10), ]) pruned_model = prune.prune_low_magnitude(model, **self.params) self.assertEqual(self._count_pruned_layers(pruned_model), 2) # With InputLayer model = keras.Sequential([ layers.Dense(10, input_shape=(10, )), layers.Dense(10), ]) pruned_model = prune.prune_low_magnitude(model, **self.params) self.assertEqual(self._count_pruned_layers(pruned_model), 2)
def testPruneSubclassModel(self): model = TestSubclassedModel() with self.assertRaises(ValueError) as e: _ = prune.prune_low_magnitude(model, **self.params) self.assertEqual( str(e.exception), self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='TestSubclassedModel'))
def testPrunesPreviouslyUnprunedModel(self): model = keras_test_utils.build_simple_dense_model() model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Simple unpruned model. No sparsity. model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), epochs=2, batch_size=20) test_utils.assert_model_sparsity(self, 0.0, model) # Apply pruning to model. model = prune.prune_low_magnitude(model, **self.params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Since newly compiled, iterations starts from 0. model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn): # TODO(tfmot): renable once SavedModel preserves step again. # This existed in TF 2.0 and 2.1 and should be reenabled in # TF 2.3. b/151755698 if save_restore_fn.__name__ == '_save_restore_tf_model': return begin_step, end_step = 1, 4 params = self.params params['pruning_schedule'] = pruning_schedule.PolynomialDecay( 0.2, 0.6, begin_step, end_step, 3, 1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Train only 1 step. Sparsity 0.2 (initial_sparsity) self._train_model(model, epochs=1) test_utils.assert_model_sparsity(self, 0.2, model) model = save_restore_fn(model) # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) self._train_model(model, epochs=3) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testPrunesSimpleDenseModel(self, distribution): with distribution.scope(): model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **self.params) model.compile( loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Simple unpruned model. No sparsity. model.fit( np.random.rand(20, 10), keras.utils.np_utils.to_categorical( np.random.randint(5, size=(20, 1)), 5), epochs=2, callbacks=[pruning_callbacks.UpdatePruningStep()], batch_size=20) model.predict(np.random.rand(20, 10)) test_utils.assert_model_sparsity(self, 0.5, model) _, keras_file = tempfile.mkstemp('.h5') keras.models.save_model(model, keras_file) with prune.prune_scope(): loaded_model = keras.models.load_model(keras_file) test_utils.assert_model_sparsity(self, 0.5, loaded_model)
def testPruneWithPolynomialDecayPastEndStep_PreservesSparsity( self, save_restore_fn): begin_step, end_step = 0, 2 params = self.params params['pruning_schedule'] = pruning_schedule.PolynomialDecay( 0.2, 0.6, begin_step, end_step, 3, 1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Train 3 steps, past end_step. Sparsity 0.6 (final_sparsity) self._train_model(model, epochs=3) test_utils.assert_model_sparsity(self, 0.6, model) model = save_restore_fn(model) # Ensure sparsity is preserved. test_utils.assert_model_sparsity(self, 0.6, model) # Train one more step to ensure nothing happens that brings sparsity # back below 0.6. self._train_model(model, epochs=1) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testPrunesModel_CustomTrainingLoop_ReachesTargetSparsity(self): pruned_model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model()) batch_size = 20 x_train = np.random.rand(20, 10) y_train = keras.utils.to_categorical( np.random.randint(5, size=(batch_size, 1)), 5) loss = keras.losses.categorical_crossentropy optimizer = keras.optimizers.SGD() unused_arg = -1 step_callback = pruning_callbacks.UpdatePruningStep() step_callback.set_model(pruned_model) pruned_model.optimizer = optimizer step_callback.on_train_begin() # 2 epochs for _ in range(2): step_callback.on_train_batch_begin(batch=unused_arg) inp = np.reshape(x_train, [batch_size, 10]) # original shape: from [10]. with tf.GradientTape() as tape: logits = pruned_model(inp, training=True) loss_value = loss(y_train, logits) grads = tape.gradient(loss_value, pruned_model.trainable_variables) optimizer.apply_gradients( zip(grads, pruned_model.trainable_variables)) step_callback.on_epoch_end(batch=unused_arg) test_utils.assert_model_sparsity(self, 0.5, pruned_model)
def testRNNLayersWithRNNCellParams(self): model = keras.Sequential() model.add( prune.prune_low_magnitude(keras.layers.RNN([ layers.LSTMCell(10), layers.GRUCell(10), layers.PeepholeLSTMCell(10), layers.SimpleRNNCell(10) ]), input_shape=(3, 4), **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity(self, 0.5, model) self._check_strip_pruning_matches_original(model, 0.5)
def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn): begin_step, end_step = 0, 4 params = self.params params['pruning_schedule'] = pruning_schedule.PolynomialDecay( 0.2, 0.6, begin_step, end_step, 3, 1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) # Train only 1 step. Sparsity 0.2 (initial_sparsity) self._train_model(model, epochs=1) test_utils.assert_model_sparsity(self, 0.2, model) model = save_restore_fn(model) # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) self._train_model(model, epochs=3) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testPruneStopAndRestartOnModel(self, save_restore_fn): params = { 'pruning_schedule': pruning_schedule.PolynomialDecay(0.2, 0.6, 0, 4, 3, 1), 'block_size': (1, 1), 'block_pooling_type': 'AVG' } model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # Model hasn't been trained yet. Sparsity 0.0 test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, callbacks=[pruning_callbacks.UpdatePruningStep()]) # Training has run only 1 step. Sparsity 0.2 (initial_sparsity) test_utils.assert_model_sparsity(self, 0.2, model) model = save_restore_fn(model) model.fit(np.random.rand(20, 10), np_utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, epochs=3, callbacks=[pruning_callbacks.UpdatePruningStep()]) # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) test_utils.assert_model_sparsity(self, 0.6, model) self._check_strip_pruning_matches_original(model, 0.6)
def testMbyNSparsityPruning_SupportedLayers(self, layer_type, layer_arg, input_shape, m_by_n=(2, 4), sparsity_ratio=0.50): """Check that we prune supported layers with m by n sparsity.""" self.params.update({'sparsity_m_by_n': m_by_n}) model = keras.Sequential() model.add( prune.prune_low_magnitude(layer_type(*layer_arg), input_shape=input_shape, **self.params)) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) test_utils.assert_model_sparsity(self, 0.0, model) model.fit(np.random.randn( *self._batch(model.input.get_shape().as_list(), 32)), np.random.randn( *self._batch(model.output.get_shape().as_list(), 32)), callbacks=[pruning_callbacks.UpdatePruningStep()]) test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n) self._check_strip_pruning_matches_original(model, sparsity_ratio)
def float_cnn_allPrune(name_, Inputs, nclasses, filters, kernel, strides, pooling, dropout, activation, pruning_params={}): print("Building model: float_cnn") length = len(filters) if any( len(lst) != length for lst in [filters, kernel, strides, pooling, dropout]): sys.exit( "One value for stride and kernel must be added for each filter! Exiting" ) x = x_in = Inputs x = BatchNormalization()(x) x = ZeroPadding2D(padding=(1, 1), data_format="channels_last")(x) for i, (f, k, s, p, d) in enumerate(zip(filters, kernel, strides, pooling, dropout)): print(( "Adding layer with {} filters, kernel_size=({},{}), strides=({},{})" ).format(f, k, k, s, s)) x = prune.prune_low_magnitude((Conv2D(int(f), kernel_size=(int(k), int(k)), strides=(int(s), int(s)), use_bias=False, name='conv_%i' % i)), **pruning_params)(x) if float(p) != 0: x = MaxPooling2D(pool_size=(int(p), int(p)))(x) x = BatchNormalization()(x) x = Activation(activation, name='conv_act_%i' % i)(x) x = Flatten()(x) x = prune.prune_low_magnitude((Dense(128, kernel_initializer='lecun_uniform', use_bias=False, name='dense_1')), **pruning_params)(x) x = BatchNormalization()(x) x = Activation(activation, name='dense_act')(x) x_out = Dense(nclasses, activation='softmax', name='output')(x) model = Model(inputs=[x_in], outputs=[x_out], name=name_) return model
def testPruneMiscObject(self): model = object() with self.assertRaises(ValueError) as e: _ = prune.prune_low_magnitude(model, **self.params) self.assertEqual( str(e.exception), self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='object'))
def testPruneFunctionalModel(self): i1 = keras.Input(shape=(10, )) i2 = keras.Input(shape=(10, )) x1 = layers.Dense(10)(i1) x2 = layers.Dense(10)(i2) outputs = layers.Add()([x1, x2]) model = keras.Model(inputs=[i1, i2], outputs=outputs) pruned_model = prune.prune_low_magnitude(model, **self.params) self.assertEqual(self._count_pruned_layers(pruned_model), 3)
def testPruneModelRecursively(self): internal_model = keras.Sequential( [keras.layers.Dense(10, input_shape=(10, ))]) original_model = keras.Sequential([ internal_model, layers.Dense(10), ]) pruned_model = prune.prune_low_magnitude(original_model, **self.params) self.assertEqual(self._count_pruned_layers(pruned_model), 2)
def build_layerwise_model(input_shape, **pruning_params): return tf.keras.Sequential([ prune.prune_low_magnitude( l.Conv2D(32, 5, padding='same', activation='relu'), input_shape=input_shape, **pruning_params), l.MaxPooling2D((2, 2), (2, 2), padding='same'), l.BatchNormalization(), prune.prune_low_magnitude( l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params), l.MaxPooling2D((2, 2), (2, 2), padding='same'), l.Flatten(), prune.prune_low_magnitude( l.Dense(1024, activation='relu'), **pruning_params), l.Dropout(0.4), prune.prune_low_magnitude( l.Dense(num_classes, activation='softmax'), **pruning_params) ])
def testPruneFunctionalModelWithLayerReused(self): # The model reuses the Dense() layer. Make sure it's only pruned once. inp = keras.Input(shape=(10, )) dense_layer = layers.Dense(10) x = dense_layer(inp) x = dense_layer(x) model = keras.Model(inputs=[inp], outputs=[x]) pruned_model = prune.prune_low_magnitude(model, **self.params) self.assertEqual(self._count_pruned_layers(pruned_model), 1)
def testPrunePretrainedModel_SameInferenceWithoutTraining(self): model = self._get_pretrained_model() pruned_model = prune.prune_low_magnitude(model, **self.params) input_data = np.random.rand(10, 10) out = model.predict(input_data) pruned_out = pruned_model.predict(input_data) self.assertTrue((out == pruned_out).all())
def testStripPruningSequentialModel(self): model = keras.Sequential([ layers.Dense(10), layers.Dense(10), ]) pruned_model = prune.prune_low_magnitude(model, **self.params) stripped_model = prune.strip_pruning(pruned_model) self.assertEqual(self._count_pruned_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config())
def testPruneValidLayersListSuccessful(self): model_layers = [ self.keras_prunable_layer, self.keras_non_prunable_layer, self.custom_prunable_layer ] pruned_layers = prune.prune_low_magnitude(model_layers, **self.params) self.assertEqual(len(model_layers), len(pruned_layers)) for layer, pruned_layer in zip(model_layers, pruned_layers): self._validate_pruned_layer(layer, pruned_layer)
def testPruneWithHighSparsity_Fails(self): params = self.params params['pruning_schedule'] = pruning_schedule.ConstantSparsity( target_sparsity=0.99, begin_step=0, frequency=1) model = prune.prune_low_magnitude( keras_test_utils.build_simple_dense_model(), **params) with self.assertRaises(tf.errors.InvalidArgumentError): self._train_model(model, epochs=1)
def testPruneModelValidLayersSuccessful(self): model = keras.Sequential([ self.keras_prunable_layer, self.keras_non_prunable_layer, self.custom_prunable_layer ]) pruned_model = prune.prune_low_magnitude(model, **self.params) pruned_model.build(input_shape=(1, 28, 28, 1)) self.assertEqual(len(model.layers), len(pruned_model.layers)) for layer, pruned_layer in zip(model.layers, pruned_model.layers): self._validate_pruned_layer(layer, pruned_layer)