def test_check_for_stalled_adopted_tasks(self): exec_date = timezone.utcnow() - timedelta(minutes=40) start_date = timezone.utcnow() - timedelta(days=2) queued_dttm = timezone.utcnow() - timedelta(minutes=30) try_number = 1 with DAG("test_check_for_stalled_adopted_tasks") as dag: task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) executor = celery_executor.CeleryExecutor() executor.adopted_task_timeouts = { key_1: queued_dttm + executor.task_adoption_timeout, key_2: queued_dttm + executor.task_adoption_timeout } executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} executor.sync() self.assertEqual(executor.event_buffer, { key_1: (State.FAILED, None), key_2: (State.FAILED, None) }) self.assertEqual(executor.tasks, {}) self.assertEqual(executor.adopted_task_timeouts, {})
def _labels_to_key(self, labels: Dict[str, str]) -> Optional[TaskInstanceKey]: try_num = 1 try: try_num = int(labels.get('try_number', '1')) except ValueError: self.log.warning("could not get try_number as an int: %s", labels.get('try_number', '1')) try: dag_id = labels['dag_id'] task_id = labels['task_id'] ex_time = self._label_safe_datestring_to_datetime( labels['execution_date']) except Exception as e: # pylint: disable=broad-except self.log.warning( 'Error while retrieving labels; labels: %s; exception: %s', labels, e) return None with create_session() as session: task = (session.query(TaskInstance).filter_by( task_id=task_id, dag_id=dag_id, execution_date=ex_time).one_or_none()) if task: self.log.info( 'Found matching task %s-%s (%s) with current state of %s', task.dag_id, task.task_id, task.execution_date, task.state) return TaskInstanceKey(dag_id, task_id, ex_time, try_num) else: self.log.warning( 'task_id/dag_id are not safe to use as Kubernetes labels. This can cause ' 'severe performance regressions. Please see ' '<https://kubernetes.io/docs/concepts/overview/working-with-objects' '/labels/#syntax-and-character-set>. ' 'Given dag_id: %s, task_id: %s', task_id, dag_id) tasks = (session.query(TaskInstance).filter_by( execution_date=ex_time).all()) self.log.info('Checking %s task instances.', len(tasks)) for task in tasks: if (pod_generator.make_safe_label_value(task.dag_id) == dag_id and pod_generator.make_safe_label_value(task.task_id) == task_id and task.execution_date == ex_time): self.log.info( 'Found matching task %s-%s (%s) with current state of %s', task.dag_id, task.task_id, task.execution_date, task.state) dag_id = task.dag_id task_id = task.task_id return TaskInstanceKey(dag_id, task_id, ex_time, try_num) self.log.warning( 'Failed to find and match task details to a pod; labels: %s', labels) return None
def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInstanceKey]: dag_id = annotations['dag_id'] task_id = annotations['task_id'] try_number = int(annotations['try_number']) execution_date = parser.parse(annotations['execution_date']) return TaskInstanceKey(dag_id, task_id, execution_date, try_number)
def annotations_to_key(annotations: Dict[str, str]) -> Optional[TaskInstanceKey]: """Build a TaskInstanceKey based on pod annotations""" log.debug("Creating task key for annotations %s", annotations) dag_id = annotations['dag_id'] task_id = annotations['task_id'] try_number = int(annotations['try_number']) run_id = annotations.get('run_id') if not run_id and 'execution_date' in annotations: # Compat: Look up the run_id from the TI table! from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.settings import Session execution_date = pendulum.parse(annotations['execution_date']) # Do _not_ use create-session, we don't want to expunge session = Session() run_id: str = ( session.query(TaskInstance.run_id) .join(TaskInstance.dag_run) .filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id, DagRun.execution_date == execution_date, ) .scalar() ) return TaskInstanceKey(dag_id, task_id, run_id, try_number)
def test_get_event_buffer(self): executor = BaseExecutor() date = datetime.utcnow() try_number = 1 key1 = TaskInstanceKey("my_dag1", "my_task1", date, try_number) key2 = TaskInstanceKey("my_dag2", "my_task1", date, try_number) key3 = TaskInstanceKey("my_dag2", "my_task2", date, try_number) state = State.SUCCESS executor.event_buffer[key1] = state, None executor.event_buffer[key2] = state, None executor.event_buffer[key3] = state, None self.assertEqual(len(executor.get_event_buffer(("my_dag1", ))), 1) self.assertEqual(len(executor.get_event_buffer()), 2) self.assertEqual(len(executor.event_buffer), 0)
def _annotations_to_key(self, annotations: Dict[str, str]) -> Optional[TaskInstanceKey]: self.log.debug("Creating task key for annotations %s", annotations) dag_id = annotations['dag_id'] task_id = annotations['task_id'] try_number = int(annotations['try_number']) execution_date = parser.parse(annotations['execution_date']) return TaskInstanceKey(dag_id, task_id, execution_date, try_number)
def mock_task_fail(self, dag_id, task_id, date, try_number=1): """ Set the mock outcome of running this particular task instances to FAILED. If the task identified by the tuple ``(dag_id, task_id, date, try_number)`` is run by this executor it's state will be FAILED. """ self.mock_task_results[TaskInstanceKey(dag_id, task_id, date, try_number)] = State.FAILED
def test_try_adopt_task_instances(self): exec_date = timezone.utcnow() - timedelta(minutes=2) start_date = timezone.utcnow() - timedelta(days=2) queued_dttm = timezone.utcnow() - timedelta(minutes=1) try_number = 1 with DAG("test_try_adopt_task_instances_none") as dag: task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) ti1 = TaskInstance(task=task_1, execution_date=exec_date) ti1.external_executor_id = '231' ti1.queued_dttm = queued_dttm ti2 = TaskInstance(task=task_2, execution_date=exec_date) ti2.external_executor_id = '232' ti2.queued_dttm = queued_dttm tis = [ti1, ti2] executor = celery_executor.CeleryExecutor() self.assertEqual(executor.running, set()) self.assertEqual(executor.adopted_task_timeouts, {}) self.assertEqual(executor.tasks, {}) not_adopted_tis = executor.try_adopt_task_instances(tis) key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) self.assertEqual(executor.running, {key_1, key_2}) self.assertEqual( dict(executor.adopted_task_timeouts), { key_1: queued_dttm + executor.task_adoption_timeout, key_2: queued_dttm + executor.task_adoption_timeout, }, ) self.assertEqual(executor.tasks, { key_1: AsyncResult("231"), key_2: AsyncResult("232") }) self.assertEqual(not_adopted_tis, [])
def _schedule_task(self, scheduling_event: TaskSchedulingEvent): task_key = TaskInstanceKey(scheduling_event.dag_id, scheduling_event.task_id, scheduling_event.execution_date, scheduling_event.try_number) self.executor.schedule_task(task_key, scheduling_event.action)
def _send_message(self, ti): self.send_message( TaskInstanceKey(ti.dag_id, ti.task_id, ti.execution_date, ti.try_number))