def get_link(self, operator, dttm):
     task_instance = TaskInstance(task=operator, execution_date=dttm)
     gcp_metadata_dict = task_instance.xcom_pull(task_ids=operator.task_id,
                                                 key="gcp_metadata")
     if not gcp_metadata_dict:
         return ''
     job_id = gcp_metadata_dict['job_id']
     project_id = gcp_metadata_dict['project_id']
     console_link = f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}"
     return console_link
Пример #2
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        data = {'called': False}

        def success_callback(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_success')
            data['called'] = True

        dag = DAG(dag_id='test_mark_success',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        task = DummyOperator(task_id='test_state_succeeded1',
                             dag=dag,
                             on_success_callback=success_callback)

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        job1.task_runner = StandardTaskRunner(job1)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)
        self.assertTrue(data['called'])
        process.join(timeout=10)
        self.assertFalse(process.is_alive())
Пример #3
0
    def test_mark_failure_on_failure_callback(self):
        """
        Test that ensures that mark_failure in the UI fails
        the task, and executes on_failure_callback
        """
        data = {'called': False}

        def check_failure(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure')
            data['called'] = True

        def task_function(ti):
            print("python_callable run in pid %s", os.getpid())
            with create_session() as session:
                self.assertEqual(State.RUNNING, ti.state)
                ti.log.info("Marking TI as failed 'externally'")
                ti.state = State.FAILED
                session.merge(ti)
                session.commit()

            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            data['reached_end_of_sleep'] = True

        with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
            task = PythonOperator(task_id='test_state_succeeded1',
                                  python_callable=task_function,
                                  on_failure_callback=check_failure)

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(run_id="test",
                          state=State.RUNNING,
                          execution_date=DEFAULT_DATE,
                          start_date=DEFAULT_DATE,
                          session=session)
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        with timeout(30):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callbable above bigger
            job1.run()

        ti.refresh_from_db()
        self.assertEqual(ti.state, State.FAILED)
        self.assertTrue(data['called'])
        self.assertNotIn(
            'reached_end_of_sleep', data,
            'Task should not have been allowed to run to completion')
    def test_init_with_template_cluster_label(self):
        dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
        task = QuboleOperator(task_id=TASK_ID,
                              dag=dag,
                              cluster_label='{{ params.cluster_label }}',
                              params={'cluster_label': 'default'})

        ti = TaskInstance(task, DEFAULT_DATE)
        ti.render_templates()

        self.assertEqual(task.cluster_label, 'default')
Пример #5
0
def test_get_dated_main_runner_handles_zero_shift():
    dag = DAG(dag_id='test_dag',
              start_date=datetime.strptime('2019-01-01', '%Y-%m-%d'))
    execution_date = datetime.strptime('2019-01-01',
                                       '%Y-%m-%d').replace(tzinfo=timezone.utc)
    main_func = PickleMock()
    runner = op_util.get_dated_main_runner_operator(dag, main_func,
                                                    timedelta(minutes=1))
    ti = TaskInstance(runner, execution_date)
    ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
    main_func.assert_called_with('2019-01-01')
Пример #6
0
    def test_try_adopt_task_instances(self):
        exec_date = timezone.utcnow() - timedelta(minutes=2)
        start_date = timezone.utcnow() - timedelta(days=2)
        queued_dttm = timezone.utcnow() - timedelta(minutes=1)

        try_number = 1

        with DAG("test_try_adopt_task_instances_none") as dag:
            task_1 = BaseOperator(task_id="task_1", start_date=start_date)
            task_2 = BaseOperator(task_id="task_2", start_date=start_date)

        ti1 = TaskInstance(task=task_1, execution_date=exec_date)
        ti1.external_executor_id = '231'
        ti1.queued_dttm = queued_dttm
        ti2 = TaskInstance(task=task_2, execution_date=exec_date)
        ti2.external_executor_id = '232'
        ti2.queued_dttm = queued_dttm

        tis = [ti1, ti2]
        executor = celery_executor.CeleryExecutor()
        self.assertEqual(executor.running, set())
        self.assertEqual(executor.adopted_task_timeouts, {})
        self.assertEqual(executor.tasks, {})

        not_adopted_tis = executor.try_adopt_task_instances(tis)

        key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date,
                                try_number)
        key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date,
                                try_number)
        self.assertEqual(executor.running, {key_1, key_2})
        self.assertEqual(
            dict(executor.adopted_task_timeouts), {
                key_1: queued_dttm + executor.task_adoption_timeout,
                key_2: queued_dttm + executor.task_adoption_timeout
            })
        self.assertEqual(executor.tasks, {
            key_1: AsyncResult("231"),
            key_2: AsyncResult("232")
        })
        self.assertEqual(not_adopted_tis, [])
Пример #7
0
    def verify_integrity(self, session=None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = []
        for ti in tis:
            task_ids.append(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '{}' for dag '{}'. "
                                     "Marking it as removed.".format(ti, dag))
                    Stats.incr(
                        "task_removed_from_dag.{}".format(dag.dag_id), 1, 1)
                    ti.state = State.REMOVED

            is_task_in_dag = task is not None
            should_restore_task = is_task_in_dag and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info("Restoring task '{}' which was previously "
                              "removed from DAG '{}'".format(ti, dag))
                Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
                ti.state = State.NONE

        # check for missing tasks
        for task in six.itervalues(dag.task_dict):
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(
                    "task_instance_created-{}".format(task.__class__.__name__),
                    1, 1)
                # add TaskState to db
                ti = TaskInstance(task, self.execution_date)
                ts = TaskState(ti)
                if task.event_met_handler() is not None:
                    ts.event_handler = task.event_met_handler()
                session.add(ti)
                session.add(ts)

        session.commit()
Пример #8
0
    def test_try_adopt_task_instances_none(self):
        date = datetime.utcnow()
        start_date = datetime.utcnow() - timedelta(days=2)

        with DAG("test_try_adopt_task_instances_none"):
            task_1 = BaseOperator(task_id="task_1", start_date=start_date)

        key1 = TaskInstance(task=task_1, execution_date=date)
        tis = [key1]
        executor = celery_executor.CeleryExecutor()

        self.assertEqual(executor.try_adopt_task_instances(tis), tis)
Пример #9
0
    def test_localtaskjob_maintain_heart_rate(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())

        # this should make sure we only heartbeat once and exit at the second
        # loop in _execute()
        return_codes = [None, 0]

        def multi_return_code():
            return return_codes.pop(0)

        time_start = time.time()
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_start:
            with patch.object(StandardTaskRunner,
                              'return_code') as mock_ret_code:
                mock_ret_code.side_effect = multi_return_code
                job1.run()
                self.assertEqual(mock_start.call_count, 1)
                self.assertEqual(mock_ret_code.call_count, 2)
        time_end = time.time()

        self.assertEqual(self.mock_base_job_sleep.call_count, 1)
        self.assertEqual(job1.state, State.SUCCESS)

        # Consider we have patched sleep call, it should not be sleeping to
        # keep up with the heart rate in other unpatched places
        #
        # We already make sure patched sleep call is only called once
        self.assertLess(time_end - time_start, job1.heartrate)
        session.close()
Пример #10
0
def test_should_continue_with_cp(load_dag):
    dag_bag = load_dag('bq_to_wrench')
    dag = dag_bag.get_dag('bq_to_wrench')
    table = 'staging.users'
    task = dag.get_task(f'continue_if_data_{table}')
    assert isinstance(task, BranchPythonOperator)
    ti = TaskInstance(task=task, execution_date=datetime.now())
    XCom.set(key=table,
             value={'has_data': True},
             task_id=task.task_id,
             dag_id=dag.dag_id,
             execution_date=ti.execution_date)

    task.execute(ti.get_template_context())
Пример #11
0
def test_sets_initial_checkpoint(load_dag, env, bigquery_helper):
    # Remove all checkpoints for table
    table = 'staging.users'
    bigquery_helper.query(
        f"DELETE FROM `{env['project']}.system.checkpoint` WHERE table = '{table}'"
    )

    # Execute get checkpoint task. I expect it to create an initial checkpoint.
    dag_bag = load_dag('bq_to_wrench')
    dag = dag_bag.get_dag('bq_to_wrench')
    task = dag.get_task(f'get_checkpoint_{table}')
    assert isinstance(task, GetCheckpointOperator)
    ti = TaskInstance(task=task, execution_date=datetime.now())
    task.execute(ti.get_template_context())
Пример #12
0
    def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes):
        unique_prefix = str(uuid.uuid4())
        dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE)
        task = DummyOperator(task_id='test_state_succeeded1', dag=dag)

        dag.clear()
        dag.create_dagrun(run_id=unique_prefix, state=State.NONE)

        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)

        mock_get_task_runner.return_value.return_code.side_effects = return_codes

        job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
        with assert_queries_count(13):
            job.run()
