num_classes = 12 # should probably draw this directly from the dataset. # FLAGS = None if __name__ == '__main__': fname = 'quant_cal_idxs.txt' num_files_per_label = 10 Flags, unparsed = kws_util.parse_command() np.random.seed(2) tf.random.set_seed(2) print('We will download data to {:}'.format(Flags.data_dir)) ds_train, ds_test, ds_val = kws_data.get_training_data(Flags) print("Done getting data") labels = np.array([]) for _, batch_labels in ds_val: labels = np.hstack((labels, batch_labels)) cal_idxs = np.array([], dtype=int) for l in np.unique(labels): all_label_idxs = np.nonzero(labels == l)[ 0] # nonzero => tuple of arrays; get the first/only one sel_label_idxs = np.random.choice(all_label_idxs, size=num_files_per_label, replace=False) cal_idxs = np.concatenate((cal_idxs, sel_label_idxs))
Flags, unparsed = kws_util.parse_command() print( f"Converting trained model {Flags.saved_model_path} to TFL model at {Flags.tfl_file_name}" ) model = tf.keras.models.load_model(Flags.saved_model_path) converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] with open("quant_cal_idxs.txt") as fpi: cal_indices = [int(line) for line in fpi] cal_indices.sort() num_calibration_steps = len(cal_indices) _, _, ds_val = kws_data.get_training_data(Flags, val_cal_subset=True) ds_val = ds_val.unbatch().batch(1) if True: # enable if you want to check the distribution of labels in the calibration set label_counts = {} for label in range(12): label_counts[label] = 0 for _, label in ds_val.as_numpy_iterator(): label_counts[label[0]] += 1 for label in range(12): print(f"Cal set has {label_counts[label]} of label {label}") ds_iter = ds_val.as_numpy_iterator() def representative_dataset_gen(): for _ in range(num_calibration_steps):
import numpy as np import os import argparse import keras_model as models import get_dataset as aww_data import aww_util num_classes = 12 # should probably draw this directly from the dataset. # FLAGS = None if __name__ == '__main__': Flags, unparsed = aww_util.parse_command() print('We will download data to {:}'.format(Flags.data_dir)) print('We will train for {:} epochs'.format(Flags.epochs)) ds_train, ds_test, ds_val = aww_data.get_training_data(Flags) print("Done getting data") model = models.get_model(model_name=Flags.model_architecture) model.summary() train_hist = model.fit(ds_train, validation_data=ds_val, epochs=Flags.epochs) model.save(Flags.saved_model_path) if Flags.run_test_set: test_scores = model.evaluate(ds_test) print("Test loss:", test_scores[0]) print("Test accuracy:", test_scores[1])
import os import numpy as np import argparse import get_dataset as aww_data import aww_util if __name__ == '__main__': Flags, unparsed = aww_util.parse_command() num_calibration_steps = 10 converter = tf.lite.TFLiteConverter.from_saved_model( Flags.saved_model_path) converter.optimizations = [tf.lite.Optimize.DEFAULT] _, _, ds_val = aww_data.get_training_data(Flags) # ds_val = ds_val.batch(1) # can we use a larger batch? def representative_dataset_gen(): for _ in range(num_calibration_steps): next_input = np.expand_dims(next(ds_val.as_numpy_iterator())[0], 3) yield [next_input] converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset_gen converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 # or tf.uint8 converter.inference_output_type = tf.int8 # or tf.uint8 tflite_quant_model = converter.convert()