示例#1
0
文件: sftp.py 项目: leahecole/airflow
    def poke(self, context: 'Context') -> bool:
        self.hook = SFTPHook(self.sftp_conn_id)
        self.log.info(f"Poking for {self.path}, with pattern {self.file_pattern}")

        if self.file_pattern:
            file_from_pattern = self.hook.get_file_by_pattern(self.path, self.file_pattern)
            if file_from_pattern:
                actual_file_to_check = file_from_pattern
            else:
                return False
        else:
            actual_file_to_check = self.path

        try:
            mod_time = self.hook.get_mod_time(actual_file_to_check)
            self.log.info('Found File %s last modified: %s', str(actual_file_to_check), str(mod_time))
        except OSError as e:
            if e.errno != SFTP_NO_SUCH_FILE:
                raise e
            return False
        self.hook.close_conn()
        if self.newer_than:
            _mod_time = convert_to_utc(datetime.strptime(mod_time, '%Y%m%d%H%M%S'))
            _newer_than = convert_to_utc(self.newer_than)
            return _newer_than <= _mod_time
        else:
            return True
示例#2
0
文件: base.py 项目: karankale/airflow
    def map(self, **kwargs: "MapArgument") -> XComArg:
        self._validate_arg_names("map", kwargs)

        partial_kwargs = self.kwargs.copy()

        dag = partial_kwargs.pop("dag", DagContext.get_current_dag())
        task_group = partial_kwargs.pop(
            "task_group", TaskGroupContext.get_current_task_group(dag))
        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
                                     task_group)
        params = partial_kwargs.pop("params", None)

        # Logic here should be kept in sync with BaseOperatorMeta.partial().
        if "task_concurrency" in partial_kwargs:
            raise TypeError("unexpected argument: task_concurrency")
        if partial_kwargs.get("wait_for_downstream"):
            partial_kwargs["depends_on_past"] = True
        start_date = timezone.convert_to_utc(
            partial_kwargs.pop("start_date", None))
        end_date = timezone.convert_to_utc(partial_kwargs.pop(
            "end_date", None))
        if partial_kwargs.get("pool") is None:
            partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
        partial_kwargs["retries"] = parse_retries(
            partial_kwargs.get("retries", DEFAULT_RETRIES))
        partial_kwargs["retry_delay"] = coerce_retry_delay(
            partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), )
        partial_kwargs["resources"] = coerce_resources(
            partial_kwargs.get("resources"))
        partial_kwargs.setdefault("executor_config", {})
        partial_kwargs.setdefault("op_args", [])
        partial_kwargs.setdefault("op_kwargs", {})

        # Mypy does not work well with a subclassed attrs class :(
        _MappedOperator = cast(Any, DecoratedMappedOperator)
        operator = _MappedOperator(
            operator_class=self.operator_class,
            mapped_kwargs={},
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_dummy=False,
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            multiple_outputs=self.multiple_outputs,
            python_callable=self.function,
            mapped_op_kwargs=kwargs,
        )
        return XComArg(operator=operator)
示例#3
0
    def test_convert_to_utc(self):
        naive = datetime.datetime(2011, 9, 1, 13, 20, 30)
        utc = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=UTC)
        assert utc == timezone.convert_to_utc(naive)

        eat = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)
        utc = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC)
        assert utc == timezone.convert_to_utc(eat)
示例#4
0
    def test_convert_to_utc(self):
        naive = datetime.datetime(2011, 9, 1, 13, 20, 30)
        utc = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=UTC)
        self.assertEqual(utc, timezone.convert_to_utc(naive))

        eat = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)
        utc = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC)
        self.assertEqual(utc, timezone.convert_to_utc(eat))
    def test_convert_to_utc(self):
        naive = datetime.datetime(2011, 9, 1, 13, 20, 30)
        utc = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=UTC)
        self.assertEquals(utc, timezone.convert_to_utc(naive))

        eat = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)
        utc = datetime.datetime(2011, 9, 1, 10, 20, 30, tzinfo=UTC)
        self.assertEquals(utc, timezone.convert_to_utc(eat))
示例#6
0
 def _get_prev(self, current: DateTime) -> DateTime:
     """Get the first schedule before specified time, with DST fixed."""
     naive = make_naive(current, self._timezone)
     cron = croniter(self._expression, start_time=naive)
     scheduled = cron.get_prev(datetime.datetime)
     if not self._should_fix_dst:
         return convert_to_utc(make_aware(scheduled, self._timezone))
     delta = naive - scheduled
     return convert_to_utc(current.in_timezone(self._timezone) - delta)
def secondary_training_status_message(job_description, prev_description):
    """
    Returns a string contains start time and the secondary training job status message.

    :param job_description: Returned response from DescribeTrainingJob call
    :type job_description: dict
    :param prev_description: Previous job description from DescribeTrainingJob call
    :type prev_description: dict

    :return: Job status string to be printed.
    """

    if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
            or len(job_description.get('SecondaryStatusTransitions')) == 0:
        return ''

    prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
        if prev_description is not None else None
    prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
        if prev_description_secondary_transitions is not None else 0
    current_transitions = job_description['SecondaryStatusTransitions']

    transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
        current_transitions[prev_transitions_num - len(current_transitions):]

    status_strs = []
    for transition in transitions_to_print:
        message = transition['StatusMessage']
        time_str = timezone.convert_to_utc(job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
        status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))

    return '\n'.join(status_strs)
示例#8
0
    def test_following_previous_schedule_daily_dag_CET_to_CEST(self):
        """
        Make sure DST transitions are properly observed
        """
        local_tz = pendulum.timezone('Europe/Zurich')
        start = local_tz.convert(datetime.datetime(2018, 3, 25, 2),
                                 dst_rule=pendulum.PRE_TRANSITION)

        utc = timezone.convert_to_utc(start)

        dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *')

        prev = dag.previous_schedule(utc)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
        self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")

        _next = dag.following_schedule(utc)
        next_local = local_tz.convert(_next)

        self.assertEqual(next_local.isoformat(), "2018-03-25T03:00:00+02:00")
        self.assertEqual(_next.isoformat(), "2018-03-25T01:00:00+00:00")

        prev = dag.previous_schedule(_next)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
        self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
示例#9
0
    def test_following_previous_schedule(self):
        """
        Make sure DST transitions are properly observed
        """
        local_tz = pendulum.timezone('Europe/Zurich')
        start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55),
                                 dst_rule=pendulum.PRE_TRANSITION)
        self.assertEqual(start.isoformat(), "2018-10-28T02:55:00+02:00",
                         "Pre-condition: start date is in DST")

        utc = timezone.convert_to_utc(start)

        dag = DAG('tz_dag', start_date=start, schedule_interval='*/5 * * * *')
        _next = dag.following_schedule(utc)
        next_local = local_tz.convert(_next)

        self.assertEqual(_next.isoformat(), "2018-10-28T01:00:00+00:00")
        self.assertEqual(next_local.isoformat(), "2018-10-28T02:00:00+01:00")

        prev = dag.previous_schedule(utc)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-10-28T02:50:00+02:00")

        prev = dag.previous_schedule(_next)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-10-28T02:55:00+02:00")
        self.assertEqual(prev, utc)
