def run_fn(fn_args: FnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor,
                                  tf_transform_output, base.TRAIN_BATCH_SIZE)

    eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor,
                                 tf_transform_output, base.EVAL_BATCH_SIZE)

    model = _make_trained_model(train_dataset,
                                eval_dataset,
                                num_epochs=1,
                                steps_per_epoch=fn_args.train_steps,
                                eval_steps_per_epoch=fn_args.eval_steps,
                                tensorboard_log_dir=fn_args.model_run_dir)
    # TODO(b/180721874): batch polymorphic model not yet supported.

    signatures = base.make_serving_signatures(model,
                                              tf_transform_output,
                                              serving_batch_size=1)
    tf.saved_model.save(model,
                        fn_args.serving_model_dir,
                        signatures=signatures)
예제 #2
0
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = base.input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      tf_transform_output,
      base.TRAIN_BATCH_SIZE)

  eval_dataset = base.input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      tf_transform_output,
      base.EVAL_BATCH_SIZE)

  model = _make_trained_model(
      train_dataset,
      eval_dataset,
      num_epochs=1,
      steps_per_epoch=fn_args.train_steps,
      eval_steps_per_epoch=fn_args.eval_steps,
      tensorboard_log_dir=fn_args.model_run_dir)

  signatures = base.make_serving_signatures(model, tf_transform_output)
  tf.saved_model.save(model, fn_args.serving_model_dir, signatures=signatures)
예제 #3
0
def tuner_fn(fn_args: tfx.components.FnArgs) -> tfx.components.TunerFnResult:
    """Build the tuner using the KerasTuner API.

  Args:
    fn_args: Holds args as name/value pairs.
      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.

  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.
  """
    # RandomSearch is a subclass of kerastuner.Tuner which inherits from
    # BaseTuner.
    tuner = kerastuner.RandomSearch(_make_keras_model,
                                    max_trials=6,
                                    hyperparameters=_get_hyperparameters(),
                                    allow_new_entries=False,
                                    objective=kerastuner.Objective(
                                        'val_sparse_categorical_accuracy',
                                        'max'),
                                    directory=fn_args.working_dir,
                                    project_name='penguin_tuning')

    transform_graph = tft.TFTransformOutput(fn_args.transform_graph_path)

    train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor,
                                  transform_graph, base.TRAIN_BATCH_SIZE)

    eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor,
                                 transform_graph, base.EVAL_BATCH_SIZE)

    return tfx.components.TunerFnResult(tuner=tuner,
                                        fit_kwargs={
                                            'x': train_dataset,
                                            'validation_data': eval_dataset,
                                            'steps_per_epoch':
                                            fn_args.train_steps,
                                            'validation_steps':
                                            fn_args.eval_steps
                                        })
예제 #4
0
def run_fn(fn_args: tfx.components.FnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor,
                                  tf_transform_output, base.TRAIN_BATCH_SIZE)

    eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor,
                                 tf_transform_output, base.EVAL_BATCH_SIZE)

    if fn_args.hyperparameters:
        hparams = kerastuner.HyperParameters.from_config(
            fn_args.hyperparameters)
    else:
        # This is a shown case when hyperparameters is decided and Tuner is removed
        # from the pipeline. User can also inline the hyperparameters directly in
        # _build_keras_model.
        hparams = _get_hyperparameters()
    absl.logging.info('HyperParameters for training: %s' %
                      hparams.get_config())

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = _make_keras_model(hparams)

    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=fn_args.model_run_dir, update_freq='batch')

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    signatures = base.make_serving_signatures(model, tf_transform_output)
    model.save(fn_args.serving_model_dir,
               save_format='tf',
               signatures=signatures)