Beispiel #1
0
    def testAirflowComponent(self, mock_python_operator_init):
        mock_component_launcher_class = mock.Mock()
        airflow_component.AirflowComponent(
            parent_dag=self._parent_dag,
            component=self._component,
            component_launcher_class=mock_component_launcher_class,
            pipeline_info=self._pipeline_info,
            enable_cache=True,
            metadata_connection_config=self._metadata_connection_config,
            beam_pipeline_args=[],
            additional_pipeline_args={},
            component_config=None)

        mock_python_operator_init.assert_called_once_with(
            task_id=self._component.id,
            provide_context=True,
            python_callable=mock.ANY,
            dag=self._parent_dag)

        python_callable = mock_python_operator_init.call_args_list[0][1][
            'python_callable']
        self.assertEqual(python_callable.func,
                         airflow_component._airflow_component_launcher)
        self.assertTrue(
            python_callable.keywords.pop('driver_args').enable_cache)
        self.assertEqual(
            python_callable.keywords, {
                'component': self._component,
                'component_launcher_class': mock_component_launcher_class,
                'pipeline_info': self._pipeline_info,
                'metadata_connection_config': self._metadata_connection_config,
                'beam_pipeline_args': [],
                'additional_pipeline_args': {},
                'component_config': None,
            })
 def testAirflowComponent(self, mock_functools_partial):
   mock_component_launcher_class = mock.Mock()
   airflow_component.AirflowComponent(
       parent_dag=self._parent_dag,
       component=self._component,
       component_launcher_class=mock_component_launcher_class,
       pipeline_info=self._pipeline_info,
       enable_cache=True,
       metadata_connection_config=self._metadata_connection_config,
       beam_pipeline_args=[],
       additional_pipeline_args={},
       component_config=None)
   # Airflow complained if we completely mock this function. So we "wraps" the
   # function. `partial` can be called multiple times from other than
   # AirflowComponent. We will check the first call only.
   mock_functools_partial.assert_called()
   args = mock_functools_partial.call_args_list[0][0]
   kwargs = mock_functools_partial.call_args_list[0][1]
   self.assertCountEqual(args,
                         (airflow_component._airflow_component_launcher,))
   self.assertTrue(kwargs.pop('driver_args').enable_cache)
   self.assertEqual(
       kwargs, {
           'component': self._component,
           'component_launcher_class': mock_component_launcher_class,
           'pipeline_info': self._pipeline_info,
           'metadata_connection_config': self._metadata_connection_config,
           'beam_pipeline_args': [],
           'additional_pipeline_args': {},
           'component_config': None
       })
Beispiel #3
0
    def run(self, tfx_pipeline: pipeline.Pipeline):
        """Deploys given logical pipeline on Airflow.

    Args:
      tfx_pipeline: Logical pipeline containing pipeline args and components.

    Returns:
      An Airflow DAG.
    """

        # Merge airflow-specific configs with pipeline args
        airflow_dag = models.DAG(
            dag_id=tfx_pipeline.pipeline_info.pipeline_name,
            **(typing.cast(AirflowPipelineConfig,
                           self._config).airflow_dag_config))
        if 'tmp_dir' not in tfx_pipeline.additional_pipeline_args:
            tmp_dir = os.path.join(tfx_pipeline.pipeline_info.pipeline_root,
                                   '.temp', '')
            tfx_pipeline.additional_pipeline_args['tmp_dir'] = tmp_dir

        component_impl_map = {}
        for tfx_component in tfx_pipeline.components:
            # TODO(b/187122662): Pass through pip dependencies as a first-class
            # component flag.
            if isinstance(tfx_component, base_component.BaseComponent):
                tfx_component._resolve_pip_dependencies(  # pylint: disable=protected-access
                    tfx_pipeline.pipeline_info.pipeline_root)

            tfx_component = self._replace_runtime_params(tfx_component)

            (component_launcher_class,
             component_config) = config_utils.find_component_launch_info(
                 self._config, tfx_component)
            current_airflow_component = airflow_component.AirflowComponent(
                parent_dag=airflow_dag,
                component=tfx_component,
                component_launcher_class=component_launcher_class,
                pipeline_info=tfx_pipeline.pipeline_info,
                enable_cache=tfx_pipeline.enable_cache,
                metadata_connection_config=tfx_pipeline.
                metadata_connection_config,
                beam_pipeline_args=tfx_pipeline.beam_pipeline_args,
                additional_pipeline_args=tfx_pipeline.additional_pipeline_args,
                component_config=component_config)
            component_impl_map[tfx_component] = current_airflow_component
            for upstream_node in tfx_component.upstream_nodes:
                assert upstream_node in component_impl_map, (
                    'Components is not in '
                    'topological order')
                current_airflow_component.set_upstream(
                    component_impl_map[upstream_node])

        return airflow_dag
Beispiel #4
0
    def run(self, tfx_pipeline: pipeline.Pipeline):
        """Deploys given logical pipeline on Airflow.

    Args:
      tfx_pipeline: Logical pipeline containing pipeline args and components.

    Returns:
      An Airflow DAG.
    """

        # Merge airflow-specific configs with pipeline args
        airflow_dag = models.DAG(
            dag_id=tfx_pipeline.pipeline_info.pipeline_name,
            **self._config.airflow_dag_config)
        if 'tmp_dir' not in tfx_pipeline.additional_pipeline_args:
            tmp_dir = os.path.join(tfx_pipeline.pipeline_info.pipeline_root,
                                   '.temp', '')
            tfx_pipeline.additional_pipeline_args['tmp_dir'] = tmp_dir

        component_impl_map = {}
        for tfx_component in tfx_pipeline.components:

            tfx_component = self._replace_runtime_params(tfx_component)

            (component_launcher_class,
             component_config) = config_utils.find_component_launch_info(
                 self._config, tfx_component)
            current_airflow_component = airflow_component.AirflowComponent(
                airflow_dag,
                component=tfx_component,
                component_launcher_class=component_launcher_class,
                pipeline_info=tfx_pipeline.pipeline_info,
                enable_cache=tfx_pipeline.enable_cache,
                metadata_connection_config=tfx_pipeline.
                metadata_connection_config,
                beam_pipeline_args=tfx_pipeline.beam_pipeline_args,
                additional_pipeline_args=tfx_pipeline.additional_pipeline_args,
                component_config=component_config)
            component_impl_map[tfx_component] = current_airflow_component
            for upstream_node in tfx_component.upstream_nodes:
                assert upstream_node in component_impl_map, (
                    'Components is not in '
                    'topological order')
                current_airflow_component.set_upstream(
                    component_impl_map[upstream_node])

        return airflow_dag
Beispiel #5
0
 def test_airflow_component(self, mock_functools_partial):
   airflow_component.AirflowComponent(
       parent_dag=self._parent_dag,
       component=self._component,
       pipeline_info=self._pipeline_info,
       enable_cache=True,
       metadata_connection_config=self._metadata_connection_config,
       additional_pipeline_args={})
   mock_functools_partial.assert_called_once_with(
       airflow_component._airflow_component_launcher,
       component=self._component,
       pipeline_info=self._pipeline_info,
       driver_args=mock.ANY,
       metadata_connection_config=self._metadata_connection_config,
       additional_pipeline_args={})
   arg_list = mock_functools_partial.call_args_list
   self.assertTrue(arg_list[0][1]['driver_args'].enable_cache)