Esempio n. 1
0
 def testGetFromSplitsMultipleArtifacts(self):
     """Test split retrieval utility on a multiple list of split Artifacts."""
     artifacts = [
         standard_artifacts.Examples(),
         standard_artifacts.Examples()
     ]
     artifacts[0].uri = '/tmp1'
     artifacts[0].split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     artifacts[1].uri = '/tmp2'
     artifacts[1].split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     self.assertEqual(['/tmp1/train', '/tmp2/train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/eval', '/tmp2/eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
Esempio n. 2
0
 def testGetFromSplitsMultipleArtifacts(self):
     """Test split retrieval utility on a multiple list of split Artifacts."""
     artifacts = [
         standard_artifacts.Examples(),
         standard_artifacts.Examples()
     ]
     artifacts[0].uri = '/tmp1'
     artifacts[0].split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     artifacts[1].uri = '/tmp2'
     artifacts[1].split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     # When creating new splits, use 'Split-<split_name>' format.
     self.assertEqual(['/tmp1/Split-train', '/tmp2/Split-train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/Split-eval', '/tmp2/Split-eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
     # When reading artifacts without version.
     artifacts[0].mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE
     artifacts[1].mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE
     self.assertEqual(['/tmp1/train', '/tmp2/train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/eval', '/tmp2/eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
     # When reading artifacts with old version.
     artifacts[0].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.1')
     artifacts[1].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.1')
     self.assertEqual(['/tmp1/train', '/tmp2/train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/eval', '/tmp2/eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
     # When reading artifacts with new version.
     artifacts[0].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY,
         artifact_utils._ARTIFACT_VERSION_FOR_SPLIT_UPDATE)
     artifacts[1].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY,
         artifact_utils._ARTIFACT_VERSION_FOR_SPLIT_UPDATE)
     self.assertEqual(['/tmp1/Split-train', '/tmp2/Split-train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/Split-eval', '/tmp2/Split-eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
Esempio n. 3
0
def GetSplitPaths(
        transformed_examples: Optional[List[types.Artifact]]) -> List[str]:
    """Gets all paths for splits in the input artifacts."""
    result = []
    if not transformed_examples:
        return result
    splits = artifact_utils.decode_split_names(
        transformed_examples[0].split_names)

    for split in splits:
        transformed_example_uris = artifact_utils.get_split_uris(
            transformed_examples, split)
        for output_uri in transformed_example_uris:
            result.append(
                os.path.join(output_uri, _DEFAULT_TRANSFORMED_EXAMPLES_PREFIX))

    return result
Esempio n. 4
0
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]],
                       exec_properties: Dict[Text, Any],
                       working_dir: Text = None) -> FnArgs:
  """Get common args of training and tuning."""
  if input_dict.get(standard_component_specs.TRANSFORM_GRAPH_KEY):
    transform_graph_path = artifact_utils.get_single_uri(
        input_dict[standard_component_specs.TRANSFORM_GRAPH_KEY])
  else:
    transform_graph_path = None

  if input_dict.get(standard_component_specs.SCHEMA_KEY):
    schema_path = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(
            input_dict[standard_component_specs.SCHEMA_KEY]))
  else:
    schema_path = None

  train_args = trainer_pb2.TrainArgs()
  eval_args = trainer_pb2.EvalArgs()
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.TRAIN_ARGS_KEY], train_args)
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.EVAL_ARGS_KEY], eval_args)

  # Default behavior is train on `train` split (when splits is empty in train
  # args) and evaluate on `eval` split (when splits is empty in eval args).
  if not train_args.splits:
    train_args.splits.append('train')
    absl.logging.info("Train on the 'train' split when train_args.splits is "
                      'not set.')
  if not eval_args.splits:
    eval_args.splits.append('eval')
    absl.logging.info("Evaluate on the 'eval' split when eval_args.splits is "
                      'not set.')

  train_files = []
  for train_split in train_args.splits:
    train_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], train_split)
    ])

  eval_files = []
  for eval_split in eval_args.splits:
    eval_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], eval_split)
    ])

  data_accessor = DataAccessor(
      tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      record_batch_factory=tfxio_utils.get_record_batch_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      data_view_decode_fn=tfxio_utils.get_data_view_decode_fn_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS)
      )

  # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
  # num_steps=None.  Conversion of the proto to python will set the default
  # value of an int as 0 so modify the value here.  Tensorflow will raise an
  # error if num_steps <= 0.
  train_steps = train_args.num_steps or None
  eval_steps = eval_args.num_steps or None

  # Load and deserialize custom config from execution properties.
  # Note that in the component interface the default serialization of custom
  # config is 'null' instead of '{}'. Therefore we need to default the
  # json_utils.loads to 'null' then populate it with an empty dict when
  # needed.
  custom_config = json_utils.loads(
      exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))

  return FnArgs(
      working_dir=working_dir,
      train_files=train_files,
      eval_files=eval_files,
      train_steps=train_steps,
      eval_steps=eval_steps,
      schema_path=schema_path,
      transform_graph_path=transform_graph_path,
      data_accessor=data_accessor,
      custom_config=custom_config,
  )
