Ejemplo n.º 1
0
    def test_get_tf_dataset_factory_from_artifact(self):
        examples = standard_artifacts.Examples()
        examples_utils.set_payload_format(
            examples, example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE)

        dataset_factory = tfxio_utils.get_tf_dataset_factory_from_artifact(
            [examples], _TELEMETRY_DESCRIPTORS)
        self.assertIsInstance(dataset_factory, Callable)
        self.assertEqual(tf.data.Dataset,
                         inspect.signature(dataset_factory).return_annotation)
Ejemplo n.º 2
0
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]],
                       exec_properties: Dict[Text, Any],
                       working_dir: Text = None) -> FnArgs:
  """Get common args of training and tuning."""
  if input_dict.get(standard_component_specs.TRANSFORM_GRAPH_KEY):
    transform_graph_path = artifact_utils.get_single_uri(
        input_dict[standard_component_specs.TRANSFORM_GRAPH_KEY])
  else:
    transform_graph_path = None

  if input_dict.get(standard_component_specs.SCHEMA_KEY):
    schema_path = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(
            input_dict[standard_component_specs.SCHEMA_KEY]))
  else:
    schema_path = None

  train_args = trainer_pb2.TrainArgs()
  eval_args = trainer_pb2.EvalArgs()
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.TRAIN_ARGS_KEY], train_args)
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.EVAL_ARGS_KEY], eval_args)

  # Default behavior is train on `train` split (when splits is empty in train
  # args) and evaluate on `eval` split (when splits is empty in eval args).
  if not train_args.splits:
    train_args.splits.append('train')
    absl.logging.info("Train on the 'train' split when train_args.splits is "
                      'not set.')
  if not eval_args.splits:
    eval_args.splits.append('eval')
    absl.logging.info("Evaluate on the 'eval' split when eval_args.splits is "
                      'not set.')

  train_files = []
  for train_split in train_args.splits:
    train_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], train_split)
    ])

  eval_files = []
  for eval_split in eval_args.splits:
    eval_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], eval_split)
    ])

  data_accessor = DataAccessor(
      tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      record_batch_factory=tfxio_utils.get_record_batch_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      data_view_decode_fn=tfxio_utils.get_data_view_decode_fn_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS)
      )

  # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
  # num_steps=None.  Conversion of the proto to python will set the default
  # value of an int as 0 so modify the value here.  Tensorflow will raise an
  # error if num_steps <= 0.
  train_steps = train_args.num_steps or None
  eval_steps = eval_args.num_steps or None

  # Load and deserialize custom config from execution properties.
  # Note that in the component interface the default serialization of custom
  # config is 'null' instead of '{}'. Therefore we need to default the
  # json_utils.loads to 'null' then populate it with an empty dict when
  # needed.
  custom_config = json_utils.loads(
      exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))

  return FnArgs(
      working_dir=working_dir,
      train_files=train_files,
      eval_files=eval_files,
      train_steps=train_steps,
      eval_steps=eval_steps,
      schema_path=schema_path,
      transform_graph_path=transform_graph_path,
      data_accessor=data_accessor,
      custom_config=custom_config,
  )
