Exemple #1
0
 def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid):
     self.log.debug("Creating Kubernetes executor")
     self.kube_config = kube_config
     self.task_queue = task_queue
     self.result_queue = result_queue
     self.namespace = self.kube_config.kube_namespace
     self.log.debug("Kubernetes using namespace %s", self.namespace)
     self.kube_client = kube_client
     self.launcher = PodLauncher(kube_client=self.kube_client)
     self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config)
     self._manager = multiprocessing.Manager()
     self.watcher_queue = self._manager.Queue()
     self.worker_uuid = worker_uuid
     self.kube_watcher = self._make_kube_watcher()
Exemple #2
0
 def __init__(self, kube_config: Any,
              task_queue: 'Queue[KubernetesJobType]',
              result_queue: 'Queue[KubernetesResultsType]',
              kube_client: client.CoreV1Api, scheduler_job_id: str):
     super().__init__()
     self.log.debug("Creating Kubernetes executor")
     self.kube_config = kube_config
     self.task_queue = task_queue
     self.result_queue = result_queue
     self.namespace = self.kube_config.kube_namespace
     self.log.debug("Kubernetes using namespace %s", self.namespace)
     self.kube_client = kube_client
     self.launcher = PodLauncher(kube_client=self.kube_client)
     self._manager = multiprocessing.Manager()
     self.watcher_queue = self._manager.Queue()
     self.scheduler_job_id = scheduler_job_id
     self.kube_watcher = self._make_kube_watcher()
class AirflowKubernetesScheduler(LoggingMixin):
    """Airflow Scheduler for Kubernetes"""
    def __init__(self, kube_config: Any,
                 task_queue: 'Queue[KubernetesJobType]',
                 result_queue: 'Queue[KubernetesResultsType]',
                 kube_client: client.CoreV1Api, worker_uuid: str):
        super().__init__()
        self.log.debug("Creating Kubernetes executor")
        self.kube_config = kube_config
        self.task_queue = task_queue
        self.result_queue = result_queue
        self.namespace = self.kube_config.kube_namespace
        self.log.debug("Kubernetes using namespace %s", self.namespace)
        self.kube_client = kube_client
        self.launcher = PodLauncher(kube_client=self.kube_client)
        self.worker_configuration_pod = WorkerConfiguration(
            kube_config=self.kube_config).as_pod()
        self._manager = multiprocessing.Manager()
        self.watcher_queue = self._manager.Queue()
        self.worker_uuid = worker_uuid
        self.kube_watcher = self._make_kube_watcher()

    def _make_kube_watcher(self) -> KubernetesJobWatcher:
        resource_version = KubeResourceVersion.get_current_resource_version()
        watcher = KubernetesJobWatcher(watcher_queue=self.watcher_queue,
                                       resource_version=resource_version,
                                       worker_uuid=self.worker_uuid,
                                       kube_config=self.kube_config)
        watcher.start()
        return watcher

    def _health_check_kube_watcher(self):
        if self.kube_watcher.is_alive():
            pass
        else:
            self.log.error('Error while health checking kube watcher process. '
                           'Process died for unknown reasons')
            self.kube_watcher = self._make_kube_watcher()

    def run_next(self, next_job: KubernetesJobType) -> None:
        """
        The run_next command will check the task_queue for any un-run jobs.
        It will then create a unique job-id, launch that job in the cluster,
        and store relevant info in the current_jobs map so we can track the job's
        status
        """
        self.log.info('Kubernetes job is %s', str(next_job))
        key, command, kube_executor_config = next_job
        dag_id, task_id, execution_date, try_number = key

        if isinstance(command, str):
            command = [command]

        pod = PodGenerator.construct_pod(
            namespace=self.namespace,
            worker_uuid=self.worker_uuid,
            pod_id=self._create_pod_id(dag_id, task_id),
            dag_id=pod_generator.make_safe_label_value(dag_id),
            task_id=pod_generator.make_safe_label_value(task_id),
            try_number=try_number,
            date=self._datetime_to_label_safe_datestring(execution_date),
            command=command,
            kube_executor_config=kube_executor_config,
            worker_config=self.worker_configuration_pod)
        # Reconcile the pod generated by the Operator and the Pod
        # generated by the .cfg file
        self.log.debug("Kubernetes running for command %s", command)
        self.log.debug("Kubernetes launching image %s",
                       pod.spec.containers[0].image)

        # the watcher will monitor pods, so we do not block.
        self.launcher.run_pod_async(
            pod, **self.kube_config.kube_client_request_args)
        self.log.debug("Kubernetes Job created!")

    def delete_pod(self, pod_id: str, namespace: str) -> None:
        """Deletes POD"""
        try:
            self.kube_client.delete_namespaced_pod(
                pod_id,
                namespace,
                body=client.V1DeleteOptions(
                    **self.kube_config.delete_option_kwargs),
                **self.kube_config.kube_client_request_args)
        except ApiException as e:
            # If the pod is already deleted
            if e.status != 404:
                raise

    def sync(self) -> None:
        """
        The sync function checks the status of all currently running kubernetes jobs.
        If a job is completed, its status is placed in the result queue to
        be sent back to the scheduler.

        :return:

        """
        self._health_check_kube_watcher()
        while True:
            try:
                task = self.watcher_queue.get_nowait()
                try:
                    self.process_watcher_task(task)
                finally:
                    self.watcher_queue.task_done()
            except Empty:
                break

    def process_watcher_task(self, task: KubernetesWatchType) -> None:
        """Process the task by watcher."""
        pod_id, namespace, state, labels, resource_version = task
        self.log.info(
            'Attempting to finish pod; pod_id: %s; state: %s; labels: %s',
            pod_id, state, labels)
        key = self._labels_to_key(labels=labels)
        if key:
            self.log.debug('finishing job %s - %s (%s)', key, state, pod_id)
            self.result_queue.put(
                (key, state, pod_id, namespace, resource_version))

    @staticmethod
    def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
        """
        Kubernetes only supports lowercase alphanumeric characters and "-" and "." in
        the pod name
        However, there are special rules about how "-" and "." can be used so let's
        only keep
        alphanumeric chars  see here for detail:
        https://kubernetes.io/docs/concepts/overview/working-with-objects/names/

        :param string: The requested Pod name
        :return: ``str`` Pod name stripped of any unsafe characters
        """
        return ''.join(ch.lower() for ind, ch in enumerate(string)
                       if ch.isalnum())

    @staticmethod
    def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str,
                          safe_uuid: str) -> str:
        """
        Kubernetes pod names must be <= 253 chars and must pass the following regex for
        validation
        ``^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$``

        :param safe_dag_id: a dag_id with only alphanumeric characters
        :param safe_task_id: a task_id with only alphanumeric characters
        :param safe_uuid: a uuid
        :return: ``str`` valid Pod name of appropriate length
        """
        safe_key = safe_dag_id + safe_task_id

        safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) -
                               1] + "-" + safe_uuid

        return safe_pod_id

    @staticmethod
    def _create_pod_id(dag_id: str, task_id: str) -> str:
        safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
            dag_id)
        safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
            task_id)
        return safe_dag_id + safe_task_id

    @staticmethod
    def _label_safe_datestring_to_datetime(string: str) -> datetime.datetime:
        """
        Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not
        "_", let's
        replace ":" with "_"

        :param string: str
        :return: datetime.datetime object
        """
        return parser.parse(string.replace('_plus_', '+').replace("_", ":"))

    @staticmethod
    def _datetime_to_label_safe_datestring(
            datetime_obj: datetime.datetime) -> str:
        """
        Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but
        not "_" let's
        replace ":" with "_"

        :param datetime_obj: datetime.datetime object
        :return: ISO-like string representing the datetime
        """
        return datetime_obj.isoformat().replace(":",
                                                "_").replace('+', '_plus_')

    def _labels_to_key(
            self, labels: Dict[str, str]) -> Optional[TaskInstanceKeyType]:
        try_num = 1
        try:
            try_num = int(labels.get('try_number', '1'))
        except ValueError:
            self.log.warning("could not get try_number as an int: %s",
                             labels.get('try_number', '1'))

        try:
            dag_id = labels['dag_id']
            task_id = labels['task_id']
            ex_time = self._label_safe_datestring_to_datetime(
                labels['execution_date'])
        except Exception as e:  # pylint: disable=broad-except
            self.log.warning(
                'Error while retrieving labels; labels: %s; exception: %s',
                labels, e)
            return None

        with create_session() as session:
            task = (session.query(TaskInstance).filter_by(
                task_id=task_id, dag_id=dag_id,
                execution_date=ex_time).one_or_none())
            if task:
                self.log.info(
                    'Found matching task %s-%s (%s) with current state of %s',
                    task.dag_id, task.task_id, task.execution_date, task.state)
                return (dag_id, task_id, ex_time, try_num)
            else:
                self.log.warning(
                    'task_id/dag_id are not safe to use as Kubernetes labels. This can cause '
                    'severe performance regressions. Please see '
                    '<https://kubernetes.io/docs/concepts/overview/working-with-objects'
                    '/labels/#syntax-and-character-set>. '
                    'Given dag_id: %s, task_id: %s', task_id, dag_id)

            tasks = (session.query(TaskInstance).filter_by(
                execution_date=ex_time).all())
            self.log.info('Checking %s task instances.', len(tasks))
            for task in tasks:
                if (pod_generator.make_safe_label_value(task.dag_id) == dag_id
                        and pod_generator.make_safe_label_value(task.task_id)
                        == task_id and task.execution_date == ex_time):
                    self.log.info(
                        'Found matching task %s-%s (%s) with current state of %s',
                        task.dag_id, task.task_id, task.execution_date,
                        task.state)
                    dag_id = task.dag_id
                    task_id = task.task_id
                    return dag_id, task_id, ex_time, try_num
        self.log.warning(
            'Failed to find and match task details to a pod; labels: %s',
            labels)
        return None

    def _flush_watcher_queue(self) -> None:
        self.log.debug('Executor shutting down, watcher_queue approx. size=%d',
                       self.watcher_queue.qsize())
        while True:
            try:
                task = self.watcher_queue.get_nowait()
                # Ignoring it since it can only have either FAILED or SUCCEEDED pods
                self.log.warning(
                    'Executor shutting down, IGNORING watcher task=%s', task)
                self.watcher_queue.task_done()
            except Empty:
                break

    def terminate(self) -> None:
        """Terminates the watcher."""
        self.log.debug("Terminating kube_watcher...")
        self.kube_watcher.terminate()
        self.kube_watcher.join()
        self.log.debug("kube_watcher=%s", self.kube_watcher)
        self.log.debug("Flushing watcher_queue...")
        self._flush_watcher_queue()
        # Queue should be empty...
        self.watcher_queue.join()
        self.log.debug("Shutting down manager...")
        self._manager.shutdown()
