예제 #1
0
def main(_):
    FLAGS.model_classifier_layer = False  # We only need the features
    # Train the model and save the feature extractor
    saved_model_main.train_and_save()

    tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances()
    feature_model_dir = saved_model_main.savedmodel_dir()

    # With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level
    # analogue of the tf.device(...) placement seen above.
    # It works on CPU, GPU and TPU.
    # Actual high-performance training would use the appropriately replicated
    # TF Distribution Strategy.
    strategy = tf.distribute.OneDeviceStrategy(tf_accelerator)
    with strategy.scope():
        images = tf.keras.layers.Input(mnist_lib.input_shape,
                                       batch_size=mnist_lib.train_batch_size)
        # We do not yet train the SavedModel, due to b/123499169.
        keras_feature_extractor = hub.KerasLayer(feature_model_dir,
                                                 trainable=False)
        features = keras_feature_extractor(images)
        predictor = tf.keras.layers.Dense(10, activation="softmax")
        predictions = predictor(features)
        keras_model = tf.keras.Model(images, predictions)

    keras_model.compile(loss=tf.keras.losses.categorical_crossentropy,
                        optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
                        metrics=["accuracy"])
    logging.info(keras_model.summary())

    train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN,
                                    batch_size=mnist_lib.train_batch_size)
    test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
                                   batch_size=mnist_lib.test_batch_size)
    keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)

    if FLAGS.show_images:
        mnist_lib.plot_images(
            test_ds,
            1,
            5,
            f"Keras inference with reuse of {saved_model_main.model_description()}",
            inference_fn=lambda images: keras_model(
                tf.convert_to_tensor(images)))
예제 #2
0
def main(_):
    if FLAGS.count_images % FLAGS.serving_batch_size != 0:
        raise ValueError(f"The count_images ({FLAGS.count_images}) must be a "
                         "multiple of "
                         f"serving_batch_size ({FLAGS.serving_batch_size})")
    test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
                                   batch_size=FLAGS.serving_batch_size)
    images_and_labels = tfds.as_numpy(
        test_ds.take(FLAGS.count_images // FLAGS.serving_batch_size))

    accurate_count = 0
    for batch_idx, (images, labels) in enumerate(images_and_labels):
        predictions_one_hot = serving_call_mnist(images)
        predictions_digit = np.argmax(predictions_one_hot, axis=1)
        labels_digit = np.argmax(labels, axis=1)
        accurate_count += np.sum(labels_digit == predictions_digit)
        running_accuracy = (100. * accurate_count / (1 + batch_idx) /
                            FLAGS.serving_batch_size)
        logging.info(
            f" predicted digits = {predictions_digit} labels {labels_digit}. "
            f"Running accuracy {running_accuracy:.3f}%")
예제 #3
0
파일: mnist.py 프로젝트: gnecula/jax
def main(_):
  logging.info('Loading the MNIST TensorFlow dataset')
  train_ds = mnist_lib.load_mnist(
      tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
  test_ds = mnist_lib.load_mnist(
      tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)

  (flax_predict,
   flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs)

  def predict(image):
    return flax_predict(flax_params, image)

  # Convert Flax model to TF function.
  tf_predict = tf.function(
      jax2tf.convert(predict, enable_xla=False),
      input_signature=[
          tf.TensorSpec(
              shape=[FLAGS.serving_batch_size, 28, 28, 1],
              dtype=tf.float32,
              name='input')
      ],
      autograph=False)

  # Convert TF function to TF Lite format.
  converter = tf.lite.TFLiteConverter.from_concrete_functions(
      [tf_predict.get_concrete_function()])
  converter.target_spec.supported_ops = [
      tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
      tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
  ]
  tflite_float_model = converter.convert()

  # Show model size in KBs.
  float_model_size = len(tflite_float_model) / 1024
  print('Float model size = %dKBs.' % float_model_size)

  # Re-convert the model to TF Lite using quantization.
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  tflite_quantized_model = converter.convert()

  # Show model size in KBs.
  quantized_model_size = len(tflite_quantized_model) / 1024
  print('Quantized model size = %dKBs,' % quantized_model_size)
  print('which is about %d%% of the float model size.' %
        (quantized_model_size * 100 / float_model_size))

  # Evaluate the TF Lite float model. You'll find that its accurary is identical
  # to the original Flax model because they are essentially the same model
  # stored in different format.
  float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds)
  print('Float model accuracy = %.4f' % float_accuracy)

  # Evalualte the TF Lite quantized model.
  # Don't be surprised if you see quantized model accuracy is higher than
  # the original float model. It happens sometimes :)
  quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds)
  print('Quantized model accuracy = %.4f' % quantized_accuracy)
  print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))

  f = open(FLAGS.tflite_file_path, 'wb')
  f.write(tflite_quantized_model)
  f.close()
예제 #4
0
def train_and_save():
  logging.info("Loading the MNIST TensorFlow dataset")
  train_ds = mnist_lib.load_mnist(
      tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
  test_ds = mnist_lib.load_mnist(
      tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)

  if FLAGS.show_images:
    mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)

  the_model_class = pick_model_class()
  model_dir = savedmodel_dir(with_version=True)

  if FLAGS.generate_model:
    model_descr = model_description()
    logging.info(f"Generating model for {model_descr}")
    (predict_fn, predict_params) = the_model_class.train(
        train_ds,
        test_ds,
        FLAGS.num_epochs,
        with_classifier=FLAGS.model_classifier_layer)

    input_signatures = [
        # The first one will be the serving signature
        tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape,
                      tf.float32),
        tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
                      tf.float32),
        tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape,
                      tf.float32),
    ]

    logging.info(f"Saving model for {model_descr}")
    saved_model_lib.convert_and_save_model(
        predict_fn,
        predict_params,
        model_dir,
        input_signatures=input_signatures,
        compile_model=FLAGS.compile_model)

    if FLAGS.test_savedmodel:
      tf_accelerator, tolerances = tf_accelerator_and_tolerances()
      with tf.device(tf_accelerator):
        logging.info("Testing savedmodel")
        pure_restored_model = tf.saved_model.load(model_dir)

        if FLAGS.show_images and FLAGS.model_classifier_layer:
          mnist_lib.plot_images(
              test_ds,
              1,
              5,
              f"Inference results for {model_descr}",
              inference_fn=pure_restored_model)

        test_input = np.ones(
            (mnist_lib.test_batch_size,) + mnist_lib.input_shape,
            dtype=np.float32)
        np.testing.assert_allclose(
            pure_restored_model(tf.convert_to_tensor(test_input)),
            predict_fn(predict_params, test_input), **tolerances)

  if FLAGS.show_model:
    def print_model(model_dir: str):
      cmd = f"saved_model_cli show --all --dir {model_dir}"
      print(cmd)
      os.system(cmd)

    print_model(model_dir)