def secondary_training_status_message(job_description, prev_description):
    """
    Returns a string contains start time and the secondary training job status message.

    :param job_description: Returned response from DescribeTrainingJob call
    :type job_description: dict
    :param prev_description: Previous job description from DescribeTrainingJob call
    :type prev_description: dict

    :return: Job status string to be printed.
    """

    if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
            or len(job_description.get('SecondaryStatusTransitions')) == 0:
        return ''

    prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
        if prev_description is not None else None
    prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
        if prev_description_secondary_transitions is not None else 0
    current_transitions = job_description['SecondaryStatusTransitions']

    transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
        current_transitions[prev_transitions_num - len(current_transitions):]

    status_strs = []
    for transition in transitions_to_print:
        message = transition['StatusMessage']
        time_str = timezone.convert_to_utc(
            job_description['LastModifiedTime']).strftime('%Y-%m-%d %H:%M:%S')
        status_strs.append('{} {} - {}'.format(time_str, transition['Status'],
                                               message))

    return '\n'.join(status_strs)
示例#11
0
    def test_following_previous_schedule(self):
        """
        Make sure DST transitions are properly observed
        """
        local_tz = pendulum.timezone('Europe/Zurich')
        start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55),
                                 dst_rule=pendulum.PRE_TRANSITION)
        self.assertEqual(start.isoformat(), "2018-10-28T02:55:00+02:00",
                         "Pre-condition: start date is in DST")

        utc = timezone.convert_to_utc(start)

        dag = DAG('tz_dag', start_date=start, schedule_interval='*/5 * * * *')
        _next = dag.following_schedule(utc)
        next_local = local_tz.convert(_next)

        self.assertEqual(_next.isoformat(), "2018-10-28T01:00:00+00:00")
        self.assertEqual(next_local.isoformat(), "2018-10-28T02:00:00+01:00")

        prev = dag.previous_schedule(utc)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-10-28T02:50:00+02:00")

        prev = dag.previous_schedule(_next)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-10-28T02:55:00+02:00")
        self.assertEqual(prev, utc)
示例#12
0
def secondary_training_status_message(job_description: Dict[str, List[Any]],
                                      prev_description: Optional[dict]) -> str:
    """
    Returns a string contains start time and the secondary training job status message.

    :param job_description: Returned response from DescribeTrainingJob call
    :param prev_description: Previous job description from DescribeTrainingJob call

    :return: Job status string to be printed.
    """
    current_transitions = job_description.get('SecondaryStatusTransitions')
    if current_transitions is None or len(current_transitions) == 0:
        return ''

    prev_transitions_num = 0
    if prev_description is not None:
        if prev_description.get('SecondaryStatusTransitions') is not None:
            prev_transitions_num = len(
                prev_description['SecondaryStatusTransitions'])

    transitions_to_print = (current_transitions[-1:]
                            if len(current_transitions) == prev_transitions_num
                            else
                            current_transitions[prev_transitions_num -
                                                len(current_transitions):])

    status_strs = []
    for transition in transitions_to_print:
        message = transition['StatusMessage']
        time_str = timezone.convert_to_utc(
            cast(datetime, job_description['LastModifiedTime'])).strftime(
                '%Y-%m-%d %H:%M:%S')
        status_strs.append(f"{time_str} {transition['Status']} - {message}")

    return '\n'.join(status_strs)
示例#13
0
    def test_following_previous_schedule_daily_dag_CET_to_CEST(self):
        """
        Make sure DST transitions are properly observed
        """
        local_tz = pendulum.timezone('Europe/Zurich')
        start = local_tz.convert(datetime.datetime(2018, 3, 25, 2),
                                 dst_rule=pendulum.PRE_TRANSITION)

        utc = timezone.convert_to_utc(start)

        dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *')

        prev = dag.previous_schedule(utc)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
        self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")

        _next = dag.following_schedule(utc)
        next_local = local_tz.convert(_next)

        self.assertEqual(next_local.isoformat(), "2018-03-25T03:00:00+02:00")
        self.assertEqual(_next.isoformat(), "2018-03-25T01:00:00+00:00")

        prev = dag.previous_schedule(_next)
        prev_local = local_tz.convert(prev)

        self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
        self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
示例#14
0
    def __init__(self, *, target_time, **kwargs):
        super().__init__(**kwargs)
        self.target_time = target_time

        aware_time = timezone.coerce_datetime(
            datetime.datetime.combine(datetime.datetime.today(),
                                      self.target_time))

        self.target_datetime = timezone.convert_to_utc(aware_time)
示例#15
0
 def __init__(self):
     super().__init__()
     self.dag = self.args['dag_id']
     self.task = self.args['task_id']
     self.execution_date = self.args['execution_date']
     self.execution_date = datetime.strptime(self.execution_date,
                                             "%Y-%m-%dT%H:%M:%S")
     self.execution_date = timezone.convert_to_utc(self.execution_date)
     self.path = "{}/{}/{}/{}/".format(LOG_PATH, self.dag, self.task,
                                       self.execution_date)
示例#16
0
文件: sftp.py 项目: subkanthi/airflow
 def poke(self, context: 'Context') -> bool:
     self.hook = SFTPHook(self.sftp_conn_id)
     self.log.info('Poking for %s', self.path)
     try:
         mod_time = self.hook.get_mod_time(self.path)
         self.log.info('Found File %s last modified: %s', str(self.path),
                       str(mod_time))
     except OSError as e:
         if e.errno != SFTP_NO_SUCH_FILE:
             raise e
         return False
     self.hook.close_conn()
     if self.newer_than:
         _mod_time = convert_to_utc(
             datetime.strptime(mod_time, '%Y%m%d%H%M%S'))
         _newer_than = convert_to_utc(self.newer_than)
         return _newer_than <= _mod_time
     else:
         return True
示例#17
0
 def __init__(self):
     super().__init__()
     self.dagbag = DagBag()
     self.dag = self.args['dag_id']
     self.task = self.args['task_id']
     self.execution_date = self.args['execution_date']
     self.action = self.args['action']
     self.options = self.args['option']
     self.execution_date = datetime.strptime(self.execution_date,
                                             "%Y-%m-%dT%H:%M:%S")
     self.execution_date = timezone.convert_to_utc(self.execution_date)
     assert timezone.is_localized(self.execution_date)
