Exemple #1
0
 def testEnableCache(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     statistics_gen_1 = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]))
     self.assertEqual(None, statistics_gen_1.enable_cache)
     statistics_gen_2 = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]), enable_cache=True)
     self.assertEqual(True, statistics_gen_2.enable_cache)
Exemple #2
0
    def setUp(self):
        super(BaseComponentTest, self).setUp()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_input')
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.component = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
            )
        self.tfx_component = statistics_gen
Exemple #3
0
 def test_construct(self):
     train_examples = types.TfxType(type_name='ExamplesPath', split='train')
     eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
     statistics_gen = component.StatisticsGen(
         input_data=channel.as_channel([train_examples, eval_examples]))
     self.assertEqual('ExampleStatisticsPath',
                      statistics_gen.outputs.output.type_name)
Exemple #4
0
 def testConstruct(self):
   train_examples = standard_artifacts.Examples(split='train')
   eval_examples = standard_artifacts.Examples(split='eval')
   statistics_gen = component.StatisticsGen(
       examples=channel_utils.as_channel([train_examples, eval_examples]))
   self.assertEqual('ExampleStatisticsPath',
                    statistics_gen.outputs['statistics'].type_name)
Exemple #5
0
  def setUp(self):
    super(BaseComponentWithPipelineParamTest, self).setUp()

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
    example_gen_buckets = data_types.RuntimeParameter(
        name='example-gen-buckets', ptype=int, default=10)

    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]),
        output_config={
            'split_config': {
                'splits': [{
                    'name': 'examples',
                    'hash_buckets': example_gen_buckets
                }]
            }
        })
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    self._tfx_ir = pipeline_pb2.Pipeline()
    with dsl.Pipeline('test_pipeline'):
      self.example_gen = base_component.BaseComponent(
          component=example_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir)
      self.statistics_gen = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir
      )

    self.tfx_example_gen = example_gen
    self.tfx_statistics_gen = statistics_gen
Exemple #6
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
      )
    self.tfx_component = statistics_gen
Exemple #7
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input_base=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples, instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name='test_pipeline',
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          depends_on=set(),
          pipeline=pipeline,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
      )
    self.tfx_component = statistics_gen
Exemple #8
0
 def testConstruct(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]))
     self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
                      statistics_gen.outputs['statistics'].type_name)
def _two_step_pipeline() -> tfx_pipeline.Pipeline:
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        query='SELECT * FROM TABLE')
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples)
    return tfx_pipeline.Pipeline(
        pipeline_name='two_step_pipeline',
        pipeline_root='pipeline_root',
        components=[example_gen, statistics_gen],
    )
def _two_step_pipeline() -> tfx_pipeline.Pipeline:
  example_gen = big_query_example_gen_component.BigQueryExampleGen(
      query='SELECT * FROM TABLE')
  statistics_gen = statistics_gen_component.StatisticsGen(
      examples=example_gen.outputs['examples'])
  return tfx_pipeline.Pipeline(
      pipeline_name='two_step_pipeline',
      pipeline_root='pipeline_root',
      metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
      components=[example_gen, statistics_gen],
  )
Exemple #11
0
    def setUp(self):
        super(BaseComponentWithPipelineParamTest, self).setUp()

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        example_gen_output_name = runtime_string_parameter.RuntimeStringParameter(
            name='example-gen-output-name', default='default-to-be-discarded')

        examples = standard_artifacts.ExternalArtifact()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input=channel_utils.as_channel([examples]),
            output_config=example_gen_pb2.Output(
                split_config=example_gen_pb2.SplitConfig(splits=[
                    example_gen_pb2.SplitConfig.Split(
                        name=example_gen_output_name, hash_buckets=10)
                ])))
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples'], instance_name='foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None)
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None,
            )

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
Exemple #12
0
 def testConstruct(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]),
         exclude_splits=exclude_splits)
     self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
                      statistics_gen.outputs['statistics'].type_name)
     self.assertEqual(statistics_gen.spec.exec_properties['exclude_splits'],
                      '["eval"]')
