def run_fn(fn_args: tfx.components.FnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = _input_fn(fn_args.train_files,
                              fn_args.data_accessor,
                              tf_transform_output,
                              batch_size=_TRAIN_BATCH_SIZE)

    eval_dataset = _input_fn(fn_args.eval_files,
                             fn_args.data_accessor,
                             tf_transform_output,
                             batch_size=_EVAL_BATCH_SIZE)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = _build_keras_model()

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              verbose=2)

    signatures = {
        'serving_default':
        _get_inference_fn(model, tf_transform_output).get_concrete_function(
            tf.TensorSpec(shape=[None],
                          dtype=tf.int64,
                          name=_CUR_PAGE_FEATURE_KEY),
            tf.TensorSpec(shape=[None],
                          dtype=tf.int64,
                          name=_SESSION_INDEX_FEATURE_KEY)),
    }

    # Create the saved_model in a temporary directory.
    temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
    model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

    # Convert the saved_model to a tfjs model and store it in the final directory.
    tfrw = rewriter_factory.create_rewriter(rewriter_factory.TFJS_REWRITER,
                                            name='tfjs_rewriter')
    converters.rewrite_saved_model(temp_saving_model_dir,
                                   fn_args.serving_model_dir, tfrw,
                                   rewriter.ModelType.TFJS_MODEL)

    # Copy the vocabulary computed by transform to the final directory.
    # The vocabulary is not included in the original savedmodel because vocab
    # lookups are currently not supported in TFJS and are expected to be done
    # independently by client code.
    fileio.copy(tf_transform_output.vocabulary_file_by_name(_VOCAB_FILENAME),
                os.path.join(fn_args.serving_model_dir, _VOCAB_FILENAME))

    fileio.rmtree(temp_saving_model_dir)
Exemplo n.º 2
0
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40)
    eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = base.build_keras_model()

    try:
        log_dir = fn_args.model_run_dir
    except KeyError:
        # TODO(b/158106209): use ModelRun instead of Model artifact for logging.
        log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir),
                               'logs')

    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                          update_freq='batch')

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
                                      tf.TensorSpec(shape=[None, 784],
                                                    dtype=tf.float32,
                                                    name='image_floats'))
    }
    temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
    model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

    tfrw = rewriter_factory.create_rewriter(
        rewriter_factory.TFLITE_REWRITER,
        name='tflite_rewriter',
        enable_experimental_new_converter=True)
    converters.rewrite_saved_model(temp_saving_model_dir,
                                   fn_args.serving_model_dir, tfrw,
                                   rewriter.ModelType.TFLITE_MODEL)

    tf.io.gfile.rmtree(temp_saving_model_dir)
Exemplo n.º 3
0
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor,
                                tf_transform_output, 40)
  eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor,
                               tf_transform_output, 40)

  mirrored_strategy = tf.distribute.MirroredStrategy()
  with mirrored_strategy.scope():
    model = base.build_keras_model()

  # Write logs to path
  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=fn_args.model_run_dir, update_freq='batch')

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps,
      callbacks=[tensorboard_callback])

  signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }
  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

  tfx.dsl.io.fileio.rmtree(temp_saving_model_dir)
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40)
    eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = base.build_keras_model()

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps)

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
                                      tf.TensorSpec(shape=[None, 784],
                                                    dtype=tf.float32,
                                                    name='image_floats'))
    }
    temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
    model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

    tfrw = rewriter_factory.create_rewriter(
        rewriter_factory.TFLITE_REWRITER,
        name='tflite_rewriter',
        enable_experimental_new_converter=True)
    converters.rewrite_saved_model(temp_saving_model_dir,
                                   fn_args.serving_model_dir, tfrw,
                                   rewriter.ModelType.TFLITE_MODEL)

    tf.io.gfile.rmtree(temp_saving_model_dir)
Exemplo n.º 5
0
 def testRewriterFactorySuccessfullyCreatedTFJSRewriter(self):
   tfrw = rewriter_factory.create_rewriter(rewriter_factory.TFJS_REWRITER,
                                           name='my_rewriter')
   self.assertTrue(tfrw)
   self.assertEqual(type(tfrw).__name__, rewriter_factory.TFJS_REWRITER)
   self.assertEqual(tfrw.name, 'my_rewriter')