Пример #13
0
    def get_link(self, operator, dttm):
        """
        Get link to qubole command result page.

        :param operator: operator
        :param dttm: datetime
        :return: url link
        """
        ti = TaskInstance(task=operator, execution_date=dttm)
        conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id'])
        if conn and conn.host:
            host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
        else:
            host = 'https://api.qubole.com/v2/analyze?command_id='
        qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id')
        url = host + str(qds_command_id) if qds_command_id else ''
        return url
Пример #14
0
    def run(self,
            start_date=None,
            end_date=None,
            ignore_first_depends_on_past=False,
            ignore_ti_state=False,
            mark_success=False):
        """
        Run a set of task instances for a date range.
        """
        start_date = start_date or self.start_date
        end_date = end_date or self.end_date or timezone.utcnow()

        for dt in self.dag.date_range(start_date, end_date=end_date):
            TaskInstance(self, dt).run(
                mark_success=mark_success,
                ignore_depends_on_past=(dt == start_date
                                        and ignore_first_depends_on_past),
                ignore_ti_state=ignore_ti_state)
Пример #15
0
    def test_heartbeat_failed_fast(self):
        """
        Test that task heartbeat will sleep when it fails fast
        """
        self.mock_base_job_sleep.side_effect = time.sleep

        with create_session() as session:
            dagbag = DagBag(
                dag_folder=TEST_DAG_FOLDER,
                include_examples=False,
            )
            dag_id = 'test_heartbeat_failed_fast'
            task_id = 'test_heartbeat_failed_fast_op'
            dag = dagbag.get_dag(dag_id)
            task = dag.get_task(task_id)

            dag.create_dagrun(
                run_id="test_heartbeat_failed_fast_run",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
            ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
            ti.refresh_from_db()
            ti.state = State.RUNNING
            ti.hostname = get_hostname()
            ti.pid = 1
            session.commit()

            job = LocalTaskJob(task_instance=ti,
                               executor=MockExecutor(do_update=False))
            job.heartrate = 2
            heartbeat_records = []
            job.heartbeat_callback = lambda session: heartbeat_records.append(
                job.latest_heartbeat)
            job._execute()
            self.assertGreater(len(heartbeat_records), 2)
            for i in range(1, len(heartbeat_records)):
                time1 = heartbeat_records[i - 1]
                time2 = heartbeat_records[i]
                # Assert that difference small enough
                delta = (time2 - time1).total_seconds()
                self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
Пример #16
0
    def run(self,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            ignore_first_depends_on_past: bool = False,
            ignore_ti_state: bool = False,
            mark_success: bool = False) -> None:
        """
        Run a set of task instances for a date range.
        """
        start_date = start_date or self.start_date
        end_date = end_date or self.end_date or timezone.utcnow()

        for execution_date in self.dag.date_range(start_date,
                                                  end_date=end_date):
            TaskInstance(self, execution_date).run(
                mark_success=mark_success,
                ignore_depends_on_past=(execution_date == start_date
                                        and ignore_first_depends_on_past),
                ignore_ti_state=ignore_ti_state)
    def test_PyIdempaOp_idempotent(self):
        self.dag = DAG(
            TEST_DAG_ID,
            schedule_interval="@daily",
            default_args={"start_date": datetime.now()},
        )

        with TemporaryDirectory() as tempdir:
            output_path = f"{tempdir}/test_file.txt"

            self.assertFalse(
                os.path.exists(output_path))  # ensure doesn't already exist

            self.op = PythonIdempatomicFileOperator(
                dag=self.dag,
                task_id="test",
                output_pattern=output_path,
                python_callable=self.f,
            )
            self.ti = TaskInstance(task=self.op, execution_date=datetime.now())

            result = self.op.execute(self.ti.get_template_context())
            self.assertEqual(result, output_path)
            self.assertFalse(self.op.previously_completed)
            self.assertTrue(os.path.exists(output_path))

            with open(output_path, "r") as fout:
                self.assertEqual(fout.read(), "test")

            # now run task again
            result = self.op.execute(self.ti.get_template_context())
            self.assertEqual(result,
                             output_path)  # result will still give path
            self.assertTrue(self.op.previously_completed)

            # if function had run again, it would now be 'testtest'
            with open(output_path, "r") as fout:
                self.assertEqual(fout.read(), "test")

            # run function again to ensure 'testtest' is written to file upon second call
            self.f(output_path)
            with open(output_path, "r") as fout:
                self.assertEqual(fout.read(), "testtest")
Пример #18
0
    def test_mark_success_no_kill(self):
        """
        Test that ensures that mark_success in the UI doesn't cause
        the task to fail, and that the task exits
        """
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_mark_success')
        task = dag.get_task('task1')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        self.assertFalse(process.is_alive())
        ti.refresh_from_db()
        self.assertEqual(State.SUCCESS, ti.state)
    def test_get_redirect_url(self):
        dag = DAG(DAG_ID, start_date=DEFAULT_DATE)

        with dag:
            task = QuboleOperator(task_id=TASK_ID,
                                  qubole_conn_id=TEST_CONN,
                                  command_type='shellcmd',
                                  parameters="param1 param2",
                                  dag=dag)

        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.xcom_push('qbol_cmd_id', 12345)

        # check for positive case
        url = task.get_extra_links(DEFAULT_DATE, 'Go to QDS')
        self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345')

        # check for negative case
        url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS')
        self.assertEqual(url2, '')
    def skip(self, dag_run, execution_date, tasks, session=None):
        """
        Sets tasks instances to skipped from the same dag run.

        :param dag_run: the DagRun for which to set the tasks to skipped
        :param execution_date: execution_date
        :param tasks: tasks to skip (not task_ids)
        :param session: db session to use
        """
        if not tasks:
            return

        task_ids = [d.task_id for d in tasks]
        now = timezone.utcnow()

        if dag_run:
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == dag_run.dag_id,
                TaskInstance.execution_date == dag_run.execution_date,
                TaskInstance.task_id.in_(task_ids)).update(
                    {
                        TaskInstance.state: State.SKIPPED,
                        TaskInstance.start_date: now,
                        TaskInstance.end_date: now
                    },
                    synchronize_session=False)
            session.commit()
        else:
            if execution_date is None:
                raise ValueError("Execution date is None and no dag run")

            self.log.warning("No DAG RUN present this should not happen")
            # this is defensive against dag runs that are not complete
            for task in tasks:
                ti = TaskInstance(task, execution_date=execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            session.commit()
Пример #21
0
    def test_localtaskjob_double_trigger(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dr = dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_method:
            job1.run()
            mock_method.assert_not_called()

        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        self.assertEqual(ti.pid, 1)
        self.assertEqual(ti.state, State.RUNNING)

        session.close()
Пример #22
0
    def test_error_sending_task(self):
        def fake_execute_command():
            pass

        with _prepare_app(execute=fake_execute_command):
            # fake_execute_command takes no arguments while execute_command takes 1,
            # which will cause TypeError when calling task.apply_async()
            executor = celery_executor.CeleryExecutor()
            task = BashOperator(
                task_id="test",
                bash_command="true",
                dag=DAG(dag_id='id'),
                start_date=datetime.now()
            )
            when = datetime.now()
            value_tuple = 'command', 1, None, \
                SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now()))
            key = ('fail', 'fake_simple_ti', when, 0)
            executor.queued_tasks[key] = value_tuple
            executor.heartbeat()
        self.assertEqual(0, len(executor.queued_tasks), "Task should no longer be queued")
        self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0], State.FAILED)