Esempio n. 5
0
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]],
                       exec_properties: Dict[Text, Any],
                       working_dir: Text = None) -> FnArgs:
    """Get common args of training and tuning."""
    if input_dict.get(constants.TRANSFORM_GRAPH_KEY):
        transform_graph_path = artifact_utils.get_single_uri(
            input_dict[constants.TRANSFORM_GRAPH_KEY])
    else:
        transform_graph_path = None

    if input_dict.get(constants.SCHEMA_KEY):
        schema_path = io_utils.get_only_uri_in_dir(
            artifact_utils.get_single_uri(input_dict[constants.SCHEMA_KEY]))
    else:
        schema_path = None

    train_args = trainer_pb2.TrainArgs()
    eval_args = trainer_pb2.EvalArgs()
    json_format.Parse(exec_properties[constants.TRAIN_ARGS_KEY], train_args)
    json_format.Parse(exec_properties[constants.EVAL_ARGS_KEY], eval_args)

    # Default behavior is train on `train` split (when splits is empty in train
    # args) and evaluate on `eval` split (when splits is empty in eval args).
    if not train_args.splits:
        train_args.splits.append('train')
        absl.logging.info(
            "Train on the 'train' split when train_args.splits is "
            'not set.')
    if not eval_args.splits:
        eval_args.splits.append('eval')
        absl.logging.info(
            "Evaluate on the 'eval' split when eval_args.splits is "
            'not set.')

    train_files = []
    for train_split in train_args.splits:
        train_files.extend([
            io_utils.all_files_pattern(uri)
            for uri in artifact_utils.get_split_uris(
                input_dict[constants.EXAMPLES_KEY], train_split)
        ])

    eval_files = []
    for eval_split in eval_args.splits:
        eval_files.extend([
            io_utils.all_files_pattern(uri)
            for uri in artifact_utils.get_split_uris(
                input_dict[constants.EXAMPLES_KEY], eval_split)
        ])

    # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
    # num_steps=None.  Conversion of the proto to python will set the default
    # value of an int as 0 so modify the value here.  Tensorflow will raise an
    # error if num_steps <= 0.
    train_steps = train_args.num_steps or None
    eval_steps = eval_args.num_steps or None

    # TODO(b/156929910): Refactor Trainer to be consistent with empty or None
    #                    custom_config handling.
    custom_config = json_utils.loads(
        exec_properties.get(constants.CUSTOM_CONFIG_KEY, 'null'))

    return FnArgs(
        working_dir=working_dir,
        train_files=train_files,
        eval_files=eval_files,
        train_steps=train_steps,
        eval_steps=eval_steps,
        schema_path=schema_path,
        transform_graph_path=transform_graph_path,
        custom_config=custom_config,
    )