Exemplo n.º 6
0
 def testRewriterFactorySuccessfullyCreated(self, rewriter_name):
   tfrw = rewriter_factory.create_rewriter(rewriter_name, name='my_rewriter')
   self.assertTrue(tfrw)
   self.assertEqual(type(tfrw).__name__, rewriter_name)
   self.assertEqual(tfrw.name, 'my_rewriter')
Exemplo n.º 7
0
def run_fn(fn_args: FnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.

  Raises:
    ValueError: if invalid inputs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = _input_fn(fn_args.train_files,
                              fn_args.data_accessor,
                              tf_transform_output,
                              is_train=True,
                              batch_size=_TRAIN_BATCH_SIZE)
    eval_dataset = _input_fn(fn_args.eval_files,
                             fn_args.data_accessor,
                             tf_transform_output,
                             is_train=False,
                             batch_size=_EVAL_BATCH_SIZE)

    model, base_model = _build_keras_model()

    absl.logging.info('Tensorboard logging to {}'.format(
        fn_args.model_run_dir))
    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=fn_args.model_run_dir, update_freq='batch')

    # Our training regime has two phases: we first freeze the backbone and train
    # the newly added classifier only, then unfreeze part of the backbone and
    # fine-tune with classifier jointly.
    steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE)
    total_epochs = int(fn_args.train_steps / steps_per_epoch)
    if _CLASSIFIER_EPOCHS > total_epochs:
        raise ValueError('Classifier epochs is greater than the total epochs')

    absl.logging.info('Start training the top classifier')
    model.fit(train_dataset,
              epochs=_CLASSIFIER_EPOCHS,
              steps_per_epoch=steps_per_epoch,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    absl.logging.info('Start fine-tuning the model')
    # Unfreeze the top MobileNet layers and do joint fine-tuning
    _freeze_model_by_percentage(base_model, 0.9)

    # We need to recompile the model because layer properties have changed
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=tf.keras.optimizers.RMSprop(lr=_FINETUNE_LEARNING_RATE),
        metrics=['sparse_categorical_accuracy'])
    model.summary(print_fn=absl.logging.info)

    model.fit(train_dataset,
              initial_epoch=_CLASSIFIER_EPOCHS,
              epochs=total_epochs,
              steps_per_epoch=steps_per_epoch,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    # Prepare the TFLite model used for serving in MLKit
    signatures = {
        'serving_default':
        _get_serve_image_fn(model).get_concrete_function(
            tf.TensorSpec(shape=[None, 224, 224, 3],
                          dtype=tf.float32,
                          name=_transformed_name(_IMAGE_KEY)))
    }

    temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
    model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

    tfrw = rewriter_factory.create_rewriter(rewriter_factory.TFLITE_REWRITER,
                                            name='tflite_rewriter')
    converters.rewrite_saved_model(temp_saving_model_dir,
                                   fn_args.serving_model_dir, tfrw,
                                   rewriter.ModelType.TFLITE_MODEL)

    # Add necessary TFLite metadata to the model in order to use it within MLKit
    # TODO(dzats@): Handle label map file path more properly, currently
    # hard-coded.
    tflite_model_path = os.path.join(fn_args.serving_model_dir,
                                     _TFLITE_MODEL_NAME)
    # TODO(dzats@): Extend the TFLite rewriter to be able to add TFLite metadata
    #@ to the model.
    _write_metadata(model_path=tflite_model_path,
                    label_map_path=fn_args.custom_config['labels_path'],
                    mean=[127.5],
                    std=[127.5])

    fileio.rmtree(temp_saving_model_dir)
Exemplo n.º 8
0
 def testRewriterSuccessfullyCreatedTFLiteRewriter(self):
     tfrw = rewriter_factory.create_rewriter(
         rewriter_factory.TFLITE_REWRITER, name='my_rewriter')
     self.assertTrue(tfrw)
     self.assertEqual(tfrw.name, 'my_rewriter')