Beispiel #1
0
def run_fn(fn_args: tfx.components.FnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    if fn_args.transform_output is None:  # Transform is not used.
        tf_transform_output = None
        schema = tfx.utils.parse_pbtxt_file(fn_args.schema_file,
                                            schema_pb2.Schema())
        feature_list = features.FEATURE_KEYS
        label_key = features.LABEL_KEY
    else:
        tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
        schema = tf_transform_output.transformed_metadata.schema
        feature_list = [
            features.transformed_name(f) for f in features.FEATURE_KEYS
        ]
        label_key = features.transformed_name(features.LABEL_KEY)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    train_batch_size = (constants.TRAIN_BATCH_SIZE *
                        mirrored_strategy.num_replicas_in_sync)
    eval_batch_size = (constants.EVAL_BATCH_SIZE *
                       mirrored_strategy.num_replicas_in_sync)

    train_dataset = _input_fn(fn_args.train_files,
                              fn_args.data_accessor,
                              schema,
                              label_key,
                              batch_size=train_batch_size)
    eval_dataset = _input_fn(fn_args.eval_files,
                             fn_args.data_accessor,
                             schema,
                             label_key,
                             batch_size=eval_batch_size)

    with mirrored_strategy.scope():
        model = _build_keras_model(feature_list)

    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=fn_args.model_run_dir, update_freq='batch')

    model.fit(train_dataset,
              steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset,
              validation_steps=fn_args.eval_steps,
              callbacks=[tensorboard_callback])

    signatures = {
        'serving_default':
        _get_tf_examples_serving_signature(model, schema, tf_transform_output),
        'transform_features':
        _get_transform_features_signature(model, schema, tf_transform_output),
    }
    model.save(fn_args.serving_model_dir,
               save_format='tf',
               signatures=signatures)
Beispiel #2
0
def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}

  # This function is the entry point for your feature engineering with
  # TensorFlow Transform, using the TFX Transform component.  In this example
  # the feature engineering is very simple, only applying z-score scaling.
  for key in features.FEATURE_KEYS:
    outputs[features.transformed_name(key)] = tft.scale_to_z_score(inputs[key])

  # Do not apply label transformation as it will result in wrong evaluation.
  outputs[features.transformed_name(
      features.LABEL_KEY)] = inputs[features.LABEL_KEY]

  return outputs