Esempio n. 1
0
 def testEnableCache(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train'])
   schema_gen_1 = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]))
   schema_gen_2 = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]),
       enable_cache=True)
   self.assertEqual(None, schema_gen_1.enable_cache)
   self.assertEqual(True, schema_gen_2.enable_cache)
Esempio n. 2
0
 def testConstruct(self):
     schema_gen = component.SchemaGen(statistics=channel_utils.as_channel(
         [standard_artifacts.ExampleStatistics(split='train')]))
     self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                      schema_gen.outputs['schema'].type_name)
     self.assertFalse(
         schema_gen.spec.exec_properties['infer_feature_shape'])
Esempio n. 3
0
 def testConstruct(self):
   schema_gen = component.SchemaGen(
       stats=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='train')]),
       infer_feature_shape=True)
   self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
   self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
Esempio n. 4
0
 def testConstructWithParameter(self):
   infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool)
   schema_gen = component.SchemaGen(
       statistics=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='train')]),
       infer_feature_shape=infer_shape)
   self.assertEqual('SchemaPath', schema_gen.outputs['schema'].type_name)
   self.assertJsonEqual(
       str(schema_gen.spec.exec_properties['infer_feature_shape']),
       str(infer_shape))
Esempio n. 5
0
    def test_airflow_runner(self, mock_airflow_pipeline_class,
                            mock_airflow_component_class):
        mock_airflow_pipeline_class.return_value = 'DAG'

        c1 = stats_gen_component.StatisticsGen(input_data=channel.Channel(
            type_name='ExamplesPath'))
        c2 = schema_component.SchemaGen(stats=channel.Channel(
            type_name='ExampleStatisticsPath'))
        airflow_config = {
            'schedule_interval': '* * * * *',
            'start_date': datetime.datetime(2019, 1, 1)
        }
        pipeline_config = {
            'pipeline_name': 'chicago_taxi_gcp',
            'log_root': '/var/tmp/tfx/logs',
            'metadata_db_root': 'var/tmp/tfx//metadata',
            'pipeline_root': '/var/tmp/tfx/pipelines'
        }
        # Simulate the runner's call to pipeline
        combined_config = pipeline_config.copy()
        combined_config.update(airflow_config)

        tfx_pipeline = pipeline.Pipeline(a='a', b='b', **pipeline_config)
        tfx_pipeline.components = [c1, c2]

        tfx_runner = airflow_runner.AirflowDAGRunner(airflow_config)
        tfx_runner.run(tfx_pipeline)

        mock_airflow_pipeline_class.assert_called_with(a='a',
                                                       b='b',
                                                       **combined_config)

        component_calls = [
            mock.call('DAG',
                      component_name=c1.component_name,
                      unique_name=c1.unique_name,
                      driver=c1.driver,
                      executor=c1.executor,
                      input_dict=mock.ANY,
                      output_dict=mock.ANY,
                      exec_properties=c1.exec_properties),
            mock.call('DAG',
                      component_name=c2.component_name,
                      unique_name=c2.unique_name,
                      driver=c2.driver,
                      executor=c2.executor,
                      input_dict=mock.ANY,
                      output_dict=mock.ANY,
                      exec_properties=c2.exec_properties)
        ]
        mock_airflow_component_class.assert_has_calls(component_calls,
                                                      any_order=True)
Esempio n. 6
0
 def testConstructWithParameter(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train'])
   infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool)
   schema_gen = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]),
       infer_feature_shape=infer_shape)
   self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                    schema_gen.outputs['schema'].type_name)
   self.assertJsonEqual(
       str(schema_gen.spec.exec_properties['infer_feature_shape']),
       str(infer_shape))
Esempio n. 7
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     schema_gen = component.SchemaGen(statistics=channel_utils.as_channel(
         [statistics_artifact]),
                                      exclude_splits=exclude_splits)
     self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                      schema_gen.outputs['schema'].type_name)
     self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
     self.assertEqual(schema_gen.spec.exec_properties['exclude_splits'],
                      '["eval"]')
Esempio n. 8
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     schema_gen = component.SchemaGen(statistics=channel_utils.as_channel(
         [statistics_artifact]),
                                      exclude_splits=exclude_splits)
     self.assertEqual(
         standard_artifacts.Schema.TYPE_NAME,
         schema_gen.outputs[standard_component_specs.SCHEMA_KEY].type_name)
     self.assertTrue(schema_gen.spec.exec_properties[
         standard_component_specs.INFER_FEATURE_SHAPE_KEY])
     self.assertEqual(
         schema_gen.spec.exec_properties[
             standard_component_specs.EXCLUDE_SPLITS_KEY], '["eval"]')
Esempio n. 9
0
 def __init__(self, stats: dsl.PipelineParam):
     component = schema_gen_component.SchemaGen(
         channel.Channel('ExampleStatisticsPath'))
     super().__init__(component, {"stats": stats})
Esempio n. 10
0
 def test_construct(self):
     schema_gen = component.SchemaGen(stats=channel.as_channel([
         types.TfxArtifact(type_name='ExampleStatisticsPath', split='train')
     ]))
     self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
Esempio n. 11
0
 def testConstructWithNeitherStatsNorSchema(self):
     with self.assertRaises(ValueError):
         _ = component.SchemaGen()
Esempio n. 12
0
 def testConstructWithBothStatsAndSchema(self):
     with self.assertRaises(ValueError):
         _ = component.SchemaGen(stats=channel_utils.as_channel(
             [standard_artifacts.ExampleStatistics(split='train')]),
                                 schema=channel_utils.as_channel(
                                     [standard_artifacts.Schema()]))
Esempio n. 13
0
 def testConstructWithSchema(self):
     schema_gen = component.SchemaGen(
         schema=channel_utils.as_channel([standard_artifacts.Schema()]))
     self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
Esempio n. 14
0
 def test_construct(self):
     schema_gen = component.SchemaGen(stats=channel_utils.as_channel(
         [standard_artifacts.ExampleStatistics(split='train')]))
     self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)