Esempio n. 6
0
  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]) -> None:
    """Runs a batch job to evaluate the eval_model against the given input.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - model: exported model.
        - examples: examples for eval the model.
      output_dict: Output dict from output key to a list of Artifacts.
        - evaluation: model evaluation results.
      exec_properties: A dict of execution properties.
        - eval_config: JSON string of tfma.EvalConfig.
        - feature_slicing_spec: JSON string of evaluator_pb2.FeatureSlicingSpec
          instance, providing the way to slice the data. Deprecated, use
          eval_config.slicing_specs instead.
        - example_splits: JSON-serialized list of names of splits on which the
          metrics are computed. Default behavior (when example_splits is set to
          None) is using the 'eval' split.

    Returns:
      None
    """
    if standard_component_specs.EXAMPLES_KEY not in input_dict:
      raise ValueError('EXAMPLES_KEY is missing from input dict.')
    if standard_component_specs.EVALUATION_KEY not in output_dict:
      raise ValueError('EVALUATION_KEY is missing from output dict.')
    if standard_component_specs.MODEL_KEY in input_dict and len(
        input_dict[standard_component_specs.MODEL_KEY]) > 1:
      raise ValueError('There can be only one candidate model, there are %d.' %
                       (len(input_dict[standard_component_specs.MODEL_KEY])))
    if standard_component_specs.BASELINE_MODEL_KEY in input_dict and len(
        input_dict[standard_component_specs.BASELINE_MODEL_KEY]) > 1:
      raise ValueError(
          'There can be only one baseline model, there are %d.' %
          (len(input_dict[standard_component_specs.BASELINE_MODEL_KEY])))

    self._log_startup(input_dict, output_dict, exec_properties)

    # Add fairness indicator metric callback if necessary.
    fairness_indicator_thresholds = json_utils.loads(
        exec_properties.get(
            standard_component_specs.FAIRNESS_INDICATOR_THRESHOLDS_KEY, 'null'))
    add_metrics_callbacks = None
    if fairness_indicator_thresholds:
      add_metrics_callbacks = [
          tfma.post_export_metrics.fairness_indicators(  # pytype: disable=module-attr
              thresholds=fairness_indicator_thresholds),
      ]

    output_uri = artifact_utils.get_single_uri(
        output_dict[constants.EVALUATION_KEY])

    # Make sure user packages get propagated to the remote Beam worker.
    unused_module_path, extra_pip_packages = udf_utils.decode_user_module_key(
        exec_properties.get(standard_component_specs.MODULE_PATH_KEY, None))
    for pip_package_path in extra_pip_packages:
      local_pip_package_path = io_utils.ensure_local(pip_package_path)
      self._beam_pipeline_args.append('--extra_package=%s' %
                                      local_pip_package_path)

    eval_shared_model_fn = udf_utils.try_get_fn(
        exec_properties=exec_properties,
        fn_name='custom_eval_shared_model') or tfma.default_eval_shared_model

    run_validation = False
    models = []
    if (standard_component_specs.EVAL_CONFIG_KEY in exec_properties
        and exec_properties[standard_component_specs.EVAL_CONFIG_KEY]):
      slice_spec = None
      has_baseline = bool(
          input_dict.get(standard_component_specs.BASELINE_MODEL_KEY))
      eval_config = tfma.EvalConfig()
      proto_utils.json_to_proto(
          exec_properties[standard_component_specs.EVAL_CONFIG_KEY],
          eval_config)
      # rubber_stamp is always assumed true, i.e., change threshold will always
      # be ignored when a baseline model is missing.
      if hasattr(tfma, 'utils'):
        eval_config = tfma.utils.update_eval_config_with_defaults(
            eval_config, has_baseline=has_baseline, rubber_stamp=True)
        tfma.utils.verify_eval_config(eval_config)
      else:
        # TODO(b/171992041): Replaced by tfma.utils.
        eval_config = tfma.update_eval_config_with_defaults(
            eval_config, has_baseline=has_baseline, rubber_stamp=True)
        tfma.verify_eval_config(eval_config)
      # Do not validate model when there is no thresholds configured. This is to
      # avoid accidentally blessing models when users forget to set thresholds.
      run_validation = bool(
          tfma.metrics.metric_thresholds_from_metrics_specs(
              eval_config.metrics_specs, eval_config=eval_config))
      if len(eval_config.model_specs) > 2:
        raise ValueError(
            """Cannot support more than two models. There are %d models in this
             eval_config.""" % (len(eval_config.model_specs)))
      # Extract model artifacts.
      for model_spec in eval_config.model_specs:
        if standard_component_specs.MODEL_KEY not in input_dict:
          if not model_spec.prediction_key:
            raise ValueError(
                'model_spec.prediction_key required if model not provided')
          continue
        if model_spec.is_baseline:
          model_artifact = artifact_utils.get_single_instance(
              input_dict[standard_component_specs.BASELINE_MODEL_KEY])
        else:
          model_artifact = artifact_utils.get_single_instance(
              input_dict[standard_component_specs.MODEL_KEY])
        # TODO(b/171992041): tfma.get_model_type replaced by tfma.utils.
        if ((hasattr(tfma, 'utils') and
             tfma.utils.get_model_type(model_spec) == tfma.TF_ESTIMATOR) or
            hasattr(tfma, 'get_model_type') and
            tfma.get_model_type(model_spec) == tfma.TF_ESTIMATOR):
          model_path = path_utils.eval_model_path(
              model_artifact.uri,
              path_utils.is_old_model_artifact(model_artifact))
        else:
          model_path = path_utils.serving_model_path(
              model_artifact.uri,
              path_utils.is_old_model_artifact(model_artifact))
        logging.info('Using %s as %s model.', model_path, model_spec.name)
        models.append(
            eval_shared_model_fn(
                eval_saved_model_path=model_path,
                model_name=model_spec.name,
                eval_config=eval_config,
                add_metrics_callbacks=add_metrics_callbacks))
    else:
      eval_config = None
      assert (standard_component_specs.FEATURE_SLICING_SPEC_KEY
              in exec_properties and
              exec_properties[standard_component_specs.FEATURE_SLICING_SPEC_KEY]
             ), 'both eval_config and feature_slicing_spec are unset.'
      feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec()
      proto_utils.json_to_proto(
          exec_properties[standard_component_specs.FEATURE_SLICING_SPEC_KEY],
          feature_slicing_spec)
      slice_spec = self._get_slice_spec_from_feature_slicing_spec(
          feature_slicing_spec)
      model_artifact = artifact_utils.get_single_instance(
          input_dict[standard_component_specs.MODEL_KEY])
      model_path = path_utils.eval_model_path(
          model_artifact.uri, path_utils.is_old_model_artifact(model_artifact))
      logging.info('Using %s for model eval.', model_path)
      models.append(
          eval_shared_model_fn(
              eval_saved_model_path=model_path,
              model_name='',
              eval_config=None,
              add_metrics_callbacks=add_metrics_callbacks))

    eval_shared_model = models[0] if len(models) == 1 else models
    schema = None
    if standard_component_specs.SCHEMA_KEY in input_dict:
      schema = io_utils.SchemaReader().read(
          io_utils.get_only_uri_in_dir(
              artifact_utils.get_single_uri(
                  input_dict[standard_component_specs.SCHEMA_KEY])))

    # Load and deserialize example splits from execution properties.
    example_splits = json_utils.loads(
        exec_properties.get(standard_component_specs.EXAMPLE_SPLITS_KEY,
                            'null'))
    if not example_splits:
      example_splits = ['eval']
      logging.info("The 'example_splits' parameter is not set, using 'eval' "
                   'split.')

    logging.info('Evaluating model.')
    # TempPipInstallContext is needed here so that subprocesses (which
    # may be created by the Beam multi-process DirectRunner) can find the
    # needed dependencies.
    # TODO(b/187122662): Move this to the ExecutorOperator or Launcher.
    with udf_utils.TempPipInstallContext(extra_pip_packages):
      with self._make_beam_pipeline() as pipeline:
        examples_list = []
        tensor_adapter_config = None
        # pylint: disable=expression-not-assigned
        if tfma.is_batched_input(eval_shared_model, eval_config):
          tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
              examples=input_dict[standard_component_specs.EXAMPLES_KEY],
              telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
              schema=schema,
              raw_record_column_name=tfma_constants.ARROW_INPUT_COLUMN)
          # TODO(b/161935932): refactor after TFXIO supports multiple patterns.
          for split in example_splits:
            split_uris = artifact_utils.get_split_uris(
                input_dict[standard_component_specs.EXAMPLES_KEY], split)
            for index in range(len(split_uris)):
              split_uri = split_uris[index]
              file_pattern = io_utils.all_files_pattern(split_uri)
              tfxio = tfxio_factory(file_pattern)
              data = (
                  pipeline
                  | f'ReadFromTFRecordToArrow[{split}][{index}]' >>
                  tfxio.BeamSource())
              examples_list.append(data)
          if schema is not None:
            # Use last tfxio as TensorRepresentations and ArrowSchema are fixed.
            tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
                arrow_schema=tfxio.ArrowSchema(),
                tensor_representations=tfxio.TensorRepresentations())
        else:
          for split in example_splits:
            split_uris = artifact_utils.get_split_uris(
                input_dict[standard_component_specs.EXAMPLES_KEY], split)
            for index in range(len(split_uris)):
              split_uri = split_uris[index]
              file_pattern = io_utils.all_files_pattern(split_uri)
              data = (
                  pipeline
                  | f'ReadFromTFRecord[{split}][{index}]' >>
                  beam.io.ReadFromTFRecord(file_pattern=file_pattern))
              examples_list.append(data)

        custom_extractors = udf_utils.try_get_fn(
            exec_properties=exec_properties, fn_name='custom_extractors')
        extractors = None
        if custom_extractors:
          extractors = custom_extractors(
              eval_shared_model=eval_shared_model,
              eval_config=eval_config,
              tensor_adapter_config=tensor_adapter_config)

        (examples_list | 'FlattenExamples' >> beam.Flatten()
         | 'ExtractEvaluateAndWriteResults' >>
         (tfma.ExtractEvaluateAndWriteResults(
             eval_shared_model=models[0] if len(models) == 1 else models,
             eval_config=eval_config,
             extractors=extractors,
             output_path=output_uri,
             slice_spec=slice_spec,
             tensor_adapter_config=tensor_adapter_config)))
    logging.info('Evaluation complete. Results written to %s.', output_uri)

    if not run_validation:
      # TODO(jinhuang): delete the BLESSING_KEY from output_dict when supported.
      logging.info('No threshold configured, will not validate model.')
      return
    # Set up blessing artifact
    blessing = artifact_utils.get_single_instance(
        output_dict[standard_component_specs.BLESSING_KEY])
    blessing.set_string_custom_property(
        constants.ARTIFACT_PROPERTY_CURRENT_MODEL_URI_KEY,
        artifact_utils.get_single_uri(
            input_dict[standard_component_specs.MODEL_KEY]))
    blessing.set_int_custom_property(
        constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY,
        input_dict[standard_component_specs.MODEL_KEY][0].id)
    if input_dict.get(standard_component_specs.BASELINE_MODEL_KEY):
      baseline_model = input_dict[
          standard_component_specs.BASELINE_MODEL_KEY][0]
      blessing.set_string_custom_property(
          constants.ARTIFACT_PROPERTY_BASELINE_MODEL_URI_KEY,
          baseline_model.uri)
      blessing.set_int_custom_property(
          constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY, baseline_model.id)
    if 'current_component_id' in exec_properties:
      blessing.set_string_custom_property(
          'component_id', exec_properties['current_component_id'])
    # Check validation result and write BLESSED file accordingly.
    logging.info('Checking validation results.')
    validation_result = tfma.load_validation_result(output_uri)
    if validation_result.validation_ok:
      io_utils.write_string_file(
          os.path.join(blessing.uri, constants.BLESSED_FILE_NAME), '')
      blessing.set_int_custom_property(constants.ARTIFACT_PROPERTY_BLESSED_KEY,
                                       constants.BLESSED_VALUE)
    else:
      io_utils.write_string_file(
          os.path.join(blessing.uri, constants.NOT_BLESSED_FILE_NAME), '')
      blessing.set_int_custom_property(constants.ARTIFACT_PROPERTY_BLESSED_KEY,
                                       constants.NOT_BLESSED_VALUE)
    logging.info('Blessing result %s written to %s.',
                 validation_result.validation_ok, blessing.uri)