def testMainEmptyInputs(self): """Test executor class import under empty inputs/outputs.""" inputs = { 'x': [ standard_artifacts.ExternalArtifact(), standard_artifacts.ExternalArtifact() ] } outputs = {'y': [standard_artifacts.Examples()]} exec_properties = {'a': 'b'} args = [ '--executor_class_path=%s.%s' % (FakeExecutor.__module__, FakeExecutor.__name__), '--inputs=%s' % artifact_utils.jsonify_artifact_dict(inputs), '--outputs=%s' % artifact_utils.jsonify_artifact_dict(outputs), '--exec-properties=%s' % json.dumps(exec_properties), ] with ArgsCapture() as args_capture: run_executor.main(args) # TODO(b/131417512): Add equal comparison to types.Artifact class so we # can use asserters. self.assertSetEqual(set(args_capture.input_dict.keys()), set(inputs.keys())) self.assertSetEqual(set(args_capture.output_dict.keys()), set(outputs.keys())) self.assertDictEqual(args_capture.exec_properties, exec_properties)
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_buckets = data_types.RuntimeParameter( name='example-gen-buckets', ptype=int, default=10) examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config={ 'split_config': { 'splits': [{ 'name': 'examples', 'hash_buckets': example_gen_buckets }] } }) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def setUp(self): super(DriverTest, self).setUp() # Create input splits. test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._input_base_path = os.path.join(test_dir, 'input_base') tf.io.gfile.makedirs(self._input_base_path) # Mock metadata. self._mock_metadata = tf.test.mock.Mock() self._example_gen_driver = driver.Driver(self._mock_metadata) # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = self._input_base_path self._input_channels = { 'input_base': channel_utils.as_channel([input_base]) } # Create exec proterties. self._exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN}/split2/*') ])), }
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_component = statistics_gen
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input_base=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( input_data=example_gen.outputs.examples, instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name='test_pipeline', pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, ) self.tfx_component = statistics_gen
def testCsvExampleGenWrapper(self): input_base = standard_artifacts.ExternalArtifact(split='') input_base.uri = '/path/to/dataset' with patch.object(executor, 'Executor', autospec=True) as _: wrapper = executor_wrappers.CsvExampleGenWrapper( argparse.Namespace( exec_properties=json.dumps(self.exec_properties), outputs=artifact_utils.jsonify_artifact_dict( {'examples': self.examples}), executor_class_path=( 'tfx.components.example_gen.csv_example_gen.executor.Executor' ), input_base=json.dumps([input_base.json_dict()])),) wrapper.run(output_basedir=self.output_basedir) # TODO(b/133011207): Validate arguments for executor and Do() method. metadata_file = os.path.join( self.output_basedir, 'output/ml_metadata/examples') expected_output_examples = standard_artifacts.Examples(split='dummy') # Expect that span and path are resolved. expected_output_examples.span = 1 expected_output_examples.uri = ( '/path/to/output/csv_example_gen/examples/mock_workflow_id/dummy/') with tf.gfile.GFile(metadata_file) as f: self.assertEqual( [expected_output_examples.json_dict()], json.loads(f.read()))
def setUp(self): input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = os.path.join(input_data_dir, 'external') self._input_dict = {'input_base': [input_base]}
def testEnableCache(self): input_base = standard_artifacts.ExternalArtifact() import_example_gen_1 = component.ImportExampleGen( input=channel_utils.as_channel([input_base])) self.assertEqual(None, import_example_gen_1.enable_cache) import_example_gen_2 = component.ImportExampleGen( input=channel_utils.as_channel([input_base]), enable_cache=True) self.assertEqual(True, import_example_gen_2.enable_cache)
def testConstruct(self): input_base = standard_artifacts.ExternalArtifact() import_example_gen = component.ImportExampleGen( input=channel_utils.as_channel([input_base])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, import_example_gen.outputs['examples'].type_name) artifact_collection = import_example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def setUp(self): super(ExecutorTest, self).setUp() input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = os.path.join(input_data_dir, 'external') self._input_dict = {INPUT_KEY: [input_base]}
def testConstruct(self): input_base = standard_artifacts.ExternalArtifact() csv_example_gen = component.CsvExampleGen( input_base=channel_utils.as_channel([input_base])) self.assertEqual('ExamplesPath', csv_example_gen.outputs['examples'].type_name) artifact_collection = csv_example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def testConstruct(self): input_base = standard_artifacts.ExternalArtifact() csv_example_gen = component.CsvExampleGen( input=channel_utils.as_channel([input_base])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, csv_example_gen.outputs['examples'].type_name) artifact_collection = csv_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstructCustomExecutor(self): input_base = standard_artifacts.ExternalArtifact() example_gen = component.FileBasedExampleGen( input_base=channel_utils.as_channel([input_base]), custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) self.assertEqual(driver.Driver, example_gen.driver_class) self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def testConstructSubclassFileBased(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input_base=channel_utils.as_channel([input_base])) self.assertIn('input_base', example_gen.inputs.get_all()) self.assertEqual(driver.Driver, example_gen.driver_class) self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name) self.assertIsNone(example_gen.exec_properties.get('custom_config')) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_output_name = runtime_string_parameter.RuntimeStringParameter( name='example-gen-output-name', default='default-to-be-discarded') examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name=example_gen_output_name, hash_buckets=10) ]))) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher. InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher. InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def external_input(uri: Text) -> types.Channel: """Helper function to declare external input. Args: uri: external path. Returns: input channel. """ instance = standard_artifacts.ExternalArtifact() instance.uri = uri return channel_utils.as_channel([instance])
def external_input(uri: Any) -> types.Channel: """Helper function to declare external input. Args: uri: external path, can be RuntimeParameter Returns: input channel. """ instance = standard_artifacts.ExternalArtifact() instance.uri = str(uri) return channel_utils.as_channel([instance])
def testConstructWithCustomConfig(self): input_base = standard_artifacts.ExternalArtifact() custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any()) example_gen = component.FileBasedExampleGen( input_base=channel_utils.as_channel([input_base]), custom_config=custom_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) stored_custom_config = example_gen_pb2.CustomConfig() json_format.Parse(example_gen.exec_properties['custom_config'], stored_custom_config) self.assertEqual(custom_config, stored_custom_config)
def testConstructWithInputConfig(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input_base=channel_utils.as_channel([input_base]), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testConstructWithInputConfig(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input=channel_utils.as_channel([input_base]), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstructWithOutputConfig(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input_base=channel_utils.as_channel([input_base]), output_config=example_gen_pb2. Output(split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testEnableCache(self): input_base = standard_artifacts.ExternalArtifact() custom_config = example_gen_pb2.CustomConfig( custom_config=any_pb2.Any()) example_gen_1 = component.FileBasedExampleGen( input=channel_utils.as_channel([input_base]), custom_config=custom_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) self.assertEqual(None, example_gen_1.enable_cache) example_gen_2 = component.FileBasedExampleGen( input=channel_utils.as_channel([input_base]), custom_config=custom_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor), enable_cache=True) self.assertEqual(True, example_gen_2.enable_cache)
def __init__(self, input_data: types.Channel = None, output_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a HelloComponent. Args: input_data: A Channel of type `standard_artifacts.InferenceResult`. output_data: A Channel of type `standard_artifacts.ExternalArtifact`. instance_name: Optional unique name. Necessary if multiple Hello components are declared in the same pipeline. """ if not output_data: examples_artifact = standard_artifacts.ExternalArtifact() output_data = channel_utils.as_channel([examples_artifact]) spec = HelloComponentSpec(input_data=input_data, output_data=output_data) super(HelloComponent, self).__init__(spec=spec, instance_name=instance_name)
def CsvExampleGen_GCS( # # Inputs #input_base_path: InputPath('ExternalPath'), input_base_path: 'ExternalPath', # A Channel of 'ExternalPath' type, which includes one artifact whose uri is an external directory with csv files inside (required). # Outputs #example_artifacts_path: OutputPath('ExamplesPath'), example_artifacts_path: 'ExamplesPath', # Execution properties #input_config_splits: {'List' : {'item_type': 'ExampleGen.Input.Split'}}, input_config: 'ExampleGen.Input' = None, # = '{"splits": []}', # JSON-serialized example_gen_pb2.Input instance, providing input configuration. If unset, the files under input_base will be treated as a single split. #output_config_splits: {'List' : {'item_type': 'ExampleGen.SplitConfig'}}, output_config: 'ExampleGen.Output' = None, # = '{"splitConfig": {"splits": []}}', # JSON-serialized example_gen_pb2.Output instance, providing output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. #custom_config: 'ExampleGen.CustomConfig' = None, ) -> NamedTuple('Outputs', [ ('example_artifacts', 'ExamplesPath'), ]): """Executes the CsvExampleGen component. Args: input_base: A Channel of 'ExternalPath' type, which includes one artifact whose uri is an external directory with csv files inside (required). input_config: An example_gen_pb2.Input instance, providing input configuration. If unset, the files under input_base will be treated as a single split. output_config: An example_gen_pb2.Output instance, providing output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. ??? input: Forwards compatibility alias for the 'input_base' argument. Returns: example_artifacts: Artifact of type 'ExamplesPath' for output train and eval examples. """ import json import os from google.protobuf import json_format from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen from tfx.proto import example_gen_pb2 from tfx.types import standard_artifacts from tfx.types import channel_utils # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = input_base_path input_base_channel = channel_utils.as_channel([input_base]) input_config_obj = None if input_config: input_config_obj = example_gen_pb2.Input() json_format.Parse(input_config, input_config_obj) output_config_obj = None if output_config: output_config_obj = example_gen_pb2.Output() json_format.Parse(output_config, output_config_obj) component_class_instance = CsvExampleGen( input=input_base_channel, input_config=input_config_obj, output_config=output_config_obj, ) # component_class_instance.inputs/outputs are wrappers that do not behave like real dictionaries. The underlying dict can be accessed using .get_all() # Channel artifacts can be accessed by calling .get() input_dict = { name: channel.get() for name, channel in component_class_instance.inputs.get_all().items() } output_dict = { name: channel.get() for name, channel in component_class_instance.outputs.get_all().items() } exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for output_artifact in output_dict['examples']: output_artifact.uri = example_artifacts_path if output_artifact.split: output_artifact.uri = os.path.join(output_artifact.uri, output_artifact.split) print('component instance: ' + str(component_class_instance)) executor = CsvExampleGen.EXECUTOR_SPEC.executor_class() executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) return (example_artifacts_path, )
def test_prepare_input_for_processing(self): # Create input splits. test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) input_base_path = os.path.join(test_dir, 'input_base') split1 = os.path.join(input_base_path, 'split1', 'data') io_utils.write_string_file(split1, 'testing') os.utime(split1, (0, 1)) split2 = os.path.join(input_base_path, 'split2', 'data') io_utils.write_string_file(split2, 'testing2') os.utime(split2, (0, 3)) # Mock metadata. mock_metadata = tf.test.mock.Mock() example_gen_driver = driver.Driver(mock_metadata) # Mock artifact. artifacts = [] for i in [4, 3, 2, 1]: artifact = metadata_store_pb2.Artifact() artifact.id = i artifact.uri = input_base_path # Only odd ids will be matched if i % 2 == 1: artifact.custom_properties[ 'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3' else: artifact.custom_properties[ 'input_fingerprint'].string_value = 'not_match' artifacts.append(artifact) # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = input_base_path input_dict = {'input_base': [input_base]} # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='split2/*') ])), } # Cache not hit. mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]] mock_metadata.publish_artifacts.return_value = [artifacts[3]] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict), exec_properties) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(1, updated_input_base.id) self.assertEqual(input_base_path, updated_input_base.uri) # Cache hit. mock_metadata.get_artifacts_by_uri.return_value = artifacts mock_metadata.publish_artifacts.return_value = [] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict), exec_properties) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(3, updated_input_base.id) self.assertEqual(input_base_path, updated_input_base.uri)