Пример #23
0
    def _kill_zombies(self, dag, zombies, session):
        """
        copy paste from airflow.models.dagbag.DagBag.kill_zombies
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        for zombie in zombies:
            if zombie.task_id in dag.task_ids:
                task = dag.get_task(zombie.task_id)
                ti = TaskInstance(task, zombie.execution_date)
                # Get properties needed for failure handling from SimpleTaskInstance.
                ti.start_date = zombie.start_date
                ti.end_date = zombie.end_date
                ti.try_number = zombie.try_number
                ti.state = zombie.state
                # ti.test_mode = self.UNIT_TEST_MODE
                ti.handle_failure(
                    "{} detected as zombie".format(ti),
                    ti.test_mode,
                    ti.get_template_context(),
                )
                self.log.info("Marked zombie job %s as %s", ti, ti.state)
        session.commit()
Пример #24
0
    def expand_mapped_task(self, run_id: str, *,
                           session: Session) -> Sequence["TaskInstance"]:
        """Create the mapped task instances for mapped task.

        :return: The mapped task instances, in ascending order by map index.
        """
        from airflow.models.taskinstance import TaskInstance
        from airflow.settings import task_instance_mutation_hook

        total_length = functools.reduce(
            operator.mul,
            self._get_map_lengths(run_id, session=session).values())

        state: Optional[TaskInstanceState] = None
        unmapped_ti: Optional[TaskInstance] = (
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == self.dag_id,
                TaskInstance.task_id == self.task_id,
                TaskInstance.run_id == run_id,
                TaskInstance.map_index == -1,
                or_(TaskInstance.state.in_(State.unfinished),
                    TaskInstance.state.is_(None)),
            ).one_or_none())

        ret: List[TaskInstance] = []

        if unmapped_ti:
            # The unmapped task instance still exists and is unfinished, i.e. we
            # haven't tried to run it before.
            if total_length < 1:
                # If the upstream maps this to a zero-length value, simply marked the
                # unmapped task instance as SKIPPED (if needed).
                self.log.info(
                    "Marking %s as SKIPPED since the map has %d values to expand",
                    unmapped_ti,
                    total_length,
                )
                unmapped_ti.state = TaskInstanceState.SKIPPED
                session.flush()
                return ret
            # Otherwise convert this into the first mapped index, and create
            # TaskInstance for other indexes.
            unmapped_ti.map_index = 0
            state = unmapped_ti.state
            self.log.debug("Updated in place to become %s", unmapped_ti)
            ret.append(unmapped_ti)
            indexes_to_map = range(1, total_length)
        else:
            # Only create "missing" ones.
            current_max_mapping = (session.query(
                func.max(TaskInstance.map_index)).filter(
                    TaskInstance.dag_id == self.dag_id,
                    TaskInstance.task_id == self.task_id,
                    TaskInstance.run_id == run_id,
                ).scalar())
            indexes_to_map = range(current_max_mapping + 1, total_length)

        for index in indexes_to_map:
            # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
            # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
            ti = TaskInstance(self,
                              run_id=run_id,
                              map_index=index,
                              state=state)  # type: ignore
            self.log.debug("Expanding TIs upserted %s", ti)
            task_instance_mutation_hook(ti)
            ti = session.merge(ti)
            ti.task = self
            ret.append(ti)

        # Set to "REMOVED" any (old) TaskInstances with map indices greater
        # than the current map value
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == self.dag_id,
            TaskInstance.task_id == self.task_id,
            TaskInstance.run_id == run_id,
            TaskInstance.map_index >= total_length,
        ).update({TaskInstance.state: TaskInstanceState.REMOVED})

        session.flush()

        return ret
Пример #25
0
    def expand_mapped_task(
            self, run_id: str, *,
            session: Session) -> Tuple[Sequence["TaskInstance"], int]:
        """Create the mapped task instances for mapped task.

        :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the
            maximum map_index.
        """
        from airflow.models.taskinstance import TaskInstance
        from airflow.settings import task_instance_mutation_hook

        total_length: Optional[int]
        try:
            total_length = self._get_specified_expand_input(
            ).get_total_map_length(run_id, session=session)
        except NotFullyPopulated as e:
            self.log.info(
                "Cannot expand %r for run %s; missing upstream values: %s",
                self,
                run_id,
                sorted(e.missing),
            )
            total_length = None

        state: Optional[TaskInstanceState] = None
        unmapped_ti: Optional[TaskInstance] = (
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == self.dag_id,
                TaskInstance.task_id == self.task_id,
                TaskInstance.run_id == run_id,
                TaskInstance.map_index == -1,
                or_(TaskInstance.state.in_(State.unfinished),
                    TaskInstance.state.is_(None)),
            ).one_or_none())

        all_expanded_tis: List[TaskInstance] = []

        if unmapped_ti:
            # The unmapped task instance still exists and is unfinished, i.e. we
            # haven't tried to run it before.
            if total_length is None:
                # If the map length cannot be calculated (due to unavailable
                # upstream sources), fail the unmapped task.
                unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
                indexes_to_map: Iterable[int] = ()
            elif total_length < 1:
                # If the upstream maps this to a zero-length value, simply mark
                # the unmapped task instance as SKIPPED (if needed).
                self.log.info(
                    "Marking %s as SKIPPED since the map has %d values to expand",
                    unmapped_ti,
                    total_length,
                )
                unmapped_ti.state = TaskInstanceState.SKIPPED
                indexes_to_map = ()
            else:
                # Otherwise convert this into the first mapped index, and create
                # TaskInstance for other indexes.
                unmapped_ti.map_index = 0
                self.log.debug("Updated in place to become %s", unmapped_ti)
                all_expanded_tis.append(unmapped_ti)
                indexes_to_map = range(1, total_length)
            state = unmapped_ti.state
        elif not total_length:
            # Nothing to fixup.
            indexes_to_map = ()
        else:
            # Only create "missing" ones.
            current_max_mapping = (session.query(
                func.max(TaskInstance.map_index)).filter(
                    TaskInstance.dag_id == self.dag_id,
                    TaskInstance.task_id == self.task_id,
                    TaskInstance.run_id == run_id,
                ).scalar())
            indexes_to_map = range(current_max_mapping + 1, total_length)

        for index in indexes_to_map:
            # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
            ti = TaskInstance(self,
                              run_id=run_id,
                              map_index=index,
                              state=state)
            self.log.debug("Expanding TIs upserted %s", ti)
            task_instance_mutation_hook(ti)
            ti = session.merge(ti)
            ti.refresh_from_task(
                self)  # session.merge() loses task information.
            all_expanded_tis.append(ti)

        # Coerce the None case to 0 -- these two are almost treated identically,
        # except the unmapped ti (if exists) is marked to different states.
        total_expanded_ti_count = total_length or 0

        # Set to "REMOVED" any (old) TaskInstances with map indices greater
        # than the current map value
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == self.dag_id,
            TaskInstance.task_id == self.task_id,
            TaskInstance.run_id == run_id,
            TaskInstance.map_index >= total_expanded_ti_count,
        ).update({TaskInstance.state: TaskInstanceState.REMOVED})

        session.flush()
        return all_expanded_tis, total_expanded_ti_count - 1
Пример #26
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        success_callback_called = Value('i', 0)
        task_terminated_externally = Value('i', 1)
        shared_mem_lock = Lock()

        def success_callback(context):
            with shared_mem_lock:
                success_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_mark_success'

        dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

        def task_function(ti):
            # pylint: disable=unused-argument
            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            with shared_mem_lock:
                task_terminated_externally.value = 0

        task = PythonOperator(
            task_id='test_state_succeeded1',
            python_callable=task_function,
            on_success_callback=success_callback,
            dag=dag,
        )

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
        job1.task_runner = StandardTaskRunner(job1)

        settings.engine.dispose()
        process = multiprocessing.Process(target=job1.run)
        process.start()

        for _ in range(0, 25):
            ti.refresh_from_db()
            if ti.state == State.RUNNING:
                break
            time.sleep(0.2)
        assert ti.state == State.RUNNING
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        assert success_callback_called.value == 1
        assert task_terminated_externally.value == 1
        assert not process.is_alive()
Пример #27
0
 def get_link(self, operator, dttm):
     ti = TaskInstance(task=operator, execution_date=dttm)
     operator.render_template_fields(ti.get_template_context())
     query = {"dag_id": operator.external_dag_id, "execution_date": dttm.isoformat()}
     return build_airflow_url_with_query(query)
    def test_retry_on_error_sending_task(self):
        """Test that Airflow retries publishing tasks to Celery Broker at least 3 times"""

        with _prepare_app(), self.assertLogs(
                celery_executor.log) as cm, mock.patch.object(
                    # Mock `with timeout()` to _instantly_ fail.
                    celery_executor.timeout,
                    "__enter__",
                    side_effect=AirflowTaskTimeout,
                ):
            executor = celery_executor.CeleryExecutor()
            assert executor.task_publish_retries == {}
            assert executor.task_publish_max_retries == 3, "Assert Default Max Retries is 3"

            task = BashOperator(task_id="test",
                                bash_command="true",
                                dag=DAG(dag_id='id'),
                                start_date=datetime.now())
            when = datetime.now()
            value_tuple = (
                'command',
                1,
                None,
                SimpleTaskInstance(
                    ti=TaskInstance(task=task, execution_date=datetime.now())),
            )
            key = ('fail', 'fake_simple_ti', when, 0)
            executor.queued_tasks[key] = value_tuple

            # Test that when heartbeat is called again, task is published again to Celery Queue
            executor.heartbeat()
            assert dict(executor.task_publish_retries) == {key: 2}
            assert 1 == len(
                executor.queued_tasks), "Task should remain in queue"
            assert executor.event_buffer == {}
            assert ("INFO:airflow.executors.celery_executor.CeleryExecutor:"
                    f"[Try 1 of 3] Task Timeout Error for Task: ({key})."
                    in cm.output)

            executor.heartbeat()
            assert dict(executor.task_publish_retries) == {key: 3}
            assert 1 == len(
                executor.queued_tasks), "Task should remain in queue"
            assert executor.event_buffer == {}
            assert ("INFO:airflow.executors.celery_executor.CeleryExecutor:"
                    f"[Try 2 of 3] Task Timeout Error for Task: ({key})."
                    in cm.output)

            executor.heartbeat()
            assert dict(executor.task_publish_retries) == {key: 4}
            assert 1 == len(
                executor.queued_tasks), "Task should remain in queue"
            assert executor.event_buffer == {}
            assert ("INFO:airflow.executors.celery_executor.CeleryExecutor:"
                    f"[Try 3 of 3] Task Timeout Error for Task: ({key})."
                    in cm.output)

            executor.heartbeat()
            assert dict(executor.task_publish_retries) == {}
            assert 0 == len(
                executor.queued_tasks), "Task should no longer be in queue"
            assert executor.event_buffer[('fail', 'fake_simple_ti', when,
                                          0)][0] == State.FAILED
Пример #29
0
 def get_link(self, operator, dttm):
     ti = TaskInstance(task=operator, execution_date=dttm)
     job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id')
     return 'https://console.cloud.google.com/bigquery?j={job_id}'.format(
         job_id=job_id) if job_id else ''
Пример #30
0
 def __create_task_instance(self, task):
     ti = TaskInstance(task=task, execution_date=datetime.now())
     return ti