Ejemplo n.º 3
0
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]],
                       exec_properties: Dict[Text, Any],
                       working_dir: Text = None) -> FnArgs:
    """Get common args of training and tuning."""
    if input_dict.get(constants.TRANSFORM_GRAPH_KEY):
        transform_graph_path = artifact_utils.get_single_uri(
            input_dict[constants.TRANSFORM_GRAPH_KEY])
    else:
        transform_graph_path = None

    if input_dict.get(constants.SCHEMA_KEY):
        schema_path = io_utils.get_only_uri_in_dir(
            artifact_utils.get_single_uri(input_dict[constants.SCHEMA_KEY]))
    else:
        schema_path = None

    train_args = trainer_pb2.TrainArgs()
    eval_args = trainer_pb2.EvalArgs()
    json_format.Parse(exec_properties[constants.TRAIN_ARGS_KEY], train_args)
    json_format.Parse(exec_properties[constants.EVAL_ARGS_KEY], eval_args)

    # Default behavior is train on `train` split (when splits is empty in train
    # args) and evaluate on `eval` split (when splits is empty in eval args).
    if not train_args.splits:
        train_args.splits.append('train')
        absl.logging.info(
            "Train on the 'train' split when train_args.splits is "
            'not set.')
    if not eval_args.splits:
        eval_args.splits.append('eval')
        absl.logging.info(
            "Evaluate on the 'eval' split when eval_args.splits is "
            'not set.')

    train_files = []
    for train_split in train_args.splits:
        train_files.extend([
            io_utils.all_files_pattern(uri)
            for uri in artifact_utils.get_split_uris(
                input_dict[constants.EXAMPLES_KEY], train_split)
        ])

    eval_files = []
    for eval_split in eval_args.splits:
        eval_files.extend([
            io_utils.all_files_pattern(uri)
            for uri in artifact_utils.get_split_uris(
                input_dict[constants.EXAMPLES_KEY], eval_split)
        ])

    data_accessor = DataAccessor(
        tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
            input_dict[constants.EXAMPLES_KEY], _TELEMETRY_DESCRIPTORS))

    # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
    # num_steps=None.  Conversion of the proto to python will set the default
    # value of an int as 0 so modify the value here.  Tensorflow will raise an
    # error if num_steps <= 0.
    train_steps = train_args.num_steps or None
    eval_steps = eval_args.num_steps or None

    # TODO(b/156929910): Refactor Trainer to be consistent with empty or None
    #                    custom_config handling.
    custom_config = json_utils.loads(
        exec_properties.get(constants.CUSTOM_CONFIG_KEY, 'null'))

    return FnArgs(
        working_dir=working_dir,
        train_files=train_files,
        eval_files=eval_files,
        train_steps=train_steps,
        eval_steps=eval_steps,
        schema_path=schema_path,
        transform_graph_path=transform_graph_path,
        data_accessor=data_accessor,
        custom_config=custom_config,
    )
Ejemplo n.º 4
0
  def testTrainerFn(self):
    temp_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt')
    data_accessor = DataAccessor(
        tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
            [standard_artifacts.Examples()], []),
        record_batch_factory=None,
        data_view_decode_fn=None)
    trainer_fn_args = trainer_executor.TrainerFnArgs(
        train_files=os.path.join(
            self._testdata_path,
            'transform/transformed_examples/Split-train/*.gz'),
        transform_output=os.path.join(self._testdata_path,
                                      'transform/transform_graph'),
        serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'),
        eval_files=os.path.join(
            self._testdata_path,
            'transform/transformed_examples/Split-eval/*.gz'),
        schema_file=schema_file,
        train_steps=1,
        eval_steps=1,
        base_model=None,
        data_accessor=data_accessor)
    schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
    training_spec = taxi_utils.trainer_fn(trainer_fn_args, schema)

    estimator = training_spec['estimator']
    train_spec = training_spec['train_spec']
    eval_spec = training_spec['eval_spec']
    eval_input_receiver_fn = training_spec['eval_input_receiver_fn']

    self.assertIsInstance(estimator,
                          tf.estimator.DNNLinearCombinedClassifier)
    self.assertIsInstance(train_spec, tf.estimator.TrainSpec)
    self.assertIsInstance(eval_spec, tf.estimator.EvalSpec)
    self.assertIsInstance(eval_input_receiver_fn, types.FunctionType)

    # Test keep_max_checkpoint in RunConfig
    self.assertGreater(estimator._config.keep_checkpoint_max, 1)

    # Train for one step, then eval for one step.
    eval_result, exports = tf.estimator.train_and_evaluate(
        estimator, train_spec, eval_spec)
    self.assertGreater(eval_result['loss'], 0.0)
    self.assertEqual(len(exports), 1)
    self.assertGreaterEqual(len(fileio.listdir(exports[0])), 1)

    # Export the eval saved model.
    eval_savedmodel_path = tfma.export.export_eval_savedmodel(
        estimator=estimator,
        export_dir_base=path_utils.eval_model_dir(temp_dir),
        eval_input_receiver_fn=eval_input_receiver_fn)
    self.assertGreaterEqual(len(fileio.listdir(eval_savedmodel_path)), 1)

    # Test exported serving graph.
    with tf.compat.v1.Session() as sess:
      metagraph_def = tf.compat.v1.saved_model.loader.load(
          sess, [tf.saved_model.SERVING], exports[0])
      self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef)