コード例 #1
0
def is_old_model_artifact(model_artifact: artifact.Artifact) -> bool:
  """Check whether the model artifact is generated by old TFX version."""
  assert model_artifact.type == standard_artifacts.Model, ('Wrong artifact '
                                                           'type, only accept '
                                                           'Model.')
  return artifact_utils.is_artifact_version_older_than(
      model_artifact, artifact_utils._ARTIFACT_VERSION_FOR_MODEL_UPDATE)  # pylint: disable=protected-access
コード例 #2
0
ファイル: path_utils.py プロジェクト: suryaavala/tfx
def is_old_model_artifact(model_artifact: artifact.Artifact) -> bool:
    """Check whether the model artifact is generated by old TFX version."""
    if model_artifact.type != standard_artifacts.Model:
        absl.logging.warning(
            f'Artifact type is not Model: {model_artifact.type}.')
    return artifact_utils.is_artifact_version_older_than(
        model_artifact, artifact_utils._ARTIFACT_VERSION_FOR_MODEL_UPDATE)  # pylint: disable=protected-access
コード例 #3
0
 def display(self, artifact: types.Artifact):
   from IPython.core.display import display  # pylint: disable=g-import-not-at-top
   from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
   for split in artifact_utils.decode_split_names(artifact.split_names):
     display(HTML('<div><b>%r split:</b></div><br/>' % split))
     stats_path = io_utils.get_only_uri_in_dir(
         artifact_utils.get_split_uri([artifact], split))
     if artifact_utils.is_artifact_version_older_than(
         artifact, artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE):  # pylint: disable=protected-access
       stats = tfdv.load_statistics(stats_path)
     else:
       stats = tfdv.load_stats_binary(stats_path)
     tfdv.visualize_statistics(stats)
コード例 #4
0
 def display(self, artifact: types.Artifact):
   from IPython.core.display import display  # pylint: disable=g-import-not-at-top
   from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
   for split in artifact_utils.decode_split_names(artifact.split_names):
     display(HTML('<div><b>%r split:</b></div><br/>' % split))
     anomalies_path = io_utils.get_only_uri_in_dir(
         artifact_utils.get_split_uri([artifact], split))
     if artifact_utils.is_artifact_version_older_than(
         artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE):  # pylint: disable=protected-access
       anomalies = tfdv.load_anomalies_text(anomalies_path)
     else:
       anomalies = anomalies_pb2.Anomalies()
       anomalies_bytes = io_utils.read_bytes_file(anomalies_path)
       anomalies.ParseFromString(anomalies_bytes)
     tfdv.display_anomalies(anomalies)
コード例 #5
0
 def testIsArtifactVersionOlderThan(self):
     examples = standard_artifacts.Examples()
     self.assertFalse(
         artifact_utils.is_artifact_version_older_than(examples, '0.1'))
     examples.mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE
     self.assertTrue(
         artifact_utils.is_artifact_version_older_than(examples, '0.1'))
     examples.set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.2')
     self.assertTrue(
         artifact_utils.is_artifact_version_older_than(examples, '0.10'))
     self.assertTrue(
         artifact_utils.is_artifact_version_older_than(examples, '0.3'))
     self.assertFalse(
         artifact_utils.is_artifact_version_older_than(examples, '0.2'))
     self.assertFalse(
         artifact_utils.is_artifact_version_older_than(examples, '0.1'))
