def testPatcher(self, mock_run): patcher = airflow_dag_runner_patcher.AirflowDagRunnerPatcher() with patcher.patch() as context: airflow_dag_runner.AirflowDagRunner().run( tfx_pipeline.Pipeline(_PIPELINE_NAME, '')) mock_run.assert_called_once() self.assertEqual(context[patcher.PIPELINE_NAME], _PIPELINE_NAME)
def testAirflowDagRunnerInitBackwardCompatible(self): airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } runner = airflow_dag_runner.AirflowDagRunner(airflow_config) self.assertEqual(airflow_config, runner._config.airflow_dag_config)
def testRuntimeParamIntError(self): param = RuntimeParameter('name', int, 1) component_f = _FakeComponent(_FakeComponentSpecG(a=param)) airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } test_pipeline = pipeline.Pipeline(pipeline_name='x', pipeline_root='y', metadata_connection_config=None, components=[component_f]) with self.assertRaises(RuntimeError): airflow_dag_runner.AirflowDagRunner( airflow_dag_runner.AirflowPipelineConfig( airflow_dag_config=airflow_config)).run(test_pipeline)
def testRuntimeParamTemplated(self): param = RuntimeParameter('a', str, '{{execution_date}}') component_f = _FakeComponent(_FakeComponentSpecF(a=param)) airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } test_pipeline = pipeline.Pipeline(pipeline_name='x', pipeline_root='y', metadata_connection_config=None, components=[component_f]) runner = airflow_dag_runner.AirflowDagRunner( airflow_dag_runner.AirflowPipelineConfig( airflow_dag_config=airflow_config)) dag = runner.run(test_pipeline) task = dag.tasks[0] self.assertDictEqual( { 'exec_properties': { 'a': '{{ dag_run.conf.get("a", execution_date) }}' } }, task.op_kwargs)
def testAirflowDagRunner(self, mock_airflow_dag_class, mock_airflow_component_class): mock_airflow_dag_class.return_value = 'DAG' mock_airflow_component_a = mock.Mock() mock_airflow_component_b = mock.Mock() mock_airflow_component_c = mock.Mock() mock_airflow_component_d = mock.Mock() mock_airflow_component_e = mock.Mock() mock_airflow_component_class.side_effect = [ mock_airflow_component_a, mock_airflow_component_b, mock_airflow_component_c, mock_airflow_component_d, mock_airflow_component_e ] airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } component_a = _FakeComponent( _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA))) component_b = _FakeComponent( _FakeComponentSpecB(a=component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeB))) component_c = _FakeComponent( _FakeComponentSpecC(a=component_a.outputs['output'], b=component_b.outputs['output'], output=types.Channel(type=_ArtifactTypeC))) component_d = _FakeComponent( _FakeComponentSpecD(b=component_b.outputs['output'], c=component_c.outputs['output'], output=types.Channel(type=_ArtifactTypeD))) component_e = _FakeComponent( _FakeComponentSpecE(a=component_a.outputs['output'], b=component_b.outputs['output'], d=component_d.outputs['output'], output=types.Channel(type=_ArtifactTypeE))) test_pipeline = pipeline.Pipeline(pipeline_name='x', pipeline_root='y', metadata_connection_config=None, components=[ component_d, component_c, component_a, component_b, component_e ]) runner = airflow_dag_runner.AirflowDagRunner( airflow_dag_runner.AirflowPipelineConfig( airflow_dag_config=airflow_config)) runner.run(test_pipeline) mock_airflow_component_a.set_upstream.assert_not_called() mock_airflow_component_b.set_upstream.assert_has_calls( [mock.call(mock_airflow_component_a)]) mock_airflow_component_c.set_upstream.assert_has_calls([ mock.call(mock_airflow_component_a), mock.call(mock_airflow_component_b) ], any_order=True) mock_airflow_component_d.set_upstream.assert_has_calls([ mock.call(mock_airflow_component_b), mock.call(mock_airflow_component_c) ], any_order=True) mock_airflow_component_e.set_upstream.assert_has_calls([ mock.call(mock_airflow_component_a), mock.call(mock_airflow_component_b), mock.call(mock_airflow_component_d) ], any_order=True)