示例#18
0
    def _build_lifecycle_config(self, cluster_data):
        if self.idle_delete_ttl:
            cluster_data['config']['lifecycleConfig']['idleDeleteTtl'] = \
                "{}s".format(self.idle_delete_ttl)

        if self.auto_delete_time:
            utc_auto_delete_time = timezone.convert_to_utc(self.auto_delete_time)
            cluster_data['config']['lifecycleConfig']['autoDeleteTime'] = \
                utc_auto_delete_time.format('%Y-%m-%dT%H:%M:%S.%fZ', formatter='classic')
        elif self.auto_delete_ttl:
            cluster_data['config']['lifecycleConfig']['autoDeleteTtl'] = \
                "{}s".format(self.auto_delete_ttl)

        return cluster_data
示例#19
0
def test_list_dagrun_includes_conf(session, admin_client):
    data = {
        "state": "running",
        "dag_id": "example_bash_operator",
        "execution_date": "2018-07-06 05:06:03",
        "run_id": "test_list_dagrun_includes_conf",
        "conf": '{"include": "me"}',
    }
    admin_client.post('/dagrun/add', data=data, follow_redirects=True)
    dr = session.query(DagRun).one()

    expect_date = timezone.convert_to_utc(
        timezone.datetime(2018, 7, 6, 5, 6, 3))
    assert dr.execution_date == expect_date
    assert dr.conf == {"include": "me"}

    resp = admin_client.get('/dagrun/list', follow_redirects=True)
    check_content_in_response("{&#34;include&#34;: &#34;me&#34;}", resp)
示例#20
0
    def test_timezone_awareness(self):
        NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None)

        # check ti without dag (just for bw compat)
        op_no_dag = DummyOperator(task_id='op_no_dag')
        ti = TI(task=op_no_dag, execution_date=NAIVE_DATETIME)

        self.assertEqual(ti.execution_date, DEFAULT_DATE)

        # check with dag without localized execution_date
        dag = DAG('dag', start_date=DEFAULT_DATE)
        op1 = DummyOperator(task_id='op_1')
        dag.add_task(op1)
        ti = TI(task=op1, execution_date=NAIVE_DATETIME)

        self.assertEqual(ti.execution_date, DEFAULT_DATE)

        # with dag and localized execution_date
        tz = pendulum.timezone("Europe/Amsterdam")
        execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tz)
        utc_date = timezone.convert_to_utc(execution_date)
        ti = TI(task=op1, execution_date=execution_date)
        self.assertEqual(ti.execution_date, utc_date)
示例#21
0
    def test_timezone_awareness(self):
        NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None)

        # check ti without dag (just for bw compat)
        op_no_dag = DummyOperator(task_id='op_no_dag')
        ti = TI(task=op_no_dag, execution_date=NAIVE_DATETIME)

        self.assertEqual(ti.execution_date, DEFAULT_DATE)

        # check with dag without localized execution_date
        dag = DAG('dag', start_date=DEFAULT_DATE)
        op1 = DummyOperator(task_id='op_1')
        dag.add_task(op1)
        ti = TI(task=op1, execution_date=NAIVE_DATETIME)

        self.assertEqual(ti.execution_date, DEFAULT_DATE)

        # with dag and localized execution_date
        tz = pendulum.timezone("Europe/Amsterdam")
        execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tz)
        utc_date = timezone.convert_to_utc(execution_date)
        ti = TI(task=op1, execution_date=execution_date)
        self.assertEqual(ti.execution_date, utc_date)
示例#22
0
    def __init__(
            self,
            task_id: str,
            owner: str = conf.get('operators', 'DEFAULT_OWNER'),
            email: Optional[Union[str, Iterable[str]]] = None,
            email_on_retry: bool = True,
            email_on_failure: bool = True,
            retries: Optional[int] = conf.getint('core',
                                                 'default_task_retries',
                                                 fallback=0),
            retry_delay: timedelta = timedelta(seconds=300),
            retry_exponential_backoff: bool = False,
            max_retry_delay: Optional[datetime] = None,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            depends_on_past: bool = False,
            wait_for_downstream: bool = False,
            dag=None,
            params: Optional[Dict] = None,
            default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
            priority_weight: int = 1,
            weight_rule: str = WeightRule.DOWNSTREAM,
            queue: str = conf.get('celery', 'default_queue'),
            pool: str = Pool.DEFAULT_POOL_NAME,
            sla: Optional[timedelta] = None,
            execution_timeout: Optional[timedelta] = None,
            on_execute_callback: Optional[Callable] = None,
            on_failure_callback: Optional[Callable] = None,
            on_success_callback: Optional[Callable] = None,
            on_retry_callback: Optional[Callable] = None,
            trigger_rule: str = TriggerRule.ALL_SUCCESS,
            resources: Optional[Dict] = None,
            run_as_user: Optional[str] = None,
            task_concurrency: Optional[int] = None,
            executor_config: Optional[Dict] = None,
            do_xcom_push: bool = True,
            inlets: Optional[Any] = None,
            outlets: Optional[Any] = None,
            *args,
            **kwargs):
        from airflow.models.dag import DagContext
        super().__init__()
        if args or kwargs:
            if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                raise AirflowException(
                    "Invalid arguments were passed to {c} (task_id: {t}). Invalid "
                    "arguments were:\n*args: {a}\n**kwargs: {k}".format(
                        c=self.__class__.__name__, a=args, k=kwargs,
                        t=task_id), )
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'future. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_execute_callback = on_execute_callback
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            # noinspection PyTypeChecker
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule
        self.resources: Optional[Resources] = Resources(
            **resources) if resources else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids: Set[str] = set()
        self._downstream_task_ids: Set[str] = set()
        self._dag = None

        self.dag = dag or DagContext.get_current_dag()

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        from airflow.models.dag import DAG
        self.subdag: Optional[DAG] = None

        self._log = logging.getLogger("airflow.task.operators")

        # Lineage
        self.inlets: List = []
        self.outlets: List = []

        self._inlets: List = []
        self._outlets: List = []

        if inlets:
            self._inlets = inlets if isinstance(inlets, list) else [
                inlets,
            ]

        if outlets:
            self._outlets = outlets if isinstance(outlets, list) else [
                outlets,
            ]
    def __init__(
        self,
        task_id,  # type: str
        owner=configuration.conf.get('operators', 'DEFAULT_OWNER'),  # type: str
        email=None,  # type: Optional[str]
        email_on_retry=True,  # type: bool
        email_on_failure=True,  # type: bool
        retries=0,  # type: int
        retry_delay=timedelta(seconds=300),  # type: timedelta
        retry_exponential_backoff=False,  # type: bool
        max_retry_delay=None,  # type: Optional[datetime]
        start_date=None,  # type: Optional[datetime]
        end_date=None,  # type: Optional[datetime]
        schedule_interval=None,  # not hooked as of now
        depends_on_past=False,  # type: bool
        wait_for_downstream=False,  # type: bool
        dag=None,  # type: Optional[DAG]
        params=None,  # type: Optional[Dict]
        default_args=None,  # type: Optional[Dict]
        priority_weight=1,  # type: int
        weight_rule=WeightRule.DOWNSTREAM,  # type: str
        queue=configuration.conf.get('celery', 'default_queue'),  # type: str
        pool=None,  # type: Optional[str]
        sla=None,  # type: Optional[timedelta]
        execution_timeout=None,  # type: Optional[timedelta]
        on_failure_callback=None,  # type: Optional[Callable]
        on_success_callback=None,  # type: Optional[Callable]
        on_retry_callback=None,  # type: Optional[Callable]
        trigger_rule=TriggerRule.ALL_SUCCESS,  # type: str
        resources=None,  # type: Optional[Dict]
        run_as_user=None,  # type: Optional[str]
        task_concurrency=None,  # type: Optional[int]
        executor_config=None,  # type: Optional[Dict]
        do_xcom_push=True,  # type: bool
        inlets=None,  # type: Optional[Dict]
        outlets=None,  # type: Optional[Dict]
        *args,
        **kwargs
    ):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(
                    c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3
            )
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_triggers=TriggerRule.all_triggers(),
                        d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        if schedule_interval:
            self.log.warning(
                "schedule_interval is used for %s, though it has "
                "been deprecated as a task parameter, you need to "
                "specify it as a DAG parameter instead",
                self
            )
        self._schedule_interval = schedule_interval
        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback
        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_weight_rules=WeightRule.all_weight_rules,
                        d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(**(resources or {}))
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: Iterable[DataSet]
        self.outlets = []  # type: Iterable[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)

        self._comps = {
            'task_id',
            'dag_id',
            'owner',
            'email',
            'email_on_retry',
            'retry_delay',
            'retry_exponential_backoff',
            'max_retry_delay',
            'start_date',
            'schedule_interval',
            'depends_on_past',
            'wait_for_downstream',
            'priority_weight',
            'sla',
            'execution_timeout',
            'on_failure_callback',
            'on_success_callback',
            'on_retry_callback',
            'do_xcom_push',
        }