コード例 #6
0
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]) -> None:
        """TensorFlow ExampleValidator executor entrypoint.

    This validates statistics against the schema.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - statistics: A list of type `standard_artifacts.ExampleStatistics`
          generated by StatisticsGen.
        - schema: A list of type `standard_artifacts.Schema` which should
          contain a single schema artifact.
      output_dict: Output dict from key to a list of artifacts, including:
        - output: A list of 'standard_artifacts.ExampleAnomalies' of size one.
          It will include a single binary proto file which contains all
          anomalies found.
      exec_properties: A dict of execution properties.
        - exclude_splits: JSON-serialized list of names of splits that the
          example validator should not validate.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        # Load and deserialize exclude splits from execution properties.
        exclude_splits = json_utils.loads(
            exec_properties.get(standard_component_specs.EXCLUDE_SPLITS_KEY,
                                'null')) or []
        if not isinstance(exclude_splits, list):
            raise ValueError(
                'exclude_splits in execution properties needs to be a '
                'list. Got %s instead.' % type(exclude_splits))
        # Setup output splits.
        stats_artifact = artifact_utils.get_single_instance(
            input_dict[standard_component_specs.STATISTICS_KEY])
        stats_split_names = artifact_utils.decode_split_names(
            stats_artifact.split_names)
        split_names = [
            split for split in stats_split_names if split not in exclude_splits
        ]
        anomalies_artifact = artifact_utils.get_single_instance(
            output_dict[standard_component_specs.ANOMALIES_KEY])
        anomalies_artifact.split_names = artifact_utils.encode_split_names(
            split_names)

        schema = io_utils.SchemaReader().read(
            io_utils.get_only_uri_in_dir(
                artifact_utils.get_single_uri(
                    input_dict[standard_component_specs.SCHEMA_KEY])))

        for split in artifact_utils.decode_split_names(
                stats_artifact.split_names):
            if split in exclude_splits:
                continue

            logging.info(
                'Validating schema against the computed statistics for '
                'split %s.', split)
            stats_uri = io_utils.get_only_uri_in_dir(
                artifact_utils.get_split_uri([stats_artifact], split))
            if artifact_utils.is_artifact_version_older_than(
                    stats_artifact,
                    artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE):  # pylint: disable=protected-access
                stats = tfdv.load_statistics(stats_uri)
            else:
                stats = tfdv.load_stats_binary(stats_uri)
            label_inputs = {
                standard_component_specs.STATISTICS_KEY: stats,
                standard_component_specs.SCHEMA_KEY: schema
            }
            output_uri = artifact_utils.get_split_uri(
                output_dict[standard_component_specs.ANOMALIES_KEY], split)
            label_outputs = {labels.SCHEMA_DIFF_PATH: output_uri}
            self._Validate(label_inputs, label_outputs)
            logging.info(
                'Validation complete for split %s. Anomalies written to '
                '%s.', split, output_uri)
コード例 #7
0
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]) -> None:
        """TensorFlow SchemaGen executor entrypoint.

    This infers the schema using tensorflow_data_validation on the precomputed
    stats of 'train' split.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - 'statistics': A list of 'ExampleStatistics' type which must contain
          split 'train'.
      output_dict: Output dict from key to a list of artifacts, including:
        - schema: A list of 'Schema' artifact of size one.
      exec_properties: A dict of execution properties, includes:
        - infer_feature_shape: Whether or not to infer the shape of the feature.
        - exclude_splits: Names of splits that will not be taken into
          consideration when auto-generating a schema.

    Returns:
      None
    """
        infer_feature_shape = bool(
            exec_properties.get(
                standard_component_specs.INFER_FEATURE_SHAPE_KEY, True))

        # Load and deserialize exclude splits from execution properties.
        exclude_splits = json_utils.loads(
            exec_properties.get(standard_component_specs.EXCLUDE_SPLITS_KEY,
                                'null')) or []
        if not isinstance(exclude_splits, list):
            raise ValueError(
                'exclude_splits in execution properties needs to be a '
                'list. Got %s instead.' % type(exclude_splits))

        # Only one schema is generated for all splits.
        schema = None
        stats_artifact = artifact_utils.get_single_instance(
            input_dict[standard_component_specs.STATISTICS_KEY])
        for split in artifact_utils.decode_split_names(
                stats_artifact.split_names):
            if split in exclude_splits:
                continue

            logging.info('Processing schema from statistics for split %s.',
                         split)
            stats_uri = io_utils.get_only_uri_in_dir(
                artifact_utils.get_split_uri([stats_artifact], split))
            if artifact_utils.is_artifact_version_older_than(
                    stats_artifact,
                    artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE):  # pylint: disable=protected-access
                stats = tfdv.load_statistics(stats_uri)
            else:
                stats = tfdv.load_stats_binary(stats_uri)
            if not schema:
                schema = tfdv.infer_schema(stats, infer_feature_shape)
            else:
                schema = tfdv.update_schema(schema, stats, infer_feature_shape)

        output_uri = os.path.join(
            artifact_utils.get_single_uri(
                output_dict[standard_component_specs.SCHEMA_KEY]),
            DEFAULT_FILE_NAME)
        io_utils.write_pbtxt_file(output_uri, schema)
        logging.info('Schema written to %s.', output_uri)