def test_cached_execution(self, mock_metadata_class, mock_driver_class, mock_executor_class, mock_docker_operator_class, mock_get_logger): self._setup_mocks(mock_metadata_class, mock_driver_class, mock_executor_class, mock_docker_operator_class, mock_get_logger) adapter, input_dict, output_dict, exec_properties, driver_options = self._setup_adapter_and_args( ) self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json] self.mock_driver.prepare_execution.return_value = base_driver.ExecutionDecision( input_dict, output_dict, exec_properties) check_result = adapter.check_cache_and_maybe_prepare_execution( 'cached_branch', 'uncached_branch', ti=self.mock_task_instance) mock_get_logger.assert_called_with(self._logger_config) mock_driver_class.assert_called_with( logger=mock.ANY, metadata_handler=self.mock_metadata) self.mock_driver.prepare_execution.called_with(input_dict, output_dict, exec_properties, driver_options) self.mock_task_instance.xcom_pull.assert_called_with( dag_id='input_one_component_id', key='input_one_key') self.mock_task_instance.xcom_push.assert_called_with( key='output_one_key', value=self.output_one_json) self.assertEqual(check_result, 'cached_branch')
def test_new_execution(self, mock_metadata_class, mock_driver_class, mock_executor_class, mock_get_logger): self._setup_mocks(mock_metadata_class, mock_driver_class, mock_executor_class, mock_get_logger) adapter, input_dict, output_dict, exec_properties, driver_options = self._setup_adapter_and_args( ) self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json] self.mock_driver.prepare_execution.return_value = base_driver.ExecutionDecision( input_dict, output_dict, exec_properties, execution_id=12345) check_result = adapter.check_cache_and_maybe_prepare_execution( 'cached_branch', 'uncached_branch', ti=self.mock_task_instance) mock_driver_class.assert_called_with( logger=mock.ANY, metadata_handler=self.mock_metadata) self.mock_driver.prepare_execution.called_with(input_dict, output_dict, exec_properties, driver_options) self.mock_task_instance.xcom_pull.assert_called_with( dag_id='input_one_component_id', key='input_one_key') calls = [ mock.call(key='_exec_inputs', value=types.jsonify_tfx_type_dict(input_dict)), mock.call(key='_exec_outputs', value=types.jsonify_tfx_type_dict(output_dict)), mock.call(key='_exec_properties', value=json.dumps(exec_properties)), mock.call(key='_execution_id', value=12345) ] self.mock_task_instance.xcom_push.assert_has_calls(calls) self.assertEqual(check_result, 'uncached_branch')