Exemple #1
0
def jax2tf_to_tfjs(module: lib.ModuleToConvert):
    """Converts the given `module` using the TFjs converter."""
    with TempDir() as saved_model_path, TempDir() as converted_model_path:
        # the model must be converted with with_gradient set to True to be able to
        # convert the saved model to TF.js, as "PreventGradient" is not supported
        saved_model_lib.convert_and_save_model(
            module.apply,
            module.variables,
            saved_model_path,
            input_signatures=[
                tf.TensorSpec(shape=module.input_shape,
                              dtype=module.dtype,
                              name='input')
            ],
            with_gradient=True,
            compile_model=False,
            enable_xla=False)
        tfjs_converter.convert([saved_model_path, converted_model_path])
Exemple #2
0
def main(*args):
  base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
  dataset_path = os.path.join(base_model_path, "data")
  classes = utils.download_dataset(dataset_path, NB_CLASSES)
  assert len(classes) == NB_CLASSES, classes
  print(f"Classes are: {classes}")
  train_ds, test_ds = utils.load_classes(dataset_path, classes)
  flax_params = train(train_ds, test_ds, classes)

  model_dir = os.path.join(base_model_path, "saved_models")
  # the model must be converted with with_gradient set to True to be able to
  # convert the saved model to TF.js, as "PreventGradient" is not supported
  saved_model_lib.convert_and_save_model(predict, flax_params, model_dir,
                             input_signatures=[tf.TensorSpec([1, 28, 28, 1])],
                             with_gradient=True, compile_model=False,
                             enable_xla=False)
  conversion_dir = os.path.join(base_model_path, 'tfjs_models')
  convert_tf_saved_model(model_dir, conversion_dir)
Exemple #3
0
def train_and_save():
  logging.info("Loading the MNIST TensorFlow dataset")
  train_ds = mnist_lib.load_mnist(
      tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
  test_ds = mnist_lib.load_mnist(
      tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)

  if FLAGS.show_images:
    mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)

  the_model_class = pick_model_class()
  model_dir = savedmodel_dir(with_version=True)

  if FLAGS.generate_model:
    model_descr = model_description()
    logging.info(f"Generating model for {model_descr}")
    (predict_fn, predict_params) = the_model_class.train(
        train_ds,
        test_ds,
        FLAGS.num_epochs,
        with_classifier=FLAGS.model_classifier_layer)

    input_signatures = [
        # The first one will be the serving signature
        tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape,
                      tf.float32),
        tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
                      tf.float32),
        tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape,
                      tf.float32),
    ]

    logging.info(f"Saving model for {model_descr}")
    saved_model_lib.convert_and_save_model(
        predict_fn,
        predict_params,
        model_dir,
        input_signatures=input_signatures,
        compile_model=FLAGS.compile_model)

    if FLAGS.test_savedmodel:
      tf_accelerator, tolerances = tf_accelerator_and_tolerances()
      with tf.device(tf_accelerator):
        logging.info("Testing savedmodel")
        pure_restored_model = tf.saved_model.load(model_dir)

        if FLAGS.show_images and FLAGS.model_classifier_layer:
          mnist_lib.plot_images(
              test_ds,
              1,
              5,
              f"Inference results for {model_descr}",
              inference_fn=pure_restored_model)

        test_input = np.ones(
            (mnist_lib.test_batch_size,) + mnist_lib.input_shape,
            dtype=np.float32)
        np.testing.assert_allclose(
            pure_restored_model(tf.convert_to_tensor(test_input)),
            predict_fn(predict_params, test_input), **tolerances)

  if FLAGS.show_model:
    def print_model(model_dir: str):
      cmd = f"saved_model_cli show --all --dir {model_dir}"
      print(cmd)
      os.system(cmd)

    print_model(model_dir)