Exemple #4
0
class AirflowKubernetesScheduler(LoggingMixin):
    def __init__(self, kube_config, task_queue, result_queue, kube_client,
                 worker_uuid):
        self.log.debug("Creating Kubernetes executor")
        self.kube_config = kube_config
        self.task_queue = task_queue
        self.result_queue = result_queue
        self.namespace = self.kube_config.kube_namespace
        self.log.debug("Kubernetes using namespace %s", self.namespace)
        self.kube_client = kube_client
        self.launcher = PodLauncher(kube_client=self.kube_client)
        self.worker_configuration = WorkerConfiguration(
            kube_config=self.kube_config)
        self._manager = multiprocessing.Manager()
        self.watcher_queue = self._manager.Queue()
        self.worker_uuid = worker_uuid
        self.kube_watcher = self._make_kube_watcher()

    def _make_kube_watcher(self):
        resource_version = KubeResourceVersion.get_current_resource_version()
        watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue,
                                       resource_version, self.worker_uuid,
                                       self.kube_config)
        watcher.start()
        return watcher

    def _health_check_kube_watcher(self):
        if self.kube_watcher.is_alive():
            pass
        else:
            self.log.error('Error while health checking kube watcher process. '
                           'Process died for unknown reasons')
            self.kube_watcher = self._make_kube_watcher()

    def run_next(self, next_job):
        """

        The run_next command will check the task_queue for any un-run jobs.
        It will then create a unique job-id, launch that job in the cluster,
        and store relevant info in the current_jobs map so we can track the job's
        status
        """
        self.log.info('Kubernetes job is %s', str(next_job))
        key, command, kube_executor_config = next_job
        dag_id, task_id, execution_date, try_number = key
        self.log.debug("Kubernetes running for command %s", command)
        self.log.debug("Kubernetes launching image %s",
                       self.kube_config.kube_image)
        pod = self.worker_configuration.make_pod(
            namespace=self.namespace,
            worker_uuid=self.worker_uuid,
            pod_id=self._create_pod_id(dag_id, task_id),
            dag_id=self._make_safe_label_value(dag_id),
            task_id=self._make_safe_label_value(task_id),
            try_number=try_number,
            execution_date=self._datetime_to_label_safe_datestring(
                execution_date),
            airflow_command=command,
            kube_executor_config=kube_executor_config)
        # the watcher will monitor pods, so we do not block.
        self.launcher.run_pod_async(
            pod, **self.kube_config.kube_client_request_args)
        self.log.debug("Kubernetes Job created!")

    def delete_pod(self, pod_id):
        if self.kube_config.delete_worker_pods:
            try:
                self.kube_client.delete_namespaced_pod(
                    pod_id,
                    self.namespace,
                    body=client.V1DeleteOptions(),
                    **self.kube_config.kube_client_request_args)
            except ApiException as e:
                # If the pod is already deleted
                if e.status != 404:
                    raise

    def sync(self):
        """
        The sync function checks the status of all currently running kubernetes jobs.
        If a job is completed, it's status is placed in the result queue to
        be sent back to the scheduler.

        :return:

        """
        self._health_check_kube_watcher()
        while True:
            try:
                task = self.watcher_queue.get_nowait()
                try:
                    self.process_watcher_task(task)
                finally:
                    self.watcher_queue.task_done()
            except Empty:
                break

    def process_watcher_task(self, task):
        pod_id, state, labels, resource_version = task
        self.log.info(
            'Attempting to finish pod; pod_id: %s; state: %s; labels: %s',
            pod_id, state, labels)
        key = self._labels_to_key(labels=labels)
        if key:
            self.log.debug('finishing job %s - %s (%s)', key, state, pod_id)
            self.result_queue.put((key, state, pod_id, resource_version))

    @staticmethod
    def _strip_unsafe_kubernetes_special_chars(string):
        """
        Kubernetes only supports lowercase alphanumeric characters and "-" and "." in
        the pod name
        However, there are special rules about how "-" and "." can be used so let's
        only keep
        alphanumeric chars  see here for detail:
        https://kubernetes.io/docs/concepts/overview/working-with-objects/names/

        :param string: The requested Pod name
        :return: ``str`` Pod name stripped of any unsafe characters
        """
        return ''.join(ch.lower() for ind, ch in enumerate(string)
                       if ch.isalnum())

    @staticmethod
    def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid):
        """
        Kubernetes pod names must be <= 253 chars and must pass the following regex for
        validation
        "^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$"

        :param safe_dag_id: a dag_id with only alphanumeric characters
        :param safe_task_id: a task_id with only alphanumeric characters
        :param random_uuid: a uuid
        :return: ``str`` valid Pod name of appropriate length
        """
        MAX_POD_ID_LEN = 253

        safe_key = safe_dag_id + safe_task_id

        safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) -
                               1] + "-" + safe_uuid

        return safe_pod_id

    @staticmethod
    def _make_safe_label_value(string):
        """
        Valid label values must be 63 characters or less and must be empty or begin and
        end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_),
        dots (.), and alphanumerics between.

        If the label value is then greater than 63 chars once made safe, or differs in any
        way from the original value sent to this function, then we need to truncate to
        53chars, and append it with a unique hash.
        """
        MAX_LABEL_LEN = 63

        safe_label = re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$',
                            '', string)

        if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
            safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
            safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) -
                                    1] + "-" + safe_hash

        return safe_label

    @staticmethod
    def _create_pod_id(dag_id, task_id):
        safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
            dag_id)
        safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
            task_id)
        safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
            uuid4().hex)
        return AirflowKubernetesScheduler._make_safe_pod_id(
            safe_dag_id, safe_task_id, safe_uuid)

    @staticmethod
    def _label_safe_datestring_to_datetime(string):
        """
        Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not
        "_", let's
        replace ":" with "_"

        :param string: str
        :return: datetime.datetime object
        """
        return parser.parse(string.replace('_plus_', '+').replace("_", ":"))

    @staticmethod
    def _datetime_to_label_safe_datestring(datetime_obj):
        """
        Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but
        not "_" let's
        replace ":" with "_"
        :param datetime_obj: datetime.datetime object
        :return: ISO-like string representing the datetime
        """
        return datetime_obj.isoformat().replace(":",
                                                "_").replace('+', '_plus_')

    def _labels_to_key(self, labels):
        try_num = 1
        try:
            try_num = int(labels.get('try_number', '1'))
        except ValueError:
            self.log.warn("could not get try_number as an int: %s",
                          labels.get('try_number', '1'))

        try:
            dag_id = labels['dag_id']
            task_id = labels['task_id']
            ex_time = self._label_safe_datestring_to_datetime(
                labels['execution_date'])
        except Exception as e:
            self.log.warn(
                'Error while retrieving labels; labels: %s; exception: %s',
                labels, e)
            return None

        with create_session() as session:
            tasks = (session.query(TaskInstance).filter_by(
                execution_date=ex_time).all())
            self.log.info('Checking %s task instances.', len(tasks))
            for task in tasks:
                if (self._make_safe_label_value(task.dag_id) == dag_id and
                        self._make_safe_label_value(task.task_id) == task_id
                        and task.execution_date == ex_time):
                    self.log.info(
                        'Found matching task %s-%s (%s) with current state of %s',
                        task.dag_id, task.task_id, task.execution_date,
                        task.state)
                    dag_id = task.dag_id
                    task_id = task.task_id
                    return (dag_id, task_id, ex_time, try_num)
        self.log.warn(
            'Failed to find and match task details to a pod; labels: %s',
            labels)
        return None

    def terminate(self):
        self.watcher_queue.join()
        self._manager.shutdown()