示例#24
0
    def _build_cluster_data(self):
        if self.zone:
            master_type_uri = \
                "https://www.googleapis.com/compute/v1/projects/{}/zones/{}/machineTypes/{}"\
                .format(self.project_id, self.zone, self.master_machine_type)
            worker_type_uri = \
                "https://www.googleapis.com/compute/v1/projects/{}/zones/{}/machineTypes/{}"\
                .format(self.project_id, self.zone, self.worker_machine_type)
        else:
            master_type_uri = self.master_machine_type
            worker_type_uri = self.worker_machine_type

        cluster_data = {
            'projectId': self.project_id,
            'clusterName': self.cluster_name,
            'config': {
                'gceClusterConfig': {},
                'masterConfig': {
                    'numInstances': 1,
                    'machineTypeUri': master_type_uri,
                    'diskConfig': {
                        'bootDiskType': self.master_disk_type,
                        'bootDiskSizeGb': self.master_disk_size
                    }
                },
                'workerConfig': {
                    'numInstances': self.num_workers,
                    'machineTypeUri': worker_type_uri,
                    'diskConfig': {
                        'bootDiskType': self.worker_disk_type,
                        'bootDiskSizeGb': self.worker_disk_size
                    }
                },
                'secondaryWorkerConfig': {},
                'softwareConfig': {},
                'lifecycleConfig': {},
                'encryptionConfig': {},
                'autoscalingConfig': {},
            }
        }
        if self.num_preemptible_workers > 0:
            cluster_data['config']['secondaryWorkerConfig'] = {
                'numInstances': self.num_preemptible_workers,
                'machineTypeUri': worker_type_uri,
                'diskConfig': {
                    'bootDiskType': self.worker_disk_type,
                    'bootDiskSizeGb': self.worker_disk_size
                },
                'isPreemptible': True
            }

        cluster_data['labels'] = self.labels if self.labels else {}
        # Dataproc labels must conform to the following regex:
        # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
        # semantic versioning spec: x.y.z).
        cluster_data['labels'].update({
            'airflow-version':
            'v' + version.replace('.', '-').replace('+', '-')
        })
        if self.storage_bucket:
            cluster_data['config']['configBucket'] = self.storage_bucket
        if self.zone:
            zone_uri = \
                'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format(
                    self.project_id, self.zone
                )
            cluster_data['config']['gceClusterConfig']['zoneUri'] = zone_uri
        if self.metadata:
            cluster_data['config']['gceClusterConfig'][
                'metadata'] = self.metadata
        if self.network_uri:
            cluster_data['config']['gceClusterConfig'][
                'networkUri'] = self.network_uri
        if self.subnetwork_uri:
            cluster_data['config']['gceClusterConfig']['subnetworkUri'] = \
                self.subnetwork_uri
        if self.internal_ip_only:
            if not self.subnetwork_uri:
                raise AirflowException("Set internal_ip_only to true only when"
                                       " you pass a subnetwork_uri.")
            cluster_data['config']['gceClusterConfig']['internalIpOnly'] = True
        if self.tags:
            cluster_data['config']['gceClusterConfig']['tags'] = self.tags
        if self.image_version:
            cluster_data['config']['softwareConfig'][
                'imageVersion'] = self.image_version
        elif self.custom_image:
            custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \
                               '{}/global/images/{}'.format(self.project_id,
                                                            self.custom_image)
            cluster_data['config']['masterConfig'][
                'imageUri'] = custom_image_url
            if not self.single_node:
                cluster_data['config']['workerConfig'][
                    'imageUri'] = custom_image_url

        if self.single_node:
            self.properties["dataproc:dataproc.allow.zero.workers"] = "true"

        if self.properties:
            cluster_data['config']['softwareConfig'][
                'properties'] = self.properties
        if self.idle_delete_ttl:
            cluster_data['config']['lifecycleConfig']['idleDeleteTtl'] = \
                "{}s".format(self.idle_delete_ttl)
        if self.auto_delete_time:
            utc_auto_delete_time = timezone.convert_to_utc(
                self.auto_delete_time)
            cluster_data['config']['lifecycleConfig']['autoDeleteTime'] = \
                utc_auto_delete_time.format('%Y-%m-%dT%H:%M:%S.%fZ', formatter='classic')
        elif self.auto_delete_ttl:
            cluster_data['config']['lifecycleConfig']['autoDeleteTtl'] = \
                "{}s".format(self.auto_delete_ttl)
        if self.init_actions_uris:
            init_actions_dict = [{
                'executableFile':
                uri,
                'executionTimeout':
                self._get_init_action_timeout()
            } for uri in self.init_actions_uris]
            cluster_data['config']['initializationActions'] = init_actions_dict
        if self.service_account:
            cluster_data['config']['gceClusterConfig']['serviceAccount'] =\
                self.service_account
        if self.service_account_scopes:
            cluster_data['config']['gceClusterConfig']['serviceAccountScopes'] =\
                self.service_account_scopes
        if self.customer_managed_key:
            cluster_data['config']['encryptionConfig'] =\
                {'gcePdKmsKeyName': self.customer_managed_key}
        if self.autoscaling_policy:
            cluster_data['config']['autoscalingConfig'] = {
                'policyUri': self.autoscaling_policy
            }

        return cluster_data
