def testEnableCache(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) statistics_gen_1 = component.StatisticsGen( examples=channel_utils.as_channel([examples])) self.assertEqual(None, statistics_gen_1.enable_cache) statistics_gen_2 = component.StatisticsGen( examples=channel_utils.as_channel([examples]), enable_cache=True) self.assertEqual(True, statistics_gen_2.enable_cache)
def setUp(self): super(BaseComponentTest, self).setUp() example_gen = csv_example_gen_component.CsvExampleGen( input_base='data_input') statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']).with_id('foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, ) self.tfx_component = statistics_gen
def test_construct(self): train_examples = types.TfxType(type_name='ExamplesPath', split='train') eval_examples = types.TfxType(type_name='ExamplesPath', split='eval') statistics_gen = component.StatisticsGen( input_data=channel.as_channel([train_examples, eval_examples])) self.assertEqual('ExampleStatisticsPath', statistics_gen.outputs.output.type_name)
def testConstruct(self): train_examples = standard_artifacts.Examples(split='train') eval_examples = standard_artifacts.Examples(split='eval') statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([train_examples, eval_examples])) self.assertEqual('ExampleStatisticsPath', statistics_gen.outputs['statistics'].type_name)
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_buckets = data_types.RuntimeParameter( name='example-gen-buckets', ptype=int, default=10) examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config={ 'split_config': { 'splits': [{ 'name': 'examples', 'hash_buckets': example_gen_buckets }] } }) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_component = statistics_gen
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input_base=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( input_data=example_gen.outputs.examples, instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name='test_pipeline', pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, ) self.tfx_component = statistics_gen
def testConstruct(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples])) self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs['statistics'].type_name)
def _two_step_pipeline() -> tfx_pipeline.Pipeline: example_gen = big_query_example_gen_component.BigQueryExampleGen( query='SELECT * FROM TABLE') statistics_gen = statistics_gen_component.StatisticsGen( input_data=example_gen.outputs.examples) return tfx_pipeline.Pipeline( pipeline_name='two_step_pipeline', pipeline_root='pipeline_root', components=[example_gen, statistics_gen], )
def _two_step_pipeline() -> tfx_pipeline.Pipeline: example_gen = big_query_example_gen_component.BigQueryExampleGen( query='SELECT * FROM TABLE') statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']) return tfx_pipeline.Pipeline( pipeline_name='two_step_pipeline', pipeline_root='pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], )
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_output_name = runtime_string_parameter.RuntimeStringParameter( name='example-gen-output-name', default='default-to-be-discarded') examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name=example_gen_output_name, hash_buckets=10) ]))) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher. InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher. InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def testConstruct(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples]), exclude_splits=exclude_splits) self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs['statistics'].type_name) self.assertEqual(statistics_gen.spec.exec_properties['exclude_splits'], '["eval"]')
def test_airflow_runner(self, mock_airflow_pipeline_class, mock_airflow_component_class): mock_airflow_pipeline_class.return_value = 'DAG' c1 = stats_gen_component.StatisticsGen(input_data=channel.Channel( type_name='ExamplesPath')) c2 = schema_component.SchemaGen(stats=channel.Channel( type_name='ExampleStatisticsPath')) airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } pipeline_config = { 'pipeline_name': 'chicago_taxi_gcp', 'log_root': '/var/tmp/tfx/logs', 'metadata_db_root': 'var/tmp/tfx//metadata', 'pipeline_root': '/var/tmp/tfx/pipelines' } # Simulate the runner's call to pipeline combined_config = pipeline_config.copy() combined_config.update(airflow_config) tfx_pipeline = pipeline.Pipeline(a='a', b='b', **pipeline_config) tfx_pipeline.components = [c1, c2] tfx_runner = airflow_runner.AirflowDAGRunner(airflow_config) tfx_runner.run(tfx_pipeline) mock_airflow_pipeline_class.assert_called_with(a='a', b='b', **combined_config) component_calls = [ mock.call('DAG', component_name=c1.component_name, unique_name=c1.unique_name, driver=c1.driver, executor=c1.executor, input_dict=mock.ANY, output_dict=mock.ANY, exec_properties=c1.exec_properties), mock.call('DAG', component_name=c2.component_name, unique_name=c2.unique_name, driver=c2.driver, executor=c2.executor, input_dict=mock.ANY, output_dict=mock.ANY, exec_properties=c2.exec_properties) ] mock_airflow_component_class.assert_has_calls(component_calls, any_order=True)
def _two_step_pipeline() -> tfx_pipeline.Pipeline: table_name = data_types.RuntimeParameter( name='table-name', ptype=Text, default='default-table') example_gen = big_query_example_gen_component.BigQueryExampleGen( query='SELECT * FROM %s' % str(table_name)) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']) return tfx_pipeline.Pipeline( pipeline_name='two_step_pipeline', pipeline_root='pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], )
def testConstructWithSchemaAndStatsOptions(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema = standard_artifacts.Schema() stats_options = tfdv.StatsOptions(weight_feature='weight') statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples]), schema=channel_utils.as_channel([schema]), stats_options=stats_options) self.assertEqual( standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs[ standard_component_specs.STATISTICS_KEY].type_name)
def setUp(self): super().setUp() example_gen_output_config = data_types.RuntimeParameter( name='example-gen-output-config', ptype=str) example_gen = csv_example_gen_component.CsvExampleGen( input_base='data_root', output_config=example_gen_output_config) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']).with_id('foo') test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, pod_labels_to_attach={}, runtime_parameters=[example_gen_output_config]) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, pod_labels_to_attach={}, runtime_parameters=[]) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def testConstructWithSchemaAndStatsOptions(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema = standard_artifacts.Schema() stats_options = tfdv.StatsOptions( weight_feature='weight', generators=[ # generators should be dropped tfdv.LiftStatsGenerator( schema=None, y_path=tfdv.FeaturePath(['label']), x_paths=[tfdv.FeaturePath(['feature'])]) ]) statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples]), schema=channel_utils.as_channel([schema]), stats_options=stats_options) self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs['statistics'].type_name)
def _two_step_pipeline() -> tfx_pipeline.Pipeline: default_input_config = json.dumps({ 'splits': [{ 'name': 'single_split', 'pattern': 'SELECT * FROM default-table' }] }) input_config = data_types.RuntimeParameter(name='input_config', ptype=str, default=default_input_config) example_gen = big_query_example_gen_component.BigQueryExampleGen( input_config=input_config, output_config=example_gen_pb2.Output()) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']) return tfx_pipeline.Pipeline( pipeline_name='two_step_pipeline', pipeline_root='pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], )
def __init__(self, args): component = statistics_gen_component.StatisticsGen( channel.Channel('ExamplesPath')) super(StatisticsGenRunner, self).__init__(args, component, {"input_data": args.input_data})
def __init__(self, input_data: dsl.PipelineParam): component = statistics_gen_component.StatisticsGen( channel.Channel('ExamplesPath')) super().__init__(component, {"input_data": input_data})
def _two_step_pipeline(): example_gen = big_query_example_gen_component.BigQueryExampleGen( query='SELECT * FROM TABLE') statistics_gen = statistics_gen_component.StatisticsGen( input_data=example_gen.outputs.examples) return [example_gen, statistics_gen]