class AirflowKubernetesScheduler(LoggingMixin):
    """Airflow Scheduler for Kubernetes"""
    def __init__(
        self,
        kube_config: Any,
        task_queue: 'Queue[KubernetesJobType]',
        result_queue: 'Queue[KubernetesResultsType]',
        kube_client: client.CoreV1Api,
        scheduler_job_id: str,
    ):
        super().__init__()
        self.log.debug("Creating Kubernetes executor")
        self.kube_config = kube_config
        self.task_queue = task_queue
        self.result_queue = result_queue
        self.namespace = self.kube_config.kube_namespace
        self.log.debug("Kubernetes using namespace %s", self.namespace)
        self.kube_client = kube_client
        self.launcher = PodLauncher(kube_client=self.kube_client)
        self._manager = multiprocessing.Manager()
        self.watcher_queue = self._manager.Queue()
        self.scheduler_job_id = scheduler_job_id
        self.kube_watcher = self._make_kube_watcher()

    def _make_kube_watcher(self) -> KubernetesJobWatcher:
        resource_version = ResourceVersion().resource_version
        watcher = KubernetesJobWatcher(
            watcher_queue=self.watcher_queue,
            namespace=self.kube_config.kube_namespace,
            multi_namespace_mode=self.kube_config.multi_namespace_mode,
            resource_version=resource_version,
            scheduler_job_id=self.scheduler_job_id,
            kube_config=self.kube_config,
        )
        watcher.start()
        return watcher

    def _health_check_kube_watcher(self):
        if self.kube_watcher.is_alive():
            self.log.debug("KubeJobWatcher alive, continuing")
        else:
            self.log.error(
                'Error while health checking kube watcher process. Process died for unknown reasons'
            )
            self.kube_watcher = self._make_kube_watcher()

    def run_next(self, next_job: KubernetesJobType) -> None:
        """
        The run_next command will check the task_queue for any un-run jobs.
        It will then create a unique job-id, launch that job in the cluster,
        and store relevant info in the current_jobs map so we can track the job's
        status
        """
        self.log.info('Kubernetes job is %s', str(next_job))
        key, command, kube_executor_config, pod_template_file = next_job
        dag_id, task_id, execution_date, try_number = key

        if command[0:3] != ["airflow", "tasks", "run"]:
            raise ValueError(
                'The command must start with ["airflow", "tasks", "run"].')

        base_worker_pod = get_base_pod_from_template(pod_template_file,
                                                     self.kube_config)

        if not base_worker_pod:
            raise AirflowException(
                f"could not find a valid worker template yaml at {self.kube_config.pod_template_file}"
            )

        pod = PodGenerator.construct_pod(
            namespace=self.namespace,
            scheduler_job_id=self.scheduler_job_id,
            pod_id=create_pod_id(dag_id, task_id),
            dag_id=dag_id,
            task_id=task_id,
            kube_image=self.kube_config.kube_image,
            try_number=try_number,
            date=execution_date,
            args=command,
            pod_override_object=kube_executor_config,
            base_worker_pod=base_worker_pod,
        )
        # Reconcile the pod generated by the Operator and the Pod
        # generated by the .cfg file
        self.log.debug("Kubernetes running for command %s", command)
        self.log.debug("Kubernetes launching image %s",
                       pod.spec.containers[0].image)

        # the watcher will monitor pods, so we do not block.
        self.launcher.run_pod_async(
            pod, **self.kube_config.kube_client_request_args)
        self.log.debug("Kubernetes Job created!")

    def delete_pod(self, pod_id: str, namespace: str) -> None:
        """Deletes POD"""
        try:
            self.log.debug("Deleting pod %s in namespace %s", pod_id,
                           namespace)
            self.kube_client.delete_namespaced_pod(
                pod_id,
                namespace,
                body=client.V1DeleteOptions(
                    **self.kube_config.delete_option_kwargs),
                **self.kube_config.kube_client_request_args,
            )
        except ApiException as e:
            # If the pod is already deleted
            if e.status != 404:
                raise

    def sync(self) -> None:
        """
        The sync function checks the status of all currently running kubernetes jobs.
        If a job is completed, its status is placed in the result queue to
        be sent back to the scheduler.

        :return:

        """
        self.log.debug("Syncing KubernetesExecutor")
        self._health_check_kube_watcher()
        while True:
            try:
                task = self.watcher_queue.get_nowait()
                try:
                    self.log.debug("Processing task %s", task)
                    self.process_watcher_task(task)
                finally:
                    self.watcher_queue.task_done()
            except Empty:
                break

    def process_watcher_task(self, task: KubernetesWatchType) -> None:
        """Process the task by watcher."""
        pod_id, namespace, state, annotations, resource_version = task
        self.log.info(
            'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s',
            pod_id, state, annotations)
        key = self._annotations_to_key(annotations=annotations)
        if key:
            self.log.debug('finishing job %s - %s (%s)', key, state, pod_id)
            self.result_queue.put(
                (key, state, pod_id, namespace, resource_version))

    def _annotations_to_key(
            self, annotations: Dict[str, str]) -> Optional[TaskInstanceKey]:
        self.log.debug("Creating task key for annotations %s", annotations)
        dag_id = annotations['dag_id']
        task_id = annotations['task_id']
        try_number = int(annotations['try_number'])
        execution_date = parser.parse(annotations['execution_date'])

        return TaskInstanceKey(dag_id, task_id, execution_date, try_number)

    @staticmethod
    def _make_safe_pod_id(safe_dag_id: str, safe_task_id: str,
                          safe_uuid: str) -> str:
        r"""
        Kubernetes pod names must be <= 253 chars and must pass the following regex for
        validation
        ``^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$``

        :param safe_dag_id: a dag_id with only alphanumeric characters
        :param safe_task_id: a task_id with only alphanumeric characters
        :param safe_uuid: a uuid
        :return: ``str`` valid Pod name of appropriate length
        """
        safe_key = safe_dag_id + safe_task_id

        safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) -
                               1] + "-" + safe_uuid

        return safe_pod_id

    def _flush_watcher_queue(self) -> None:
        self.log.debug('Executor shutting down, watcher_queue approx. size=%d',
                       self.watcher_queue.qsize())
        while True:
            try:
                task = self.watcher_queue.get_nowait()
                # Ignoring it since it can only have either FAILED or SUCCEEDED pods
                self.log.warning(
                    'Executor shutting down, IGNORING watcher task=%s', task)
                self.watcher_queue.task_done()
            except Empty:
                break

    def terminate(self) -> None:
        """Terminates the watcher."""
        self.log.debug("Terminating kube_watcher...")
        self.kube_watcher.terminate()
        self.kube_watcher.join()
        self.log.debug("kube_watcher=%s", self.kube_watcher)
        self.log.debug("Flushing watcher_queue...")
        self._flush_watcher_queue()
        # Queue should be empty...
        self.watcher_queue.join()
        self.log.debug("Shutting down manager...")
        self._manager.shutdown()
 def setUp(self):
     self.mock_kube_client = mock.Mock()
     self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