示例#25
0
    def send_lineage(operator, inlets, outlets, context):  # pylint:disable=signature-differs
        client = Atlas(_host,
                       port=_port,
                       username=_username,
                       password=_password)
        try:
            client.typedefs.create(data=operator_typedef)
        except HttpError:
            client.typedefs.update(data=operator_typedef)

        _execution_date = convert_to_utc(context['ti'].execution_date)
        _start_date = convert_to_utc(context['ti'].start_date)
        _end_date = convert_to_utc(context['ti'].end_date)

        inlet_list = []
        if inlets:
            for entity in inlets:
                if entity is None:
                    continue

                entity.set_context(context)
                client.entity_post.create(data={"entity": entity.as_dict()})
                inlet_list.append({
                    "typeName": entity.type_name,
                    "uniqueAttributes": {
                        "qualifiedName": entity.qualified_name
                    }
                })

        outlet_list = []
        if outlets:
            for entity in outlets:
                if not entity:
                    continue

                entity.set_context(context)
                client.entity_post.create(data={"entity": entity.as_dict()})
                outlet_list.append({
                    "typeName": entity.type_name,
                    "uniqueAttributes": {
                        "qualifiedName": entity.qualified_name
                    }
                })

        operator_name = operator.__class__.__name__
        name = "{} {} ({})".format(operator.dag_id, operator.task_id,
                                   operator_name)
        qualified_name = "{}_{}_{}@{}".format(operator.dag_id,
                                              operator.task_id,
                                              _execution_date, operator_name)

        data = {
            "dag_id": operator.dag_id,
            "task_id": operator.task_id,
            "execution_date":
            _execution_date.strftime(SERIALIZED_DATE_FORMAT_STR),
            "name": name,
            "inputs": inlet_list,
            "outputs": outlet_list,
            "command": operator.lineage_data,
        }

        if _start_date:
            data["start_date"] = _start_date.strftime(
                SERIALIZED_DATE_FORMAT_STR)
        if _end_date:
            data["end_date"] = _end_date.strftime(SERIALIZED_DATE_FORMAT_STR)

        process = datasets.Operator(qualified_name=qualified_name, data=data)
        client.entity_post.create(data={"entity": process.as_dict()})
    def _build_cluster_data(self):
        zone_uri = \
            'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format(
                self.project_id, self.zone
            )
        master_type_uri = \
            "https://www.googleapis.com/compute/v1/projects/{}/zones/{}/machineTypes/{}".format(
                self.project_id, self.zone, self.master_machine_type
            )
        worker_type_uri = \
            "https://www.googleapis.com/compute/v1/projects/{}/zones/{}/machineTypes/{}".format(
                self.project_id, self.zone, self.worker_machine_type
            )
        cluster_data = {
            'projectId': self.project_id,
            'clusterName': self.cluster_name,
            'config': {
                'gceClusterConfig': {
                    'zoneUri': zone_uri
                },
                'masterConfig': {
                    'numInstances': 1,
                    'machineTypeUri': master_type_uri,
                    'diskConfig': {
                        'bootDiskSizeGb': self.master_disk_size
                    }
                },
                'workerConfig': {
                    'numInstances': self.num_workers,
                    'machineTypeUri': worker_type_uri,
                    'diskConfig': {
                        'bootDiskSizeGb': self.worker_disk_size
                    }
                },
                'secondaryWorkerConfig': {},
                'softwareConfig': {},
                'lifecycleConfig': {}
            }
        }
        if self.num_preemptible_workers > 0:
            cluster_data['config']['secondaryWorkerConfig'] = {
                'numInstances': self.num_preemptible_workers,
                'machineTypeUri': worker_type_uri,
                'diskConfig': {
                    'bootDiskSizeGb': self.worker_disk_size
                },
                'isPreemptible': True
            }

        cluster_data['labels'] = self.labels if self.labels else {}
        # Dataproc labels must conform to the following regex:
        # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
        # semantic versioning spec: x.y.z).
        cluster_data['labels'].update({
            'airflow-version':
            'v' + version.replace('.', '-').replace('+', '-')
        })
        if self.storage_bucket:
            cluster_data['config']['configBucket'] = self.storage_bucket
        if self.metadata:
            cluster_data['config']['gceClusterConfig'][
                'metadata'] = self.metadata
        if self.network_uri:
            cluster_data['config']['gceClusterConfig'][
                'networkUri'] = self.network_uri
        if self.subnetwork_uri:
            cluster_data['config']['gceClusterConfig'][
                'subnetworkUri'] = self.subnetwork_uri
        if self.tags:
            cluster_data['config']['gceClusterConfig']['tags'] = self.tags
        if self.image_version:
            cluster_data['config']['softwareConfig'][
                'imageVersion'] = self.image_version
        if self.properties:
            cluster_data['config']['softwareConfig'][
                'properties'] = self.properties
        if self.idle_delete_ttl:
            cluster_data['config']['lifecycleConfig']['idleDeleteTtl'] = \
                "{}s".format(self.idle_delete_ttl)
        if self.auto_delete_time:
            utc_auto_delete_time = timezone.convert_to_utc(
                self.auto_delete_time)
            cluster_data['config']['lifecycleConfig']['autoDeleteTime'] = \
                utc_auto_delete_time.format('%Y-%m-%dT%H:%M:%S.%fZ', formatter='classic')
        elif self.auto_delete_ttl:
            cluster_data['config']['lifecycleConfig']['autoDeleteTtl'] = \
                "{}s".format(self.auto_delete_ttl)
        if self.init_actions_uris:
            init_actions_dict = [{
                'executableFile':
                uri,
                'executionTimeout':
                self._get_init_action_timeout()
            } for uri in self.init_actions_uris]
            cluster_data['config']['initializationActions'] = init_actions_dict
        if self.service_account:
            cluster_data['config']['gceClusterConfig']['serviceAccount'] =\
                    self.service_account
        if self.service_account_scopes:
            cluster_data['config']['gceClusterConfig']['serviceAccountScopes'] =\
                    self.service_account_scopes
        return cluster_data
