def run_fn(fn_args: tfx.components.FnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, tf_transform_output, batch_size=_TRAIN_BATCH_SIZE) eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, tf_transform_output, batch_size=_EVAL_BATCH_SIZE) mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = _build_keras_model() model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset, validation_steps=fn_args.eval_steps, verbose=2) signatures = { 'serving_default': _get_inference_fn(model, tf_transform_output).get_concrete_function( tf.TensorSpec(shape=[None], dtype=tf.int64, name=_CUR_PAGE_FEATURE_KEY), tf.TensorSpec(shape=[None], dtype=tf.int64, name=_SESSION_INDEX_FEATURE_KEY)), } # Create the saved_model in a temporary directory. temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) # Convert the saved_model to a tfjs model and store it in the final directory. tfrw = rewriter_factory.create_rewriter(rewriter_factory.TFJS_REWRITER, name='tfjs_rewriter') converters.rewrite_saved_model(temp_saving_model_dir, fn_args.serving_model_dir, tfrw, rewriter.ModelType.TFJS_MODEL) # Copy the vocabulary computed by transform to the final directory. # The vocabulary is not included in the original savedmodel because vocab # lookups are currently not supported in TFJS and are expected to be done # independently by client code. fileio.copy(tf_transform_output.vocabulary_file_by_name(_VOCAB_FILENAME), os.path.join(fn_args.serving_model_dir, _VOCAB_FILENAME)) fileio.rmtree(temp_saving_model_dir)
def run_fn(fn_args: TrainerFnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40) eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40) mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = base.build_keras_model() try: log_dir = fn_args.model_run_dir except KeyError: # TODO(b/158106209): use ModelRun instead of Model artifact for logging. log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs') # Write logs to path tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='batch') model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset, validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback]) signatures = { 'serving_default': _get_serve_tf_examples_fn(model, tf_transform_output).get_concrete_function( tf.TensorSpec(shape=[None, 784], dtype=tf.float32, name='image_floats')) } temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) tfrw = rewriter_factory.create_rewriter( rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter', enable_experimental_new_converter=True) converters.rewrite_saved_model(temp_saving_model_dir, fn_args.serving_model_dir, tfrw, rewriter.ModelType.TFLITE_MODEL) tf.io.gfile.rmtree(temp_saving_model_dir)
def run_fn(fn_args: tfx.components.FnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor, tf_transform_output, 40) eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor, tf_transform_output, 40) mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = base.build_keras_model() # Write logs to path tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=fn_args.model_run_dir, update_freq='batch') model.fit( train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset, validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback]) signatures = { 'serving_default': _get_serve_tf_examples_fn( model, tf_transform_output).get_concrete_function( tf.TensorSpec( shape=[None, 784], dtype=tf.float32, name='image_floats')) } temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) tfrw = rewriter_factory.create_rewriter( rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter') converters.rewrite_saved_model(temp_saving_model_dir, fn_args.serving_model_dir, tfrw, rewriter.ModelType.TFLITE_MODEL) tfx.dsl.io.fileio.rmtree(temp_saving_model_dir)
def run_fn(fn_args: TrainerFnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40) eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40) mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = base.build_keras_model() model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset, validation_steps=fn_args.eval_steps) signatures = { 'serving_default': _get_serve_tf_examples_fn(model, tf_transform_output).get_concrete_function( tf.TensorSpec(shape=[None, 784], dtype=tf.float32, name='image_floats')) } temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) tfrw = rewriter_factory.create_rewriter( rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter', enable_experimental_new_converter=True) converters.rewrite_saved_model(temp_saving_model_dir, fn_args.serving_model_dir, tfrw, rewriter.ModelType.TFLITE_MODEL) tf.io.gfile.rmtree(temp_saving_model_dir)
def testRewriterFactorySuccessfullyCreatedTFJSRewriter(self): tfrw = rewriter_factory.create_rewriter(rewriter_factory.TFJS_REWRITER, name='my_rewriter') self.assertTrue(tfrw) self.assertEqual(type(tfrw).__name__, rewriter_factory.TFJS_REWRITER) self.assertEqual(tfrw.name, 'my_rewriter')
def testRewriterFactorySuccessfullyCreated(self, rewriter_name): tfrw = rewriter_factory.create_rewriter(rewriter_name, name='my_rewriter') self.assertTrue(tfrw) self.assertEqual(type(tfrw).__name__, rewriter_name) self.assertEqual(tfrw.name, 'my_rewriter')
def run_fn(fn_args: FnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. Raises: ValueError: if invalid inputs. """ tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, tf_transform_output, is_train=True, batch_size=_TRAIN_BATCH_SIZE) eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, tf_transform_output, is_train=False, batch_size=_EVAL_BATCH_SIZE) model, base_model = _build_keras_model() absl.logging.info('Tensorboard logging to {}'.format( fn_args.model_run_dir)) # Write logs to path tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=fn_args.model_run_dir, update_freq='batch') # Our training regime has two phases: we first freeze the backbone and train # the newly added classifier only, then unfreeze part of the backbone and # fine-tune with classifier jointly. steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE) total_epochs = int(fn_args.train_steps / steps_per_epoch) if _CLASSIFIER_EPOCHS > total_epochs: raise ValueError('Classifier epochs is greater than the total epochs') absl.logging.info('Start training the top classifier') model.fit(train_dataset, epochs=_CLASSIFIER_EPOCHS, steps_per_epoch=steps_per_epoch, validation_data=eval_dataset, validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback]) absl.logging.info('Start fine-tuning the model') # Unfreeze the top MobileNet layers and do joint fine-tuning _freeze_model_by_percentage(base_model, 0.9) # We need to recompile the model because layer properties have changed model.compile( loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.RMSprop(lr=_FINETUNE_LEARNING_RATE), metrics=['sparse_categorical_accuracy']) model.summary(print_fn=absl.logging.info) model.fit(train_dataset, initial_epoch=_CLASSIFIER_EPOCHS, epochs=total_epochs, steps_per_epoch=steps_per_epoch, validation_data=eval_dataset, validation_steps=fn_args.eval_steps, callbacks=[tensorboard_callback]) # Prepare the TFLite model used for serving in MLKit signatures = { 'serving_default': _get_serve_image_fn(model).get_concrete_function( tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32, name=_transformed_name(_IMAGE_KEY))) } temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) tfrw = rewriter_factory.create_rewriter(rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter') converters.rewrite_saved_model(temp_saving_model_dir, fn_args.serving_model_dir, tfrw, rewriter.ModelType.TFLITE_MODEL) # Add necessary TFLite metadata to the model in order to use it within MLKit # TODO(dzats@): Handle label map file path more properly, currently # hard-coded. tflite_model_path = os.path.join(fn_args.serving_model_dir, _TFLITE_MODEL_NAME) # TODO(dzats@): Extend the TFLite rewriter to be able to add TFLite metadata #@ to the model. _write_metadata(model_path=tflite_model_path, label_map_path=fn_args.custom_config['labels_path'], mean=[127.5], std=[127.5]) fileio.rmtree(temp_saving_model_dir)
def testRewriterSuccessfullyCreatedTFLiteRewriter(self): tfrw = rewriter_factory.create_rewriter( rewriter_factory.TFLITE_REWRITER, name='my_rewriter') self.assertTrue(tfrw) self.assertEqual(tfrw.name, 'my_rewriter')