class TestPodLauncher(unittest.TestCase):
    def setUp(self):
        self.mock_kube_client = mock.Mock()
        self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)

    def test_read_pod_logs_successfully_returns_logs(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        self.assertEqual(mock.sentinel.logs, logs)

    def test_read_pod_logs_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        self.assertEqual(mock.sentinel.logs, logs)
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(_preload_content=False,
                      container='base',
                      follow=True,
                      name=mock.sentinel.metadata.name,
                      namespace=mock.sentinel.metadata.namespace,
                      tail_lines=10),
            mock.call(_preload_content=False,
                      container='base',
                      follow=True,
                      name=mock.sentinel.metadata.name,
                      namespace=mock.sentinel.metadata.namespace,
                      tail_lines=10)
        ])

    def test_read_pod_logs_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs,
                          mock.sentinel)

    def test_read_pod_events_successfully_returns_events(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.return_value = mock.sentinel.events
        events = self.pod_launcher.read_pod_events(mock.sentinel)
        self.assertEqual(mock.sentinel.events, events)

    def test_read_pod_events_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.events
        ]
        events = self.pod_launcher.read_pod_events(mock.sentinel)
        self.assertEqual(mock.sentinel.events, events)
        self.mock_kube_client.list_namespaced_event.assert_has_calls([
            mock.call(namespace=mock.sentinel.metadata.namespace,
                      field_selector="involvedObject.name={}".format(
                          mock.sentinel.metadata.name)),
            mock.call(namespace=mock.sentinel.metadata.namespace,
                      field_selector="involvedObject.name={}".format(
                          mock.sentinel.metadata.name))
        ])

    def test_read_pod_events_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod_events,
                          mock.sentinel)

    def test_read_pod_returns_logs(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.return_value = mock.sentinel.pod_info
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        self.assertEqual(mock.sentinel.pod_info, pod_info)

    def test_read_pod_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.pod_info
        ]
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        self.assertEqual(mock.sentinel.pod_info, pod_info)
        self.mock_kube_client.read_namespaced_pod.assert_has_calls([
            mock.call(mock.sentinel.metadata.name,
                      mock.sentinel.metadata.namespace),
            mock.call(mock.sentinel.metadata.name,
                      mock.sentinel.metadata.namespace)
        ])

    def test_read_pod_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod,
                          mock.sentinel)
    def test_pod_mutation_v1_pod(self):
        with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore
            from airflow.kubernetes.pod_launcher import PodLauncher

            self.mock_kube_client = Mock()
            self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
            pod = pod_generator.PodGenerator(image="myimage",
                                             cmds=["foo"],
                                             namespace="baz",
                                             volume_mounts=[{
                                                 "name": "foo",
                                                 "mountPath": "/mnt",
                                                 "subPath": "/",
                                                 "readOnly": True
                                             }],
                                             volumes=[{
                                                 "name": "foo"
                                             }]).gen_pod()

            sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(
                pod)

            self.assertEqual(
                sanitized_pod_pre_mutation, {
                    'apiVersion': 'v1',
                    'kind': 'Pod',
                    'metadata': {
                        'namespace': 'baz'
                    },
                    'spec': {
                        'containers': [{
                            'args': [],
                            'command': ['foo'],
                            'env': [],
                            'envFrom': [],
                            'image':
                            'myimage',
                            'name':
                            'base',
                            'ports': [],
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'volumes': [{
                            'name': 'foo'
                        }]
                    }
                })

            # Apply Pod Mutation Hook
            pod = self.pod_launcher._mutate_pod_backcompat(pod)

            sanitized_pod_post_mutation = api_client.sanitize_for_serialization(
                pod)
            self.assertEqual(
                sanitized_pod_post_mutation, {
                    'apiVersion': 'v1',
                    'kind': 'Pod',
                    'metadata': {
                        'namespace': 'airflow-tests'
                    },
                    'spec': {
                        'containers': [{
                            'args': [],
                            'command': ['foo'],
                            'env': [{
                                'name': 'TEST_USER',
                                'value': 'ADMIN'
                            }],
                            'envFrom': [],
                            'image':
                            'test-image',
                            'name':
                            'base',
                            'ports': [{
                                'containerPort': 8080
                            }, {
                                'containerPort': 8081
                            }],
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }, {
                                'mountPath': '/opt/airflow/secrets/',
                                'name': 'airflow-secrets-mount',
                                'readOnly': True
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'volumes': [{
                            'name': 'foo'
                        }, {
                            'name': 'airflow-secrets-mount',
                            'secret': {
                                'secretName': 'airflow-test-secrets'
                            }
                        }]
                    }
                })
    def test_pod_mutation_to_k8s_pod(self):
        with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore
            from airflow.kubernetes.pod_launcher import PodLauncher

            self.mock_kube_client = Mock()
            self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
            init_container = k8s.V1Container(name="init-container",
                                             volume_mounts=[
                                                 k8s.V1VolumeMount(
                                                     mount_path="/tmp",
                                                     name="init-secret")
                                             ])
            pod = pod_generator.PodGenerator(
                image="foo",
                name="bar",
                namespace="baz",
                image_pull_policy="Never",
                init_containers=[init_container],
                cmds=["foo"],
                args=["/bin/sh", "-c", "touch /tmp/healthy"],
                tolerations=[{
                    'effect': 'NoSchedule',
                    'key': 'static-pods',
                    'operator': 'Equal',
                    'value': 'true'
                }],
                volume_mounts=[{
                    "name": "foo",
                    "mountPath": "/mnt",
                    "subPath": "/",
                    "readOnly": True
                }],
                security_context=k8s.V1PodSecurityContext(fs_group=0,
                                                          run_as_user=1),
                volumes=[k8s.V1Volume(name="foo")]).gen_pod()

            sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(
                pod)
            self.assertEqual(
                sanitized_pod_pre_mutation,
                {
                    'apiVersion': 'v1',
                    'kind': 'Pod',
                    'metadata': {
                        'name': mock.ANY,
                        'namespace': 'baz'
                    },
                    'spec': {
                        'containers': [{
                            'args': ['/bin/sh', '-c', 'touch /tmp/healthy'],
                            'command': ['foo'],
                            'env': [],
                            'envFrom': [],
                            'image':
                            'foo',
                            'imagePullPolicy':
                            'Never',
                            'name':
                            'base',
                            'ports': [],
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }]
                        }],
                        'initContainers': [{
                            'name':
                            'init-container',
                            'volumeMounts': [{
                                'mountPath': '/tmp',
                                'name': 'init-secret'
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'tolerations': [{
                            'effect': 'NoSchedule',
                            'key': 'static-pods',
                            'operator': 'Equal',
                            'value': 'true'
                        }],
                        'volumes': [{
                            'name': 'foo'
                        }],
                        'securityContext': {
                            'fsGroup': 0,
                            'runAsUser': 1
                        }
                    }
                },
            )

            # Apply Pod Mutation Hook
            pod = self.pod_launcher._mutate_pod_backcompat(pod)

            sanitized_pod_post_mutation = api_client.sanitize_for_serialization(
                pod)

            self.assertEqual(
                sanitized_pod_post_mutation, {
                    "apiVersion": "v1",
                    "kind": "Pod",
                    'metadata': {
                        'labels': {
                            'test_label': 'test_value'
                        },
                        'name': mock.ANY,
                        'namespace': 'airflow-tests'
                    },
                    'spec': {
                        'affinity': {
                            'nodeAffinity': {
                                'requiredDuringSchedulingIgnoredDuringExecution':
                                {
                                    'nodeSelectorTerms': [{
                                        'matchExpressions': [{
                                            'key': 'test/dynamic-pods',
                                            'operator': 'In',
                                            'values': ['true']
                                        }]
                                    }]
                                }
                            }
                        },
                        'containers': [{
                            'args': ['/bin/sh', '-c', 'touch /tmp/healthy2'],
                            'command': ['foo'],
                            'env': [{
                                'name': 'TEST_USER',
                                'value': 'ADMIN'
                            }],
                            'image':
                            'my_image',
                            'imagePullPolicy':
                            'Never',
                            'name':
                            'base',
                            'ports': [{
                                'containerPort': 8080
                            }, {
                                'containerPort': 8081
                            }],
                            'resources': {
                                'limits': {
                                    'nvidia.com/gpu': '200G'
                                },
                                'requests': {
                                    'cpu': '200Mi',
                                    'memory': '2G'
                                }
                            },
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }, {
                                'mountPath': '/opt/airflow/secrets/',
                                'name': 'airflow-secrets-mount',
                                'readOnly': True
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'initContainers': [{
                            'name':
                            'init-container',
                            'securityContext': {
                                'runAsGroup': 50000,
                                'runAsUser': 50000
                            },
                            'volumeMounts': [{
                                'mountPath': '/tmp',
                                'name': 'init-secret'
                            }]
                        }],
                        'tolerations': [{
                            'effect': 'NoSchedule',
                            'key': 'static-pods',
                            'operator': 'Equal',
                            'value': 'true'
                        }, {
                            'effect': 'NoSchedule',
                            'key': 'dynamic-pods',
                            'operator': 'Equal',
                            'value': 'true'
                        }],
                        'volumes': [
                            {
                                'name': 'airflow-secrets-mount',
                                'secret': {
                                    'secretName': 'airflow-test-secrets'
                                }
                            },
                            {
                                'name': 'bar'
                            },
                            {
                                'name': 'foo'
                            },
                        ],
                        'securityContext': {
                            'runAsUser': 1
                        }
                    }
                })
class LocalSettingsTest(unittest.TestCase):
    # Make sure that the configure_logging is not cached
    def setUp(self):
        self.old_modules = dict(sys.modules)
        self.maxDiff = None

    def tearDown(self):
        # Remove any new modules imported during the test run. This lets us
        # import the same source files for more than one test.
        for mod in [m for m in sys.modules if m not in self.old_modules]:
            del sys.modules[mod]

    @patch("airflow.settings.import_local_settings")
    @patch("airflow.settings.prepare_syspath")
    def test_initialize_order(self, prepare_syspath, import_local_settings):
        """
        Tests that import_local_settings is called after prepare_classpath
        """
        mock = Mock()
        mock.attach_mock(prepare_syspath, "prepare_syspath")
        mock.attach_mock(import_local_settings, "import_local_settings")

        import airflow.settings
        airflow.settings.initialize()

        mock.assert_has_calls(
            [call.prepare_syspath(),
             call.import_local_settings()])

    def test_import_with_dunder_all_not_specified(self):
        """
        Tests that if __all__ is specified in airflow_local_settings,
        only module attributes specified within are imported.
        """
        with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore

            with self.assertRaises(AttributeError):
                settings.not_policy()

    def test_import_with_dunder_all(self):
        """
        Tests that if __all__ is specified in airflow_local_settings,
        only module attributes specified within are imported.
        """
        with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore

            task_instance = MagicMock()
            settings.test_policy(task_instance)

            assert task_instance.run_as_user == "myself"

    @patch("airflow.settings.log.debug")
    def test_import_local_settings_without_syspath(self, log_mock):
        """
        Tests that an ImportError is raised in import_local_settings
        if there is no airflow_local_settings module on the syspath.
        """
        from airflow import settings
        settings.import_local_settings()
        log_mock.assert_called_with("Failed to import airflow_local_settings.",
                                    exc_info=True)

    def test_policy_function(self):
        """
        Tests that task instances are mutated by the policy
        function in airflow_local_settings.
        """
        with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore

            task_instance = MagicMock()
            settings.test_policy(task_instance)

            assert task_instance.run_as_user == "myself"

    def test_pod_mutation_hook(self):
        """
        Tests that pods are mutated by the pod_mutation_hook
        function in airflow_local_settings.
        """
        with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore

            pod = MagicMock()
            pod.volumes = []
            settings.pod_mutation_hook(pod)

            assert pod.namespace == 'airflow-tests'
            self.assertEqual(pod.volumes[0].name, "bar")

    def test_pod_mutation_to_k8s_pod(self):
        with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore
            from airflow.kubernetes.pod_launcher import PodLauncher

            self.mock_kube_client = Mock()
            self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
            init_container = k8s.V1Container(name="init-container",
                                             volume_mounts=[
                                                 k8s.V1VolumeMount(
                                                     mount_path="/tmp",
                                                     name="init-secret")
                                             ])
            pod = pod_generator.PodGenerator(
                image="foo",
                name="bar",
                namespace="baz",
                image_pull_policy="Never",
                init_containers=[init_container],
                cmds=["foo"],
                args=["/bin/sh", "-c", "touch /tmp/healthy"],
                tolerations=[{
                    'effect': 'NoSchedule',
                    'key': 'static-pods',
                    'operator': 'Equal',
                    'value': 'true'
                }],
                volume_mounts=[{
                    "name": "foo",
                    "mountPath": "/mnt",
                    "subPath": "/",
                    "readOnly": True
                }],
                security_context=k8s.V1PodSecurityContext(fs_group=0,
                                                          run_as_user=1),
                volumes=[k8s.V1Volume(name="foo")]).gen_pod()

            sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(
                pod)
            self.assertEqual(
                sanitized_pod_pre_mutation,
                {
                    'apiVersion': 'v1',
                    'kind': 'Pod',
                    'metadata': {
                        'name': mock.ANY,
                        'namespace': 'baz'
                    },
                    'spec': {
                        'containers': [{
                            'args': ['/bin/sh', '-c', 'touch /tmp/healthy'],
                            'command': ['foo'],
                            'env': [],
                            'envFrom': [],
                            'image':
                            'foo',
                            'imagePullPolicy':
                            'Never',
                            'name':
                            'base',
                            'ports': [],
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }]
                        }],
                        'initContainers': [{
                            'name':
                            'init-container',
                            'volumeMounts': [{
                                'mountPath': '/tmp',
                                'name': 'init-secret'
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'tolerations': [{
                            'effect': 'NoSchedule',
                            'key': 'static-pods',
                            'operator': 'Equal',
                            'value': 'true'
                        }],
                        'volumes': [{
                            'name': 'foo'
                        }],
                        'securityContext': {
                            'fsGroup': 0,
                            'runAsUser': 1
                        }
                    }
                },
            )

            # Apply Pod Mutation Hook
            pod = self.pod_launcher._mutate_pod_backcompat(pod)

            sanitized_pod_post_mutation = api_client.sanitize_for_serialization(
                pod)

            self.assertEqual(
                sanitized_pod_post_mutation, {
                    "apiVersion": "v1",
                    "kind": "Pod",
                    'metadata': {
                        'labels': {
                            'test_label': 'test_value'
                        },
                        'name': mock.ANY,
                        'namespace': 'airflow-tests'
                    },
                    'spec': {
                        'affinity': {
                            'nodeAffinity': {
                                'requiredDuringSchedulingIgnoredDuringExecution':
                                {
                                    'nodeSelectorTerms': [{
                                        'matchExpressions': [{
                                            'key': 'test/dynamic-pods',
                                            'operator': 'In',
                                            'values': ['true']
                                        }]
                                    }]
                                }
                            }
                        },
                        'containers': [{
                            'args': ['/bin/sh', '-c', 'touch /tmp/healthy2'],
                            'command': ['foo'],
                            'env': [{
                                'name': 'TEST_USER',
                                'value': 'ADMIN'
                            }],
                            'image':
                            'my_image',
                            'imagePullPolicy':
                            'Never',
                            'name':
                            'base',
                            'ports': [{
                                'containerPort': 8080
                            }, {
                                'containerPort': 8081
                            }],
                            'resources': {
                                'limits': {
                                    'nvidia.com/gpu': '200G'
                                },
                                'requests': {
                                    'cpu': '200Mi',
                                    'memory': '2G'
                                }
                            },
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }, {
                                'mountPath': '/opt/airflow/secrets/',
                                'name': 'airflow-secrets-mount',
                                'readOnly': True
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'initContainers': [{
                            'name':
                            'init-container',
                            'securityContext': {
                                'runAsGroup': 50000,
                                'runAsUser': 50000
                            },
                            'volumeMounts': [{
                                'mountPath': '/tmp',
                                'name': 'init-secret'
                            }]
                        }],
                        'tolerations': [{
                            'effect': 'NoSchedule',
                            'key': 'static-pods',
                            'operator': 'Equal',
                            'value': 'true'
                        }, {
                            'effect': 'NoSchedule',
                            'key': 'dynamic-pods',
                            'operator': 'Equal',
                            'value': 'true'
                        }],
                        'volumes': [
                            {
                                'name': 'airflow-secrets-mount',
                                'secret': {
                                    'secretName': 'airflow-test-secrets'
                                }
                            },
                            {
                                'name': 'bar'
                            },
                            {
                                'name': 'foo'
                            },
                        ],
                        'securityContext': {
                            'runAsUser': 1
                        }
                    }
                })

    def test_pod_mutation_v1_pod(self):
        with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD,
                             "airflow_local_settings"):
            from airflow import settings
            settings.import_local_settings()  # pylint: ignore
            from airflow.kubernetes.pod_launcher import PodLauncher

            self.mock_kube_client = Mock()
            self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
            pod = pod_generator.PodGenerator(image="myimage",
                                             cmds=["foo"],
                                             namespace="baz",
                                             volume_mounts=[{
                                                 "name": "foo",
                                                 "mountPath": "/mnt",
                                                 "subPath": "/",
                                                 "readOnly": True
                                             }],
                                             volumes=[{
                                                 "name": "foo"
                                             }]).gen_pod()

            sanitized_pod_pre_mutation = api_client.sanitize_for_serialization(
                pod)

            self.assertEqual(
                sanitized_pod_pre_mutation, {
                    'apiVersion': 'v1',
                    'kind': 'Pod',
                    'metadata': {
                        'namespace': 'baz'
                    },
                    'spec': {
                        'containers': [{
                            'args': [],
                            'command': ['foo'],
                            'env': [],
                            'envFrom': [],
                            'image':
                            'myimage',
                            'name':
                            'base',
                            'ports': [],
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'volumes': [{
                            'name': 'foo'
                        }]
                    }
                })

            # Apply Pod Mutation Hook
            pod = self.pod_launcher._mutate_pod_backcompat(pod)

            sanitized_pod_post_mutation = api_client.sanitize_for_serialization(
                pod)
            self.assertEqual(
                sanitized_pod_post_mutation, {
                    'apiVersion': 'v1',
                    'kind': 'Pod',
                    'metadata': {
                        'namespace': 'airflow-tests'
                    },
                    'spec': {
                        'containers': [{
                            'args': [],
                            'command': ['foo'],
                            'env': [{
                                'name': 'TEST_USER',
                                'value': 'ADMIN'
                            }],
                            'envFrom': [],
                            'image':
                            'test-image',
                            'name':
                            'base',
                            'ports': [{
                                'containerPort': 8080
                            }, {
                                'containerPort': 8081
                            }],
                            'volumeMounts': [{
                                'mountPath': '/mnt',
                                'name': 'foo',
                                'readOnly': True,
                                'subPath': '/'
                            }, {
                                'mountPath': '/opt/airflow/secrets/',
                                'name': 'airflow-secrets-mount',
                                'readOnly': True
                            }]
                        }],
                        'hostNetwork':
                        False,
                        'imagePullSecrets': [],
                        'volumes': [{
                            'name': 'foo'
                        }, {
                            'name': 'airflow-secrets-mount',
                            'secret': {
                                'secretName': 'airflow-test-secrets'
                            }
                        }]
                    }
                })
Exemple #11
0
class TestPodLauncher(unittest.TestCase):
    def setUp(self):
        self.mock_kube_client = mock.Mock()
        self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)

    def test_read_pod_logs_successfully_returns_logs(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        assert mock.sentinel.logs == logs

    def test_read_pod_logs_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'),
            mock.sentinel.logs,
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        assert mock.sentinel.logs == logs
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(
                _preload_content=False,
                container='base',
                follow=True,
                timestamps=False,
                name=mock.sentinel.metadata.name,
                namespace=mock.sentinel.metadata.namespace,
            ),
            mock.call(
                _preload_content=False,
                container='base',
                follow=True,
                timestamps=False,
                name=mock.sentinel.metadata.name,
                namespace=mock.sentinel.metadata.namespace,
            ),
        ])

    def test_read_pod_logs_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
        ]
        with pytest.raises(AirflowException):
            self.pod_launcher.read_pod_logs(mock.sentinel)

    def test_read_pod_logs_successfully_with_tail_lines(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel, tail_lines=100)
        assert mock.sentinel.logs == logs
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(
                _preload_content=False,
                container='base',
                follow=True,
                timestamps=False,
                name=mock.sentinel.metadata.name,
                namespace=mock.sentinel.metadata.namespace,
                tail_lines=100,
            ),
        ])

    def test_read_pod_logs_successfully_with_since_seconds(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel, since_seconds=2)
        assert mock.sentinel.logs == logs
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(
                _preload_content=False,
                container='base',
                follow=True,
                timestamps=False,
                name=mock.sentinel.metadata.name,
                namespace=mock.sentinel.metadata.namespace,
                since_seconds=2,
            ),
        ])

    def test_read_pod_events_successfully_returns_events(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.return_value = mock.sentinel.events
        events = self.pod_launcher.read_pod_events(mock.sentinel)
        assert mock.sentinel.events == events

    def test_read_pod_events_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.side_effect = [
            BaseHTTPError('Boom'),
            mock.sentinel.events,
        ]
        events = self.pod_launcher.read_pod_events(mock.sentinel)
        assert mock.sentinel.events == events
        self.mock_kube_client.list_namespaced_event.assert_has_calls([
            mock.call(
                namespace=mock.sentinel.metadata.namespace,
                field_selector=
                f"involvedObject.name={mock.sentinel.metadata.name}",
            ),
            mock.call(
                namespace=mock.sentinel.metadata.namespace,
                field_selector=
                f"involvedObject.name={mock.sentinel.metadata.name}",
            ),
        ])

    def test_read_pod_events_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
        ]
        with pytest.raises(AirflowException):
            self.pod_launcher.read_pod_events(mock.sentinel)

    def test_read_pod_returns_logs(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.return_value = mock.sentinel.pod_info
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        assert mock.sentinel.pod_info == pod_info

    def test_read_pod_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'),
            mock.sentinel.pod_info,
        ]
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        assert mock.sentinel.pod_info == pod_info
        self.mock_kube_client.read_namespaced_pod.assert_has_calls([
            mock.call(mock.sentinel.metadata.name,
                      mock.sentinel.metadata.namespace),
            mock.call(mock.sentinel.metadata.name,
                      mock.sentinel.metadata.namespace),
        ])

    def test_read_pod_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
        ]
        with pytest.raises(AirflowException):
            self.pod_launcher.read_pod(mock.sentinel)

    def test_parse_log_line(self):
        timestamp, message = self.pod_launcher.parse_log_line(
            '2020-10-08T14:16:17.793417674Z Valid message\n')

        assert timestamp == '2020-10-08T14:16:17.793417674Z'
        assert message == 'Valid message'

        with pytest.raises(Exception):
            self.pod_launcher.parse_log_line(
                '2020-10-08T14:16:17.793417674ZInvalidmessage\n')
Exemple #12
0
class TestPodLauncher(unittest.TestCase):
    def setUp(self):
        self.mock_kube_client = mock.Mock()
        self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)

    def test_read_pod_logs_successfully_returns_logs(self):
        self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        self.assertEqual(mock.sentinel.logs, logs)

    def test_read_pod_logs_retries_successfully(self):
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        self.assertEqual(mock.sentinel.logs, logs)
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(_preload_content=False,
                      _request_timeout=600,
                      container='base',
                      follow=True,
                      name=mock.sentinel.name,
                      namespace=mock.sentinel.namespace,
                      tail_lines=10),
            mock.call(_preload_content=False,
                      _request_timeout=600,
                      container='base',
                      follow=True,
                      name=mock.sentinel.name,
                      namespace=mock.sentinel.namespace,
                      tail_lines=10)
        ])

    def test_read_pod_logs_retries_fails(self):
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs,
                          mock.sentinel)

    def test_read_pod_returns_logs(self):
        self.mock_kube_client.read_namespaced_pod.return_value = mock.sentinel.pod_info
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        self.assertEqual(mock.sentinel.pod_info, pod_info)

    def test_read_pod_retries_successfully(self):
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.pod_info
        ]
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        self.assertEqual(mock.sentinel.pod_info, pod_info)
        self.mock_kube_client.read_namespaced_pod.assert_has_calls([
            mock.call(mock.sentinel.name, mock.sentinel.namespace),
            mock.call(mock.sentinel.name, mock.sentinel.namespace)
        ])

    def test_read_pod_retries_fails(self):
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod,
                          mock.sentinel)
Exemple #13
0
class TestPodLauncher(unittest.TestCase):
    def setUp(self):
        self.mock_kube_client = mock.Mock()
        self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)

    def test_read_pod_logs_successfully_returns_logs(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        self.assertEqual(mock.sentinel.logs, logs)

    def test_read_pod_logs_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel)
        self.assertEqual(mock.sentinel.logs, logs)
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(_preload_content=False,
                      container='base',
                      follow=True,
                      timestamps=False,
                      name=mock.sentinel.metadata.name,
                      namespace=mock.sentinel.metadata.namespace),
            mock.call(_preload_content=False,
                      container='base',
                      follow=True,
                      timestamps=False,
                      name=mock.sentinel.metadata.name,
                      namespace=mock.sentinel.metadata.namespace)
        ])

    def test_read_pod_logs_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs,
                          mock.sentinel)

    def test_read_pod_logs_successfully_with_tail_lines(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel, tail_lines=100)
        self.assertEqual(mock.sentinel.logs, logs)
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(_preload_content=False,
                      container='base',
                      follow=True,
                      timestamps=False,
                      name=mock.sentinel.metadata.name,
                      namespace=mock.sentinel.metadata.namespace,
                      tail_lines=100),
        ])

    def test_read_pod_logs_successfully_with_since_seconds(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod_log.side_effect = [
            mock.sentinel.logs
        ]
        logs = self.pod_launcher.read_pod_logs(mock.sentinel, since_seconds=2)
        self.assertEqual(mock.sentinel.logs, logs)
        self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([
            mock.call(_preload_content=False,
                      container='base',
                      follow=True,
                      timestamps=False,
                      name=mock.sentinel.metadata.name,
                      namespace=mock.sentinel.metadata.namespace,
                      since_seconds=2),
        ])

    def test_read_pod_events_successfully_returns_events(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.return_value = mock.sentinel.events
        events = self.pod_launcher.read_pod_events(mock.sentinel)
        self.assertEqual(mock.sentinel.events, events)

    def test_read_pod_events_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.events
        ]
        events = self.pod_launcher.read_pod_events(mock.sentinel)
        self.assertEqual(mock.sentinel.events, events)
        self.mock_kube_client.list_namespaced_event.assert_has_calls([
            mock.call(namespace=mock.sentinel.metadata.namespace,
                      field_selector="involvedObject.name={}".format(
                          mock.sentinel.metadata.name)),
            mock.call(namespace=mock.sentinel.metadata.namespace,
                      field_selector="involvedObject.name={}".format(
                          mock.sentinel.metadata.name))
        ])

    def test_read_pod_events_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.list_namespaced_event.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod_events,
                          mock.sentinel)

    def test_read_pod_returns_logs(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.return_value = mock.sentinel.pod_info
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        self.assertEqual(mock.sentinel.pod_info, pod_info)

    def test_read_pod_retries_successfully(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'), mock.sentinel.pod_info
        ]
        pod_info = self.pod_launcher.read_pod(mock.sentinel)
        self.assertEqual(mock.sentinel.pod_info, pod_info)
        self.mock_kube_client.read_namespaced_pod.assert_has_calls([
            mock.call(mock.sentinel.metadata.name,
                      mock.sentinel.metadata.namespace),
            mock.call(mock.sentinel.metadata.name,
                      mock.sentinel.metadata.namespace)
        ])

    def test_read_pod_retries_fails(self):
        mock.sentinel.metadata = mock.MagicMock()
        self.mock_kube_client.read_namespaced_pod.side_effect = [
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom'),
            BaseHTTPError('Boom')
        ]
        self.assertRaises(AirflowException, self.pod_launcher.read_pod,
                          mock.sentinel)

    def test_parse_log_line(self):
        timestamp, message = \
            self.pod_launcher.parse_log_line('2020-10-08T14:16:17.793417674Z Valid message\n')

        self.assertEqual(timestamp, '2020-10-08T14:16:17.793417674Z')
        self.assertEqual(message, 'Valid message')

        self.assertRaises(
            Exception,
            self.pod_launcher.parse_log_line(
                '2020-10-08T14:16:17.793417674ZInvalid message\n'),
        )