示例#27
0
    def __init__(
            self,
            task_id,  # type: str
            owner=configuration.conf.get('operators',
                                         'DEFAULT_OWNER'),  # type: str
            email=None,  # type: Optional[str]
            email_on_retry=True,  # type: bool
            email_on_failure=True,  # type: bool
            retries=0,  # type: int
            retry_delay=timedelta(seconds=300),  # type: timedelta
            retry_exponential_backoff=False,  # type: bool
            max_retry_delay=None,  # type: Optional[datetime]
            start_date=None,  # type: Optional[datetime]
            end_date=None,  # type: Optional[datetime]
            schedule_interval=None,  # not hooked as of now
            depends_on_past=False,  # type: bool
            wait_for_downstream=False,  # type: bool
            dag=None,  # type: Optional[DAG]
            params=None,  # type: Optional[Dict]
            default_args=None,  # type: Optional[Dict]
            priority_weight=1,  # type: int
            weight_rule=WeightRule.DOWNSTREAM,  # type: str
            queue=configuration.conf.get('celery',
                                         'default_queue'),  # type: str
            pool=None,  # type: Optional[str]
            sla=None,  # type: Optional[timedelta]
            execution_timeout=None,  # type: Optional[timedelta]
            on_failure_callback=None,  # type: Optional[Callable]
            on_success_callback=None,  # type: Optional[Callable]
            on_retry_callback=None,  # type: Optional[Callable]
            trigger_rule=TriggerRule.ALL_SUCCESS,  # type: str
            resources=None,  # type: Optional[Dict]
            run_as_user=None,  # type: Optional[str]
            task_concurrency=None,  # type: Optional[int]
            executor_config=None,  # type: Optional[Dict]
            do_xcom_push=True,  # type: bool
            inlets=None,  # type: Optional[Dict]
            outlets=None,  # type: Optional[Dict]
            *args,
            **kwargs):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        if schedule_interval:
            self.log.warning(
                "schedule_interval is used for %s, though it has "
                "been deprecated as a task parameter, you need to "
                "specify it as a DAG parameter instead", self)
        self._schedule_interval = schedule_interval
        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback
        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(**(resources or {}))
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: List[DataSet]
        self.outlets = []  # type: List[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)

        self._comps = {
            'task_id',
            'dag_id',
            'owner',
            'email',
            'email_on_retry',
            'retry_delay',
            'retry_exponential_backoff',
            'max_retry_delay',
            'start_date',
            'schedule_interval',
            'depends_on_past',
            'wait_for_downstream',
            'priority_weight',
            'sla',
            'execution_timeout',
            'on_failure_callback',
            'on_success_callback',
            'on_retry_callback',
            'do_xcom_push',
        }
示例#28
0
    def create_lineage_meta(operator, inlets, outlets, context):
        _execution_date = convert_to_utc(context['ti'].execution_date)  # noqa
        _start_date = convert_to_utc(context['ti'].start_date)
        _end_date = convert_to_utc(context['ti'].end_date)

        # Creating input entities
        inlet_list = []  # type: List[dict]
        inlet_ref_list = []  # type: List[dict]
        if inlets:
            for entity in inlets:
                if entity is None:
                    continue

                entity.set_context(context)
                try:
                    entity_dict = entity.as_nested_dict()
                except Exception as e:  # noqa: F841
                    entity_dict = entity.as_dict()

                inlet_list.append(entity_dict)
                inlet_ref_list.append({"typeName": entity.type_name,
                                       "uniqueAttributes": {
                                           "qualifiedName":
                                               entity.qualified_name
                                               }
                                       })

        # Creating output entities
        outlet_list = []  # type: List[dict]
        outlet_ref_list = []  # type: List[dict]
        if outlets:
            for entity in outlets:
                if not entity:
                    continue

                entity.set_context(context)
                try:
                    entity_dict = entity.as_nested_dict()
                except Exception as e:  # noqa: F841
                    entity_dict = entity.as_dict()

                log.info("Outlets: {}".format(entity_dict))

                outlet_list.append(entity_dict)
                outlet_ref_list.append({"typeName": entity.type_name,
                                        "uniqueAttributes": {
                                            "qualifiedName":
                                                entity.qualified_name
                                        }})

        # Creating dag and operator entities
        dag_op_list = []  # type: List[dict]

        # Creating source meta
        airflow_host = get_hostname()

        # qualified name format ':AIRFLOW//:{}'.format(airflow_host)
        data = {
            "qualifiedName": ':AIRFLOW//:{}'.format(airflow_host),
            "name": airflow_host,
            "host": get_host_ip_address(),
            "port": conf.get("webserver", "web_server_port"),
            "type": "airflow",
            "sourceType": "AIRFLOW"
        }

        airflow_source = Source(data=data)
        log.info("Airflow Source: {}".format(airflow_source.as_dict()))
        dag_op_list.append(airflow_source.as_dict())

        # Creating dag meta
        qualified_name = "{}/{}".format(airflow_source.qualified_name,
                                        operator.dag_id)
        data = {
            "name": operator.dag_id,
            "source": airflow_source.as_reference(),
            "extra": Backend._get_dag_meta(context),
            "schedule": [{'cron': str(context['dag'].schedule_interval)}],
            "jobCreatedAt": None,
            "jobUpdatedAt": convert_to_utc(
                context['dag'].last_loaded).strftime(
                    SERIALIZED_DATE_FORMAT_STR),
            "sourceType": "AIRFLOW"
        }

        if context['dag']._description:
            data["description"] = context['dag'].description

        dag = AtlanJob(qualified_name=qualified_name, data=data)
        log.info("Dag: {}".format(dag.as_dict()))
        dag_op_list.append(dag.as_dict())

        # Creating dag run meta
        qualified_name = "{}/{}".format(dag.qualified_name,
                                        context['dag_run'].run_id)

        assets = copy.deepcopy(inlet_ref_list)
        assets.extend(outlet_ref_list)

        data = {
            "name": operator.dag_id,
            "runId": context['dag_run'].run_id,
            "job": dag.as_reference(),
            "source": airflow_source.as_reference(),
            "runStatus": _DAG_RUN_STATUS_MAP.get(
                context['dag_run']._state, None),
            "extra": Backend._get_dag_run_meta(context),
            "sourceType": "AIRFLOW"
        }

        if context['dag_run'].external_trigger:
            data.update({'runType': 'manual'})
        else:
            data.update({'runType': 'scheduled'})

        if len(assets) == 0:
            pass
        else:
            data.update({"assets": assets})

        if context['dag_run'].start_date:
            data['runStartedAt'] = convert_to_utc(
                context['dag_run'].start_date).strftime(
                    SERIALIZED_DATE_FORMAT_STR)
        if context['dag_run'].end_date:
            data['runEndedAt'] = convert_to_utc(
                context['dag_run'].end_date).strftime(
                    SERIALIZED_DATE_FORMAT_STR)

        dag_run = AtlanJobRun(qualified_name=qualified_name, data=data)
        log.info("Dag Run: {}".format(dag_run.as_dict()))

        dag_op_list.append(dag_run.as_dict())

        # Creating task meta
        operator_name = operator.__class__.__name__
        qualified_name = '{}/{}'.format(dag_run.qualified_name,
                                        operator.task_id)

        data = {
            "name": operator.task_id,
            "description": operator_name,
            "inputs": inlet_ref_list,
            "outputs": outlet_ref_list,
            "job_run": dag_run.as_reference(),
            "extra": Backend._get_task_meta(context),
            "processStatus": _TASK_RUN_STATUS_MAP.get(
                context['task_instance'].state, None),
            "sourceType": "AIRFLOW"
        }

        if _start_date:
            data["processStartedAt"] = _start_date.strftime(
                SERIALIZED_DATE_FORMAT_STR)
        if _end_date:
            data["processEndedAt"] = _end_date.strftime(
                SERIALIZED_DATE_FORMAT_STR)

        process = AtlanProcess(qualified_name=qualified_name, data=data)
        log.info("Process: {}".format(process.as_dict()))

        dag_op_list.append(process.as_dict())

        return inlet_list, outlet_list, dag_op_list
