Esempio n. 1
0
 def test_train_and_save_full(self,
                              model="mnist_flax",
                              serving_batch_size=-1):
     FLAGS.model = model
     FLAGS.model_classifier_layer = True
     FLAGS.serving_batch_size = serving_batch_size
     saved_model_main.train_and_save()
Esempio n. 2
0
def main(_):
    FLAGS.model_classifier_layer = False  # We only need the features
    # Train the model and save the feature extractor
    saved_model_main.train_and_save()

    tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances()
    feature_model_dir = saved_model_main.savedmodel_dir()

    # With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level
    # analogue of the tf.device(...) placement seen above.
    # It works on CPU, GPU and TPU.
    # Actual high-performance training would use the appropriately replicated
    # TF Distribution Strategy.
    strategy = tf.distribute.OneDeviceStrategy(tf_accelerator)
    with strategy.scope():
        images = tf.keras.layers.Input(mnist_lib.input_shape,
                                       batch_size=mnist_lib.train_batch_size)
        # We do not yet train the SavedModel, due to b/123499169.
        keras_feature_extractor = hub.KerasLayer(feature_model_dir,
                                                 trainable=False)
        features = keras_feature_extractor(images)
        predictor = tf.keras.layers.Dense(10, activation="softmax")
        predictions = predictor(features)
        keras_model = tf.keras.Model(images, predictions)

    keras_model.compile(loss=tf.keras.losses.categorical_crossentropy,
                        optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
                        metrics=["accuracy"])
    logging.info(keras_model.summary())

    train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN,
                                    batch_size=mnist_lib.train_batch_size)
    test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
                                   batch_size=mnist_lib.test_batch_size)
    keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)

    if FLAGS.show_images:
        mnist_lib.plot_images(
            test_ds,
            1,
            5,
            f"Keras inference with reuse of {saved_model_main.model_description()}",
            inference_fn=lambda images: keras_model(
                tf.convert_to_tensor(images)))
Esempio n. 3
0
 def test_train_and_save_features(self, model="mnist_flax"):
     FLAGS.model = model
     FLAGS.model_classifier_layer = False
     saved_model_main.train_and_save()