def test_train_and_save_full(self, model="mnist_flax", serving_batch_size=-1): FLAGS.model = model FLAGS.model_classifier_layer = True FLAGS.serving_batch_size = serving_batch_size saved_model_main.train_and_save()
def main(_): FLAGS.model_classifier_layer = False # We only need the features # Train the model and save the feature extractor saved_model_main.train_and_save() tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances() feature_model_dir = saved_model_main.savedmodel_dir() # With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level # analogue of the tf.device(...) placement seen above. # It works on CPU, GPU and TPU. # Actual high-performance training would use the appropriately replicated # TF Distribution Strategy. strategy = tf.distribute.OneDeviceStrategy(tf_accelerator) with strategy.scope(): images = tf.keras.layers.Input(mnist_lib.input_shape, batch_size=mnist_lib.train_batch_size) # We do not yet train the SavedModel, due to b/123499169. keras_feature_extractor = hub.KerasLayer(feature_model_dir, trainable=False) features = keras_feature_extractor(images) predictor = tf.keras.layers.Dense(10, activation="softmax") predictions = predictor(features) keras_model = tf.keras.Model(images, predictions) keras_model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), metrics=["accuracy"]) logging.info(keras_model.summary()) train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist(tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds) if FLAGS.show_images: mnist_lib.plot_images( test_ds, 1, 5, f"Keras inference with reuse of {saved_model_main.model_description()}", inference_fn=lambda images: keras_model( tf.convert_to_tensor(images)))
def test_train_and_save_features(self, model="mnist_flax"): FLAGS.model = model FLAGS.model_classifier_layer = False saved_model_main.train_and_save()