示例#29
0
 def _get_prev(self, current: DateTime) -> DateTime:
     return convert_to_utc(current - self._delta)
示例#30
0
 def _get_next(self, current: DateTime) -> DateTime:
     return convert_to_utc(current + self._delta)
示例#31
0
    def send_lineage(operator, inlets, outlets, context):
        client = Atlas(_host, port=_port, username=_username, password=_password)
        try:
            client.typedefs.create(data=operator_typedef)
        except HttpError:
            client.typedefs.update(data=operator_typedef)

        _execution_date = convert_to_utc(context['ti'].execution_date)
        _start_date = convert_to_utc(context['ti'].start_date)
        _end_date = convert_to_utc(context['ti'].end_date)

        inlet_list = []
        if inlets:
            for entity in inlets:
                if entity is None:
                    continue

                entity.set_context(context)
                client.entity_post.create(data={"entity": entity.as_dict()})
                inlet_list.append({"typeName": entity.type_name,
                                   "uniqueAttributes": {
                                       "qualifiedName": entity.qualified_name
                                   }})

        outlet_list = []
        if outlets:
            for entity in outlets:
                if not entity:
                    continue

                entity.set_context(context)
                client.entity_post.create(data={"entity": entity.as_dict()})
                outlet_list.append({"typeName": entity.type_name,
                                    "uniqueAttributes": {
                                        "qualifiedName": entity.qualified_name
                                    }})

        operator_name = operator.__class__.__name__
        name = "{} {} ({})".format(operator.dag_id, operator.task_id, operator_name)
        qualified_name = "{}_{}_{}@{}".format(operator.dag_id,
                                              operator.task_id,
                                              _execution_date,
                                              operator_name)

        data = {
            "dag_id": operator.dag_id,
            "task_id": operator.task_id,
            "execution_date": _execution_date.strftime(SERIALIZED_DATE_FORMAT_STR),
            "name": name,
            "inputs": inlet_list,
            "outputs": outlet_list,
            "command": operator.lineage_data,
        }

        if _start_date:
            data["start_date"] = _start_date.strftime(SERIALIZED_DATE_FORMAT_STR)
        if _end_date:
            data["end_date"] = _end_date.strftime(SERIALIZED_DATE_FORMAT_STR)

        process = datasets.Operator(qualified_name=qualified_name, data=data)
        client.entity_post.create(data={"entity": process.as_dict()})
示例#32
0
    def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
        ensure_xcomarg_return_value(expand_input.value)

        task_kwargs = self.kwargs.copy()
        dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
        task_group = task_kwargs.pop(
            "task_group", None) or TaskGroupContext.get_current_task_group(dag)

        partial_kwargs, default_params = get_merged_defaults(
            dag=dag,
            task_group=task_group,
            task_params=task_kwargs.pop("params", None),
            task_default_args=task_kwargs.pop("default_args", None),
        )
        partial_kwargs.update(task_kwargs)

        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
                                     task_group)
        params = partial_kwargs.pop("params", None) or default_params

        # Logic here should be kept in sync with BaseOperatorMeta.partial().
        if "task_concurrency" in partial_kwargs:
            raise TypeError("unexpected argument: task_concurrency")
        if partial_kwargs.get("wait_for_downstream"):
            partial_kwargs["depends_on_past"] = True
        start_date = timezone.convert_to_utc(
            partial_kwargs.pop("start_date", None))
        end_date = timezone.convert_to_utc(partial_kwargs.pop(
            "end_date", None))
        if partial_kwargs.get("pool") is None:
            partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
        partial_kwargs["retries"] = parse_retries(
            partial_kwargs.get("retries", DEFAULT_RETRIES))
        partial_kwargs["retry_delay"] = coerce_timedelta(
            partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
            key="retry_delay",
        )
        max_retry_delay = partial_kwargs.get("max_retry_delay")
        partial_kwargs["max_retry_delay"] = (
            max_retry_delay if max_retry_delay is None else coerce_timedelta(
                max_retry_delay, key="max_retry_delay"))
        partial_kwargs["resources"] = coerce_resources(
            partial_kwargs.get("resources"))
        partial_kwargs.setdefault("executor_config", {})
        partial_kwargs.setdefault("op_args", [])
        partial_kwargs.setdefault("op_kwargs", {})

        # Mypy does not work well with a subclassed attrs class :(
        _MappedOperator = cast(Any, DecoratedMappedOperator)

        try:
            operator_name = self.operator_class.custom_operator_name  # type: ignore
        except AttributeError:
            operator_name = self.operator_class.__name__

        operator = _MappedOperator(
            operator_class=self.operator_class,
            expand_input=
            EXPAND_INPUT_EMPTY,  # Don't use this; mapped values go to op_kwargs_expand_input.
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            template_fields_renderers=self.operator_class.
            template_fields_renderers,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_empty=False,
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            operator_name=operator_name,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            multiple_outputs=self.multiple_outputs,
            python_callable=self.function,
            op_kwargs_expand_input=expand_input,
            disallow_kwargs_override=strict,
            # Different from classic operators, kwargs passed to a taskflow
            # task's expand() contribute to the op_kwargs operator argument, not
            # the operator arguments themselves, and should expand against it.
            expand_input_attr="op_kwargs_expand_input",
        )
        return XComArg(operator=operator)
示例#33
0
    def expand(self, **map_kwargs: "Mappable") -> XComArg:
        self._validate_arg_names("expand", map_kwargs)
        prevent_duplicates(self.kwargs,
                           map_kwargs,
                           fail_reason="mapping already partial")
        ensure_xcomarg_return_value(map_kwargs)

        task_kwargs = self.kwargs.copy()
        dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
        task_group = task_kwargs.pop(
            "task_group", None) or TaskGroupContext.get_current_task_group(dag)

        partial_kwargs, default_params = get_merged_defaults(
            dag=dag,
            task_group=task_group,
            task_params=task_kwargs.pop("params", None),
            task_default_args=task_kwargs.pop("default_args", None),
        )
        partial_kwargs.update(task_kwargs)

        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
                                     task_group)
        params = partial_kwargs.pop("params", None) or default_params

        # Logic here should be kept in sync with BaseOperatorMeta.partial().
        if "task_concurrency" in partial_kwargs:
            raise TypeError("unexpected argument: task_concurrency")
        if partial_kwargs.get("wait_for_downstream"):
            partial_kwargs["depends_on_past"] = True
        start_date = timezone.convert_to_utc(
            partial_kwargs.pop("start_date", None))
        end_date = timezone.convert_to_utc(partial_kwargs.pop(
            "end_date", None))
        if partial_kwargs.get("pool") is None:
            partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
        partial_kwargs["retries"] = parse_retries(
            partial_kwargs.get("retries", DEFAULT_RETRIES))
        partial_kwargs["retry_delay"] = coerce_retry_delay(
            partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), )
        partial_kwargs["resources"] = coerce_resources(
            partial_kwargs.get("resources"))
        partial_kwargs.setdefault("executor_config", {})
        partial_kwargs.setdefault("op_args", [])
        partial_kwargs.setdefault("op_kwargs", {})

        # Mypy does not work well with a subclassed attrs class :(
        _MappedOperator = cast(Any, DecoratedMappedOperator)
        operator = _MappedOperator(
            operator_class=self.operator_class,
            mapped_kwargs={},
            partial_kwargs=partial_kwargs,
            task_id=task_id,
            params=params,
            deps=MappedOperator.deps_for(self.operator_class),
            operator_extra_links=self.operator_class.operator_extra_links,
            template_ext=self.operator_class.template_ext,
            template_fields=self.operator_class.template_fields,
            template_fields_renderers=self.operator_class.
            template_fields_renderers,
            ui_color=self.operator_class.ui_color,
            ui_fgcolor=self.operator_class.ui_fgcolor,
            is_empty=False,
            task_module=self.operator_class.__module__,
            task_type=self.operator_class.__name__,
            dag=dag,
            task_group=task_group,
            start_date=start_date,
            end_date=end_date,
            multiple_outputs=self.multiple_outputs,
            python_callable=self.function,
            mapped_op_kwargs=map_kwargs,
            # Different from classic operators, kwargs passed to a taskflow
            # task's expand() contribute to the op_kwargs operator argument, not
            # the operator arguments themselves, and should expand against it.
            expansion_kwargs_attr="mapped_op_kwargs",
        )
        return XComArg(operator=operator)
示例#34
0
    def __init__(
            self,
            task_id: str,
            owner: str = conf.get('operators', 'DEFAULT_OWNER'),
            email: Optional[str] = None,
            email_on_retry: bool = True,
            email_on_failure: bool = True,
            retries: Optional[int] = conf.getint('core',
                                                 'default_task_retries',
                                                 fallback=0),
            retry_delay: timedelta = timedelta(seconds=300),
            retry_exponential_backoff: bool = False,
            max_retry_delay: Optional[datetime] = None,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            depends_on_past: bool = False,
            wait_for_downstream: bool = False,
            dag: Optional[DAG] = None,
            params: Optional[Dict] = None,
            default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
            priority_weight: int = 1,
            weight_rule: str = WeightRule.DOWNSTREAM,
            queue: str = conf.get('celery', 'default_queue'),
            pool: str = Pool.DEFAULT_POOL_NAME,
            sla: Optional[timedelta] = None,
            execution_timeout: Optional[timedelta] = None,
            on_failure_callback: Optional[Callable] = None,
            on_success_callback: Optional[Callable] = None,
            on_retry_callback: Optional[Callable] = None,
            trigger_rule: str = TriggerRule.ALL_SUCCESS,
            resources: Optional[Dict] = None,
            run_as_user: Optional[str] = None,
            task_concurrency: Optional[int] = None,
            executor_config: Optional[Dict] = None,
            do_xcom_push: bool = True,
            inlets: Optional[Dict] = None,
            outlets: Optional[Dict] = None,
            *args,
            **kwargs):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(
            *resources) if resources is not None else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: List[DataSet]
        self.outlets = []  # type: List[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)
示例#35
0
    def _default(obj):
        """Convert dates and numpy objects in a json serializable format."""
        if isinstance(obj, datetime):
            if is_naive(obj):
                obj = convert_to_utc(obj)
            return obj.isoformat()
        elif isinstance(obj, date):
            return obj.strftime('%Y-%m-%d')
        elif isinstance(obj, Decimal):
            _, _, exponent = obj.as_tuple()
            if exponent >= 0:  # No digits after the decimal point.
                return int(obj)
            # Technically lossy due to floating point errors, but the best we
            # can do without implementing a custom encode function.
            return float(obj)
        elif np is not None and isinstance(
                obj,
            (
                np.int_,
                np.intc,
                np.intp,
                np.int8,
                np.int16,
                np.int32,
                np.int64,
                np.uint8,
                np.uint16,
                np.uint32,
                np.uint64,
            ),
        ):
            return int(obj)
        elif np is not None and isinstance(obj, np.bool_):
            return bool(obj)
        elif np is not None and isinstance(
                obj, (np.float_, np.float16, np.float32, np.float64,
                      np.complex_, np.complex64, np.complex128)):
            return float(obj)
        elif k8s is not None and isinstance(
                obj, (k8s.V1Pod, k8s.V1ResourceRequirements)):
            from airflow.kubernetes.pod_generator import PodGenerator

            def safe_get_name(pod):
                """
                We're running this in an except block, so we don't want it to
                fail under any circumstances, e.g. by accessing an attribute that isn't there
                """
                try:
                    return pod.metadata.name
                except Exception:
                    return None

            try:
                return PodGenerator.serialize_pod(obj)
            except Exception:
                log.warning("JSON encoding failed for pod %s",
                            safe_get_name(obj))
                log.debug("traceback for pod JSON encode error", exc_info=True)
                return {}

        raise TypeError(
            f"Object of type '{obj.__class__.__name__}' is not JSON serializable"
        )