Ejemplo n.º 1
0
    def test_cached_execution(self, mock_metadata_class, mock_driver_class,
                              mock_executor_class, mock_docker_operator_class,
                              mock_get_logger):
        self._setup_mocks(mock_metadata_class, mock_driver_class,
                          mock_executor_class, mock_docker_operator_class,
                          mock_get_logger)
        adapter, input_dict, output_dict, exec_properties, driver_options = self._setup_adapter_and_args(
        )

        self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json]

        self.mock_driver.prepare_execution.return_value = base_driver.ExecutionDecision(
            input_dict, output_dict, exec_properties)

        check_result = adapter.check_cache_and_maybe_prepare_execution(
            'cached_branch', 'uncached_branch', ti=self.mock_task_instance)

        mock_get_logger.assert_called_with(self._logger_config)
        mock_driver_class.assert_called_with(
            logger=mock.ANY, metadata_handler=self.mock_metadata)
        self.mock_driver.prepare_execution.called_with(input_dict, output_dict,
                                                       exec_properties,
                                                       driver_options)
        self.mock_task_instance.xcom_pull.assert_called_with(
            dag_id='input_one_component_id', key='input_one_key')
        self.mock_task_instance.xcom_push.assert_called_with(
            key='output_one_key', value=self.output_one_json)

        self.assertEqual(check_result, 'cached_branch')
Ejemplo n.º 2
0
    def test_new_execution(self, mock_metadata_class, mock_driver_class,
                           mock_executor_class, mock_get_logger):
        self._setup_mocks(mock_metadata_class, mock_driver_class,
                          mock_executor_class, mock_get_logger)
        adapter, input_dict, output_dict, exec_properties, driver_options = self._setup_adapter_and_args(
        )

        self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json]

        self.mock_driver.prepare_execution.return_value = base_driver.ExecutionDecision(
            input_dict, output_dict, exec_properties, execution_id=12345)

        check_result = adapter.check_cache_and_maybe_prepare_execution(
            'cached_branch', 'uncached_branch', ti=self.mock_task_instance)

        mock_driver_class.assert_called_with(
            logger=mock.ANY, metadata_handler=self.mock_metadata)
        self.mock_driver.prepare_execution.called_with(input_dict, output_dict,
                                                       exec_properties,
                                                       driver_options)
        self.mock_task_instance.xcom_pull.assert_called_with(
            dag_id='input_one_component_id', key='input_one_key')

        calls = [
            mock.call(key='_exec_inputs',
                      value=types.jsonify_tfx_type_dict(input_dict)),
            mock.call(key='_exec_outputs',
                      value=types.jsonify_tfx_type_dict(output_dict)),
            mock.call(key='_exec_properties',
                      value=json.dumps(exec_properties)),
            mock.call(key='_execution_id', value=12345)
        ]
        self.mock_task_instance.xcom_push.assert_has_calls(calls)

        self.assertEqual(check_result, 'uncached_branch')