def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue self.result_queue = result_queue self.namespace = self.kube_config.kube_namespace self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config) self.watcher_queue = multiprocessing.Queue() self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher()
def get_image_dag_info(self): client = self.kube_client or get_kube_client() launcher = PodLauncher(kube_client=client) pod = self.create_sync_pod() status, result = launcher.run_pod(pod, get_logs=False) logs = client.read_namespaced_pod_log( name=pod.name, namespace=pod.namespace, container='base', follow=True, _preload_content=False) launcher.delete_pod(pod) return status, logs.data, pod
def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue self.result_queue = result_queue self.namespace = self.kube_config.kube_namespace self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config) self.watcher_queue = multiprocessing.Queue() self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher()
class AirflowKubernetesScheduler(LoggingMixin): def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue self.result_queue = result_queue self.namespace = self.kube_config.kube_namespace self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) self.worker_configuration = WorkerConfiguration( kube_config=self.kube_config) self.watcher_queue = multiprocessing.Queue() self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher() def _make_kube_watcher(self): resource_version = KubeResourceVersion.get_current_resource_version() watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue, resource_version, self.worker_uuid) watcher.start() return watcher def _health_check_kube_watcher(self): if self.kube_watcher.is_alive(): pass else: self.log.error('Error while health checking kube watcher process. ' 'Process died for unknown reasons') self.kube_watcher = self._make_kube_watcher() def run_next(self, next_job): """ The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, and store relevant info in the current_jobs map so we can track the job's status """ self.log.info('Kubernetes job is %s', str(next_job)) key, command, kube_executor_config = next_job dag_id, task_id, execution_date, try_number = key self.log.debug("Kubernetes running for command %s", command) self.log.debug("Kubernetes launching image %s", self.kube_config.kube_image) pod = self.worker_configuration.make_pod( namespace=self.namespace, worker_uuid=self.worker_uuid, pod_id=self._create_pod_id(dag_id, task_id), dag_id=self._make_safe_label_value(dag_id), task_id=self._make_safe_label_value(task_id), try_number=try_number, execution_date=self._datetime_to_label_safe_datestring( execution_date), airflow_command=command, kube_executor_config=kube_executor_config) # the watcher will monitor pods, so we do not block. self.launcher.run_pod_async(pod) self.log.debug("Kubernetes Job created!") def delete_pod(self, pod_id): if self.kube_config.delete_worker_pods: try: self.kube_client.delete_namespaced_pod( pod_id, self.namespace, body=client.V1DeleteOptions()) except ApiException as e: # If the pod is already deleted if e.status != 404: raise def sync(self): """ The sync function checks the status of all currently running kubernetes jobs. If a job is completed, it's status is placed in the result queue to be sent back to the scheduler. :return: """ self._health_check_kube_watcher() while not self.watcher_queue.empty(): self.process_watcher_task() def process_watcher_task(self): pod_id, state, labels, resource_version = self.watcher_queue.get() self.log.info( 'Attempting to finish pod; pod_id: %s; state: %s; labels: %s', pod_id, state, labels) key = self._labels_to_key(labels=labels) if key: self.log.debug('finishing job %s - %s (%s)', key, state, pod_id) self.result_queue.put((key, state, pod_id, resource_version)) @staticmethod def _strip_unsafe_kubernetes_special_chars(string): """ Kubernetes only supports lowercase alphanumeric characters and "-" and "." in the pod name However, there are special rules about how "-" and "." can be used so let's only keep alphanumeric chars see here for detail: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/ :param string: The requested Pod name :return: ``str`` Pod name stripped of any unsafe characters """ return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum()) @staticmethod def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): r""" Kubernetes pod names must be <= 253 chars and must pass the following regex for validation "^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$" :param safe_dag_id: a dag_id with only alphanumeric characters :param safe_task_id: a task_id with only alphanumeric characters :param random_uuid: a uuid :return: ``str`` valid Pod name of appropriate length """ MAX_POD_ID_LEN = 253 safe_key = safe_dag_id + safe_task_id safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id @staticmethod def _make_safe_label_value(string): """ Valid label values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between. If the label value is then greater than 63 chars once made safe, or differs in any way from the original value sent to this function, then we need to truncate to 53chars, and append it with a unique hash. """ MAX_LABEL_LEN = 63 safe_label = re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$', '', string) if len(safe_label) > MAX_LABEL_LEN or string != safe_label: safe_hash = hashlib.md5(string.encode()).hexdigest()[:9] safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash return safe_label @staticmethod def _create_pod_id(dag_id, task_id): safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( task_id) safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( uuid4().hex) return AirflowKubernetesScheduler._make_safe_pod_id( safe_dag_id, safe_task_id, safe_uuid) @staticmethod def _label_safe_datestring_to_datetime(string): """ Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not "_", let's replace ":" with "_" :param string: str :return: datetime.datetime object """ return parser.parse(string.replace('_plus_', '+').replace("_", ":")) @staticmethod def _datetime_to_label_safe_datestring(datetime_obj): """ Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but not "_" let's replace ":" with "_" :param datetime_obj: datetime.datetime object :return: ISO-like string representing the datetime """ return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') def _labels_to_key(self, labels): try_num = 1 try: try_num = int(labels.get('try_number', '1')) except ValueError: self.log.warn("could not get try_number as an int: %s", labels.get('try_number', '1')) try: dag_id = labels['dag_id'] task_id = labels['task_id'] ex_time = self._label_safe_datestring_to_datetime( labels['execution_date']) except Exception as e: self.log.warn( 'Error while retrieving labels; labels: %s; exception: %s', labels, e) return None with create_session() as session: tasks = (session.query(TaskInstance).filter_by( execution_date=ex_time).all()) self.log.info('Checking %s task instances.', len(tasks)) for task in tasks: if (self._make_safe_label_value(task.dag_id) == dag_id and self._make_safe_label_value(task.task_id) == task_id and task.execution_date == ex_time): self.log.info( 'Found matching task %s-%s (%s) with current state of %s', task.dag_id, task.task_id, task.execution_date, task.state) dag_id = task.dag_id task_id = task.task_id return (dag_id, task_id, ex_time, try_num) self.log.warn( 'Failed to find and match task details to a pod; labels: %s', labels) return None
class AirflowKubernetesScheduler(LoggingMixin): def __init__(self, kube_config, task_queue, result_queue, session, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue self.result_queue = result_queue self.namespace = self.kube_config.kube_namespace self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config) self.watcher_queue = multiprocessing.Queue() self._session = session self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher() def _make_kube_watcher(self): resource_version = KubeResourceVersion.get_current_resource_version(self._session) watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue, resource_version, self.worker_uuid) watcher.start() return watcher def _health_check_kube_watcher(self): if self.kube_watcher.is_alive(): pass else: self.log.error( 'Error while health checking kube watcher process. ' 'Process died for unknown reasons') self.kube_watcher = self._make_kube_watcher() def run_next(self, next_job): """ The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, and store relevant info in the current_jobs map so we can track the job's status """ self.log.info('Kubernetes job is %s', str(next_job)) key, command, kube_executor_config = next_job dag_id, task_id, execution_date, try_number = key self.log.debug("Kubernetes running for command %s", command) self.log.debug("Kubernetes launching image %s", self.kube_config.kube_image) pod = self.worker_configuration.make_pod( namespace=self.namespace, worker_uuid=self.worker_uuid, pod_id=self._create_pod_id(dag_id, task_id), dag_id=dag_id, task_id=task_id, execution_date=self._datetime_to_label_safe_datestring(execution_date), airflow_command=command, kube_executor_config=kube_executor_config ) # the watcher will monitor pods, so we do not block. self.launcher.run_pod_async(pod) self.log.debug("Kubernetes Job created!") def delete_pod(self, pod_id): if self.kube_config.delete_worker_pods: try: self.kube_client.delete_namespaced_pod( pod_id, self.namespace, body=client.V1DeleteOptions()) except ApiException as e: # If the pod is already deleted if e.status != 404: raise def sync(self): """ The sync function checks the status of all currently running kubernetes jobs. If a job is completed, it's status is placed in the result queue to be sent back to the scheduler. :return: """ self._health_check_kube_watcher() while not self.watcher_queue.empty(): self.process_watcher_task() def process_watcher_task(self): pod_id, state, labels, resource_version = self.watcher_queue.get() self.log.info( 'Attempting to finish pod; pod_id: %s; state: %s; labels: %s', pod_id, state, labels ) key = self._labels_to_key(labels=labels) if key: self.log.debug('finishing job %s - %s (%s)', key, state, pod_id) self.result_queue.put((key, state, pod_id, resource_version)) @staticmethod def _strip_unsafe_kubernetes_special_chars(string): """ Kubernetes only supports lowercase alphanumeric characters and "-" and "." in the pod name However, there are special rules about how "-" and "." can be used so let's only keep alphanumeric chars see here for detail: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/ :param string: The requested Pod name :return: ``str`` Pod name stripped of any unsafe characters """ return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum()) @staticmethod def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): """ Kubernetes pod names must be <= 253 chars and must pass the following regex for validation "^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$" :param safe_dag_id: a dag_id with only alphanumeric characters :param safe_task_id: a task_id with only alphanumeric characters :param random_uuid: a uuid :return: ``str`` valid Pod name of appropriate length """ MAX_POD_ID_LEN = 253 safe_key = safe_dag_id + safe_task_id safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id @staticmethod def _create_pod_id(dag_id, task_id): safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( task_id) safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( uuid4().hex) return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid) @staticmethod def _label_safe_datestring_to_datetime(string): """ Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not "_", let's replace ":" with "_" :param string: str :return: datetime.datetime object """ return parser.parse(string.replace('_plus_', '+').replace("_", ":")) @staticmethod def _datetime_to_label_safe_datestring(datetime_obj): """ Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but not "_" let's replace ":" with "_" :param datetime_obj: datetime.datetime object :return: ISO-like string representing the datetime """ return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') def _labels_to_key(self, labels): try: return ( labels['dag_id'], labels['task_id'], self._label_safe_datestring_to_datetime(labels['execution_date']), labels['try_number']) except Exception as e: self.log.warn( 'Error while converting labels to key; labels: %s; exception: %s', labels, e ) return None
class AirflowKubernetesScheduler(LoggingMixin): def __init__(self, kube_config, task_queue, result_queue, session, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue self.result_queue = result_queue self.namespace = self.kube_config.kube_namespace self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config) self.watcher_queue = multiprocessing.Queue() self._session = session self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher() def _make_kube_watcher(self): resource_version = KubeResourceVersion.get_current_resource_version(self._session) watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue, resource_version, self.worker_uuid) watcher.start() return watcher def _health_check_kube_watcher(self): if self.kube_watcher.is_alive(): pass else: self.log.error( 'Error while health checking kube watcher process. ' 'Process died for unknown reasons') self.kube_watcher = self._make_kube_watcher() def run_next(self, next_job): """ The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, and store relevant info in the current_jobs map so we can track the job's status """ self.log.info('Kubernetes job is %s', str(next_job)) key, command, kube_executor_config = next_job dag_id, task_id, execution_date = key self.log.debug("Kubernetes running for command %s", command) self.log.debug("Kubernetes launching image %s", self.kube_config.kube_image) pod = self.worker_configuration.make_pod( namespace=self.namespace, worker_uuid=self.worker_uuid, pod_id=self._create_pod_id(dag_id, task_id), dag_id=dag_id, task_id=task_id, execution_date=self._datetime_to_label_safe_datestring(execution_date), airflow_command=command, kube_executor_config=kube_executor_config ) # the watcher will monitor pods, so we do not block. self.launcher.run_pod_async(pod) self.log.debug("Kubernetes Job created!") def delete_pod(self, pod_id): if self.kube_config.delete_worker_pods: try: self.kube_client.delete_namespaced_pod( pod_id, self.namespace, body=client.V1DeleteOptions()) except ApiException as e: # If the pod is already deleted if e.status != 404: raise def sync(self): """ The sync function checks the status of all currently running kubernetes jobs. If a job is completed, it's status is placed in the result queue to be sent back to the scheduler. :return: """ self._health_check_kube_watcher() while not self.watcher_queue.empty(): self.process_watcher_task() def process_watcher_task(self): pod_id, state, labels, resource_version = self.watcher_queue.get() self.log.info( 'Attempting to finish pod; pod_id: %s; state: %s; labels: %s', pod_id, state, labels ) key = self._labels_to_key(labels=labels) if key: self.log.debug('finishing job %s - %s (%s)', key, state, pod_id) self.result_queue.put((key, state, pod_id, resource_version)) @staticmethod def _strip_unsafe_kubernetes_special_chars(string): """ Kubernetes only supports lowercase alphanumeric characters and "-" and "." in the pod name However, there are special rules about how "-" and "." can be used so let's only keep alphanumeric chars see here for detail: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/ :param string: The requested Pod name :return: ``str`` Pod name stripped of any unsafe characters """ return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum()) @staticmethod def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): """ Kubernetes pod names must be <= 253 chars and must pass the following regex for validation "^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$" :param safe_dag_id: a dag_id with only alphanumeric characters :param safe_task_id: a task_id with only alphanumeric characters :param random_uuid: a uuid :return: ``str`` valid Pod name of appropriate length """ MAX_POD_ID_LEN = 253 safe_key = safe_dag_id + safe_task_id safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id @staticmethod def _create_pod_id(dag_id, task_id): safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( task_id) safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( uuid4().hex) return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid) @staticmethod def _label_safe_datestring_to_datetime(string): """ Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not "_", let's replace ":" with "_" :param string: str :return: datetime.datetime object """ return parser.parse(string.replace('_plus_', '+').replace("_", ":")) @staticmethod def _datetime_to_label_safe_datestring(datetime_obj): """ Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but not "_" let's replace ":" with "_" :param datetime_obj: datetime.datetime object :return: ISO-like string representing the datetime """ return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') def _labels_to_key(self, labels): try: return ( labels['dag_id'], labels['task_id'], self._label_safe_datestring_to_datetime(labels['execution_date'])) except Exception as e: self.log.warn( 'Error while converting labels to key; labels: %s; exception: %s', labels, e ) return None
class AirflowKubernetesScheduler(LoggingMixin): def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid): self.log.debug("Creating Kubernetes executor") self.kube_config = kube_config self.task_queue = task_queue self.result_queue = result_queue self.namespace = self.kube_config.kube_namespace self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config) self.watcher_queue = SynchronizedQueue() self.worker_uuid = worker_uuid self.kube_watcher = self._make_kube_watcher() def _make_kube_watcher(self): resource_version = KubeResourceVersion.get_current_resource_version() watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue, resource_version, self.worker_uuid) watcher.start() return watcher def _health_check_kube_watcher(self): if self.kube_watcher.is_alive(): pass else: self.log.error( 'Error while health checking kube watcher process. ' 'Process died for unknown reasons') self.kube_watcher = self._make_kube_watcher() def run_next(self, next_job): """ The run_next command will check the task_queue for any un-run jobs. It will then create a unique job-id, launch that job in the cluster, and store relevant info in the current_jobs map so we can track the job's status """ self.log.info('Kubernetes job is %s', str(next_job)) key, command, kube_executor_config = next_job dag_id, task_id, execution_date, try_number = key self.log.debug("Kubernetes running for command %s", command) self.log.debug("Kubernetes launching image %s", self.kube_config.kube_image) pod = self.worker_configuration.make_pod( namespace=self.namespace, worker_uuid=self.worker_uuid, pod_id=self._create_pod_id(dag_id, task_id), dag_id=self._make_safe_label_value(dag_id), task_id=self._make_safe_label_value(task_id), try_number=try_number, execution_date=self._datetime_to_label_safe_datestring(execution_date), airflow_command=command, kube_executor_config=kube_executor_config ) # the watcher will monitor pods, so we do not block. self.launcher.run_pod_async(pod) self.log.debug("Kubernetes Job created!") def delete_pod(self, pod_id): if self.kube_config.delete_worker_pods: try: self.kube_client.delete_namespaced_pod( pod_id, self.namespace, body=client.V1DeleteOptions()) except ApiException as e: # If the pod is already deleted if e.status != 404: raise def sync(self): """ The sync function checks the status of all currently running kubernetes jobs. If a job is completed, it's status is placed in the result queue to be sent back to the scheduler. :return: """ self._health_check_kube_watcher() while not self.watcher_queue.empty(): self.process_watcher_task() def process_watcher_task(self): pod_id, state, labels, resource_version = self.watcher_queue.get() self.log.info( 'Attempting to finish pod; pod_id: %s; state: %s; labels: %s', pod_id, state, labels ) key = self._labels_to_key(labels=labels) if key: self.log.debug('finishing job %s - %s (%s)', key, state, pod_id) self.result_queue.put((key, state, pod_id, resource_version)) @staticmethod def _strip_unsafe_kubernetes_special_chars(string): """ Kubernetes only supports lowercase alphanumeric characters and "-" and "." in the pod name However, there are special rules about how "-" and "." can be used so let's only keep alphanumeric chars see here for detail: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/ :param string: The requested Pod name :return: ``str`` Pod name stripped of any unsafe characters """ return ''.join(ch.lower() for ind, ch in enumerate(string) if ch.isalnum()) @staticmethod def _make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid): r""" Kubernetes pod names must be <= 253 chars and must pass the following regex for validation "^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$" :param safe_dag_id: a dag_id with only alphanumeric characters :param safe_task_id: a task_id with only alphanumeric characters :param random_uuid: a uuid :return: ``str`` valid Pod name of appropriate length """ MAX_POD_ID_LEN = 253 safe_key = safe_dag_id + safe_task_id safe_pod_id = safe_key[:MAX_POD_ID_LEN - len(safe_uuid) - 1] + "-" + safe_uuid return safe_pod_id @staticmethod def _make_safe_label_value(string): """ Valid label values must be 63 characters or less and must be empty or begin and end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_), dots (.), and alphanumerics between. If the label value is then greater than 63 chars once made safe, or differs in any way from the original value sent to this function, then we need to truncate to 53chars, and append it with a unique hash. """ MAX_LABEL_LEN = 63 safe_label = re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$', '', string) if len(safe_label) > MAX_LABEL_LEN or string != safe_label: safe_hash = hashlib.md5(string.encode()).hexdigest()[:9] safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash return safe_label @staticmethod def _create_pod_id(dag_id, task_id): safe_dag_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( task_id) safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( uuid4().hex) return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id, safe_uuid) @staticmethod def _label_safe_datestring_to_datetime(string): """ Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not "_", let's replace ":" with "_" :param string: str :return: datetime.datetime object """ return parser.parse(string.replace('_plus_', '+').replace("_", ":")) @staticmethod def _datetime_to_label_safe_datestring(datetime_obj): """ Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but not "_" let's replace ":" with "_" :param datetime_obj: datetime.datetime object :return: ISO-like string representing the datetime """ return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_') def _labels_to_key(self, labels): try_num = 1 try: try_num = int(labels.get('try_number', '1')) except ValueError: self.log.warn("could not get try_number as an int: %s", labels.get('try_number', '1')) try: dag_id = labels['dag_id'] task_id = labels['task_id'] ex_time = self._label_safe_datestring_to_datetime(labels['execution_date']) except Exception as e: self.log.warn( 'Error while retrieving labels; labels: %s; exception: %s', labels, e ) return None with create_session() as session: tasks = ( session .query(TaskInstance) .filter_by(execution_date=ex_time).all() ) self.log.info( 'Checking %s task instances.', len(tasks) ) for task in tasks: if ( self._make_safe_label_value(task.dag_id) == dag_id and self._make_safe_label_value(task.task_id) == task_id and task.execution_date == ex_time ): self.log.info( 'Found matching task %s-%s (%s) with current state of %s', task.dag_id, task.task_id, task.execution_date, task.state ) dag_id = task.dag_id task_id = task.task_id return (dag_id, task_id, ex_time, try_num) self.log.warn( 'Failed to find and match task details to a pod; labels: %s', labels ) return None
def setUp(self): self.mock_kube_client = mock.Mock() self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client)
class TestPodLauncher(unittest.TestCase): def setUp(self): self.mock_kube_client = mock.Mock() self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) def test_read_pod_logs_successfully_returns_logs(self): self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs logs = self.pod_launcher.read_pod_logs(mock.sentinel) self.assertEqual(mock.sentinel.logs, logs) def test_read_pod_logs_retries_successfully(self): self.mock_kube_client.read_namespaced_pod_log.side_effect = [ BaseHTTPError('Boom'), mock.sentinel.logs ] logs = self.pod_launcher.read_pod_logs(mock.sentinel) self.assertEqual(mock.sentinel.logs, logs) self.mock_kube_client.read_namespaced_pod_log.assert_has_calls([ mock.call(_preload_content=False, container='base', follow=True, name=mock.sentinel.name, namespace=mock.sentinel.namespace, tail_lines=10), mock.call(_preload_content=False, container='base', follow=True, name=mock.sentinel.name, namespace=mock.sentinel.namespace, tail_lines=10) ]) def test_read_pod_logs_retries_fails(self): self.mock_kube_client.read_namespaced_pod_log.side_effect = [ BaseHTTPError('Boom'), BaseHTTPError('Boom'), BaseHTTPError('Boom') ] self.assertRaises(AirflowException, self.pod_launcher.read_pod_logs, mock.sentinel) def test_read_pod_returns_logs(self): self.mock_kube_client.read_namespaced_pod.return_value = mock.sentinel.pod_info pod_info = self.pod_launcher.read_pod(mock.sentinel) self.assertEqual(mock.sentinel.pod_info, pod_info) def test_read_pod_retries_successfully(self): self.mock_kube_client.read_namespaced_pod.side_effect = [ BaseHTTPError('Boom'), mock.sentinel.pod_info ] pod_info = self.pod_launcher.read_pod(mock.sentinel) self.assertEqual(mock.sentinel.pod_info, pod_info) self.mock_kube_client.read_namespaced_pod.assert_has_calls([ mock.call(mock.sentinel.name, mock.sentinel.namespace), mock.call(mock.sentinel.name, mock.sentinel.namespace) ]) def test_read_pod_retries_fails(self): self.mock_kube_client.read_namespaced_pod.side_effect = [ BaseHTTPError('Boom'), BaseHTTPError('Boom'), BaseHTTPError('Boom') ] self.assertRaises(AirflowException, self.pod_launcher.read_pod, mock.sentinel)