Exemple #13
0
    def test_airflow_runner(self, mock_airflow_pipeline_class,
                            mock_airflow_component_class):
        mock_airflow_pipeline_class.return_value = 'DAG'

        c1 = stats_gen_component.StatisticsGen(input_data=channel.Channel(
            type_name='ExamplesPath'))
        c2 = schema_component.SchemaGen(stats=channel.Channel(
            type_name='ExampleStatisticsPath'))
        airflow_config = {
            'schedule_interval': '* * * * *',
            'start_date': datetime.datetime(2019, 1, 1)
        }
        pipeline_config = {
            'pipeline_name': 'chicago_taxi_gcp',
            'log_root': '/var/tmp/tfx/logs',
            'metadata_db_root': 'var/tmp/tfx//metadata',
            'pipeline_root': '/var/tmp/tfx/pipelines'
        }
        # Simulate the runner's call to pipeline
        combined_config = pipeline_config.copy()
        combined_config.update(airflow_config)

        tfx_pipeline = pipeline.Pipeline(a='a', b='b', **pipeline_config)
        tfx_pipeline.components = [c1, c2]

        tfx_runner = airflow_runner.AirflowDAGRunner(airflow_config)
        tfx_runner.run(tfx_pipeline)

        mock_airflow_pipeline_class.assert_called_with(a='a',
                                                       b='b',
                                                       **combined_config)

        component_calls = [
            mock.call('DAG',
                      component_name=c1.component_name,
                      unique_name=c1.unique_name,
                      driver=c1.driver,
                      executor=c1.executor,
                      input_dict=mock.ANY,
                      output_dict=mock.ANY,
                      exec_properties=c1.exec_properties),
            mock.call('DAG',
                      component_name=c2.component_name,
                      unique_name=c2.unique_name,
                      driver=c2.driver,
                      executor=c2.executor,
                      input_dict=mock.ANY,
                      output_dict=mock.ANY,
                      exec_properties=c2.exec_properties)
        ]
        mock_airflow_component_class.assert_has_calls(component_calls,
                                                      any_order=True)
Exemple #14
0
def _two_step_pipeline() -> tfx_pipeline.Pipeline:
  table_name = data_types.RuntimeParameter(
      name='table-name', ptype=Text, default='default-table')
  example_gen = big_query_example_gen_component.BigQueryExampleGen(
      query='SELECT * FROM %s' % str(table_name))
  statistics_gen = statistics_gen_component.StatisticsGen(
      examples=example_gen.outputs['examples'])
  return tfx_pipeline.Pipeline(
      pipeline_name='two_step_pipeline',
      pipeline_root='pipeline_root',
      metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
      components=[example_gen, statistics_gen],
  )
Exemple #15
0
 def testConstructWithSchemaAndStatsOptions(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     schema = standard_artifacts.Schema()
     stats_options = tfdv.StatsOptions(weight_feature='weight')
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]),
         schema=channel_utils.as_channel([schema]),
         stats_options=stats_options)
     self.assertEqual(
         standard_artifacts.ExampleStatistics.TYPE_NAME,
         statistics_gen.outputs[
             standard_component_specs.STATISTICS_KEY].type_name)
Exemple #16
0
    def setUp(self):
        super().setUp()

        example_gen_output_config = data_types.RuntimeParameter(
            name='example-gen-output-config', ptype=str)

        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_root', output_config=example_gen_output_config)
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[example_gen_output_config])
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[])

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
Exemple #17
0
 def testConstructWithSchemaAndStatsOptions(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     schema = standard_artifacts.Schema()
     stats_options = tfdv.StatsOptions(
         weight_feature='weight',
         generators=[  # generators should be dropped
             tfdv.LiftStatsGenerator(
                 schema=None,
                 y_path=tfdv.FeaturePath(['label']),
                 x_paths=[tfdv.FeaturePath(['feature'])])
         ])
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]),
         schema=channel_utils.as_channel([schema]),
         stats_options=stats_options)
     self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
                      statistics_gen.outputs['statistics'].type_name)
def _two_step_pipeline() -> tfx_pipeline.Pipeline:
    default_input_config = json.dumps({
        'splits': [{
            'name': 'single_split',
            'pattern': 'SELECT * FROM default-table'
        }]
    })
    input_config = data_types.RuntimeParameter(name='input_config',
                                               ptype=str,
                                               default=default_input_config)
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        input_config=input_config, output_config=example_gen_pb2.Output())
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'])
    return tfx_pipeline.Pipeline(
        pipeline_name='two_step_pipeline',
        pipeline_root='pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )
 def __init__(self, args):
   component = statistics_gen_component.StatisticsGen(
       channel.Channel('ExamplesPath'))
   super(StatisticsGenRunner, self).__init__(args, component,
                                             {"input_data": args.input_data})
Exemple #20
0
 def __init__(self, input_data: dsl.PipelineParam):
     component = statistics_gen_component.StatisticsGen(
         channel.Channel('ExamplesPath'))
     super().__init__(component, {"input_data": input_data})
Exemple #21
0
def _two_step_pipeline():
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        query='SELECT * FROM TABLE')
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples)
    return [example_gen, statistics_gen]