def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.io.gfile.makedirs(output_data_dir) # Create input dict. train_examples = standard_artifacts.Examples(split='train') train_examples.uri = os.path.join(source_data_dir, 'csv_example_gen/train/') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(source_data_dir, 'csv_example_gen/eval/') train_stats = standard_artifacts.ExampleStatistics(split='train') train_stats.uri = os.path.join(output_data_dir, 'train', '') eval_stats = standard_artifacts.ExampleStatistics(split='eval') eval_stats.uri = os.path.join(output_data_dir, 'eval', '') input_dict = { 'input_data': [train_examples, eval_examples], } output_dict = { 'output': [train_stats, eval_stats], } # Run executor. evaluator = executor.Executor() evaluator.Do(input_dict, output_dict, exec_properties={}) # Check statistics_gen outputs. self._validate_stats_output(os.path.join(train_stats.uri, 'stats_tfrecord')) self._validate_stats_output(os.path.join(eval_stats.uri, 'stats_tfrecord'))
def testConstruct(self): schema_gen = component.SchemaGen( stats=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')]), infer_feature_shape=True) self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name) self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
def testConstruct(self): schema_gen = component.SchemaGen(statistics=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')])) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertFalse( schema_gen.spec.exec_properties['infer_feature_shape'])
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.io.gfile.makedirs(output_data_dir) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) stats = standard_artifacts.ExampleStatistics() stats.uri = output_data_dir stats.split_names = artifact_utils.encode_split_names( ['train', 'eval']) input_dict = { executor.EXAMPLES_KEY: [examples], } output_dict = { executor.STATISTICS_KEY: [stats], } # Run executor. stats_gen_executor = executor.Executor() stats_gen_executor.Do(input_dict, output_dict, exec_properties={}) # Check statistics_gen outputs. self._validate_stats_output( os.path.join(stats.uri, 'train', 'stats_tfrecord')) self._validate_stats_output( os.path.join(stats.uri, 'eval', 'stats_tfrecord'))
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.uri = os.path.join(source_data_dir, 'statistics_gen') statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) schema_output = standard_artifacts.Schema() schema_output.uri = os.path.join(output_data_dir, 'schema_output') input_dict = { 'stats': [statistics_artifact], } output_dict = { 'output': [schema_output], } exec_properties = {'infer_feature_shape': False} schema_gen_executor = executor.Executor() schema_gen_executor.Do(input_dict, output_dict, exec_properties) self.assertNotEqual(0, len(tf.io.gfile.listdir(schema_output.uri)))
def __init__(self, input_data: types.Channel = None, output: Optional[types.Channel] = None, examples: Optional[types.Channel] = None, name: Optional[Text] = None): """Construct a StatisticsGen component. Args: input_data: A Channel of 'ExamplesPath' type. This should contain two splits 'train' and 'eval' (required). output: Optional 'ExampleStatisticsPath' channel for statistics of each split provided in input examples. examples: Forwards compatibility alias for the 'input_data' argument. name: Optional unique name. Necessary iff multiple StatisticsGen components are declared in the same pipeline. """ input_data = input_data or examples output = output or types.Channel( type=standard_artifacts.ExampleStatistics, artifacts=[ standard_artifacts.ExampleStatistics(split=split) for split in artifact.DEFAULT_EXAMPLE_SPLITS ]) spec = StatisticsGenSpec( input_data=input_data, output=output) super(StatisticsGen, self).__init__(spec=spec, name=name)
def __init__(self, examples: types.Channel = None, output: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a StatisticsGen component. Args: examples: A Channel of `ExamplesPath` type, likely generated by the [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen). This needs to contain two splits labeled `train` and `eval`. _required_ output: `ExampleStatisticsPath` channel for statistics of each split provided in the input examples. input_data: Backwards compatibility alias for the `examples` argument. instance_name: Optional name assigned to this specific instance of StatisticsGen. Required only if multiple StatisticsGen components are declared in the same pipeline. """ examples = examples or input_data if not output: statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( artifact.DEFAULT_EXAMPLE_SPLITS) output = types.Channel(type=standard_artifacts.ExampleStatistics, artifacts=[statistics_artifact]) spec = StatisticsGenSpec(input_data=examples, output=output) super(StatisticsGen, self).__init__(spec=spec, instance_name=instance_name)
def setUp(self): super(ExecutorTest, self).setUp() self.source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self.train_stats_artifact = standard_artifacts.ExampleStatistics( split='train') self.train_stats_artifact.uri = os.path.join(self.source_data_dir, 'statistics_gen/train/') self.output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self.schema_output = standard_artifacts.Schema() self.schema_output.uri = os.path.join(self.output_data_dir, 'schema_output') self.schema = standard_artifacts.Schema() self.schema.uri = os.path.join(self.source_data_dir, 'fixed_schema/') self.expected_schema = standard_artifacts.Schema() self.expected_schema.uri = os.path.join(self.source_data_dir, 'schema_gen/') self.input_dict = { 'stats': [self.train_stats_artifact], 'schema': None } self.output_dict = { 'output': [self.schema_output], } self.exec_properties = {'infer_feature_shape': False}
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.uri = os.path.join(source_data_dir, 'statistics_gen') statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval', 'test']) output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) schema_output = standard_artifacts.Schema() schema_output.uri = os.path.join(output_data_dir, 'schema_output') input_dict = { standard_component_specs.STATISTICS_KEY: [statistics_artifact], } exec_properties = { # List needs to be serialized before being passed into Do function. standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps(['test']) } output_dict = { standard_component_specs.SCHEMA_KEY: [schema_output], } schema_gen_executor = executor.Executor() schema_gen_executor.Do(input_dict, output_dict, exec_properties) self.assertNotEqual(0, len(fileio.listdir(schema_output.uri)))
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') eval_stats_artifact = standard_artifacts.ExampleStatistics() eval_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen') eval_stats_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval', 'test']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) validation_output = standard_artifacts.ExampleAnomalies() validation_output.uri = os.path.join(output_data_dir, 'output') input_dict = { STATISTICS_KEY: [eval_stats_artifact], SCHEMA_KEY: [schema_artifact], } exec_properties = { # List needs to be serialized before being passed into Do function. EXCLUDE_SPLITS_KEY: json_utils.dumps(['test']) } output_dict = { ANOMALIES_KEY: [validation_output], } example_validator_executor = executor.Executor() example_validator_executor.Do(input_dict, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), validation_output.split_names) # Check example_validator outputs. train_anomalies_path = os.path.join(validation_output.uri, 'Split-train', 'SchemaDiff.pb') eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval', 'SchemaDiff.pb') self.assertTrue(fileio.exists(train_anomalies_path)) self.assertTrue(fileio.exists(eval_anomalies_path)) train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path) train_anomalies = anomalies_pb2.Anomalies() train_anomalies.ParseFromString(train_anomalies_bytes) eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path) eval_anomalies = anomalies_pb2.Anomalies() eval_anomalies.ParseFromString(eval_anomalies_bytes) self.assertEqual(0, len(train_anomalies.anomaly_info)) self.assertEqual(0, len(eval_anomalies.anomaly_info)) # Assert 'test' split is excluded. train_file_path = os.path.join(validation_output.uri, 'Split-test', 'SchemaDiff.pb') self.assertFalse(fileio.exists(train_file_path))
def testConstruct(self): example_validator = component.ExampleValidator( statistics=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='eval')]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), ) self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME, example_validator.outputs['anomalies'].type_name)
def testConstruct(self): example_validator = component.ExampleValidator( stats=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='eval')]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), ) self.assertEqual('ExampleValidationPath', example_validator.outputs['output'].type_name)
def testConstructWithParameter(self): infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool) schema_gen = component.SchemaGen( statistics=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')]), infer_feature_shape=infer_shape) self.assertEqual('SchemaPath', schema_gen.outputs['schema'].type_name) self.assertJsonEqual( str(schema_gen.spec.exec_properties['infer_feature_shape']), str(infer_shape))
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['eval']) example_validator = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), ) self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME, example_validator.outputs['anomalies'].type_name)
def testGetStatusOutputPathsEntriesMissingArtifact(self): pre_transform_stats = standard_artifacts.ExampleStatistics() pre_transform_stats.uri = '/pre_transform_stats' with self.assertRaisesRegex( ValueError, 'all stats_output_paths should be specified or none'): executor_utils.GetStatsOutputPathEntries(False, { standard_component_specs.PRE_TRANSFORM_STATS_KEY: [pre_transform_stats] })
def testGetStatusOutputPathsEntries(self): # disabled. self.assertEmpty(executor_utils.GetStatsOutputPathEntries(True, {})) # enabled. pre_transform_stats = standard_artifacts.ExampleStatistics() pre_transform_stats.uri = '/pre_transform_stats' pre_transform_schema = standard_artifacts.Schema() pre_transform_schema.uri = '/pre_transform_schema' post_transform_anomalies = standard_artifacts.ExampleAnomalies() post_transform_anomalies.uri = '/post_transform_anomalies' post_transform_stats = standard_artifacts.ExampleStatistics() post_transform_stats.uri = '/post_transform_stats' post_transform_schema = standard_artifacts.Schema() post_transform_schema.uri = '/post_transform_schema' result = executor_utils.GetStatsOutputPathEntries( False, { standard_component_specs.PRE_TRANSFORM_STATS_KEY: [pre_transform_stats], standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY: [pre_transform_schema], standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY: [post_transform_anomalies], standard_component_specs.POST_TRANSFORM_STATS_KEY: [post_transform_stats], standard_component_specs.POST_TRANSFORM_SCHEMA_KEY: [post_transform_schema], }) self.assertEqual( { labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL: '/pre_transform_stats', labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL: '/pre_transform_schema', labels.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL: '/post_transform_anomalies', labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL: '/post_transform_stats', labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL: '/post_transform_schema', }, result)
def __init__(self, examples: types.Channel = None, schema: Optional[types.Channel] = None, stats_options: Optional[tfdv.StatsOptions] = None, exclude_splits: Optional[List[Text]] = None, output: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a StatisticsGen component. Args: examples: A Channel of `ExamplesPath` type, likely generated by the [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen). This needs to contain two splits labeled `train` and `eval`. _required_ schema: A `Schema` channel to use for automatically configuring the value of stats options passed to TFDV. stats_options: The StatsOptions instance to configure optional TFDV behavior. When stats_options.schema is set, it will be used instead of the `schema` channel input. Due to the requirement that stats_options be serialized, the slicer functions and custom stats generators are dropped and are therefore not usable. exclude_splits: Names of splits where statistics and sample should not be generated. Default behavior (when exclude_splits is set to None) is excluding no splits. output: `ExampleStatisticsPath` channel for statistics of each split provided in the input examples. input_data: Backwards compatibility alias for the `examples` argument. instance_name: Optional name assigned to this specific instance of StatisticsGen. Required only if multiple StatisticsGen components are declared in the same pipeline. """ if input_data: logging.warning( 'The "input_data" argument to the StatisticsGen component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if exclude_splits is None: exclude_splits = [] logging.info( 'Excluding no splits because exclude_splits is not set.') if not output: output = channel_utils.as_channel( [standard_artifacts.ExampleStatistics()]) # TODO(b/150802589): Move jsonable interface to tfx_bsl and use json_utils. stats_options_json = stats_options.to_json() if stats_options else None spec = StatisticsGenSpec( examples=examples, schema=schema, stats_options_json=stats_options_json, exclude_splits=json_utils.dumps(exclude_splits), statistics=output) super(StatisticsGen, self).__init__(spec=spec, instance_name=instance_name)
def testEnableCache(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) schema_gen_1 = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact])) schema_gen_2 = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact]), enable_cache=True) self.assertEqual(None, schema_gen_1.enable_cache) self.assertEqual(True, schema_gen_2.enable_cache)
def __init__(self, examples: types.Channel = None, schema: Optional[types.Channel] = None, stats_options: Optional[tfdv.StatsOptions] = None, output: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None, enable_cache: Optional[bool] = None): """Construct a StatisticsGen component. Args: examples: A Channel of `ExamplesPath` type, likely generated by the [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen). This needs to contain two splits labeled `train` and `eval`. _required_ schema: A `Schema` channel to use for automatically configuring the value of stats options passed to TFDV. stats_options: The StatsOptions instance to configure optional TFDV behavior. When stats_options.schema is set, it will be used instead of the `schema` channel input. Due to the requirement that stats_options be serialized, the slicer functions and custom stats generators are dropped and are therefore not usable. output: `ExampleStatisticsPath` channel for statistics of each split provided in the input examples. input_data: Backwards compatibility alias for the `examples` argument. instance_name: Optional name assigned to this specific instance of StatisticsGen. Required only if multiple StatisticsGen components are declared in the same pipeline. enable_cache: Optional boolean to indicate if cache is enabled for the StatisticsGen component. If not specified, defaults to the value specified for pipeline's enable_cache parameter. """ if input_data: absl.logging.warning( 'The "input_data" argument to the StatisticsGen component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if not output: statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.get_single_instance( list(examples.get())).split_names output = types.Channel(type=standard_artifacts.ExampleStatistics, artifacts=[statistics_artifact]) # TODO(b/150802589): Move jsonable interface to tfx_bsl and use json_utils. stats_options_json = stats_options.to_json() if stats_options else None spec = StatisticsGenSpec(examples=examples, schema=schema, stats_options_json=stats_options_json, statistics=output) super(StatisticsGen, self).__init__(spec=spec, instance_name=instance_name, enable_cache=enable_cache)
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] example_validator = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), exclude_splits=exclude_splits) self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME, example_validator.outputs['anomalies'].type_name) self.assertEqual(example_validator.spec.exec_properties['exclude_splits'], '["eval"]')
def testConstructWithParameter(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool) schema_gen = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact]), infer_feature_shape=infer_shape) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertJsonEqual( str(schema_gen.spec.exec_properties['infer_feature_shape']), str(infer_shape))
def setUp(self): super(ComponentTest, self).setUp() examples_artifact = standard_artifacts.Examples() examples_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) self.examples = channel_utils.as_channel([examples_artifact]) self.statistics = channel_utils.as_channel([statistics_artifact]) self.custom_config = {'some': 'thing', 'some other': 1, 'thing': 2}
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] schema_gen = component.SchemaGen(statistics=channel_utils.as_channel( [statistics_artifact]), exclude_splits=exclude_splits) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape']) self.assertEqual(schema_gen.spec.exec_properties['exclude_splits'], '["eval"]')
def testEnableCache(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['eval']) example_validator_1 = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), ) self.assertEqual(None, example_validator_1.enable_cache) example_validator_2 = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), enable_cache=True) self.assertEqual(True, example_validator_2.enable_cache)
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] schema_gen = component.SchemaGen(statistics=channel_utils.as_channel( [statistics_artifact]), exclude_splits=exclude_splits) self.assertEqual( standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs[standard_component_specs.SCHEMA_KEY].type_name) self.assertTrue(schema_gen.spec.exec_properties[ standard_component_specs.INFER_FEATURE_SHAPE_KEY]) self.assertEqual( schema_gen.spec.exec_properties[ standard_component_specs.EXCLUDE_SPLITS_KEY], '["eval"]')
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) fileio.makedirs(output_data_dir) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval', 'test']) input_dict = { standard_component_specs.EXAMPLES_KEY: [examples], } exec_properties = { # List needs to be serialized before being passed into Do function. standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps(['test']), } # Create output dict. stats = standard_artifacts.ExampleStatistics() stats.uri = output_data_dir output_dict = { standard_component_specs.STATISTICS_KEY: [stats], } # Run executor. stats_gen_executor = executor.Executor() stats_gen_executor.Do(input_dict, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), stats.split_names) # Check statistics_gen outputs. self._validate_stats_output( os.path.join(stats.uri, 'train', 'stats_tfrecord')) self._validate_stats_output( os.path.join(stats.uri, 'eval', 'stats_tfrecord')) # Assert 'test' split is excluded. self.assertFalse( fileio.exists(os.path.join(stats.uri, 'test', 'stats_tfrecord')))
def testDoWithSchemaAndStatsOptions(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.io.gfile.makedirs(output_data_dir) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { executor.EXAMPLES_KEY: [examples], executor.SCHEMA_KEY: [schema] } exec_properties = { executor.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions(label_feature='company').to_json(), } # Create output dict. stats = standard_artifacts.ExampleStatistics() stats.uri = output_data_dir stats.split_names = artifact_utils.encode_split_names( ['train', 'eval']) output_dict = { executor.STATISTICS_KEY: [stats], } # Run executor. stats_gen_executor = executor.Executor() stats_gen_executor.Do(input_dict, output_dict, exec_properties=exec_properties) # Check statistics_gen outputs. self._validate_stats_output( os.path.join(stats.uri, 'train', 'stats_tfrecord')) self._validate_stats_output( os.path.join(stats.uri, 'eval', 'stats_tfrecord'))
def testDoWithTwoSchemas(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) tf.io.gfile.makedirs(output_data_dir) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { executor.EXAMPLES_KEY: [examples], executor.SCHEMA_KEY: [schema] } exec_properties = { executor.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions(label_feature='company', schema=schema_pb2.Schema()).to_json(), executor.EXCLUDE_SPLITS_KEY: json_utils.dumps([]) } # Create output dict. stats = standard_artifacts.ExampleStatistics() stats.uri = output_data_dir stats.split_names = artifact_utils.encode_split_names( ['train', 'eval']) output_dict = { executor.STATISTICS_KEY: [stats], } # Run executor. stats_gen_executor = executor.Executor() with self.assertRaises(ValueError): stats_gen_executor.Do(input_dict, output_dict, exec_properties)
def build(self, context: Context) -> BaseNode: from tfx.components import StatisticsGen statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( splits_or_example_defaults(self._config.params.split_names)) output = Channel(type=standard_artifacts.ExampleStatistics, artifacts=[statistics_artifact]) examples = context.get(self._config.inputs.examples) component = StatisticsGen( examples=examples, stats_options=None, output=output, instance_name=context.abs_current_url_friendly) put_outputs_to_context(context, self._config.outputs, component) return component
def setUp(self): super(ExecutorTest, self).setUp() source_data_dir = os.path.dirname(os.path.dirname(__file__)) input_data_dir = os.path.join(source_data_dir, 'testdata') statistics = standard_artifacts.ExampleStatistics() statistics.uri = os.path.join(input_data_dir, 'StatisticsGen.train_mockdata_1', 'statistics', '5') statistics.split_names = artifact_utils.encode_split_names( ['train', 'eval']) transformed_examples = standard_artifacts.Examples() transformed_examples.uri = os.path.join(input_data_dir, 'Transform.train_mockdata_1', 'transformed_examples', '10') transformed_examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) self._input_dict = { executor.EXAMPLES_KEY: [transformed_examples], executor.STATISTICS_KEY: [statistics], } output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)), self._testMethodName) self._metafeatures = artifacts.MetaFeatures() self._metafeatures.uri = output_data_dir self._output_dict = { executor.METAFEATURES_KEY: [self._metafeatures], } self._exec_properties = { 'custom_config': { 'problem_statement_path': '/some/fake/path' } }