def run(self, pipeline): """Deploys given logical pipeline on Airflow. Args: pipeline: Logical pipeline containing pipeline args and components. Returns: An Airflow DAG. """ # Merge airflow-specific configs with pipeline args self._config.update(pipeline.pipeline_args) airflow_dag = airflow_pipeline.AirflowPipeline(**self._config) # For every components in logical pipeline, add in real component. for component in pipeline.components: airflow_component.Component( airflow_dag, component_name=component.component_name, unique_name=component.unique_name, driver=component.driver, executor=component.executor, input_dict=self._prepare_input_dict(component.input_dict), output_dict=self._prepare_output_dict(component.outputs), exec_properties=component.exec_properties) return airflow_dag
def test_airflow_component(self, mock_airflow_adapter_class, mock_python_operator_class, mock_branch_python_operator_class): mock_airflow_adapter = mock.Mock() mock_airflow_adapter.check_cache_and_maybe_prepare_execution = 'check_cache' mock_airflow_adapter.python_exec = 'python_exec' mock_airflow_adapter.publish_exec = 'publish_exec' mock_airflow_adapter_class.return_value = mock_airflow_adapter # Ensure the new component is added to the dag component_count = len(self.parent_dag.subdags) _ = airflow_component.Component( parent_dag=self.parent_dag, component_name='test_component', unique_name='test_component_unique_name', driver=driver.Driver, executor=executor.Executor, input_dict=self.input_dict, output_dict=self.output_dict, exec_properties=self.exec_properties) self.assertEqual(len(self.parent_dag.subdags), component_count + 1)
def test_airflow_component(self, mock_tfx_worker_class, mock_airflow_adapter_class): mock_airflow_adapter = mock.Mock() mock_airflow_adapter.check_cache_and_maybe_prepare_execution = 'check_cache' mock_airflow_adapter.python_exec = 'python_exec' mock_airflow_adapter.publish_exec = 'publish_exec' mock_airflow_adapter_class.return_value = mock_airflow_adapter mock_tfx_worker_class.return_value = models.DAG( dag_id='pipeline_name.component_name.unique_name', start_date=datetime.datetime(2019, 1, 1)) component = airflow_component.Component( parent_dag=self.parent_dag, component_name='component_name', unique_name='unique_name', driver=None, executor=None, input_dict=self.input_dict, output_dict=self.output_dict, exec_properties=self.exec_properties) mock_tfx_worker_class.assert_called_with( component_name='component_name', task_id='pipeline_name.component_name.unique_name', parent_dag=self.parent_dag, input_dict=self.input_dict, output_dict=self.output_dict, exec_properties=self.exec_properties, driver_options=mock.ANY, driver_class=None, executor_class=None, additional_pipeline_args=None, metadata_connection_config=self.parent_dag. metadata_connection_config, logger_config=mock.ANY) self.assertItemsEqual(component.upstream_list, [])