def testEnableCache(self): input_base = standard_artifacts.ExternalArtifact() csv_example_gen_1 = component.CsvExampleGen( input=channel_utils.as_channel([input_base])) self.assertEqual(None, csv_example_gen_1.enable_cache) csv_example_gen_2 = component.CsvExampleGen( input=channel_utils.as_channel([input_base]), enable_cache=True) self.assertEqual(True, csv_example_gen_2.enable_cache)
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input_base=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( input_data=example_gen.outputs.examples, instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name='test_pipeline', pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, ) self.tfx_component = statistics_gen
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_component = statistics_gen
def setUp(self): super(BaseComponentTest, self).setUp() example_gen = csv_example_gen_component.CsvExampleGen( input_base='data_input') statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']).with_id('foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, ) self.tfx_component = statistics_gen
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_buckets = data_types.RuntimeParameter( name='example-gen-buckets', ptype=int, default=10) examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config={ 'split_config': { 'splits': [{ 'name': 'examples', 'hash_buckets': example_gen_buckets }] } }) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def testConstruct(self): input_base = standard_artifacts.ExternalArtifact() csv_example_gen = component.CsvExampleGen( input_base=channel_utils.as_channel([input_base])) self.assertEqual('ExamplesPath', csv_example_gen.outputs['examples'].type_name) artifact_collection = csv_example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def test_construct(self): input_base = types.TfxType(type_name='ExternalPath') csv_example_gen = component.CsvExampleGen( input_base=channel.as_channel([input_base])) self.assertEqual('ExamplesPath', csv_example_gen.outputs.examples.type_name) artifact_collection = csv_example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def testConstruct(self): csv_example_gen = component.CsvExampleGen(input_base='path') self.assertEqual(standard_artifacts.Examples.TYPE_NAME, csv_example_gen.outputs['examples'].type_name) artifact_collection = csv_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_output_name = runtime_string_parameter.RuntimeStringParameter( name='example-gen-output-name', default='default-to-be-discarded') examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name=example_gen_output_name, hash_buckets=10) ]))) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher. InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher. InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def test_construct_with_output_config(self): input_base = types.TfxType(type_name='ExternalPath') csv_example_gen = component.CsvExampleGen( input_base=channel.as_channel([input_base]), output_config=example_gen_pb2. Output(split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual('ExamplesPath', csv_example_gen.outputs.examples.type_name) artifact_collection = csv_example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def setUp(self): super().setUp() example_gen_output_config = data_types.RuntimeParameter( name='example-gen-output-config', ptype=str) example_gen = csv_example_gen_component.CsvExampleGen( input_base='data_root', output_config=example_gen_output_config) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']).with_id('foo') test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, pod_labels_to_attach={}, runtime_parameters=[example_gen_output_config]) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, pod_labels_to_attach={}, runtime_parameters=[]) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def test_construct(self): input_base = types.TfxType(type_name='ExternalPath') csv_example_gen = component.CsvExampleGen( input_base=channel.as_channel([input_base])) self.assertEqual('ExamplesPath', csv_example_gen.outputs.examples.type_name)
def testConstruct(self): csv_example_gen = component.CsvExampleGen(input_base='path') self.assertEqual(standard_artifacts.Examples.TYPE_NAME, csv_example_gen.outputs['examples'].type_name) artifact_collection = csv_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection))