def test_some_resources_specified(self): resources = Resources(cpus=0, disk=1) self.assertEqual(resources.cpus.qty, 0) self.assertEqual(resources.ram.qty, configuration.getint('operators', 'default_ram')) self.assertEqual(resources.disk.qty, 1) self.assertEqual(resources.gpus.qty, configuration.getint('operators', 'default_gpus'))
def start(self): # Use bash if it's already in a cgroup cgroups = self._get_cgroup_names() if ((cgroups.get("cpu") and cgroups.get("cpu") != "/") or (cgroups.get("memory") and cgroups.get("memory") != "/")): self.log.debug( "Already running in a cgroup (cpu: %s memory: %s) so not " "creating another one", cgroups.get("cpu"), cgroups.get("memory") ) self.process = self.run_command() return # Create a unique cgroup name cgroup_name = "airflow/{}/{}".format(datetime.datetime.utcnow(). strftime("%Y-%m-%d"), str(uuid.uuid4())) self.mem_cgroup_name = f"memory/{cgroup_name}" self.cpu_cgroup_name = f"cpu/{cgroup_name}" # Get the resource requirements from the task task = self._task_instance.task resources = task.resources if task.resources is not None else Resources() cpus = resources.cpus.qty self._cpu_shares = cpus * 1024 self._mem_mb_limit = resources.ram.qty # Create the memory cgroup mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name) self._created_mem_cgroup = True if self._mem_mb_limit > 0: self.log.debug( "Setting %s with %s MB of memory", self.mem_cgroup_name, self._mem_mb_limit ) mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024 # Create the CPU cgroup cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name) self._created_cpu_cgroup = True if self._cpu_shares > 0: self.log.debug( "Setting %s with %s CPU shares", self.cpu_cgroup_name, self._cpu_shares ) cpu_cgroup_node.controller.shares = self._cpu_shares # Start the process w/ cgroups self.log.debug( "Starting task process with cgroups cpu,memory: %s", cgroup_name ) self.process = self.run_command( ['cgexec', '-g', f'cpu,memory:{cgroup_name}'] )
def test_no_resources_specified(self): resources = Resources() self.assertEqual(resources.cpus.qty, configuration.getint('operators', 'default_cpus')) self.assertEqual(resources.ram.qty, configuration.getint('operators', 'default_ram')) self.assertEqual(resources.disk.qty, configuration.getint('operators', 'default_disk')) self.assertEqual(resources.gpus.qty, configuration.getint('operators', 'default_gpus'))
def test_no_resources_specified(self): resources = Resources() self.assertEqual(resources.cpus.qty, configuration.defaults['operators']['default_cpus']) self.assertEqual(resources.ram.qty, configuration.defaults['operators']['default_ram']) self.assertEqual(resources.disk.qty, configuration.defaults['operators']['default_disk']) self.assertEqual(resources.gpus.qty, configuration.defaults['operators']['default_gpus'])
def __init__( self, task_id, # type: str owner=configuration.conf.get('operators', 'DEFAULT_OWNER'), # type: str email=None, # type: Optional[str] email_on_retry=True, # type: bool email_on_failure=True, # type: bool retries=0, # type: int retry_delay=timedelta(seconds=300), # type: timedelta retry_exponential_backoff=False, # type: bool max_retry_delay=None, # type: Optional[datetime] start_date=None, # type: Optional[datetime] end_date=None, # type: Optional[datetime] schedule_interval=None, # not hooked as of now depends_on_past=False, # type: bool wait_for_downstream=False, # type: bool dag=None, # type: Optional[DAG] params=None, # type: Optional[Dict] default_args=None, # type: Optional[Dict] priority_weight=1, # type: int weight_rule=WeightRule.DOWNSTREAM, # type: str queue=configuration.conf.get('celery', 'default_queue'), # type: str pool=None, # type: Optional[str] sla=None, # type: Optional[timedelta] execution_timeout=None, # type: Optional[timedelta] on_failure_callback=None, # type: Optional[Callable] on_success_callback=None, # type: Optional[Callable] on_retry_callback=None, # type: Optional[Callable] trigger_rule=TriggerRule.ALL_SUCCESS, # type: str resources=None, # type: Optional[Dict] run_as_user=None, # type: Optional[str] task_concurrency=None, # type: Optional[int] executor_config=None, # type: Optional[Dict] do_xcom_push=True, # type: bool inlets=None, # type: Optional[Dict] outlets=None, # type: Optional[Dict] *args, **kwargs): if args or kwargs: # TODO remove *args and **kwargs in Airflow 2.0 warnings.warn( 'Invalid arguments were passed to {c} (task_id: {t}). ' 'Support for passing such arguments will be dropped in ' 'Airflow 2.0. Invalid arguments were:' '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__, a=args, k=kwargs, t=task_id), category=PendingDeprecationWarning, stacklevel=3) validate_key(task_id) self.task_id = task_id self.owner = owner self.email = email self.email_on_retry = email_on_retry self.email_on_failure = email_on_failure self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) elif start_date: self.start_date = timezone.convert_to_utc(start_date) self.end_date = end_date if end_date: self.end_date = timezone.convert_to_utc(end_date) if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," "'{d}.{t}'; received '{tr}'.".format( all_triggers=TriggerRule.all_triggers(), d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule)) self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True if schedule_interval: self.log.warning( "schedule_interval is used for %s, though it has " "been deprecated as a task parameter, you need to " "specify it as a DAG parameter instead", self) self._schedule_interval = schedule_interval self.retries = retries self.queue = queue self.pool = pool self.sla = sla self.execution_timeout = execution_timeout self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback if isinstance(retry_delay, timedelta): self.retry_delay = retry_delay else: self.log.debug("Retry_delay isn't timedelta object, assuming secs") self.retry_delay = timedelta(seconds=retry_delay) self.retry_exponential_backoff = retry_exponential_backoff self.max_retry_delay = max_retry_delay self.params = params or {} # Available in templates! self.priority_weight = priority_weight if not WeightRule.is_valid(weight_rule): raise AirflowException( "The weight_rule must be one of {all_weight_rules}," "'{d}.{t}'; received '{tr}'.".format( all_weight_rules=WeightRule.all_weight_rules, d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) self.weight_rule = weight_rule self.resources = Resources(**(resources or {})) self.run_as_user = run_as_user self.task_concurrency = task_concurrency self.executor_config = executor_config or {} self.do_xcom_push = do_xcom_push # Private attributes self._upstream_task_ids = set() # type: Set[str] self._downstream_task_ids = set() # type: Set[str] if not dag and settings.CONTEXT_MANAGER_DAG: dag = settings.CONTEXT_MANAGER_DAG if dag: self.dag = dag self._log = logging.getLogger("airflow.task.operators") # lineage self.inlets = [] # type: List[DataSet] self.outlets = [] # type: List[DataSet] self.lineage_data = None self._inlets = { "auto": False, "task_ids": [], "datasets": [], } self._outlets = { "datasets": [], } # type: Dict if inlets: self._inlets.update(inlets) if outlets: self._outlets.update(outlets) self._comps = { 'task_id', 'dag_id', 'owner', 'email', 'email_on_retry', 'retry_delay', 'retry_exponential_backoff', 'max_retry_delay', 'start_date', 'schedule_interval', 'depends_on_past', 'wait_for_downstream', 'priority_weight', 'sla', 'execution_timeout', 'on_failure_callback', 'on_success_callback', 'on_retry_callback', 'do_xcom_push', }
def __init__( self, task_id: str, owner: str = conf.get('operators', 'DEFAULT_OWNER'), email: Optional[Union[str, Iterable[str]]] = None, email_on_retry: bool = True, email_on_failure: bool = True, retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0), retry_delay: timedelta = timedelta(seconds=300), retry_exponential_backoff: bool = False, max_retry_delay: Optional[datetime] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, depends_on_past: bool = False, wait_for_downstream: bool = False, dag=None, params: Optional[Dict] = None, default_args: Optional[Dict] = None, # pylint: disable=unused-argument priority_weight: int = 1, weight_rule: str = WeightRule.DOWNSTREAM, queue: str = conf.get('celery', 'default_queue'), pool: str = Pool.DEFAULT_POOL_NAME, sla: Optional[timedelta] = None, execution_timeout: Optional[timedelta] = None, on_execute_callback: Optional[Callable] = None, on_failure_callback: Optional[Callable] = None, on_success_callback: Optional[Callable] = None, on_retry_callback: Optional[Callable] = None, trigger_rule: str = TriggerRule.ALL_SUCCESS, resources: Optional[Dict] = None, run_as_user: Optional[str] = None, task_concurrency: Optional[int] = None, executor_config: Optional[Dict] = None, do_xcom_push: bool = True, inlets: Optional[Any] = None, outlets: Optional[Any] = None, *args, **kwargs): from airflow.models.dag import DagContext super().__init__() if args or kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): raise AirflowException( "Invalid arguments were passed to {c} (task_id: {t}). Invalid " "arguments were:\n*args: {a}\n**kwargs: {k}".format( c=self.__class__.__name__, a=args, k=kwargs, t=task_id), ) warnings.warn( 'Invalid arguments were passed to {c} (task_id: {t}). ' 'Support for passing such arguments will be dropped in ' 'future. Invalid arguments were:' '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__, a=args, k=kwargs, t=task_id), category=PendingDeprecationWarning, stacklevel=3) validate_key(task_id) self.task_id = task_id self.owner = owner self.email = email self.email_on_retry = email_on_retry self.email_on_failure = email_on_failure self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) elif start_date: self.start_date = timezone.convert_to_utc(start_date) self.end_date = end_date if end_date: self.end_date = timezone.convert_to_utc(end_date) if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," "'{d}.{t}'; received '{tr}'.".format( all_triggers=TriggerRule.all_triggers(), d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule)) self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True self.retries = retries self.queue = queue self.pool = pool self.sla = sla self.execution_timeout = execution_timeout self.on_execute_callback = on_execute_callback self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback if isinstance(retry_delay, timedelta): self.retry_delay = retry_delay else: self.log.debug("Retry_delay isn't timedelta object, assuming secs") # noinspection PyTypeChecker self.retry_delay = timedelta(seconds=retry_delay) self.retry_exponential_backoff = retry_exponential_backoff self.max_retry_delay = max_retry_delay self.params = params or {} # Available in templates! self.priority_weight = priority_weight if not WeightRule.is_valid(weight_rule): raise AirflowException( "The weight_rule must be one of {all_weight_rules}," "'{d}.{t}'; received '{tr}'.".format( all_weight_rules=WeightRule.all_weight_rules, d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) self.weight_rule = weight_rule self.resources: Optional[Resources] = Resources( **resources) if resources else None self.run_as_user = run_as_user self.task_concurrency = task_concurrency self.executor_config = executor_config or {} self.do_xcom_push = do_xcom_push # Private attributes self._upstream_task_ids: Set[str] = set() self._downstream_task_ids: Set[str] = set() self._dag = None self.dag = dag or DagContext.get_current_dag() # subdag parameter is only set for SubDagOperator. # Setting it to None by default as other Operators do not have that field from airflow.models.dag import DAG self.subdag: Optional[DAG] = None self._log = logging.getLogger("airflow.task.operators") # Lineage self.inlets: List = [] self.outlets: List = [] self._inlets: List = [] self._outlets: List = [] if inlets: self._inlets = inlets if isinstance(inlets, list) else [ inlets, ] if outlets: self._outlets = outlets if isinstance(outlets, list) else [ outlets, ]
def __init__( self, task_id: str, owner: str = conf.get('operators', 'DEFAULT_OWNER'), email: Optional[str] = None, email_on_retry: bool = True, email_on_failure: bool = True, retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0), retry_delay: timedelta = timedelta(seconds=300), retry_exponential_backoff: bool = False, max_retry_delay: Optional[datetime] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, depends_on_past: bool = False, wait_for_downstream: bool = False, dag: Optional[DAG] = None, params: Optional[Dict] = None, default_args: Optional[Dict] = None, # pylint: disable=unused-argument priority_weight: int = 1, weight_rule: str = WeightRule.DOWNSTREAM, queue: str = conf.get('celery', 'default_queue'), pool: str = Pool.DEFAULT_POOL_NAME, sla: Optional[timedelta] = None, execution_timeout: Optional[timedelta] = None, on_failure_callback: Optional[Callable] = None, on_success_callback: Optional[Callable] = None, on_retry_callback: Optional[Callable] = None, trigger_rule: str = TriggerRule.ALL_SUCCESS, resources: Optional[Dict] = None, run_as_user: Optional[str] = None, task_concurrency: Optional[int] = None, executor_config: Optional[Dict] = None, do_xcom_push: bool = True, inlets: Optional[Dict] = None, outlets: Optional[Dict] = None, *args, **kwargs): if args or kwargs: # TODO remove *args and **kwargs in Airflow 2.0 warnings.warn( 'Invalid arguments were passed to {c} (task_id: {t}). ' 'Support for passing such arguments will be dropped in ' 'Airflow 2.0. Invalid arguments were:' '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__, a=args, k=kwargs, t=task_id), category=PendingDeprecationWarning, stacklevel=3) validate_key(task_id) self.task_id = task_id self.owner = owner self.email = email self.email_on_retry = email_on_retry self.email_on_failure = email_on_failure self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) elif start_date: self.start_date = timezone.convert_to_utc(start_date) self.end_date = end_date if end_date: self.end_date = timezone.convert_to_utc(end_date) if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," "'{d}.{t}'; received '{tr}'.".format( all_triggers=TriggerRule.all_triggers(), d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule)) self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past self.wait_for_downstream = wait_for_downstream if wait_for_downstream: self.depends_on_past = True self.retries = retries self.queue = queue self.pool = pool self.sla = sla self.execution_timeout = execution_timeout self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback if isinstance(retry_delay, timedelta): self.retry_delay = retry_delay else: self.log.debug("Retry_delay isn't timedelta object, assuming secs") self.retry_delay = timedelta(seconds=retry_delay) self.retry_exponential_backoff = retry_exponential_backoff self.max_retry_delay = max_retry_delay self.params = params or {} # Available in templates! self.priority_weight = priority_weight if not WeightRule.is_valid(weight_rule): raise AirflowException( "The weight_rule must be one of {all_weight_rules}," "'{d}.{t}'; received '{tr}'.".format( all_weight_rules=WeightRule.all_weight_rules, d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) self.weight_rule = weight_rule self.resources = Resources( *resources) if resources is not None else None self.run_as_user = run_as_user self.task_concurrency = task_concurrency self.executor_config = executor_config or {} self.do_xcom_push = do_xcom_push # Private attributes self._upstream_task_ids = set() # type: Set[str] self._downstream_task_ids = set() # type: Set[str] if not dag and settings.CONTEXT_MANAGER_DAG: dag = settings.CONTEXT_MANAGER_DAG if dag: self.dag = dag self._log = logging.getLogger("airflow.task.operators") # lineage self.inlets = [] # type: List[DataSet] self.outlets = [] # type: List[DataSet] self.lineage_data = None self._inlets = { "auto": False, "task_ids": [], "datasets": [], } self._outlets = { "datasets": [], } # type: Dict if inlets: self._inlets.update(inlets) if outlets: self._outlets.update(outlets)
def test_negative_resource_qty(self): with self.assertRaises(AirflowException): Resources(cpus=-1)
def test_all_resources_specified(self): resources = Resources(cpus=1, ram=2, disk=3, gpus=4) self.assertEqual(resources.cpus.qty, 1) self.assertEqual(resources.ram.qty, 2) self.assertEqual(resources.disk.qty, 3) self.assertEqual(resources.gpus.qty, 4)
def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None: if "label" not in encoded_op: # Handle deserialization of old data before the introduction of TaskGroup encoded_op["label"] = encoded_op["task_id"] # Extra Operator Links defined in Plugins op_extra_links_from_plugin = {} # We don't want to load Extra Operator links in Scheduler if cls._load_operator_extra_links: from airflow import plugins_manager plugins_manager.initialize_extra_operators_links_plugins() if plugins_manager.operator_extra_links is None: raise AirflowException("Can not load plugins") for ope in plugins_manager.operator_extra_links: for operator in ope.operators: if ( operator.__name__ == encoded_op["_task_type"] and operator.__module__ == encoded_op["_task_module"] ): op_extra_links_from_plugin.update({ope.name: ope}) # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized # set the Operator links attribute # The case for "If OperatorLinks are defined in the operator that is being Serialized" # is handled in the deserialization loop where it matches k == "_operator_extra_links" if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op: setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) for k, v in encoded_op.items(): # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed if k == "_is_dummy": k = "_is_empty" if k == "_downstream_task_ids": # Upgrade from old format/name k = "downstream_task_ids" if k == "label": # Label shouldn't be set anymore -- it's computed from task_id now continue elif k == "downstream_task_ids": v = set(v) elif k == "subdag": v = SerializedDAG.deserialize_dag(v) elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}: v = cls._deserialize_timedelta(v) elif k in encoded_op["template_fields"]: pass elif k == "resources": v = Resources.from_dict(v) elif k.endswith("_date"): v = cls._deserialize_datetime(v) elif k == "_operator_extra_links": if cls._load_operator_extra_links: op_predefined_extra_links = cls._deserialize_operator_extra_links(v) # If OperatorLinks with the same name exists, Links via Plugin have higher precedence op_predefined_extra_links.update(op_extra_links_from_plugin) else: op_predefined_extra_links = {} v = list(op_predefined_extra_links.values()) k = "operator_extra_links" elif k == "deps": v = cls._deserialize_deps(v) elif k == "params": v = cls._deserialize_params_dict(v) elif k in ("mapped_kwargs", "partial_kwargs"): if "op_kwargs" not in v: op_kwargs: Optional[dict] = None else: op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()} v = {arg: cls._deserialize(value) for arg, value in v.items()} if op_kwargs is not None: v["op_kwargs"] = op_kwargs elif k == "mapped_op_kwargs": v = {arg: cls._deserialize(value) for arg, value in v.items()} elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) # else use v as it is setattr(op, k, v) for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check # could go away. if not hasattr(op, k): setattr(op, k, None) # Set all the template_field to None that were not present in Serialized JSON for field in op.template_fields: if not hasattr(op, field): setattr(op, field, None) # Used to determine if an Operator is inherited from EmptyOperator setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))