def main(_): FLAGS.model_classifier_layer = False # We only need the features # Train the model and save the feature extractor saved_model_main.train_and_save() tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances() feature_model_dir = saved_model_main.savedmodel_dir() # With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level # analogue of the tf.device(...) placement seen above. # It works on CPU, GPU and TPU. # Actual high-performance training would use the appropriately replicated # TF Distribution Strategy. strategy = tf.distribute.OneDeviceStrategy(tf_accelerator) with strategy.scope(): images = tf.keras.layers.Input(mnist_lib.input_shape, batch_size=mnist_lib.train_batch_size) # We do not yet train the SavedModel, due to b/123499169. keras_feature_extractor = hub.KerasLayer(feature_model_dir, trainable=False) features = keras_feature_extractor(images) predictor = tf.keras.layers.Dense(10, activation="softmax") predictions = predictor(features) keras_model = tf.keras.Model(images, predictions) keras_model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), metrics=["accuracy"]) logging.info(keras_model.summary()) train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist(tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds) if FLAGS.show_images: mnist_lib.plot_images( test_ds, 1, 5, f"Keras inference with reuse of {saved_model_main.model_description()}", inference_fn=lambda images: keras_model( tf.convert_to_tensor(images)))
def main(_): if FLAGS.count_images % FLAGS.serving_batch_size != 0: raise ValueError(f"The count_images ({FLAGS.count_images}) must be a " "multiple of " f"serving_batch_size ({FLAGS.serving_batch_size})") test_ds = mnist_lib.load_mnist(tfds.Split.TEST, batch_size=FLAGS.serving_batch_size) images_and_labels = tfds.as_numpy( test_ds.take(FLAGS.count_images // FLAGS.serving_batch_size)) accurate_count = 0 for batch_idx, (images, labels) in enumerate(images_and_labels): predictions_one_hot = serving_call_mnist(images) predictions_digit = np.argmax(predictions_one_hot, axis=1) labels_digit = np.argmax(labels, axis=1) accurate_count += np.sum(labels_digit == predictions_digit) running_accuracy = (100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size) logging.info( f" predicted digits = {predictions_digit} labels {labels_digit}. " f"Running accuracy {running_accuracy:.3f}%")
def main(_): logging.info('Loading the MNIST TensorFlow dataset') train_ds = mnist_lib.load_mnist( tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist( tfds.Split.TEST, batch_size=FLAGS.serving_batch_size) (flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs) def predict(image): return flax_predict(flax_params, image) # Convert Flax model to TF function. tf_predict = tf.function( jax2tf.convert(predict, enable_xla=False), input_signature=[ tf.TensorSpec( shape=[FLAGS.serving_batch_size, 28, 28, 1], dtype=tf.float32, name='input') ], autograph=False) # Convert TF function to TF Lite format. converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()]) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_float_model = converter.convert() # Show model size in KBs. float_model_size = len(tflite_float_model) / 1024 print('Float model size = %dKBs.' % float_model_size) # Re-convert the model to TF Lite using quantization. converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_quantized_model = converter.convert() # Show model size in KBs. quantized_model_size = len(tflite_quantized_model) / 1024 print('Quantized model size = %dKBs,' % quantized_model_size) print('which is about %d%% of the float model size.' % (quantized_model_size * 100 / float_model_size)) # Evaluate the TF Lite float model. You'll find that its accurary is identical # to the original Flax model because they are essentially the same model # stored in different format. float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds) print('Float model accuracy = %.4f' % float_accuracy) # Evalualte the TF Lite quantized model. # Don't be surprised if you see quantized model accuracy is higher than # the original float model. It happens sometimes :) quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds) print('Quantized model accuracy = %.4f' % quantized_accuracy) print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy)) f = open(FLAGS.tflite_file_path, 'wb') f.write(tflite_quantized_model) f.close()
def train_and_save(): logging.info("Loading the MNIST TensorFlow dataset") train_ds = mnist_lib.load_mnist( tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist( tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) if FLAGS.show_images: mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None) the_model_class = pick_model_class() model_dir = savedmodel_dir(with_version=True) if FLAGS.generate_model: model_descr = model_description() logging.info(f"Generating model for {model_descr}") (predict_fn, predict_params) = the_model_class.train( train_ds, test_ds, FLAGS.num_epochs, with_classifier=FLAGS.model_classifier_layer) input_signatures = [ # The first one will be the serving signature tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape, tf.float32), ] logging.info(f"Saving model for {model_descr}") saved_model_lib.convert_and_save_model( predict_fn, predict_params, model_dir, input_signatures=input_signatures, compile_model=FLAGS.compile_model) if FLAGS.test_savedmodel: tf_accelerator, tolerances = tf_accelerator_and_tolerances() with tf.device(tf_accelerator): logging.info("Testing savedmodel") pure_restored_model = tf.saved_model.load(model_dir) if FLAGS.show_images and FLAGS.model_classifier_layer: mnist_lib.plot_images( test_ds, 1, 5, f"Inference results for {model_descr}", inference_fn=pure_restored_model) test_input = np.ones( (mnist_lib.test_batch_size,) + mnist_lib.input_shape, dtype=np.float32) np.testing.assert_allclose( pure_restored_model(tf.convert_to_tensor(test_input)), predict_fn(predict_params, test_input), **tolerances) if FLAGS.show_model: def print_model(model_dir: str): cmd = f"saved_model_cli show --all --dir {model_dir}" print(cmd) os.system(cmd) print_model(model_dir)