def testPatcher(self, mock_run):
   patcher = airflow_dag_runner_patcher.AirflowDagRunnerPatcher()
   with patcher.patch() as context:
     airflow_dag_runner.AirflowDagRunner().run(
         tfx_pipeline.Pipeline(_PIPELINE_NAME, ''))
     mock_run.assert_called_once()
     self.assertEqual(context[patcher.PIPELINE_NAME], _PIPELINE_NAME)
Exemplo n.º 2
0
    def testAirflowDagRunnerInitBackwardCompatible(self):
        airflow_config = {
            'schedule_interval': '* * * * *',
            'start_date': datetime.datetime(2019, 1, 1)
        }

        runner = airflow_dag_runner.AirflowDagRunner(airflow_config)

        self.assertEqual(airflow_config, runner._config.airflow_dag_config)
Exemplo n.º 3
0
 def testRuntimeParamIntError(self):
     param = RuntimeParameter('name', int, 1)
     component_f = _FakeComponent(_FakeComponentSpecG(a=param))
     airflow_config = {
         'schedule_interval': '* * * * *',
         'start_date': datetime.datetime(2019, 1, 1)
     }
     test_pipeline = pipeline.Pipeline(pipeline_name='x',
                                       pipeline_root='y',
                                       metadata_connection_config=None,
                                       components=[component_f])
     with self.assertRaises(RuntimeError):
         airflow_dag_runner.AirflowDagRunner(
             airflow_dag_runner.AirflowPipelineConfig(
                 airflow_dag_config=airflow_config)).run(test_pipeline)
Exemplo n.º 4
0
    def testRuntimeParamTemplated(self):
        param = RuntimeParameter('a', str, '{{execution_date}}')
        component_f = _FakeComponent(_FakeComponentSpecF(a=param))
        airflow_config = {
            'schedule_interval': '* * * * *',
            'start_date': datetime.datetime(2019, 1, 1)
        }
        test_pipeline = pipeline.Pipeline(pipeline_name='x',
                                          pipeline_root='y',
                                          metadata_connection_config=None,
                                          components=[component_f])

        runner = airflow_dag_runner.AirflowDagRunner(
            airflow_dag_runner.AirflowPipelineConfig(
                airflow_dag_config=airflow_config))
        dag = runner.run(test_pipeline)
        task = dag.tasks[0]
        self.assertDictEqual(
            {
                'exec_properties': {
                    'a': '{{ dag_run.conf.get("a", execution_date) }}'
                }
            }, task.op_kwargs)
Exemplo n.º 5
0
    def testAirflowDagRunner(self, mock_airflow_dag_class,
                             mock_airflow_component_class):
        mock_airflow_dag_class.return_value = 'DAG'
        mock_airflow_component_a = mock.Mock()
        mock_airflow_component_b = mock.Mock()
        mock_airflow_component_c = mock.Mock()
        mock_airflow_component_d = mock.Mock()
        mock_airflow_component_e = mock.Mock()
        mock_airflow_component_class.side_effect = [
            mock_airflow_component_a, mock_airflow_component_b,
            mock_airflow_component_c, mock_airflow_component_d,
            mock_airflow_component_e
        ]

        airflow_config = {
            'schedule_interval': '* * * * *',
            'start_date': datetime.datetime(2019, 1, 1)
        }
        component_a = _FakeComponent(
            _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA)))
        component_b = _FakeComponent(
            _FakeComponentSpecB(a=component_a.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeB)))
        component_c = _FakeComponent(
            _FakeComponentSpecC(a=component_a.outputs['output'],
                                b=component_b.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeC)))
        component_d = _FakeComponent(
            _FakeComponentSpecD(b=component_b.outputs['output'],
                                c=component_c.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeD)))
        component_e = _FakeComponent(
            _FakeComponentSpecE(a=component_a.outputs['output'],
                                b=component_b.outputs['output'],
                                d=component_d.outputs['output'],
                                output=types.Channel(type=_ArtifactTypeE)))

        test_pipeline = pipeline.Pipeline(pipeline_name='x',
                                          pipeline_root='y',
                                          metadata_connection_config=None,
                                          components=[
                                              component_d, component_c,
                                              component_a, component_b,
                                              component_e
                                          ])
        runner = airflow_dag_runner.AirflowDagRunner(
            airflow_dag_runner.AirflowPipelineConfig(
                airflow_dag_config=airflow_config))
        runner.run(test_pipeline)

        mock_airflow_component_a.set_upstream.assert_not_called()
        mock_airflow_component_b.set_upstream.assert_has_calls(
            [mock.call(mock_airflow_component_a)])
        mock_airflow_component_c.set_upstream.assert_has_calls([
            mock.call(mock_airflow_component_a),
            mock.call(mock_airflow_component_b)
        ],
                                                               any_order=True)
        mock_airflow_component_d.set_upstream.assert_has_calls([
            mock.call(mock_airflow_component_b),
            mock.call(mock_airflow_component_c)
        ],
                                                               any_order=True)
        mock_airflow_component_e.set_upstream.assert_has_calls([
            mock.call(mock_airflow_component_a),
            mock.call(mock_airflow_component_b),
            mock.call(mock_airflow_component_d)
        ],
                                                               any_order=True)