def testAirflowComponent(self, mock_python_operator_init): mock_component_launcher_class = mock.Mock() airflow_component.AirflowComponent( parent_dag=self._parent_dag, component=self._component, component_launcher_class=mock_component_launcher_class, pipeline_info=self._pipeline_info, enable_cache=True, metadata_connection_config=self._metadata_connection_config, beam_pipeline_args=[], additional_pipeline_args={}, component_config=None) mock_python_operator_init.assert_called_once_with( task_id=self._component.id, provide_context=True, python_callable=mock.ANY, dag=self._parent_dag) python_callable = mock_python_operator_init.call_args_list[0][1][ 'python_callable'] self.assertEqual(python_callable.func, airflow_component._airflow_component_launcher) self.assertTrue( python_callable.keywords.pop('driver_args').enable_cache) self.assertEqual( python_callable.keywords, { 'component': self._component, 'component_launcher_class': mock_component_launcher_class, 'pipeline_info': self._pipeline_info, 'metadata_connection_config': self._metadata_connection_config, 'beam_pipeline_args': [], 'additional_pipeline_args': {}, 'component_config': None, })
def testAirflowComponent(self, mock_functools_partial): mock_component_launcher_class = mock.Mock() airflow_component.AirflowComponent( parent_dag=self._parent_dag, component=self._component, component_launcher_class=mock_component_launcher_class, pipeline_info=self._pipeline_info, enable_cache=True, metadata_connection_config=self._metadata_connection_config, beam_pipeline_args=[], additional_pipeline_args={}, component_config=None) # Airflow complained if we completely mock this function. So we "wraps" the # function. `partial` can be called multiple times from other than # AirflowComponent. We will check the first call only. mock_functools_partial.assert_called() args = mock_functools_partial.call_args_list[0][0] kwargs = mock_functools_partial.call_args_list[0][1] self.assertCountEqual(args, (airflow_component._airflow_component_launcher,)) self.assertTrue(kwargs.pop('driver_args').enable_cache) self.assertEqual( kwargs, { 'component': self._component, 'component_launcher_class': mock_component_launcher_class, 'pipeline_info': self._pipeline_info, 'metadata_connection_config': self._metadata_connection_config, 'beam_pipeline_args': [], 'additional_pipeline_args': {}, 'component_config': None })
def run(self, tfx_pipeline: pipeline.Pipeline): """Deploys given logical pipeline on Airflow. Args: tfx_pipeline: Logical pipeline containing pipeline args and components. Returns: An Airflow DAG. """ # Merge airflow-specific configs with pipeline args airflow_dag = models.DAG( dag_id=tfx_pipeline.pipeline_info.pipeline_name, **(typing.cast(AirflowPipelineConfig, self._config).airflow_dag_config)) if 'tmp_dir' not in tfx_pipeline.additional_pipeline_args: tmp_dir = os.path.join(tfx_pipeline.pipeline_info.pipeline_root, '.temp', '') tfx_pipeline.additional_pipeline_args['tmp_dir'] = tmp_dir component_impl_map = {} for tfx_component in tfx_pipeline.components: # TODO(b/187122662): Pass through pip dependencies as a first-class # component flag. if isinstance(tfx_component, base_component.BaseComponent): tfx_component._resolve_pip_dependencies( # pylint: disable=protected-access tfx_pipeline.pipeline_info.pipeline_root) tfx_component = self._replace_runtime_params(tfx_component) (component_launcher_class, component_config) = config_utils.find_component_launch_info( self._config, tfx_component) current_airflow_component = airflow_component.AirflowComponent( parent_dag=airflow_dag, component=tfx_component, component_launcher_class=component_launcher_class, pipeline_info=tfx_pipeline.pipeline_info, enable_cache=tfx_pipeline.enable_cache, metadata_connection_config=tfx_pipeline. metadata_connection_config, beam_pipeline_args=tfx_pipeline.beam_pipeline_args, additional_pipeline_args=tfx_pipeline.additional_pipeline_args, component_config=component_config) component_impl_map[tfx_component] = current_airflow_component for upstream_node in tfx_component.upstream_nodes: assert upstream_node in component_impl_map, ( 'Components is not in ' 'topological order') current_airflow_component.set_upstream( component_impl_map[upstream_node]) return airflow_dag
def run(self, tfx_pipeline: pipeline.Pipeline): """Deploys given logical pipeline on Airflow. Args: tfx_pipeline: Logical pipeline containing pipeline args and components. Returns: An Airflow DAG. """ # Merge airflow-specific configs with pipeline args airflow_dag = models.DAG( dag_id=tfx_pipeline.pipeline_info.pipeline_name, **self._config.airflow_dag_config) if 'tmp_dir' not in tfx_pipeline.additional_pipeline_args: tmp_dir = os.path.join(tfx_pipeline.pipeline_info.pipeline_root, '.temp', '') tfx_pipeline.additional_pipeline_args['tmp_dir'] = tmp_dir component_impl_map = {} for tfx_component in tfx_pipeline.components: tfx_component = self._replace_runtime_params(tfx_component) (component_launcher_class, component_config) = config_utils.find_component_launch_info( self._config, tfx_component) current_airflow_component = airflow_component.AirflowComponent( airflow_dag, component=tfx_component, component_launcher_class=component_launcher_class, pipeline_info=tfx_pipeline.pipeline_info, enable_cache=tfx_pipeline.enable_cache, metadata_connection_config=tfx_pipeline. metadata_connection_config, beam_pipeline_args=tfx_pipeline.beam_pipeline_args, additional_pipeline_args=tfx_pipeline.additional_pipeline_args, component_config=component_config) component_impl_map[tfx_component] = current_airflow_component for upstream_node in tfx_component.upstream_nodes: assert upstream_node in component_impl_map, ( 'Components is not in ' 'topological order') current_airflow_component.set_upstream( component_impl_map[upstream_node]) return airflow_dag
def test_airflow_component(self, mock_functools_partial): airflow_component.AirflowComponent( parent_dag=self._parent_dag, component=self._component, pipeline_info=self._pipeline_info, enable_cache=True, metadata_connection_config=self._metadata_connection_config, additional_pipeline_args={}) mock_functools_partial.assert_called_once_with( airflow_component._airflow_component_launcher, component=self._component, pipeline_info=self._pipeline_info, driver_args=mock.ANY, metadata_connection_config=self._metadata_connection_config, additional_pipeline_args={}) arg_list = mock_functools_partial.call_args_list self.assertTrue(arg_list[0][1]['driver_args'].enable_cache)