def testEnableCache(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) schema_gen_1 = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact])) schema_gen_2 = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact]), enable_cache=True) self.assertEqual(None, schema_gen_1.enable_cache) self.assertEqual(True, schema_gen_2.enable_cache)
def testConstruct(self): schema_gen = component.SchemaGen(statistics=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')])) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertFalse( schema_gen.spec.exec_properties['infer_feature_shape'])
def testConstruct(self): schema_gen = component.SchemaGen( stats=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')]), infer_feature_shape=True) self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name) self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
def testConstructWithParameter(self): infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool) schema_gen = component.SchemaGen( statistics=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')]), infer_feature_shape=infer_shape) self.assertEqual('SchemaPath', schema_gen.outputs['schema'].type_name) self.assertJsonEqual( str(schema_gen.spec.exec_properties['infer_feature_shape']), str(infer_shape))
def test_airflow_runner(self, mock_airflow_pipeline_class, mock_airflow_component_class): mock_airflow_pipeline_class.return_value = 'DAG' c1 = stats_gen_component.StatisticsGen(input_data=channel.Channel( type_name='ExamplesPath')) c2 = schema_component.SchemaGen(stats=channel.Channel( type_name='ExampleStatisticsPath')) airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } pipeline_config = { 'pipeline_name': 'chicago_taxi_gcp', 'log_root': '/var/tmp/tfx/logs', 'metadata_db_root': 'var/tmp/tfx//metadata', 'pipeline_root': '/var/tmp/tfx/pipelines' } # Simulate the runner's call to pipeline combined_config = pipeline_config.copy() combined_config.update(airflow_config) tfx_pipeline = pipeline.Pipeline(a='a', b='b', **pipeline_config) tfx_pipeline.components = [c1, c2] tfx_runner = airflow_runner.AirflowDAGRunner(airflow_config) tfx_runner.run(tfx_pipeline) mock_airflow_pipeline_class.assert_called_with(a='a', b='b', **combined_config) component_calls = [ mock.call('DAG', component_name=c1.component_name, unique_name=c1.unique_name, driver=c1.driver, executor=c1.executor, input_dict=mock.ANY, output_dict=mock.ANY, exec_properties=c1.exec_properties), mock.call('DAG', component_name=c2.component_name, unique_name=c2.unique_name, driver=c2.driver, executor=c2.executor, input_dict=mock.ANY, output_dict=mock.ANY, exec_properties=c2.exec_properties) ] mock_airflow_component_class.assert_has_calls(component_calls, any_order=True)
def testConstructWithParameter(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train']) infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool) schema_gen = component.SchemaGen( statistics=channel_utils.as_channel([statistics_artifact]), infer_feature_shape=infer_shape) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertJsonEqual( str(schema_gen.spec.exec_properties['infer_feature_shape']), str(infer_shape))
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] schema_gen = component.SchemaGen(statistics=channel_utils.as_channel( [statistics_artifact]), exclude_splits=exclude_splits) self.assertEqual(standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs['schema'].type_name) self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape']) self.assertEqual(schema_gen.spec.exec_properties['exclude_splits'], '["eval"]')
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] schema_gen = component.SchemaGen(statistics=channel_utils.as_channel( [statistics_artifact]), exclude_splits=exclude_splits) self.assertEqual( standard_artifacts.Schema.TYPE_NAME, schema_gen.outputs[standard_component_specs.SCHEMA_KEY].type_name) self.assertTrue(schema_gen.spec.exec_properties[ standard_component_specs.INFER_FEATURE_SHAPE_KEY]) self.assertEqual( schema_gen.spec.exec_properties[ standard_component_specs.EXCLUDE_SPLITS_KEY], '["eval"]')
def __init__(self, stats: dsl.PipelineParam): component = schema_gen_component.SchemaGen( channel.Channel('ExampleStatisticsPath')) super().__init__(component, {"stats": stats})
def test_construct(self): schema_gen = component.SchemaGen(stats=channel.as_channel([ types.TfxArtifact(type_name='ExampleStatisticsPath', split='train') ])) self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
def testConstructWithNeitherStatsNorSchema(self): with self.assertRaises(ValueError): _ = component.SchemaGen()
def testConstructWithBothStatsAndSchema(self): with self.assertRaises(ValueError): _ = component.SchemaGen(stats=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')]), schema=channel_utils.as_channel( [standard_artifacts.Schema()]))
def testConstructWithSchema(self): schema_gen = component.SchemaGen( schema=channel_utils.as_channel([standard_artifacts.Schema()])) self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
def test_construct(self): schema_gen = component.SchemaGen(stats=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')])) self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)