def train(model, x_train, y_train, x_test, y_test): model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy']) # Print the model summary. model.summary() # Add a pruning step callback to peg the pruning step to the optimizer's # step. Also add a callback to add pruning summaries to tensorboard callbacks = [ pruning_callbacks.UpdatePruningStep(), pruning_callbacks.PruningSummaries(log_dir='/tmp/logs') ] model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=callbacks, validation_data=(x_test, y_test)) score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) model = prune.strip_pruning(model) return model
def testStripPruningSequentialModel(self): model = keras.Sequential([ layers.Dense(10), layers.Dense(10), ]) pruned_model = prune.prune_low_magnitude(model, **self.params) stripped_model = prune.strip_pruning(pruned_model) self.assertEqual(self._count_pruned_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config())
def _check_strip_pruning_matches_original(self, model, sparsity): stripped_model = prune.strip_pruning(model) test_utils.assert_model_sparsity(self, sparsity, stripped_model) input_data = np.random.randn( *self._batch(model.input.get_shape().as_list(), 1)) model_result = model.predict(input_data) stripped_model_result = stripped_model.predict(input_data) np.testing.assert_almost_equal(model_result, stripped_model_result)
def save(self, id): with self.graph.as_default(), self.session.as_default(): keras_file_path = './models/{}.h5'.format(id) if self.prune_params: self.model = prune.strip_pruning(self.model) save_keras_model(self.model, keras_file_path) tflite_path = './models/{}.tflite'.format(id) convert_keras_file_to_tflite(keras_file_path, tflite_path, quantized=self.quantized)
def pruning(epochs_prune, keras_file, training_data, batch_size, training_labels, validation_data, validation_labels, testing_data, testing_labels, Path): logdir = (Path + '/ThesisShrink/model/') loaded_model = tf.keras.models.load_model(keras_file) num_train_samples = training_data.shape[0] end_step = np.ceil(1.0 * num_train_samples / batch_size).astype( np.int32) * epochs_prune print('End Step:') print(end_step) new_pruning_params = { 'pruning_schedule': PolynomialDecay(initial_sparsity=0.50, final_sparsity=0.80, begin_step=0, end_step=end_step, frequency=50) } new_pruned_model = prune_low_magnitude(loaded_model, **new_pruning_params) new_pruned_model.summary() new_pruned_model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy']) callbacks = [ UpdatePruningStep(), PruningSummaries(log_dir=logdir, profile_batch=batch_size), WandbCallback() ] new_pruned_model.fit(training_data, training_labels, batch_size=batch_size, epochs=epochs_prune, verbose=1, callbacks=callbacks, validation_data=(validation_data, validation_labels)) score = new_pruned_model.evaluate(testing_data, testing_labels, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) final_model = strip_pruning(new_pruned_model) final_model.summary() return final_model
def _prune_model(original_model): """Apply the pruning wrapper, compile and train the model.""" prune_epoch = 1 pruning_params = { 'pruning_schedule': pruning_schedule.ConstantSparsity(0.50, begin_step=0, frequency=10) } pruning_model = prune.prune_low_magnitude(original_model, **pruning_params) callbacks = [pruning_callbacks.UpdatePruningStep()] pruning_model = _train_model(pruning_model, callbacks, prune_epoch) pruning_model_stripped = prune.strip_pruning(pruning_model) return pruning_model, pruning_model_stripped
def testStripPruningFunctionalModel(self): i1 = keras.Input(shape=(10, )) i2 = keras.Input(shape=(10, )) x1 = layers.Dense(10)(i1) x2 = layers.Dense(10)(i2) outputs = layers.Add()([x1, x2]) model = keras.Model(inputs=[i1, i2], outputs=outputs) pruned_model = prune.prune_low_magnitude(model, **self.params) stripped_model = prune.strip_pruning(pruned_model) self.assertEqual(self._count_pruned_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config())
def run(input_model_path, output_dir, target_sparsity, block_size): """Prunes the model and converts both pruned and unpruned versions to TFLite.""" print( textwrap.dedent("""\ Warning: The sparse models produced by this tool have poor accuracy. They are not intended to be served in production, but to be used for performance benchmarking.""")) input_model = tf.keras.models.load_model(input_model_path) os.makedirs(output_dir, exist_ok=True) unpruned_tflite_path = os.path.join(output_dir, 'unpruned_model.tflite') pruned_tflite_path = os.path.join( output_dir, pruned_model_filename(target_sparsity, block_size)) # Convert to TFLite without pruning convert_to_tflite(input_model, unpruned_tflite_path) # Prune and convert to TFLite pruned_model = sparsity_tooling.prune_for_benchmark( keras_model=input_model, target_sparsity=target_sparsity, block_size=block_size) stripped_model = prune.strip_pruning( pruned_model) # Remove pruning wrapper convert_to_tflite(stripped_model, pruned_tflite_path) # Measure the compressed size of unpruned vs pruned TFLite models unpruned_compressed_size = get_gzipped_size(unpruned_tflite_path) pruned_compressed_size = get_gzipped_size(pruned_tflite_path) print('Size of gzipped TFLite models:') print(' * Unpruned : %.2fMiB' % (unpruned_compressed_size / (2.**20))) print(' * Pruned : %.2fMiB' % (pruned_compressed_size / (2.**20))) print(' diff : %d%%' % (100. * (pruned_compressed_size - unpruned_compressed_size) / unpruned_compressed_size))
def prune_preserve_quantize_model(pruned_model, train_images, train_labels): batch_size = 256 epochs = 5 pruned_model = prune.strip_pruning(pruned_model) # Prune preserve QAT model quant_aware_annotate_model = quantize.quantize_annotate_model(pruned_model) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_prune_preserve_quantize_scheme .Default8BitPrunePreserveQuantizeScheme()) quant_aware_model.summary() fit_kwargs = { 'batch_size': batch_size, 'epochs': epochs, } compile_and_fit(quant_aware_model, train_images, train_labels, compile_kwargs={}, fit_kwargs=fit_kwargs) return quant_aware_model