예제 #1
0
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40)
    eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = base.build_keras_model()

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps)

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
                                      tf.TensorSpec(shape=[None],
                                                    dtype=tf.string,
                                                    name='examples'))
    }
    model.save(fn_args.serving_model_dir,
               save_format='tf',
               signatures=signatures)
예제 #2
0
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40)
    eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = base.build_keras_model()

    try:
        log_dir = fn_args.model_run_dir
    except KeyError:
        # TODO(b/158106209): use ModelRun instead of Model artifact for logging.
        log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir),
                               'logs')

    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                          update_freq='batch')

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
                                      tf.TensorSpec(shape=[None, 784],
                                                    dtype=tf.float32,
                                                    name='image_floats'))
    }
    temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
    model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

    tfrw = rewriter_factory.create_rewriter(
        rewriter_factory.TFLITE_REWRITER,
        name='tflite_rewriter',
        enable_experimental_new_converter=True)
    converters.rewrite_saved_model(temp_saving_model_dir,
                                   fn_args.serving_model_dir, tfrw,
                                   rewriter.ModelType.TFLITE_MODEL)

    tf.io.gfile.rmtree(temp_saving_model_dir)
예제 #3
0
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor,
                                tf_transform_output, 40)
  eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor,
                               tf_transform_output, 40)

  mirrored_strategy = tf.distribute.MirroredStrategy()
  with mirrored_strategy.scope():
    model = base.build_keras_model()

  # Write logs to path
  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=fn_args.model_run_dir, update_freq='batch')

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps,
      callbacks=[tensorboard_callback])

  signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }
  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

  tfx.dsl.io.fileio.rmtree(temp_saving_model_dir)
예제 #4
0
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40)
    eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = base.build_keras_model()

    try:
        log_dir = fn_args.model_run_dir
    except KeyError:
        # TODO(b/158106209): use ModelRun instead of Model artifact for logging.
        log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir),
                               'logs')

    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                          update_freq='batch')

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
                                      tf.TensorSpec(shape=[None],
                                                    dtype=tf.string,
                                                    name='examples'))
    }
    model.save(fn_args.serving_model_dir,
               save_format='tf',
               signatures=signatures)
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, tf_transform_output, 40)
    eval_dataset = base.input_fn(fn_args.eval_files, tf_transform_output, 40)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = base.build_keras_model()

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps)

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
                                      tf.TensorSpec(shape=[None, 784],
                                                    dtype=tf.float32,
                                                    name='image_floats'))
    }
    temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
    model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

    tfrw = rewriter_factory.create_rewriter(
        rewriter_factory.TFLITE_REWRITER,
        name='tflite_rewriter',
        enable_experimental_new_converter=True)
    converters.rewrite_saved_model(temp_saving_model_dir,
                                   fn_args.serving_model_dir, tfrw,
                                   rewriter.ModelType.TFLITE_MODEL)

    tf.io.gfile.rmtree(temp_saving_model_dir)