Beispiel #1
0
    def test_tfx_workflow_docker(self, mock_airflow_adapter_class,
                                 mock_dummy_operator_class,
                                 mock_python_operator_class,
                                 mock_branch_python_operator_class):
        mock_airflow_adapter = mock.Mock()
        mock_airflow_adapter.check_cache_and_maybe_prepare_execution = 'check_cache'
        mock_airflow_adapter.python_exec = 'python_exec'
        mock_airflow_adapter.publish_exec = 'publish_exec'
        mock_airflow_adapter_class.return_value = mock_airflow_adapter
        mock_dummy_operator_class.side_effect = [self.publishcache_op]
        mock_python_operator_class.side_effect = [self.publishexec_op]
        mock_branch_python_operator_class.side_effect = [self.checkcache_op]
        mock_airflow_adapter.docker_operator.return_value = self.tfx_docker_op

        self.parent_dag.docker_operator_cfg = {'volumes': ['test_volume']}

        tfx_worker = airflow_component._TfxWorker(
            component_name='component_name',
            task_id='my_component',
            parent_dag=self.parent_dag,
            input_dict={},
            output_dict={},
            exec_properties={},
            driver_options={},
            driver_class=None,
            executor_class=None,
            additional_pipeline_args=None,
            metadata_connection_config=None,
            logger_config=self._logger_config)

        self.assertItemsEqual(self.checkcache_op.upstream_list, [])
        self.assertItemsEqual(self.tfx_docker_op.upstream_list,
                              [self.checkcache_op])
        self.assertItemsEqual(self.publishexec_op.upstream_list,
                              [self.tfx_docker_op])
        self.assertItemsEqual(self.publishcache_op.upstream_list,
                              [self.checkcache_op])

        mock_branch_python_operator_class.assert_called_with(
            task_id='my_component.checkcache',
            provide_context=True,
            python_callable='check_cache',
            op_kwargs={
                'uncached_branch': 'my_component.exec',
                'cached_branch': 'my_component.publishcache',
            },
            dag=tfx_worker)
        mock_dummy_operator_class.assert_called_with(
            task_id='my_component.publishcache', dag=tfx_worker)
        mock_python_operator_class.assert_called_with(
            task_id='my_component.publishexec',
            provide_context=True,
            python_callable='publish_exec',
            op_kwargs={
                'cache_task_name': 'my_component.checkcache',
                'exec_task_name': 'my_component.exec',
            },
            dag=tfx_worker)
        mock_airflow_adapter.docker_operator.assert_called_with(
            task_id='my_component.exec',
            pusher_task='my_component.checkcache',
            parent_dag=tfx_worker,
            docker_operator_cfg=self.parent_dag.docker_operator_cfg)
