Exemplo n.º 1
0
def export_model(model, train_input_fn, eval_input_fn, tf_feat_cols, base_dir):
    """Export TensorFlow estimator (model).

    Args:
        model (tf.estimator.Estimator): Model to export.
        train_input_fn (function): Training input function to create data receiver spec.
        eval_input_fn (function): Evaluation input function to create data receiver spec.
        tf_feat_cols (list(tf.feature_column)): Feature columns.
        base_dir (str): Base directory to export the model.

    Returns:
        str: Exported model path
    """
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
    train_rcvr_fn = build_supervised_input_receiver_fn_from_input_fn(
        train_input_fn)
    eval_rcvr_fn = build_supervised_input_receiver_fn_from_input_fn(
        eval_input_fn)
    serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
        tf.feature_column.make_parse_example_spec(tf_feat_cols))
    rcvr_fn_map = {
        tf.estimator.ModeKeys.TRAIN: train_rcvr_fn,
        tf.estimator.ModeKeys.EVAL: eval_rcvr_fn,
        tf.estimator.ModeKeys.PREDICT: serve_rcvr_fn,
    }
    exported_path = model.experimental_export_all_saved_models(
        export_dir_base=base_dir, input_receiver_fn_map=rcvr_fn_map)

    return exported_path.decode("utf-8")
Exemplo n.º 2
0
  def test_build_supervised_input_receiver_fn_from_input_fn(self):
    def dummy_input_fn():
      return ({"x": constant_op.constant([[1], [1]]),
               "y": constant_op.constant(["hello", "goodbye"])},
              constant_op.constant([[1], [1]]))

    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
        dummy_input_fn)

    with ops.Graph().as_default():
      input_receiver = input_receiver_fn()
      self.assertEqual(set(["x", "y"]),
                       set(input_receiver.features.keys()))
      self.assertIsInstance(input_receiver.labels, ops.Tensor)
      self.assertEqual(set(["x", "y", "label"]),
                       set(input_receiver.receiver_tensors.keys()))
Exemplo n.º 3
0
  def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
    def dummy_input_fn(feature_key="x"):
      return ({feature_key: constant_op.constant([[1], [1]]),
               "y": constant_op.constant(["hello", "goodbye"])},
              {"my_label": constant_op.constant([[1], [1]])})

    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
        dummy_input_fn, feature_key="z")

    with ops.Graph().as_default():
      input_receiver = input_receiver_fn()
      self.assertEqual(set(["z", "y"]),
                       set(input_receiver.features.keys()))
      self.assertEqual(set(["my_label"]),
                       set(input_receiver.labels.keys()))
      self.assertEqual(set(["z", "y", "my_label"]),
                       set(input_receiver.receiver_tensors.keys()))
Exemplo n.º 4
0
def dummy_supervised_receiver_fn():
  return export.build_supervised_input_receiver_fn_from_input_fn(dummy_input_fn)