def _validate_stats_output(self, stats_path): self.assertTrue(fileio.exists(stats_path)) stats = tfdv.load_stats_binary(stats_path) self.assertLen(stats.datasets, 1) data_set = stats.datasets[0] self.assertGreater(data_set.num_examples, 0) self.assertNotEmpty(data_set.features)
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) stats_path = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([artifact], split)) if artifact_utils.is_artifact_version_older_than( artifact, artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE): # pylint: disable=protected-access stats = tfdv.load_statistics(stats_path) else: stats = tfdv.load_stats_binary(stats_path) tfdv.visualize_statistics(stats)
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> None: """TensorFlow ExampleValidator executor entrypoint. This validates statistics against the schema. Args: input_dict: Input dict from input key to a list of artifacts, including: - statistics: A list of type `standard_artifacts.ExampleStatistics` generated by StatisticsGen. - schema: A list of type `standard_artifacts.Schema` which should contain a single schema artifact. output_dict: Output dict from key to a list of artifacts, including: - output: A list of 'standard_artifacts.ExampleAnomalies' of size one. It will include a single binary proto file which contains all anomalies found. exec_properties: A dict of execution properties. - exclude_splits: JSON-serialized list of names of splits that the example validator should not validate. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) # Load and deserialize exclude splits from execution properties. exclude_splits = json_utils.loads( exec_properties.get(standard_component_specs.EXCLUDE_SPLITS_KEY, 'null')) or [] if not isinstance(exclude_splits, list): raise ValueError( 'exclude_splits in execution properties needs to be a ' 'list. Got %s instead.' % type(exclude_splits)) # Setup output splits. stats_artifact = artifact_utils.get_single_instance( input_dict[standard_component_specs.STATISTICS_KEY]) stats_split_names = artifact_utils.decode_split_names( stats_artifact.split_names) split_names = [ split for split in stats_split_names if split not in exclude_splits ] anomalies_artifact = artifact_utils.get_single_instance( output_dict[standard_component_specs.ANOMALIES_KEY]) anomalies_artifact.split_names = artifact_utils.encode_split_names( split_names) schema = io_utils.SchemaReader().read( io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri( input_dict[standard_component_specs.SCHEMA_KEY]))) for split in artifact_utils.decode_split_names( stats_artifact.split_names): if split in exclude_splits: continue logging.info( 'Validating schema against the computed statistics for ' 'split %s.', split) stats_uri = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([stats_artifact], split)) if artifact_utils.is_artifact_version_older_than( stats_artifact, artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE): # pylint: disable=protected-access stats = tfdv.load_statistics(stats_uri) else: stats = tfdv.load_stats_binary(stats_uri) label_inputs = { standard_component_specs.STATISTICS_KEY: stats, standard_component_specs.SCHEMA_KEY: schema } output_uri = artifact_utils.get_split_uri( output_dict[standard_component_specs.ANOMALIES_KEY], split) label_outputs = {labels.SCHEMA_DIFF_PATH: output_uri} self._Validate(label_inputs, label_outputs) logging.info( 'Validation complete for split %s. Anomalies written to ' '%s.', split, output_uri)
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> None: """TensorFlow SchemaGen executor entrypoint. This infers the schema using tensorflow_data_validation on the precomputed stats of 'train' split. Args: input_dict: Input dict from input key to a list of artifacts, including: - 'statistics': A list of 'ExampleStatistics' type which must contain split 'train'. output_dict: Output dict from key to a list of artifacts, including: - schema: A list of 'Schema' artifact of size one. exec_properties: A dict of execution properties, includes: - infer_feature_shape: Whether or not to infer the shape of the feature. - exclude_splits: Names of splits that will not be taken into consideration when auto-generating a schema. Returns: None """ infer_feature_shape = bool( exec_properties.get( standard_component_specs.INFER_FEATURE_SHAPE_KEY, True)) # Load and deserialize exclude splits from execution properties. exclude_splits = json_utils.loads( exec_properties.get(standard_component_specs.EXCLUDE_SPLITS_KEY, 'null')) or [] if not isinstance(exclude_splits, list): raise ValueError( 'exclude_splits in execution properties needs to be a ' 'list. Got %s instead.' % type(exclude_splits)) # Only one schema is generated for all splits. schema = None stats_artifact = artifact_utils.get_single_instance( input_dict[standard_component_specs.STATISTICS_KEY]) for split in artifact_utils.decode_split_names( stats_artifact.split_names): if split in exclude_splits: continue logging.info('Processing schema from statistics for split %s.', split) stats_uri = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([stats_artifact], split)) if artifact_utils.is_artifact_version_older_than( stats_artifact, artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE): # pylint: disable=protected-access stats = tfdv.load_statistics(stats_uri) else: stats = tfdv.load_stats_binary(stats_uri) if not schema: schema = tfdv.infer_schema(stats, infer_feature_shape) else: schema = tfdv.update_schema(schema, stats, infer_feature_shape) output_uri = os.path.join( artifact_utils.get_single_uri( output_dict[standard_component_specs.SCHEMA_KEY]), DEFAULT_FILE_NAME) io_utils.write_pbtxt_file(output_uri, schema) logging.info('Schema written to %s.', output_uri)