예제 #1
0
 def testEnableCache(self):
   input_base = standard_artifacts.ExternalArtifact()
   csv_example_gen_1 = component.CsvExampleGen(
       input=channel_utils.as_channel([input_base]))
   self.assertEqual(None, csv_example_gen_1.enable_cache)
   csv_example_gen_2 = component.CsvExampleGen(
       input=channel_utils.as_channel([input_base]), enable_cache=True)
   self.assertEqual(True, csv_example_gen_2.enable_cache)
예제 #2
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input_base=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples, instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name='test_pipeline',
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          depends_on=set(),
          pipeline=pipeline,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
      )
    self.tfx_component = statistics_gen
예제 #3
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
      )
    self.tfx_component = statistics_gen
예제 #4
0
    def setUp(self):
        super(BaseComponentTest, self).setUp()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_input')
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.component = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
            )
        self.tfx_component = statistics_gen
예제 #5
0
  def setUp(self):
    super(BaseComponentWithPipelineParamTest, self).setUp()

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
    example_gen_buckets = data_types.RuntimeParameter(
        name='example-gen-buckets', ptype=int, default=10)

    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]),
        output_config={
            'split_config': {
                'splits': [{
                    'name': 'examples',
                    'hash_buckets': example_gen_buckets
                }]
            }
        })
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    self._tfx_ir = pipeline_pb2.Pipeline()
    with dsl.Pipeline('test_pipeline'):
      self.example_gen = base_component.BaseComponent(
          component=example_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir)
      self.statistics_gen = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir
      )

    self.tfx_example_gen = example_gen
    self.tfx_statistics_gen = statistics_gen
예제 #6
0
 def testConstruct(self):
     input_base = standard_artifacts.ExternalArtifact()
     csv_example_gen = component.CsvExampleGen(
         input_base=channel_utils.as_channel([input_base]))
     self.assertEqual('ExamplesPath',
                      csv_example_gen.outputs['examples'].type_name)
     artifact_collection = csv_example_gen.outputs['examples'].get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
예제 #7
0
 def test_construct(self):
     input_base = types.TfxType(type_name='ExternalPath')
     csv_example_gen = component.CsvExampleGen(
         input_base=channel.as_channel([input_base]))
     self.assertEqual('ExamplesPath',
                      csv_example_gen.outputs.examples.type_name)
     artifact_collection = csv_example_gen.outputs.examples.get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
예제 #8
0
 def testConstruct(self):
   csv_example_gen = component.CsvExampleGen(input_base='path')
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    csv_example_gen.outputs['examples'].type_name)
   artifact_collection = csv_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
예제 #9
0
    def setUp(self):
        super(BaseComponentWithPipelineParamTest, self).setUp()

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        example_gen_output_name = runtime_string_parameter.RuntimeStringParameter(
            name='example-gen-output-name', default='default-to-be-discarded')

        examples = standard_artifacts.ExternalArtifact()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input=channel_utils.as_channel([examples]),
            output_config=example_gen_pb2.Output(
                split_config=example_gen_pb2.SplitConfig(splits=[
                    example_gen_pb2.SplitConfig.Split(
                        name=example_gen_output_name, hash_buckets=10)
                ])))
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples'], instance_name='foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None)
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None,
            )

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
예제 #10
0
 def test_construct_with_output_config(self):
     input_base = types.TfxType(type_name='ExternalPath')
     csv_example_gen = component.CsvExampleGen(
         input_base=channel.as_channel([input_base]),
         output_config=example_gen_pb2.
         Output(split_config=example_gen_pb2.SplitConfig(splits=[
             example_gen_pb2.SplitConfig.Split(name='train',
                                               hash_buckets=2),
             example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
             example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
         ])))
     self.assertEqual('ExamplesPath',
                      csv_example_gen.outputs.examples.type_name)
     artifact_collection = csv_example_gen.outputs.examples.get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
     self.assertEqual('test', artifact_collection[2].split)
예제 #11
0
    def setUp(self):
        super().setUp()

        example_gen_output_config = data_types.RuntimeParameter(
            name='example-gen-output-config', ptype=str)

        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_root', output_config=example_gen_output_config)
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[example_gen_output_config])
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[])

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
예제 #12
0
 def test_construct(self):
     input_base = types.TfxType(type_name='ExternalPath')
     csv_example_gen = component.CsvExampleGen(
         input_base=channel.as_channel([input_base]))
     self.assertEqual('ExamplesPath',
                      csv_example_gen.outputs.examples.type_name)
예제 #13
0
 def testConstruct(self):
   csv_example_gen = component.CsvExampleGen(input_base='path')
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    csv_example_gen.outputs['examples'].type_name)
   artifact_collection = csv_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))