def jax2tf_to_tfjs(module: lib.ModuleToConvert): """Converts the given `module` using the TFjs converter.""" with TempDir() as saved_model_path, TempDir() as converted_model_path: # the model must be converted with with_gradient set to True to be able to # convert the saved model to TF.js, as "PreventGradient" is not supported saved_model_lib.convert_and_save_model( module.apply, module.variables, saved_model_path, input_signatures=[ tf.TensorSpec(shape=module.input_shape, dtype=module.dtype, name='input') ], with_gradient=True, compile_model=False, enable_xla=False) tfjs_converter.convert([saved_model_path, converted_model_path])
def main(*args): base_model_path = "/tmp/jax2tf/tf_js_quickdraw" dataset_path = os.path.join(base_model_path, "data") classes = utils.download_dataset(dataset_path, NB_CLASSES) assert len(classes) == NB_CLASSES, classes print(f"Classes are: {classes}") train_ds, test_ds = utils.load_classes(dataset_path, classes) flax_params = train(train_ds, test_ds, classes) model_dir = os.path.join(base_model_path, "saved_models") # the model must be converted with with_gradient set to True to be able to # convert the saved model to TF.js, as "PreventGradient" is not supported saved_model_lib.convert_and_save_model(predict, flax_params, model_dir, input_signatures=[tf.TensorSpec([1, 28, 28, 1])], with_gradient=True, compile_model=False, enable_xla=False) conversion_dir = os.path.join(base_model_path, 'tfjs_models') convert_tf_saved_model(model_dir, conversion_dir)
def train_and_save(): logging.info("Loading the MNIST TensorFlow dataset") train_ds = mnist_lib.load_mnist( tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist( tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) if FLAGS.show_images: mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None) the_model_class = pick_model_class() model_dir = savedmodel_dir(with_version=True) if FLAGS.generate_model: model_descr = model_description() logging.info(f"Generating model for {model_descr}") (predict_fn, predict_params) = the_model_class.train( train_ds, test_ds, FLAGS.num_epochs, with_classifier=FLAGS.model_classifier_layer) input_signatures = [ # The first one will be the serving signature tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape, tf.float32), ] logging.info(f"Saving model for {model_descr}") saved_model_lib.convert_and_save_model( predict_fn, predict_params, model_dir, input_signatures=input_signatures, compile_model=FLAGS.compile_model) if FLAGS.test_savedmodel: tf_accelerator, tolerances = tf_accelerator_and_tolerances() with tf.device(tf_accelerator): logging.info("Testing savedmodel") pure_restored_model = tf.saved_model.load(model_dir) if FLAGS.show_images and FLAGS.model_classifier_layer: mnist_lib.plot_images( test_ds, 1, 5, f"Inference results for {model_descr}", inference_fn=pure_restored_model) test_input = np.ones( (mnist_lib.test_batch_size,) + mnist_lib.input_shape, dtype=np.float32) np.testing.assert_allclose( pure_restored_model(tf.convert_to_tensor(test_input)), predict_fn(predict_params, test_input), **tolerances) if FLAGS.show_model: def print_model(model_dir: str): cmd = f"saved_model_cli show --all --dir {model_dir}" print(cmd) os.system(cmd) print_model(model_dir)