def ResolveSplitsConfig( splits_config_str: Optional[str], examples: List[types.Artifact]) -> transform_pb2.SplitsConfig: """Resolve SplitsConfig proto for the transfrom request.""" result = transform_pb2.SplitsConfig() if splits_config_str: proto_utils.json_to_proto(splits_config_str, result) if not result.analyze: raise ValueError( 'analyze cannot be empty when splits_config is set.') return result result.analyze.append('train') # All input artifacts should have the same set of split names. split_names = set( artifact_utils.decode_split_names(examples[0].split_names)) for artifact in examples: artifact_split_names = set( artifact_utils.decode_split_names(artifact.split_names)) if split_names != artifact_split_names: raise ValueError( 'Not all input artifacts have the same split names: (%s, %s)' % (split_names, artifact_split_names)) result.transform.extend(split_names) logging.info("Analyze the 'train' split and transform all splits when " 'splits_config is not set.') return result
def testDoWithOutputExamplesTwoSplitsSampledOneSplitCopied(self): self._exec_properties[SPLITS_TO_TRANSFORM_KEY] = json.dumps( ['train', 'eval']) self._exec_properties[SPLITS_TO_COPY_KEY] = json.dumps(['unlabelled']) # Run executor. stratified_sampler = executor.Executor(self._context) stratified_sampler.Do(self._input_dict, self._output_dict_sr, self._exec_properties) # Check outputs. self.assertTrue(fileio.exists(self._stratified_examples_dir)) self.assertIn( 'train', artifact_utils.decode_split_names( self._sampling_result.split_names)) self.assertIn( 'eval', artifact_utils.decode_split_names( self._sampling_result.split_names)) self.assertIn( 'unlabelled', artifact_utils.decode_split_names( self._sampling_result.split_names)) self.assertLen( artifact_utils.decode_split_names( self._sampling_result.split_names), 3) self._verify_stratified_example_split('train') self._verify_copied_example_split('unlabelled') self._verify_stratified_example_split('eval')
def _run_model_inference( self, data_spec: bulk_inferrer_pb2.DataSpec, examples: List[types.Artifact], output_uri: Text, inference_endpoint: model_spec_pb2.InferenceSpecType) -> bool: """Runs model inference on given example data. Args: data_spec: bulk_inferrer_pb2.DataSpec instance. examples: List of example artifacts. output_uri: Output artifact uri. inference_endpoint: Model inference endpoint. Returns: Whether the inference job succeed. """ example_uris = {} if data_spec.example_splits: for example in examples: for split in artifact_utils.decode_split_names( example.split_names): if split in data_spec.example_splits: example_uris[split] = os.path.join(example.uri, split) else: for example in examples: for split in artifact_utils.decode_split_names( example.split_names): example_uris[split] = os.path.join(example.uri, split) output_path = os.path.join(output_uri, _PREDICTION_LOGS_DIR_NAME) logging.info('BulkInferrer generates prediction log to %s', output_path) with self._make_beam_pipeline() as pipeline: data_list = [] for split, example_uri in example_uris.items(): data = ( pipeline | 'ReadData[{}]'.format(split) >> beam.io.ReadFromTFRecord( file_pattern=io_utils.all_files_pattern(example_uri))) data_list.append(data) _ = ( data_list | 'FlattenExamples' >> beam.Flatten(pipeline=pipeline) # TODO(b/131873699): Use the correct Example type here, which # is either Example or SequenceExample. | 'ParseExamples' >> beam.Map(tf.train.Example.FromString) | 'RunInference' >> run_inference.RunInference(inference_endpoint) | 'WritePredictionLogs' >> beam.io.WriteToTFRecord( output_path, file_name_suffix='.gz', coder=beam.coders.ProtoCoder( prediction_log_pb2.PredictionLog))) logging.info('Inference result written to %s.', output_path)
def testDoWithOutputExamplesAllSplits(self): self._exec_properties[SPLITS_TO_TRANSFORM_KEY] = json.dumps(['eval', 'train']) # Run executor. stratified_sampler = executor.Executor(self._context) stratified_sampler.Do(self._input_dict, self._output_dict_sr, self._exec_properties) # Check outputs. self.assertTrue(fileio.exists(self._filtered_examples_dir)) self.assertIn('train', artifact_utils.decode_split_names(self._filtering_result.split_names)) self.assertIn('eval', artifact_utils.decode_split_names(self._filtering_result.split_names)) self.assertLen(artifact_utils.decode_split_names(self._filtering_result.split_names), 2) self._verify_filtered_example_split('train') self._verify_filtered_example_split('eval')
def testDoWithOutputExamplesEvalSplit(self): self._exec_properties['splits_to_transform'] = json.dumps(['eval']) # Run executor. stratified_sampler = executor.Executor(self._context) stratified_sampler.Do(self._input_dict, self._output_dict_sr, self._exec_properties) # Check outputs. self.assertTrue(fileio.exists(self._filtered_examples_dir)) # self._verify_example_split('train') self.assertNotIn('train', artifact_utils.decode_split_names(self._filtering_result.split_names)) self.assertIn('eval', artifact_utils.decode_split_names(self._filtering_result.split_names)) self.assertLen(artifact_utils.decode_split_names(self._filtering_result.split_names), 1) self._verify_filtered_example_split('eval')
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: self._log_startup(input_dict, output_dict, exec_properties) logging.info('Validating schema against the computed statistics.') split_uris: List[Text] = [] for artifact in input_dict[executor.STATISTICS_KEY]: for split in artifact_utils.decode_split_names( artifact.split_names): split_uris.append(split) label_inputs = { labels.STATS: tfdv.load_statistics( io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri( input_dict[executor.STATISTICS_KEY], split_uris[0]))), labels.SCHEMA: io_utils.SchemaReader().read( io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri( input_dict[executor.SCHEMA_KEY]))) } output_uri = artifact_utils.get_single_uri( output_dict[executor.ANOMALIES_KEY]) label_outputs = {labels.SCHEMA_DIFF_PATH: output_uri} self._Validate(label_inputs, label_outputs) logging.info( 'Validation complete. Anomalies written to {}.'.format(output_uri))
def SynthesizeGraph(identified_examples: InputArtifact[Examples], synthesized_graph: OutputArtifact[SynthesizedGraph], similarity_threshold: Parameter[float], component_name: Parameter[str]) -> None: # Get a list of the splits in input_data splits_list = artifact_utils.decode_split_names( split_names=identified_examples.split_names) # We build a graph only based on the 'train' split which includes both # labeled and unlabeled examples. train_input_examples_uri = os.path.join(identified_examples.uri, 'train') print(train_input_examples_uri, 'These') output_graph_uri = os.path.join(synthesized_graph.uri, 'train') os.mkdir(output_graph_uri) print('Creating embeddings...') create_embeddings(train_input_examples_uri, output_graph_uri) print('Synthesizing graph...') build_graph(output_graph_uri, similarity_threshold) synthesized_graph.split_names = artifact_utils.encode_split_names( splits=['train']) return
def IdentifyExamples(orig_examples: InputArtifact[Examples], identified_examples: OutputArtifact[Examples], id_feature_name: Parameter[str], component_name: Parameter[str]) -> None: # Get a list of the splits in input_data splits_list = artifact_utils.decode_split_names( split_names=orig_examples.split_names) for split in splits_list: input_dir = os.path.join(orig_examples.uri, split) output_dir = os.path.join(identified_examples.uri, split) os.mkdir(output_dir) with beam.Pipeline() as pipeline: (pipeline | 'ReadExamples' >> beam.io.ReadFromTFRecord( os.path.join(input_dir, '*'), coder=beam.coders.coders.ProtoCoder(tf.train.Example)) | 'AddUniqueId' >> beam.Map(make_example_with_unique_id, id_feature_name) | 'WriteIdentifiedExamples' >> beam.io.WriteToTFRecord( file_path_prefix=os.path.join(output_dir, 'data_tfrecord'), coder=beam.coders.coders.ProtoCoder(tf.train.Example), file_name_suffix='.gz')) identified_examples.split_names = artifact_utils.encode_split_names( splits=splits_list) return
def _prepare_output_paths(artifact: types.Artifact): """Create output directories for output artifact.""" if tf.io.gfile.exists(artifact.uri): msg = 'Output artifact uri %s already exists' % artifact.uri absl.logging.warning(msg) # TODO(b/158689199): We currently simply return as a short-term workaround # to unblock execution retires. A comprehensive solution to guarantee # idempotent executions is needed. return # TODO(b/147242148): Introduce principled artifact structure (directory # or file) definition. if isinstance(artifact, types.ValueArtifact): artifact_dir = os.path.dirname(artifact.uri) else: artifact_dir = artifact.uri # TODO(zhitaoli): Consider refactoring this out into something # which can handle permission bits. absl.logging.debug('Creating output artifact uri %s as directory', artifact_dir) tf.io.gfile.makedirs(artifact_dir) # TODO(b/147242148): Avoid special-casing the "split_names" property. if artifact.type.PROPERTIES and 'split_names' in artifact.type.PROPERTIES: split_names = artifact_utils.decode_split_names(artifact.split_names) for split in split_names: split_dir = os.path.join(artifact.uri, split) absl.logging.debug('Creating output split %s as directory', split_dir) tf.io.gfile.makedirs(split_dir)
def _callImporterDriver(self, reimport: bool): with metadata.Metadata(connection_config=self.connection_config) as m: m.publish_artifacts(self.existing_artifacts) driver = importer_node.ImporterDriver(metadata_handler=m) execution_result = driver.pre_execution( component_info=self.component_info, pipeline_info=self.pipeline_info, driver_args=self.driver_args, input_dict={}, output_dict=self.output_dict, exec_properties={ importer_node.SOURCE_URI_KEY: self.source_uri, importer_node.REIMPORT_OPTION_KEY: reimport, importer_node.SPLIT_KEY: self.split, }) self.assertFalse(execution_result.use_cached_results) self.assertEmpty(execution_result.input_dict) self.assertEqual( execution_result.output_dict[importer_node.IMPORT_RESULT_KEY] [0].uri, self.source_uri[0]) self.assertEqual( execution_result.output_dict[importer_node.IMPORT_RESULT_KEY] [0].id, 3 if reimport else 1) self.assertNotEmpty( self.output_dict[importer_node.IMPORT_RESULT_KEY].get()) results = self.output_dict[importer_node.IMPORT_RESULT_KEY].get() for res, uri, split in zip(results, self.source_uri, self.split): self.assertEqual(res.uri, uri) self.assertEqual( artifact_utils.decode_split_names(res.split_names)[0], split)
def ReadExamplesArtifact(self, examples: types.Artifact, num_examples: int, split_name: Optional[Text] = None): """Read records from Examples artifact. Currently it assumes Examples artifact contains serialized tf.Example in gzipped TFRecord files. Args: examples: `Examples` artifact. num_examples: Number of examples to read. If the specified value is larger than the actual number of examples, all examples would be read. split_name: Name of the split to read from the Examples artifact. Raises: RuntimeError: If read twice. """ if self._records: raise RuntimeError('Cannot read records twice.') if num_examples < 1: raise ValueError('num_examples < 1 (got {})'.format(num_examples)) available_splits = artifact_utils.decode_split_names( examples.split_names) if not available_splits: raise ValueError( 'No split_name is available in given Examples artifact.') if split_name is None: split_name = available_splits[0] if split_name not in available_splits: raise ValueError( 'No split_name {}; available split names: {}'.format( split_name, ', '.join(available_splits))) # ExampleGen generates artifacts under each split_name directory. glob_pattern = os.path.join(examples.uri, split_name, '*') tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact( examples=[examples], telemetry_descriptors=_TELEMETRY_DESCRIPTORS, schema=None, read_as_raw_records=True, raw_record_column_name=_RAW_RECORDS_COLUMN) try: filenames = fileio.glob(glob_pattern) except tf.errors.NotFoundError: filenames = [] if not filenames: raise ValueError( 'Unable to find examples matching {}.'.format(glob_pattern)) self._payload_format = examples_utils.get_payload_format(examples) tfxio = tfxio_factory(filenames) self._ReadFromDataset( tfxio.TensorFlowDataset( dataset_options.TensorFlowDatasetOptions( batch_size=num_examples)))
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) stats_path = os.path.join(artifact.uri, split, 'stats_tfrecord') stats = tfdv.load_statistics(stats_path) tfdv.visualize_statistics(stats)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """TensorFlow SchemaGen executor entrypoint. This infers the schema using tensorflow_data_validation on the precomputed stats of 'train' split. Args: input_dict: Input dict from input key to a list of artifacts, including: - 'stats': A list of 'ExampleStatistics' type which must contain split 'train'. Stats on other splits are ignored. - 'statistics': Synonym for 'stats'. output_dict: Output dict from key to a list of artifacts, including: - output: A list of 'Schema' artifact of size one. exec_properties: A dict of execution properties, includes: - infer_feature_shape: Whether or not to infer the shape of the feature. - exclude_splits: Names of splits that will not be taken into consideration when auto-generating a schema. Returns: None """ # TODO(zhitaoli): Move constants between this file and component.py to a # constants.py. infer_feature_shape = exec_properties.get(INFER_FEATURE_SHAPE_KEY) # Load and deserialize exclude splits from execution properties. exclude_splits = json_utils.loads( exec_properties.get(EXCLUDE_SPLITS_KEY, 'null')) or [] if not isinstance(exclude_splits, list): raise ValueError('exclude_splits in execution properties needs to be a ' 'list. Got %s instead.' % type(exclude_splits)) # Only one schema is generated for all splits. schema = None stats_artifact = artifact_utils.get_single_instance( input_dict[STATISTICS_KEY]) for split in artifact_utils.decode_split_names(stats_artifact.split_names): if split in exclude_splits: continue logging.info('Processing schema from statistics for split %s.', split) stats_uri = io_utils.get_only_uri_in_dir( os.path.join(stats_artifact.uri, split)) if not schema: schema = tfdv.infer_schema( tfdv.load_statistics(stats_uri), infer_feature_shape) else: schema = tfdv.update_schema(schema, tfdv.load_statistics(stats_uri), infer_feature_shape) output_uri = os.path.join( artifact_utils.get_single_uri(output_dict[SCHEMA_KEY]), _DEFAULT_FILE_NAME) io_utils.write_pbtxt_file(output_uri, schema) logging.info('Schema written to %s.', output_uri)
def testConstruct(self): big_query_example_gen = component.BigQueryExampleGen(query='query') self.assertEqual(standard_artifacts.Examples.TYPE_NAME, big_query_example_gen.outputs['examples'].type_name) artifact_collection = big_query_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) anomalies_path = os.path.join(artifact.uri, split, 'anomalies.pbtxt') anomalies = tfdv.load_anomalies_text(anomalies_path) tfdv.display_anomalies(anomalies)
def testConstruct(self): import_example_gen = component.ImportExampleGen(input_base='path') self.assertEqual(standard_artifacts.Examples.TYPE_NAME, import_example_gen.outputs['examples'].type_name) artifact_collection = import_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def __init__(self, statistics: types.Channel = None, schema: types.Channel = None, exclude_splits: Optional[List[Text]] = None, output: Optional[types.Channel] = None, stats: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct an ExampleValidator component. Args: statistics: A Channel of type `standard_artifacts.ExampleStatistics`. This should contain at least 'eval' split. Other splits are currently ignored. schema: A Channel of type `standard_artifacts.Schema`. _required_ exclude_splits: Names of splits that the example validator should not validate. Default behavior (when exclude_splits is set to None) is excluding no splits. output: Output channel of type `standard_artifacts.ExampleAnomalies`. stats: Backwards compatibility alias for the 'statistics' argument. instance_name: Optional name assigned to this specific instance of ExampleValidator. Required only if multiple ExampleValidator components are declared in the same pipeline. Either `stats` or `statistics` must be present in the arguments. """ if stats: logging.warning( 'The "stats" argument to the StatisticsGen component has ' 'been renamed to "statistics" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') statistics = stats if exclude_splits is None: exclude_splits = [] logging.info('Excluding no splits because exclude_splits is not set.') anomalies = output if not anomalies: anomalies_artifact = standard_artifacts.ExampleAnomalies() statistics_split_names = artifact_utils.decode_split_names( artifact_utils.get_single_instance(list( statistics.get())).split_names) split_names = [ split for split in statistics_split_names if split not in exclude_splits ] anomalies_artifact.split_names = artifact_utils.encode_split_names( split_names) anomalies = types.Channel( type=standard_artifacts.ExampleAnomalies, artifacts=[anomalies_artifact]) spec = ExampleValidatorSpec( statistics=statistics, schema=schema, exclude_splits=json_utils.dumps(exclude_splits), anomalies=anomalies) super(ExampleValidator, self).__init__( spec=spec, instance_name=instance_name)
def testConstruct(self): input_base = standard_artifacts.ExternalArtifact() csv_example_gen = component.CsvExampleGen( input=channel_utils.as_channel([input_base])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, csv_example_gen.outputs['examples'].type_name) artifact_collection = csv_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstructSubclassFileBased(self): example_gen = TestFileBasedExampleGenComponent(input_base='path') self.assertIn('input_base', example_gen.exec_properties) self.assertEqual(driver.Driver, example_gen.driver_class) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) self.assertIsNone(example_gen.exec_properties.get('custom_config')) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstruct(self): big_query_elwc_example_gen = component.BigQueryElwcExampleGen( elwc_config=example_gen_pb2.ElwcConfig( context_feature_fields=['query_id', 'query_content']), query='query', ) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, big_query_elwc_example_gen.outputs['examples'].type_name) artifact_collection = big_query_elwc_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def ReadExamplesArtifact(self, examples: types.Artifact, num_examples: int, split_name: Optional[Text] = None): """Read records from Examples artifact. Currently it assumes Examples artifact contains serialized tf.Example in gzipped TFRecord files. Args: examples: `Examples` artifact. num_examples: Number of examples to read. If the specified value is larger than the actual number of examples, all examples would be read. split_name: Name of the split to read from the Examples artifact. Raises: RuntimeError: If read twice. """ if self._records: raise RuntimeError('Cannot read records twice.') if num_examples < 1: raise ValueError('num_examples < 1 (got {})'.format(num_examples)) available_splits = artifact_utils.decode_split_names( examples.split_names) if not available_splits: raise ValueError( 'No split_name is available in given Examples artifact.') if split_name is None: split_name = available_splits[0] if split_name not in available_splits: raise ValueError( 'No split_name {}; available split names: {}'.format( split_name, ', '.join(available_splits))) # ExampleGen generates artifacts under each split_name directory. glob_pattern = os.path.join(examples.uri, split_name, '*.gz') try: filenames = tf.io.gfile.glob(glob_pattern) except tf.errors.NotFoundError: filenames = [] if not filenames: raise ValueError( 'Unable to find examples matching {}.'.format(glob_pattern)) # Assume we have a tf.Example logical format. self._record_format = _LogicalFormat.TF_EXAMPLE self._ReadFromDataset(tf.data.TFRecordDataset(filenames, compression_type='GZIP'), num_examples=num_examples)
def testConstructCustomExecutor(self): example_gen = component.FileBasedExampleGen( input_base='path', custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) self.assertEqual(driver.Driver, example_gen.driver_class) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) stats_path = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([artifact], split)) if artifact_utils.is_artifact_version_older_than( artifact, artifact_utils._ARTIFACT_VERSION_FOR_STATS_UPDATE): # pylint: disable=protected-access stats = tfdv.load_statistics(stats_path) else: stats = tfdv.load_stats_binary(stats_path) tfdv.visualize_statistics(stats)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: source = exec_properties[StepKeys.SOURCE] args = exec_properties[StepKeys.ARGS] c = source_utils.load_source_path_class(source) tokenizer_step: BaseTokenizer = c(**args) tokenizer_location = artifact_utils.get_single_uri( output_dict["tokenizer"]) split_uris, split_names, all_files = [], [], [] for artifact in input_dict["examples"]: for split in artifact_utils.decode_split_names( artifact.split_names): split_names.append(split) uri = os.path.join(artifact.uri, split) split_uris.append((split, uri)) all_files += path_utils.list_dir(uri) # Get output split path output_examples = artifact_utils.get_single_instance( output_dict["output_examples"]) output_examples.split_names = artifact_utils.encode_split_names( split_names) if not tokenizer_step.skip_training: tokenizer_step.train(files=all_files) tokenizer_step.save(output_dir=tokenizer_location) with self._make_beam_pipeline() as p: for split, uri in split_uris: input_uri = io_utils.all_files_pattern(uri) _ = (p | 'ReadData.' + split >> beam.io.ReadFromTFRecord( file_pattern=input_uri) | "ParseTFExFromString." + split >> beam.Map( tf.train.Example.FromString) | "AddTokens." + split >> beam.Map( append_tf_example, tokenizer_step=tokenizer_step) | 'Serialize.' + split >> beam.Map( lambda x: x.SerializeToString()) | 'WriteSplit.' + split >> WriteSplit( get_split_uri( output_dict["output_examples"], split)))
def testConstruct(self): presto_example_gen = component.PrestoExampleGen( self.conn_config, query='query') self.assertEqual( self.conn_config, self._extract_conn_config( presto_example_gen.exec_properties['custom_config'])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, presto_example_gen.outputs['examples'].type_name) artifact_collection = presto_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstructWithInputConfig(self): big_query_example_gen = component.BigQueryExampleGen( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='query1'), example_gen_pb2.Input.Split(name='eval', pattern='query2'), example_gen_pb2.Input.Split(name='test', pattern='query3') ])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, big_query_example_gen.outputs['examples'].type_name) artifact_collection = big_query_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Computes stats for each split of input using tensorflow_data_validation. Args: input_dict: Input dict from input key to a list of Artifacts. - input_data: A list of type `standard_artifacts.Examples`. This should contain both 'train' and 'eval' split. output_dict: Output dict from output key to a list of Artifacts. - output: A list of type `standard_artifacts.ExampleStatistics`. This should contain both the 'train' and 'eval' splits. exec_properties: A dict of execution properties. Not used yet. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) split_uris = [] for artifact in input_dict['input_data']: for split in artifact_utils.decode_split_names( artifact.split_names): uri = os.path.join(artifact.uri, split) split_uris.append((split, uri)) with self._make_beam_pipeline() as p: # TODO(b/126263006): Support more stats_options through config. stats_options = options.StatsOptions() for split, uri in split_uris: absl.logging.info( 'Generating statistics for split {}'.format(split)) input_uri = io_utils.all_files_pattern(uri) output_uri = artifact_utils.get_split_uri( output_dict['output'], split) output_path = os.path.join(output_uri, _DEFAULT_FILE_NAME) _ = (p | 'ReadData.' + split >> beam.io.ReadFromTFRecord(file_pattern=input_uri) | 'DecodeData.' + split >> tf_example_decoder.DecodeTFExample() | 'GenerateStatistics.' + split >> stats_api.GenerateStatistics(stats_options) | 'WriteStatsOutput.' + split >> beam.io.WriteToTFRecord( output_path, shard_name_template='', coder=beam.coders.ProtoCoder( statistics_pb2.DatasetFeatureStatisticsList))) absl.logging.info( 'Statistics for split {} written to {}.'.format( split, output_uri))
def testConstructSubclassQueryBased(self): example_gen = TestQueryBasedExampleGenComponent( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='single', pattern='query'), ])) self.assertEqual({}, example_gen.inputs.get_all()) self.assertEqual(base_driver.BaseDriver, example_gen.driver_class) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) self.assertIsNone(example_gen.exec_properties.get('custom_config')) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) anomalies_path = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([artifact], split)) if artifact_utils.is_artifact_version_older_than( artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE): # pylint: disable=protected-access anomalies = tfdv.load_anomalies_text(anomalies_path) else: anomalies = anomalies_pb2.Anomalies() anomalies_bytes = io_utils.read_bytes_file(anomalies_path) anomalies.ParseFromString(anomalies_bytes) tfdv.display_anomalies(anomalies)
def testConstructWithInputConfig(self): example_gen = TestFileBasedExampleGenComponent( input_base='path', input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))