Beispiel #2
0
    def test_tfx_workflow(self, mock_airflow_adapter_class,
                          mock_dummy_operator_class,
                          mock_python_operator_class,
                          mock_branch_python_operator_class):
        mock_airflow_adapter = mock.Mock()
        mock_airflow_adapter.check_cache_and_maybe_prepare_execution = 'check_cache'
        mock_airflow_adapter.python_exec = 'python_exec'
        mock_airflow_adapter.publish_exec = 'publish_exec'
        mock_airflow_adapter_class.return_value = mock_airflow_adapter
        mock_dummy_operator_class.side_effect = [self.noop_sink_op]
        mock_python_operator_class.side_effect = [
            self.tfx_python_op, self.publishexec_op
        ]
        mock_branch_python_operator_class.side_effect = [self.checkcache_op]
        tfx_worker = airflow_component._TfxWorker(
            component_name='component_name',
            task_id='my_component',
            parent_dag=self.parent_dag,
            input_dict=self.input_dict,
            output_dict=self.output_dict,
            exec_properties=self.exec_properties,
            driver_options=self.driver_options,
            driver_class=None,
            executor_class=None,
            additional_pipeline_args=None,
            metadata_connection_config=None,
            logger_config=self._logger_config)

        self.assertItemsEqual(self.checkcache_op.upstream_list, [])
        self.assertItemsEqual(self.tfx_python_op.upstream_list,
                              [self.checkcache_op])
        self.assertItemsEqual(self.publishexec_op.upstream_list,
                              [self.tfx_python_op])
        self.assertItemsEqual(self.noop_sink_op.upstream_list,
                              [self.checkcache_op])

        mock_airflow_adapter_class.assert_called_with(
            component_name='component_name',
            input_dict=self.input_dict,
            output_dict=self.output_dict,
            exec_properties=self.exec_properties,
            driver_options=self.driver_options,
            driver_class=None,
            executor_class=None,
            additional_pipeline_args=None,
            metadata_connection_config=None,
            logger_config=self._logger_config)

        mock_branch_python_operator_class.assert_called_with(
            task_id='my_component.checkcache',
            provide_context=True,
            python_callable='check_cache',
            op_kwargs={
                'uncached_branch': 'my_component.exec',
                'cached_branch': 'my_component.noop_sink',
            },
            dag=tfx_worker)

        mock_dummy_operator_class.assert_called_with(
            task_id='my_component.noop_sink', dag=tfx_worker)

        python_operator_calls = [
            mock.call(task_id='my_component.exec',
                      provide_context=True,
                      python_callable='python_exec',
                      op_kwargs={
                          'cache_task_name': 'my_component.checkcache',
                      },
                      dag=tfx_worker),
            mock.call(task_id='my_component.publishexec',
                      provide_context=True,
                      python_callable='publish_exec',
                      op_kwargs={
                          'cache_task_name': 'my_component.checkcache',
                          'exec_task_name': 'my_component.exec',
                      },
                      dag=tfx_worker)
        ]
        mock_python_operator_class.assert_has_calls(python_operator_calls)
  def test_tfx_workflow_non_docker(
      self, mock_airflow_adapter_class, mock_dummy_operator_class,
      mock_python_operator_class, mock_branch_python_operator_class):
    mock_airflow_adapter = mock.Mock()
    mock_airflow_adapter.check_cache_and_maybe_prepare_execution = 'check_cache'
    mock_airflow_adapter.python_exec = 'python_exec'
    mock_airflow_adapter.publish_exec = 'publish_exec'
    mock_airflow_adapter_class.return_value = mock_airflow_adapter
    mock_dummy_operator_class.side_effect = [self.publishcache_op]
    mock_python_operator_class.side_effect = [
        self.tfx_python_op, self.publishexec_op
    ]
    mock_branch_python_operator_class.side_effect = [self.checkcache_op]

    tfx_worker = airflow_component._TfxWorker(
        component_name='component_name',
        task_id='my_component',
        parent_dag=self.parent_dag,
        input_dict=self.input_dict,
        output_dict=self.output_dict,
        exec_properties=self.exec_properties,
        driver_options=self.driver_options,
        driver_class=None,
        executor_class=None,
        additional_pipeline_args=None,
        metadata_connection_config=None)

    self.assertItemsEqual(self.checkcache_op.upstream_list, [])
    self.assertItemsEqual(self.tfx_python_op.upstream_list,
                          [self.checkcache_op])
    self.assertItemsEqual(self.publishexec_op.upstream_list,
                          [self.tfx_python_op])
    self.assertItemsEqual(self.publishcache_op.upstream_list,
                          [self.checkcache_op])

    mock_airflow_adapter_class.assert_called_with(
        component_name='component_name',
        input_dict=self.input_dict,
        output_dict=self.output_dict,
        exec_properties=self.exec_properties,
        driver_options=self.driver_options,
        driver_class=None,
        executor_class=None,
        additional_pipeline_args=None,
        metadata_connection_config=None)

    mock_branch_python_operator_class.assert_called_with(
        task_id='my_component.checkcache',
        provide_context=True,
        python_callable='check_cache',
        op_kwargs={
            'uncached_branch': 'my_component.exec',
            'cached_branch': 'my_component.publishcache',
        },
        dag=tfx_worker)

    mock_dummy_operator_class.assert_called_with(
        task_id='my_component.publishcache', dag=tfx_worker)

    python_operator_calls = [
        mock.call(
            task_id='my_component.exec',
            provide_context=True,
            python_callable='python_exec',
            op_kwargs={
                'cache_task_name': 'my_component.checkcache',
            },
            dag=tfx_worker),
        mock.call(
            task_id='my_component.publishexec',
            provide_context=True,
            python_callable='publish_exec',
            op_kwargs={
                'cache_task_name': 'my_component.checkcache',
                'exec_task_name': 'my_component.exec',
            },
            dag=tfx_worker)
    ]
    mock_python_operator_class.assert_has_calls(python_operator_calls)