class TestKubernetesPodOperatorSystem(unittest.TestCase):
    def get_current_task_name(self):
        # reverse test name to make pod name unique (it has limited length)
        return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1]

    def setUp(self):
        self.maxDiff = None  # pylint: disable=invalid-name
        self.api_client = ApiClient()
        self.expected_pod = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'namespace': 'default',
                'name': ANY,
                'annotations': {},
                'labels': {
                    'foo': 'bar',
                    'kubernetes_pod_operator': 'True',
                    'airflow_version': airflow_version.replace('+', '-'),
                    'execution_date': '2016-01-01T0100000100-a2f50a31f',
                    'dag_id': 'dag',
                    'task_id': ANY,
                    'try_number': '1'
                },
            },
            'spec': {
                'affinity': {},
                'containers': [{
                    'image': 'ubuntu:16.04',
                    'args': ["echo 10"],
                    'command': ["bash", "-cx"],
                    'env': [],
                    'envFrom': [],
                    'resources': {},
                    'name': 'base',
                    'ports': [],
                    'volumeMounts': [],
                }],
                'hostNetwork':
                False,
                'imagePullSecrets': [],
                'initContainers': [],
                'nodeSelector': {},
                'restartPolicy':
                'Never',
                'securityContext': {},
                'serviceAccountName':
                'default',
                'tolerations': [],
                'volumes': [],
            }
        }

    def tearDown(self) -> None:
        client = kube_client.get_kube_client(in_cluster=False)
        client.delete_collection_namespaced_pod(namespace="default")
        import time
        time.sleep(1)

    def test_do_xcom_push_defaults_false(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = os.path.expanduser('~/.kube/config')
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            config_file=new_config_path,
        )
        self.assertFalse(k.do_xcom_push)

    def test_config_path_move(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = os.path.expanduser('~/.kube/config')
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test1",
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            config_file=new_config_path,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod, actual_pod)

    def test_working_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_delete_operator_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            is_delete_operator_pod=True,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_hostnetwork(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_dnspolicy(self):
        dns_policy = "ClusterFirstWithHostNet"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
            dnspolicy=dns_policy)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.expected_pod['spec']['dnsPolicy'] = dns_policy
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_schedulername(self):
        scheduler_name = "default-scheduler"
        k = KubernetesPodOperator(
            namespace="default",
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            schedulername=scheduler_name)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['schedulerName'] = scheduler_name
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_node_selectors(self):
        node_selectors = {'beta.kubernetes.io/os': 'linux'}
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            node_selectors=node_selectors,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['nodeSelector'] = node_selectors
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_resources(self):
        resources = k8s.V1ResourceRequirements(requests={
            'memory': '64Mi',
            'cpu': '250m',
            'ephemeral-storage': '1Gi'
        },
                                               limits={
                                                   'memory': '64Mi',
                                                   'cpu': 0.25,
                                                   'nvidia.com/gpu': None,
                                                   'ephemeral-storage': '2Gi'
                                               })
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            resources=resources,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['resources'] = {
            'requests': {
                'memory': '64Mi',
                'cpu': '250m',
                'ephemeral-storage': '1Gi'
            },
            'limits': {
                'memory': '64Mi',
                'cpu': 0.25,
                'nvidia.com/gpu': None,
                'ephemeral-storage': '2Gi'
            }
        }
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_affinity(self):
        affinity = {
            'nodeAffinity': {
                'requiredDuringSchedulingIgnoredDuringExecution': {
                    'nodeSelectorTerms': [{
                        'matchExpressions': [{
                            'key': 'beta.kubernetes.io/os',
                            'operator': 'In',
                            'values': ['linux']
                        }]
                    }]
                }
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            affinity=affinity,
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['affinity'] = affinity
        self.assertEqual(self.expected_pod, actual_pod)

    def test_port(self):
        port = k8s.V1ContainerPort(
            name='http',
            container_port=80,
        )

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            ports=[port],
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['ports'] = [{
            'name':
            'http',
            'containerPort':
            80
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    def test_volume_mount(self):
        with mock.patch.object(PodLauncher, 'log') as mock_logger:
            volume_mount = k8s.V1VolumeMount(name='test-volume',
                                             mount_path='/tmp/test_volume',
                                             sub_path=None,
                                             read_only=False)

            volume = k8s.V1Volume(
                name='test-volume',
                persistent_volume_claim=k8s.
                V1PersistentVolumeClaimVolumeSource(claim_name='test-volume'))

            args = [
                "echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
                "&& cat /tmp/test_volume/test.txt"
            ]
            k = KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=args,
                labels={"foo": "bar"},
                volume_mounts=[volume_mount],
                volumes=[volume],
                name="test-" + str(random.randint(0, 1000000)),
                task_id="task" + self.get_current_task_name(),
                in_cluster=False,
                do_xcom_push=False,
            )
            context = create_context(k)
            k.execute(context=context)
            mock_logger.info.assert_any_call(b"retrieved from mount\n")
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = args
            self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{
                'name':
                'test-volume',
                'mountPath':
                '/tmp/test_volume',
                'readOnly':
                False
            }]
            self.expected_pod['spec']['volumes'] = [{
                'name': 'test-volume',
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }]
            self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 0,
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_non_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_fs_group(self):
        security_context = {
            'securityContext': {
                'fsGroup': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-fs-group",
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_image(self):
        bad_image_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image=bad_image_name,
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
        )
        with self.assertRaises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'image'] = bad_image_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_service_account(self):
        bad_service_account_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
            service_account_name=bad_service_account_name,
        )
        with self.assertRaises(ApiException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec'][
                'serviceAccountName'] = bad_service_account_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_failure(self):
        """
        Tests that the task fails when a pod reports a failure
        """
        bad_internal_command = ["foobar 10 "]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=bad_internal_command,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        with self.assertRaises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'args'] = bad_internal_command
            self.assertEqual(self.expected_pod, actual_pod)

    def test_xcom_push(self):
        return_value = '{"foo": "bar"\n, "buzz": 2}'
        args = ['echo \'{}\' > /airflow/xcom/return.json'.format(return_value)]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=args,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=True,
        )
        context = create_context(k)
        self.assertEqual(k.execute(context), json.loads(return_value))
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
        volume_mount = self.api_client.sanitize_for_serialization(
            PodDefaults.VOLUME_MOUNT)
        container = self.api_client.sanitize_for_serialization(
            PodDefaults.SIDECAR_CONTAINER)
        self.expected_pod['spec']['containers'][0]['args'] = args
        self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(
            0, volume_mount)  # noqa
        self.expected_pod['spec']['volumes'].insert(0, volume)
        self.expected_pod['spec']['containers'].append(container)
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
        # GIVEN
        from airflow.utils.state import State

        configmap_name = "test-config-map"
        env_from = [
            k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                name=configmap_name))
        ]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            env_from=env_from)
        # THEN
        mock_monitor.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        self.assertEqual(
            mock_start.call_args[0][0].spec.containers[0].env_from, env_from)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
        # GIVEN
        from airflow.utils.state import State
        secret_ref = 'secret_name'
        secrets = [Secret('env', None, secret_ref)]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            secrets=secrets,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN
        monitor_mock.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        self.assertEqual(
            start_mock.call_args[0][0].spec.containers[0].env_from, [
                k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                    name=secret_ref))
            ])

    def test_env_vars(self):
        # WHEN
        env_vars = [
            k8s.V1EnvVar(name="ENV1", value="val1"),
            k8s.V1EnvVar(name="ENV2", value="val2"),
            k8s.V1EnvVar(name="ENV3",
                         value_from=k8s.V1EnvVarSource(
                             field_ref=k8s.V1ObjectFieldSelector(
                                 field_path="status.podIP"))),
        ]

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            env_vars=env_vars,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )

        context = create_context(k)
        k.execute(context)

        # THEN
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['env'] = [{
            'name': 'ENV1',
            'value': 'val1'
        }, {
            'name': 'ENV2',
            'value': 'val2'
        }, {
            'name': 'ENV3',
            'valueFrom': {
                'fieldRef': {
                    'fieldPath': 'status.podIP'
                }
            }
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_template_file_system(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        k = KubernetesPodOperator(task_id="task" +
                                  self.get_current_task_name(),
                                  in_cluster=False,
                                  pod_template_file=fixture,
                                  do_xcom_push=True)

        context = create_context(k)
        result = k.execute(context)
        self.assertIsNotNone(result)
        self.assertDictEqual(result, {"hello": "world"})

    def test_pod_template_file_with_overrides_system(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            labels={
                "foo": "bar",
                "fizz": "buzz"
            },
            env_vars=[k8s.V1EnvVar(name="env_name", value="value")],
            in_cluster=False,
            pod_template_file=fixture,
            do_xcom_push=True)

        context = create_context(k)
        result = k.execute(context)
        self.assertIsNotNone(result)
        self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
        self.assertEqual(k.pod.spec.containers[0].env,
                         [k8s.V1EnvVar(name="env_name", value="value")])
        self.assertDictEqual(result, {"hello": "world"})

    def test_init_container(self):
        # GIVEN
        volume_mounts = [
            k8s.V1VolumeMount(mount_path='/etc/foo',
                              name='test-volume',
                              sub_path=None,
                              read_only=True)
        ]

        init_environments = [
            k8s.V1EnvVar(name='key1', value='value1'),
            k8s.V1EnvVar(name='key2', value='value2')
        ]

        init_container = k8s.V1Container(name="init-container",
                                         image="ubuntu:16.04",
                                         env=init_environments,
                                         volume_mounts=volume_mounts,
                                         command=["bash", "-cx"],
                                         args=["echo 10"])

        volume = k8s.V1Volume(
            name='test-volume',
            persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
                claim_name='test-volume'))
        expected_init_container = {
            'name':
            'init-container',
            'image':
            'ubuntu:16.04',
            'command': ['bash', '-cx'],
            'args': ['echo 10'],
            'env': [{
                'name': 'key1',
                'value': 'value1'
            }, {
                'name': 'key2',
                'value': 'value2'
            }],
            'volumeMounts': [{
                'mountPath': '/etc/foo',
                'name': 'test-volume',
                'readOnly': True
            }],
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            volumes=[volume],
            init_containers=[init_container],
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['initContainers'] = [expected_init_container]
        self.expected_pod['spec']['volumes'] = [{
            'name': 'test-volume',
            'persistentVolumeClaim': {
                'claimName': 'test-volume'
            }
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_template_file(
            self,
            mock_client,
            monitor_mock,
            start_mock  # pylint: disable=unused-argument
    ):
        from airflow.utils.state import State
        path = sys.path[0] + '/tests/kubernetes/pod.yaml'
        k = KubernetesPodOperator(task_id="task" +
                                  self.get_current_task_name(),
                                  pod_template_file=path,
                                  do_xcom_push=True)

        monitor_mock.return_value = (State.SUCCESS, None)
        context = create_context(k)
        with self.assertLogs(k.log, level=logging.DEBUG) as cm:
            k.execute(context)
            expected_line = textwrap.dedent("""\
            DEBUG:airflow.task.operators:Starting pod:
            api_version: v1
            kind: Pod
            metadata:
              annotations: {}
              cluster_name: null
              creation_timestamp: null
              deletion_grace_period_seconds: null\
            """).strip()
            self.assertTrue(
                any(line.startswith(expected_line) for line in cm.output))

        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        expected_dict = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'annotations': {},
                'labels': {},
                'name': 'memory-demo',
                'namespace': 'mem-example'
            },
            'spec': {
                'affinity': {},
                'containers': [{
                    'args':
                    ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
                    'command': ['stress'],
                    'env': [],
                    'envFrom': [],
                    'image':
                    'apache/airflow:stress-2020.07.10-1.0.4',
                    'name':
                    'base',
                    'ports': [],
                    'resources': {
                        'limits': {
                            'memory': '200Mi'
                        },
                        'requests': {
                            'memory': '100Mi'
                        }
                    },
                    'volumeMounts': [{
                        'mountPath': '/airflow/xcom',
                        'name': 'xcom'
                    }]
                }, {
                    'command': [
                        'sh', '-c', 'trap "exit 0" INT; while true; do sleep '
                        '30; done;'
                    ],
                    'image':
                    'alpine',
                    'name':
                    'airflow-xcom-sidecar',
                    'resources': {
                        'requests': {
                            'cpu': '1m'
                        }
                    },
                    'volumeMounts': [{
                        'mountPath': '/airflow/xcom',
                        'name': 'xcom'
                    }]
                }],
                'hostNetwork':
                False,
                'imagePullSecrets': [],
                'initContainers': [],
                'nodeSelector': {},
                'restartPolicy':
                'Never',
                'securityContext': {},
                'serviceAccountName':
                'default',
                'tolerations': [],
                'volumes': [{
                    'emptyDir': {},
                    'name': 'xcom'
                }]
            }
        }
        self.assertEqual(expected_dict, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_priority_class_name(
            self,
            mock_client,
            monitor_mock,
            start_mock  # pylint: disable=unused-argument
    ):
        """Test ability to assign priorityClassName to pod

        """
        from airflow.utils.state import State

        priority_class_name = "medium-test"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            priority_class_name=priority_class_name,
        )

        monitor_mock.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['priorityClassName'] = priority_class_name
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_name(self):
        pod_name_too_long = "a" * 221
        with self.assertRaises(AirflowException):
            KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=["echo 10"],
                labels={"foo": "bar"},
                name=pod_name_too_long,
                task_id="task" + self.get_current_task_name(),
                in_cluster=False,
                do_xcom_push=False,
            )

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    def test_on_kill(self, monitor_mock):  # pylint: disable=unused-argument
        from airflow.utils.state import State
        client = kube_client.get_kube_client(in_cluster=False)
        name = "test"
        namespace = "default"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["sleep 1000"],
            labels={"foo": "bar"},
            name="test",
            task_id=name,
            in_cluster=False,
            do_xcom_push=False,
            termination_grace_period=0,
        )
        context = create_context(k)
        monitor_mock.return_value = (State.SUCCESS, None)
        k.execute(context)
        name = k.pod.metadata.name
        pod = client.read_namespaced_pod(name=name, namespace=namespace)
        self.assertEqual(pod.status.phase, "Running")
        k.on_kill()
        with self.assertRaises(ApiException):
            pod = client.read_namespaced_pod(name=name, namespace=namespace)
Esempio n. 2
0
class TestKubernetesWorkerConfiguration(unittest.TestCase):
    """
    Tests that if dags_volume_subpath/logs_volume_subpath configuration
    options are passed to worker pod config
    """

    affinity_config = {
        'podAntiAffinity': {
            'requiredDuringSchedulingIgnoredDuringExecution': [{
                'topologyKey': 'kubernetes.io/hostname',
                'labelSelector': {
                    'matchExpressions': [{
                        'key': 'app',
                        'operator': 'In',
                        'values': ['airflow']
                    }]
                }
            }]
        }
    }

    tolerations_config = [{
        'key': 'dedicated',
        'operator': 'Equal',
        'value': 'airflow'
    }, {
        'key': 'prod',
        'operator': 'Exists'
    }]

    worker_annotations_config = {
        'iam.amazonaws.com/role': 'role-arn',
        'other/annotation': 'value'
    }

    def setUp(self):
        if AirflowKubernetesScheduler is None:
            self.skipTest("kubernetes python package is not installed")

        self.kube_config = mock.MagicMock()
        self.kube_config.airflow_home = '/'
        self.kube_config.airflow_dags = 'dags'
        self.kube_config.airflow_logs = 'logs'
        self.kube_config.dags_volume_subpath = None
        self.kube_config.logs_volume_subpath = None
        self.kube_config.dags_in_image = False
        self.kube_config.dags_folder = None
        self.kube_config.git_dags_folder_mount_point = None
        self.kube_config.kube_labels = {
            'dag_id': 'original_dag_id',
            'my_label': 'label_id'
        }
        self.api_client = ApiClient()

    def test_worker_configuration_no_subpaths(self):
        self.kube_config.dags_volume_claim = 'airflow-dags'
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()
        for volume_or_mount in volumes + volume_mounts:
            if volume_or_mount.name != 'airflow-config':
                self.assertNotIn(
                    'subPath',
                    self.api_client.sanitize_for_serialization(
                        volume_or_mount), "subPath shouldn't be defined")

    @conf_vars({
        ('kubernetes', 'git_ssh_known_hosts_configmap_name'):
        'airflow-configmap',
        ('kubernetes', 'git_ssh_key_secret_name'):
        'airflow-secrets',
        ('kubernetes', 'git_user'):
        'some-user',
        ('kubernetes', 'git_password'):
        'some-password',
        ('kubernetes', 'git_repo'):
        '[email protected]:apache/airflow.git',
        ('kubernetes', 'git_branch'):
        'master',
        ('kubernetes', 'git_dags_folder_mount_point'):
        '/usr/local/airflow/dags',
        ('kubernetes', 'delete_worker_pods'):
        'True',
        ('kubernetes', 'kube_client_request_args'):
        '{"_request_timeout" : [60,360]}',
    })
    def test_worker_configuration_auth_both_ssh_and_user(self):
        with self.assertRaisesRegex(
                AirflowConfigException,
                'either `git_user` and `git_password`.*'
                'or `git_ssh_key_secret_name`.*'
                'but not both$'):
            KubeConfig()

    def test_worker_with_subpaths(self):
        self.kube_config.dags_volume_subpath = 'dags'
        self.kube_config.logs_volume_subpath = 'logs'
        self.kube_config.dags_volume_claim = 'dags'
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        for volume in volumes:
            self.assertNotIn(
                'subPath', self.api_client.sanitize_for_serialization(volume),
                "subPath isn't valid configuration for a volume")

        for volume_mount in volume_mounts:
            if volume_mount.name != 'airflow-config':
                self.assertIn(
                    'subPath',
                    self.api_client.sanitize_for_serialization(volume_mount),
                    "subPath should've been passed to volumeMount configuration"
                )

    def test_worker_generate_dag_volume_mount_path(self):
        self.kube_config.git_dags_folder_mount_point = '/root/airflow/git/dags'
        self.kube_config.dags_folder = '/root/airflow/dags'
        worker_config = WorkerConfiguration(self.kube_config)

        self.kube_config.dags_volume_claim = 'airflow-dags'
        self.kube_config.dags_volume_host = ''
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path, self.kube_config.dags_folder)

        self.kube_config.dags_volume_claim = ''
        self.kube_config.dags_volume_host = '/host/airflow/dags'
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path, self.kube_config.dags_folder)

        self.kube_config.dags_volume_claim = ''
        self.kube_config.dags_volume_host = ''
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path,
                         self.kube_config.git_dags_folder_mount_point)

    def test_worker_environment_no_dags_folder(self):
        self.kube_config.airflow_configmap = ''
        self.kube_config.git_dags_folder_mount_point = ''
        self.kube_config.dags_folder = ''
        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()

        self.assertNotIn('AIRFLOW__CORE__DAGS_FOLDER', env)

    def test_worker_environment_when_dags_folder_specified(self):
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_dags_folder_mount_point = ''
        dags_folder = '/workers/path/to/dags'
        self.kube_config.dags_folder = dags_folder

        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()

        self.assertEqual(dags_folder, env['AIRFLOW__CORE__DAGS_FOLDER'])

    def test_worker_environment_dags_folder_using_git_sync(self):
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'dags'
        self.kube_config.git_dags_folder_mount_point = '/workers/path/to/dags'

        dags_folder = '{}/{}/{}'.format(
            self.kube_config.git_dags_folder_mount_point,
            self.kube_config.git_sync_dest, self.kube_config.git_subpath)

        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()

        self.assertEqual(dags_folder, env['AIRFLOW__CORE__DAGS_FOLDER'])

    def test_init_environment_using_git_sync_ssh_without_known_hosts(self):
        # Tests the init environment created with git-sync SSH authentication option is correct
        # without known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_secret_name = 'airflow-secrets'
        self.kube_config.git_ssh_known_hosts_configmap_name = None
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertIn(
            k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'),
            env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='false'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_ssh_with_known_hosts(self):
        # Tests the init environment created with git-sync SSH authentication option is correct
        # with known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_key_secret_name = 'airflow-secrets'
        self.kube_config.git_ssh_known_hosts_configmap_name = 'airflow-configmap'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertIn(
            k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'),
            env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='true'), env)
        self.assertIn(
            k8s.V1EnvVar(name='GIT_SSH_KNOWN_HOSTS_FILE',
                         value='/etc/git-secret/known_hosts'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_user_without_known_hosts(self):
        # Tests the init environment created with git-sync User authentication option is correct
        # without known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_user = '******'
        self.kube_config.git_password = '******'
        self.kube_config.git_ssh_known_hosts_configmap_name = None
        self.kube_config.git_ssh_key_secret_name = None
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertNotIn(
            k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'),
            env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_USERNAME', value='git_user'),
                      env)
        self.assertIn(
            k8s.V1EnvVar(name='GIT_SYNC_PASSWORD', value='git_password'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='false'), env)
        self.assertNotIn(
            k8s.V1EnvVar(name='GIT_SSH_KNOWN_HOSTS_FILE',
                         value='/etc/git-secret/known_hosts'), env)
        self.assertNotIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_user_with_known_hosts(self):
        # Tests the init environment created with git-sync User authentication option is correct
        # with known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_user = '******'
        self.kube_config.git_password = '******'
        self.kube_config.git_ssh_known_hosts_configmap_name = 'airflow-configmap'
        self.kube_config.git_ssh_key_secret_name = None
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertNotIn(
            k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'),
            env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_USERNAME', value='git_user'),
                      env)
        self.assertIn(
            k8s.V1EnvVar(name='GIT_SYNC_PASSWORD', value='git_password'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='true'), env)
        self.assertIn(
            k8s.V1EnvVar(name='GIT_SSH_KNOWN_HOSTS_FILE',
                         value='/etc/git-secret/known_hosts'), env)
        self.assertNotIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_run_as_user_empty(self):
        # Tests if git_syn_run_as_user is none, then no securityContext created in init container

        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.git_sync_run_as_user = ''

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()
        self.assertTrue(init_containers)  # check not empty

        self.assertIsNone(init_containers[0].security_context)

    def test_make_pod_run_as_user_0(self):
        # Tests the pod created with run-as-user 0 actually gets that in it's config
        self.kube_config.worker_run_as_user = 0
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.make_pod("default", str(uuid.uuid4()),
                                     "test_pod_id", "test_dag_id",
                                     "test_task_id", str(datetime.utcnow()), 1,
                                     "bash -c 'ls /'")

        self.assertEqual(0, pod.spec.security_context.run_as_user)

    def test_make_pod_git_sync_ssh_without_known_hosts(self):
        # Tests the pod created with git-sync SSH authentication option is correct without known hosts
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_key_secret_name = 'airflow-secrets'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)

        pod = worker_config.make_pod("default", str(uuid.uuid4()),
                                     "test_pod_id", "test_dag_id",
                                     "test_task_id", str(datetime.utcnow()), 1,
                                     "bash -c 'ls /'")

        init_containers = worker_config._get_init_containers()
        git_ssh_key_file = next(
            (x.value
             for x in init_containers[0].env if x.name == 'GIT_SSH_KEY_FILE'),
            None)
        volume_mount_ssh_key = next(
            (x.mount_path for x in init_containers[0].volume_mounts
             if x.name == worker_config.git_sync_ssh_secret_volume_name), None)
        self.assertTrue(git_ssh_key_file)
        self.assertTrue(volume_mount_ssh_key)
        self.assertEqual(65533, pod.spec.security_context.fs_group)
        self.assertEqual(
            git_ssh_key_file, volume_mount_ssh_key,
            'The location where the git ssh secret is mounted'
            ' needs to be the same as the GIT_SSH_KEY_FILE path')

    def test_make_pod_git_sync_credentials_secret(self):
        # Tests the pod created with git_sync_credentials_secret will get into the init container
        self.kube_config.git_sync_credentials_secret = 'airflow-git-creds-secret'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)

        pod = worker_config.make_pod("default", str(uuid.uuid4()),
                                     "test_pod_id", "test_dag_id",
                                     "test_task_id", str(datetime.utcnow()), 1,
                                     "bash -c 'ls /'")

        username_env = k8s.V1EnvVar(
            name='GIT_SYNC_USERNAME',
            value_from=k8s.V1EnvVarSource(
                secret_key_ref=k8s.V1SecretKeySelector(
                    name=self.kube_config.git_sync_credentials_secret,
                    key='GIT_SYNC_USERNAME')))
        password_env = k8s.V1EnvVar(
            name='GIT_SYNC_PASSWORD',
            value_from=k8s.V1EnvVarSource(
                secret_key_ref=k8s.V1SecretKeySelector(
                    name=self.kube_config.git_sync_credentials_secret,
                    key='GIT_SYNC_PASSWORD')))

        self.assertIn(
            username_env, pod.spec.init_containers[0].env,
            'The username env for git credentials did not get into the init container'
        )

        self.assertIn(
            password_env, pod.spec.init_containers[0].env,
            'The password env for git credentials did not get into the init container'
        )

    def test_make_pod_git_sync_ssh_with_known_hosts(self):
        # Tests the pod created with git-sync SSH authentication option is correct with known hosts
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_secret_name = 'airflow-secrets'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)

        init_containers = worker_config._get_init_containers()
        git_ssh_known_hosts_file = next(
            (x.value for x in init_containers[0].env
             if x.name == 'GIT_SSH_KNOWN_HOSTS_FILE'), None)

        volume_mount_ssh_known_hosts_file = next(
            (x.mount_path for x in init_containers[0].volume_mounts
             if x.name == worker_config.git_sync_ssh_known_hosts_volume_name),
            None)
        self.assertTrue(git_ssh_known_hosts_file)
        self.assertTrue(volume_mount_ssh_known_hosts_file)
        self.assertEqual(
            git_ssh_known_hosts_file, volume_mount_ssh_known_hosts_file,
            'The location where the git known hosts file is mounted'
            ' needs to be the same as the GIT_SSH_KNOWN_HOSTS_FILE path')

    def test_make_pod_with_empty_executor_config(self):
        self.kube_config.kube_affinity = self.affinity_config
        self.kube_config.kube_tolerations = self.tolerations_config
        self.kube_config.kube_annotations = self.worker_annotations_config
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)

        pod = worker_config.make_pod("default", str(uuid.uuid4()),
                                     "test_pod_id", "test_dag_id",
                                     "test_task_id", str(datetime.utcnow()), 1,
                                     "bash -c 'ls /'")

        self.assertTrue(pod.spec.affinity['podAntiAffinity'] is not None)
        self.assertEqual(
            'app', pod.spec.affinity['podAntiAffinity']
            ['requiredDuringSchedulingIgnoredDuringExecution'][0]
            ['labelSelector']['matchExpressions'][0]['key'])

        self.assertEqual(2, len(pod.spec.tolerations))
        self.assertEqual('prod', pod.spec.tolerations[1]['key'])
        self.assertEqual('role-arn',
                         pod.metadata.annotations['iam.amazonaws.com/role'])
        self.assertEqual('value', pod.metadata.annotations['other/annotation'])

    def test_make_pod_with_executor_config(self):
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        config_pod = PodGenerator(
            image='',
            affinity=self.affinity_config,
            tolerations=self.tolerations_config,
        ).gen_pod()

        pod = worker_config.make_pod("default", str(uuid.uuid4()),
                                     "test_pod_id", "test_dag_id",
                                     "test_task_id", str(datetime.utcnow()), 1,
                                     "bash -c 'ls /'")

        result = PodGenerator.reconcile_pods(pod, config_pod)

        self.assertTrue(result.spec.affinity['podAntiAffinity'] is not None)
        self.assertEqual(
            'app', result.spec.affinity['podAntiAffinity']
            ['requiredDuringSchedulingIgnoredDuringExecution'][0]
            ['labelSelector']['matchExpressions'][0]['key'])

        self.assertEqual(2, len(result.spec.tolerations))
        self.assertEqual('prod', result.spec.tolerations[1]['key'])

    def test_worker_pvc_dags(self):
        # Tests persistence volume config created when `dags_volume_claim` is set
        self.kube_config.dags_volume_claim = 'airflow-dags'
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        init_containers = worker_config._get_init_containers()

        dag_volume = [
            volume for volume in volumes if volume.name == 'airflow-dags'
        ]
        dag_volume_mount = [
            mount for mount in volume_mounts if mount.name == 'airflow-dags'
        ]

        self.assertEqual('airflow-dags',
                         dag_volume[0].persistent_volume_claim.claim_name)
        self.assertEqual(1, len(dag_volume_mount))
        self.assertTrue(dag_volume_mount[0].read_only)
        self.assertEqual(0, len(init_containers))

    def test_worker_git_dags(self):
        # Tests persistence volume config created when `git_repo` is set
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_folder = '/usr/local/airflow/dags'
        self.kube_config.worker_dags_folder = '/usr/local/airflow/dags'

        self.kube_config.git_sync_container_repository = 'gcr.io/google-containers/git-sync-amd64'
        self.kube_config.git_sync_container_tag = 'v2.0.5'
        self.kube_config.git_sync_container = 'gcr.io/google-containers/git-sync-amd64:v2.0.5'
        self.kube_config.git_sync_init_container_name = 'git-sync-clone'
        self.kube_config.git_subpath = 'dags_folder'
        self.kube_config.git_sync_root = '/git'
        self.kube_config.git_sync_run_as_user = 65533
        self.kube_config.git_dags_folder_mount_point = '/usr/local/airflow/dags/repo/dags_folder'

        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        dag_volume = [
            volume for volume in volumes if volume.name == 'airflow-dags'
        ]
        dag_volume_mount = [
            mount for mount in volume_mounts if mount.name == 'airflow-dags'
        ]

        self.assertIsNotNone(dag_volume[0].empty_dir)
        self.assertEqual(self.kube_config.git_dags_folder_mount_point,
                         dag_volume_mount[0].mount_path)
        self.assertTrue(dag_volume_mount[0].read_only)

        init_container = worker_config._get_init_containers()[0]
        init_container_volume_mount = [
            mount for mount in init_container.volume_mounts
            if mount.name == 'airflow-dags'
        ]

        self.assertEqual('git-sync-clone', init_container.name)
        self.assertEqual('gcr.io/google-containers/git-sync-amd64:v2.0.5',
                         init_container.image)
        self.assertEqual(1, len(init_container_volume_mount))
        self.assertFalse(init_container_volume_mount[0].read_only)
        self.assertEqual(65533, init_container.security_context.run_as_user)

    def test_worker_container_dags(self):
        # Tests that the 'airflow-dags' persistence volume is NOT created when `dags_in_image` is set
        self.kube_config.dags_in_image = True
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        dag_volume = [
            volume for volume in volumes if volume.name == 'airflow-dags'
        ]
        dag_volume_mount = [
            mount for mount in volume_mounts if mount.name == 'airflow-dags'
        ]

        init_containers = worker_config._get_init_containers()

        self.assertEqual(0, len(dag_volume))
        self.assertEqual(0, len(dag_volume_mount))
        self.assertEqual(0, len(init_containers))

    def test_kubernetes_environment_variables(self):
        # Tests the kubernetes environment variables get copied into the worker pods
        input_environment = {'ENVIRONMENT': 'prod', 'LOG_LEVEL': 'warning'}
        self.kube_config.kube_env_vars = input_environment
        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()
        for key in input_environment:
            self.assertIn(key, env)
            self.assertIn(input_environment[key], env.values())

        core_executor = 'AIRFLOW__CORE__EXECUTOR'
        input_environment = {core_executor: 'NotLocalExecutor'}
        self.kube_config.kube_env_vars = input_environment
        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()
        self.assertEqual(env[core_executor], 'LocalExecutor')

    def test_get_secrets(self):
        # Test when secretRef is None and kube_secrets is not empty
        self.kube_config.kube_secrets = {
            'AWS_SECRET_KEY': 'airflow-secret=aws_secret_key',
            'POSTGRES_PASSWORD': '******'
        }
        self.kube_config.env_from_secret_ref = None
        worker_config = WorkerConfiguration(self.kube_config)
        secrets = worker_config._get_secrets()
        secrets.sort(key=lambda secret: secret.deploy_target)
        expected = [
            Secret('env', 'AWS_SECRET_KEY', 'airflow-secret',
                   'aws_secret_key'),
            Secret('env', 'POSTGRES_PASSWORD', 'airflow-secret',
                   'postgres_credentials')
        ]
        self.assertListEqual(expected, secrets)

        # Test when secret is not empty and kube_secrets is empty dict
        self.kube_config.kube_secrets = {}
        self.kube_config.env_from_secret_ref = 'secret_a,secret_b'
        worker_config = WorkerConfiguration(self.kube_config)
        secrets = worker_config._get_secrets()
        expected = [
            Secret('env', None, 'secret_a'),
            Secret('env', None, 'secret_b')
        ]
        self.assertListEqual(expected, secrets)

    def test_get_env_from(self):
        # Test when configmap is empty
        self.kube_config.env_from_configmap_ref = ''
        worker_config = WorkerConfiguration(self.kube_config)
        configmaps = worker_config._get_env_from()
        self.assertListEqual([], configmaps)

        # test when configmap is not empty
        self.kube_config.env_from_configmap_ref = 'configmap_a,configmap_b'
        self.kube_config.env_from_secret_ref = 'secretref_a,secretref_b'
        worker_config = WorkerConfiguration(self.kube_config)
        configmaps = worker_config._get_env_from()
        self.assertListEqual([
            k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                name='configmap_a')),
            k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                name='configmap_b')),
            k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                name='secretref_a')),
            k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                name='secretref_b'))
        ], configmaps)

    def test_get_labels(self):
        worker_config = WorkerConfiguration(self.kube_config)
        labels = worker_config._get_labels(
            {'my_kube_executor_label': 'kubernetes'}, {
                'dag_id': 'override_dag_id',
            })
        self.assertEqual(
            {
                'my_label': 'label_id',
                'dag_id': 'override_dag_id',
                'my_kube_executor_label': 'kubernetes'
            }, labels)
class TestKubernetesPodOperatorSystem(unittest.TestCase):
    def setUp(self):
        self.maxDiff = None  # pylint: disable=invalid-name
        self.api_client = ApiClient()
        self.expected_pod = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'namespace': 'default',
                'name': ANY,
                'annotations': {},
                'labels': {
                    'foo': 'bar',
                    'kubernetes_pod_operator': 'True',
                    'airflow_version': airflow_version.replace('+', '-'),
                    'execution_date': '2016-01-01T0100000100-a2f50a31f',
                    'dag_id': 'dag',
                    'task_id': 'task',
                    'try_number': '1'
                },
            },
            'spec': {
                'affinity': {},
                'containers': [{
                    'image': 'ubuntu:16.04',
                    'args': ["echo 10"],
                    'command': ["bash", "-cx"],
                    'env': [],
                    'imagePullPolicy': 'IfNotPresent',
                    'envFrom': [],
                    'name': 'base',
                    'ports': [],
                    'volumeMounts': [],
                }],
                'hostNetwork':
                False,
                'imagePullSecrets': [],
                'initContainers': [],
                'nodeSelector': {},
                'restartPolicy':
                'Never',
                'securityContext': {},
                'serviceAccountName':
                'default',
                'tolerations': [],
                'volumes': [],
            }
        }

    def tearDown(self) -> None:
        client = kube_client.get_kube_client(in_cluster=False)
        client.delete_collection_namespaced_pod(namespace="default")

    def test_do_xcom_push_defaults_false(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = os.path.expanduser('~/.kube/config')
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            config_file=new_config_path,
        )
        self.assertFalse(k.do_xcom_push)

    def test_config_path_move(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = os.path.expanduser('~/.kube/config')
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test1",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            config_file=new_config_path,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod, actual_pod)

    def test_working_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_delete_operator_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            is_delete_operator_pod=True,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_hostnetwork(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_dnspolicy(self):
        dns_policy = "ClusterFirstWithHostNet"
        k = KubernetesPodOperator(namespace='default',
                                  image="ubuntu:16.04",
                                  cmds=["bash", "-cx"],
                                  arguments=["echo 10"],
                                  labels={"foo": "bar"},
                                  name="test",
                                  task_id="task",
                                  in_cluster=False,
                                  do_xcom_push=False,
                                  hostnetwork=True,
                                  dnspolicy=dns_policy)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.expected_pod['spec']['dnsPolicy'] = dns_policy
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_schedulername(self):
        scheduler_name = "default-scheduler"
        k = KubernetesPodOperator(namespace="default",
                                  image="ubuntu:16.04",
                                  cmds=["bash", "-cx"],
                                  arguments=["echo 10"],
                                  labels={"foo": "bar"},
                                  name="test",
                                  task_id="task",
                                  in_cluster=False,
                                  do_xcom_push=False,
                                  schedulername=scheduler_name)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['schedulerName'] = scheduler_name
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_node_selectors(self):
        node_selectors = {'beta.kubernetes.io/os': 'linux'}
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            node_selectors=node_selectors,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['nodeSelector'] = node_selectors
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_resources(self):
        resources = {
            'limit_cpu': 0.25,
            'limit_memory': '64Mi',
            'limit_ephemeral_storage': '2Gi',
            'request_cpu': '250m',
            'request_memory': '64Mi',
            'request_ephemeral_storage': '1Gi',
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            resources=resources,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['resources'] = {
            'requests': {
                'memory': '64Mi',
                'cpu': '250m',
                'ephemeral-storage': '1Gi'
            },
            'limits': {
                'memory': '64Mi',
                'cpu': 0.25,
                'nvidia.com/gpu': None,
                'ephemeral-storage': '2Gi'
            }
        }
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_affinity(self):
        affinity = {
            'nodeAffinity': {
                'requiredDuringSchedulingIgnoredDuringExecution': {
                    'nodeSelectorTerms': [{
                        'matchExpressions': [{
                            'key': 'beta.kubernetes.io/os',
                            'operator': 'In',
                            'values': ['linux']
                        }]
                    }]
                }
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            affinity=affinity,
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['affinity'] = affinity
        self.assertEqual(self.expected_pod, actual_pod)

    def test_port(self):
        port = Port('http', 80)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            ports=[port],
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['ports'] = [{
            'name':
            'http',
            'containerPort':
            80
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    def test_volume_mount(self):
        with mock.patch.object(PodLauncher, 'log') as mock_logger:
            volume_mount = VolumeMount('test-volume',
                                       mount_path='/tmp/test_volume',
                                       sub_path=None,
                                       read_only=False)

            volume_config = {
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }
            volume = Volume(name='test-volume', configs=volume_config)
            args = [
                "echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
                "&& cat /tmp/test_volume/test.txt"
            ]
            k = KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=args,
                labels={"foo": "bar"},
                volume_mounts=[volume_mount],
                volumes=[volume],
                name="test",
                task_id="task",
                in_cluster=False,
                do_xcom_push=False,
            )
            context = create_context(k)
            k.execute(context=context)
            mock_logger.info.assert_any_call(b"retrieved from mount\n")
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = args
            self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{
                'name':
                'test-volume',
                'mountPath':
                '/tmp/test_volume',
                'readOnly':
                False
            }]
            self.expected_pod['spec']['volumes'] = [{
                'name': 'test-volume',
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }]
            self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 0,
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_non_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_fs_group(self):
        security_context = {
            'securityContext': {
                'fsGroup': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_image(self):
        bad_image_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image=bad_image_name,
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
        )
        with self.assertRaises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'image'] = bad_image_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_service_account(self):
        bad_service_account_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
            service_account_name=bad_service_account_name,
        )
        with self.assertRaises(ApiException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec'][
                'serviceAccountName'] = bad_service_account_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_failure(self):
        """
            Tests that the task fails when a pod reports a failure
        """
        bad_internal_command = ["foobar 10 "]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=bad_internal_command,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        with self.assertRaises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'args'] = bad_internal_command
            self.assertEqual(self.expected_pod, actual_pod)

    def test_xcom_push(self):
        return_value = '{"foo": "bar"\n, "buzz": 2}'
        args = ['echo \'{}\' > /airflow/xcom/return.json'.format(return_value)]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=args,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=True,
        )
        context = create_context(k)
        self.assertEqual(k.execute(context), json.loads(return_value))
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
        volume_mount = self.api_client.sanitize_for_serialization(
            PodDefaults.VOLUME_MOUNT)
        container = self.api_client.sanitize_for_serialization(
            PodDefaults.SIDECAR_CONTAINER)
        self.expected_pod['spec']['containers'][0]['args'] = args
        self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(
            0, volume_mount)  # noqa
        self.expected_pod['spec']['volumes'].insert(0, volume)
        self.expected_pod['spec']['containers'].append(container)
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
        # GIVEN
        from airflow.utils.state import State

        configmap = 'test-configmap'
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            configmaps=[configmap],
        )
        # THEN
        mock_monitor.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        self.assertEqual(
            mock_start.call_args[0][0].spec.containers[0].env_from, [
                k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                    name=configmap))
            ])

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
        # GIVEN
        from airflow.utils.state import State
        secret_ref = 'secret_name'
        secrets = [Secret('env', None, secret_ref)]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            secrets=secrets,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN
        monitor_mock.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        self.assertEqual(
            start_mock.call_args[0][0].spec.containers[0].env_from, [
                k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                    name=secret_ref))
            ])

    def test_init_container(self):
        # GIVEN
        volume_mounts = [
            k8s.V1VolumeMount(mount_path='/etc/foo',
                              name='test-volume',
                              sub_path=None,
                              read_only=True)
        ]

        init_environments = [
            k8s.V1EnvVar(name='key1', value='value1'),
            k8s.V1EnvVar(name='key2', value='value2')
        ]

        init_container = k8s.V1Container(name="init-container",
                                         image="ubuntu:16.04",
                                         env=init_environments,
                                         volume_mounts=volume_mounts,
                                         command=["bash", "-cx"],
                                         args=["echo 10"])

        volume_config = {'persistentVolumeClaim': {'claimName': 'test-volume'}}
        volume = Volume(name='test-volume', configs=volume_config)

        expected_init_container = {
            'name':
            'init-container',
            'image':
            'ubuntu:16.04',
            'command': ['bash', '-cx'],
            'args': ['echo 10'],
            'env': [{
                'name': 'key1',
                'value': 'value1'
            }, {
                'name': 'key2',
                'value': 'value2'
            }],
            'volumeMounts': [{
                'mountPath': '/etc/foo',
                'name': 'test-volume',
                'readOnly': True
            }],
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            volumes=[volume],
            init_containers=[init_container],
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['initContainers'] = [expected_init_container]
        self.expected_pod['spec']['volumes'] = [{
            'name': 'test-volume',
            'persistentVolumeClaim': {
                'claimName': 'test-volume'
            }
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_template_file(self, mock_client, monitor_mock, start_mock):  # pylint: disable=unused-argument
        from airflow.utils.state import State
        k = KubernetesPodOperator(
            task_id='task',
            pod_template_file='tests/kubernetes/pod.yaml',
            do_xcom_push=True)
        monitor_mock.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(
            {
                'apiVersion': 'v1',
                'kind': 'Pod',
                'metadata': {
                    'name': ANY,
                    'namespace': 'mem-example'
                },
                'spec': {
                    'volumes': [{
                        'name': 'xcom',
                        'emptyDir': {}
                    }],
                    'containers': [{
                        'args':
                        ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
                        'command': ['stress'],
                        'image':
                        'polinux/stress',
                        'name':
                        'memory-demo-ctr',
                        'resources': {
                            'limits': {
                                'memory': '200Mi'
                            },
                            'requests': {
                                'memory': '100Mi'
                            }
                        },
                        'volumeMounts': [{
                            'name': 'xcom',
                            'mountPath': '/airflow/xcom'
                        }]
                    }, {
                        'name':
                        'airflow-xcom-sidecar',
                        'image':
                        "alpine",
                        'command': ['sh', '-c', PodDefaults.XCOM_CMD],
                        'volumeMounts': [{
                            'name': 'xcom',
                            'mountPath': '/airflow/xcom'
                        }],
                        'resources': {
                            'requests': {
                                'cpu': '1m'
                            }
                        },
                    }],
                }
            }, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_priority_class_name(self, mock_client, monitor_mock,
                                     start_mock):  # pylint: disable=unused-argument
        """Test ability to assign priorityClassName to pod

        """
        from airflow.utils.state import State

        priority_class_name = "medium-test"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            priority_class_name=priority_class_name,
        )

        monitor_mock.return_value = (State.SUCCESS, None)
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['priorityClassName'] = priority_class_name
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_name(self):
        pod_name_too_long = "a" * 221
        with self.assertRaises(AirflowException):
            KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=["echo 10"],
                labels={"foo": "bar"},
                name=pod_name_too_long,
                task_id="task",
                in_cluster=False,
                do_xcom_push=False,
            )
Esempio n. 4
0
class TestKubernetesPodOperator(unittest.TestCase):

    def setUp(self):
        self.maxDiff = None  # pylint: disable=invalid-name
        self.api_client = ApiClient()
        self.expected_pod = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'namespace': 'default',
                'name': ANY,
                'annotations': {},
                'labels': {
                    'foo': 'bar', 'kubernetes_pod_operator': 'True',
                    'airflow_version': airflow_version.replace('+', '-')
                }
            },
            'spec': {
                'affinity': {},
                'containers': [{
                    'image': 'ubuntu:16.04',
                    'args': ["echo 10"],
                    'command': ["bash", "-cx"],
                    'env': [],
                    'imagePullPolicy': 'IfNotPresent',
                    'envFrom': [],
                    'name': 'base',
                    'ports': [],
                    'resources': {'limits': {'cpu': None,
                                             'memory': None,
                                             'nvidia.com/gpu': None},
                                  'requests': {'cpu': None,
                                               'memory': None}},
                    'volumeMounts': [],
                }],
                'hostNetwork': False,
                'imagePullSecrets': [],
                'nodeSelector': {},
                'restartPolicy': 'Never',
                'securityContext': {},
                'serviceAccountName': 'default',
                'tolerations': [],
                'volumes': [],
            }
        }

    def test_config_path_move(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = os.path.expanduser('~/.kube/config')
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            config_file=new_config_path,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_config_path(self, client_mock, launcher_mock):
        from airflow.utils.state import State

        file_path = "/tmp/fake_file"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            config_file=file_path,
            cluster_context='default',
        )
        launcher_mock.return_value = (State.SUCCESS, None)
        k.execute(None)
        client_mock.assert_called_once_with(
            in_cluster=False,
            cluster_context='default',
            config_file=file_path,
        )

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_image_pull_secrets_correctly_set(self, mock_client, launcher_mock):
        from airflow.utils.state import State

        fake_pull_secrets = "fakeSecret"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            image_pull_secrets=fake_pull_secrets,
            cluster_context='default',
        )
        launcher_mock.return_value = (State.SUCCESS, None)
        k.execute(None)
        self.assertEqual(
            launcher_mock.call_args[0][0].spec.image_pull_secrets,
            [k8s.V1LocalObjectReference(name=fake_pull_secrets)]
        )

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.delete_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_delete_even_on_launcher_error(self, mock_client, delete_pod_mock, run_pod_mock):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            cluster_context='default',
            is_delete_operator_pod=True,
        )
        run_pod_mock.side_effect = AirflowException('fake failure')
        with self.assertRaises(AirflowException):
            k.execute(None)
        assert delete_pod_mock.called

    def test_working_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod, actual_pod)

    def test_delete_operator_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            is_delete_operator_pod=True,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_hostnetwork(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_dnspolicy(self):
        dns_policy = "ClusterFirstWithHostNet"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
            dnspolicy=dns_policy
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.expected_pod['spec']['dnsPolicy'] = dns_policy
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_node_selectors(self):
        node_selectors = {
            'beta.kubernetes.io/os': 'linux'
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            node_selectors=node_selectors,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['nodeSelector'] = node_selectors
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_resources(self):
        resources = {
            'limit_cpu': 0.25,
            'limit_memory': '64Mi',
            'request_cpu': '250m',
            'request_memory': '64Mi',
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            resources=resources,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['resources'] = {
            'requests': {
                'memory': '64Mi',
                'cpu': '250m'
            },
            'limits': {
                'memory': '64Mi',
                'cpu': 0.25,
                'nvidia.com/gpu': None
            }
        }
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_affinity(self):
        affinity = {
            'nodeAffinity': {
                'requiredDuringSchedulingIgnoredDuringExecution': {
                    'nodeSelectorTerms': [
                        {
                            'matchExpressions': [
                                {
                                    'key': 'beta.kubernetes.io/os',
                                    'operator': 'In',
                                    'values': ['linux']
                                }
                            ]
                        }
                    ]
                }
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            affinity=affinity,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['affinity'] = affinity
        self.assertEqual(self.expected_pod, actual_pod)

    def test_port(self):
        port = Port('http', 80)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            ports=[port],
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['ports'] = [{
            'name': 'http',
            'containerPort': 80
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    def test_volume_mount(self):
        with mock.patch.object(PodLauncher, 'log') as mock_logger:
            volume_mount = VolumeMount('test-volume',
                                       mount_path='/root/mount_file',
                                       sub_path=None,
                                       read_only=True)

            volume_config = {
                'persistentVolumeClaim':
                    {
                        'claimName': 'test-volume'
                    }
            }
            volume = Volume(name='test-volume', configs=volume_config)
            args = ["cat /root/mount_file/test.txt"]
            k = KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=args,
                labels={"foo": "bar"},
                volume_mounts=[volume_mount],
                volumes=[volume],
                name="test",
                task_id="task",
                in_cluster=False,
                do_xcom_push=False,
            )
            k.execute(None)
            mock_logger.info.assert_any_call(b"retrieved from mount\n")
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = args
            self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{
                'name': 'test-volume',
                'mountPath': '/root/mount_file',
                'readOnly': True
            }]
            self.expected_pod['spec']['volumes'] = [{
                'name': 'test-volume',
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }]
            self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 0,
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_non_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_fs_group(self):
        security_context = {
            'securityContext': {
                'fsGroup': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        k.execute(None)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_image(self):
        bad_image_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image=bad_image_name,
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
        )
        with self.assertRaises(AirflowException):
            k.execute(None)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['image'] = bad_image_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_service_account(self):
        bad_service_account_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
            service_account_name=bad_service_account_name,
        )
        with self.assertRaises(ApiException):
            k.execute(None)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['serviceAccountName'] = bad_service_account_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_failure(self):
        """
            Tests that the task fails when a pod reports a failure
        """
        bad_internal_command = ["foobar 10 "]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=bad_internal_command,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        with self.assertRaises(AirflowException):
            k.execute(None)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = bad_internal_command
            self.assertEqual(self.expected_pod, actual_pod)

    def test_xcom_push(self):
        return_value = '{"foo": "bar"\n, "buzz": 2}'
        args = ['echo \'{}\' > /airflow/xcom/return.json'.format(return_value)]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=args,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=True,
        )
        self.assertEqual(k.execute(None), json.loads(return_value))
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
        volume_mount = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT)
        container = self.api_client.sanitize_for_serialization(PodDefaults.SIDECAR_CONTAINER)
        self.expected_pod['spec']['containers'][0]['args'] = args
        self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(0, volume_mount)
        self.expected_pod['spec']['volumes'].insert(0, volume)
        self.expected_pod['spec']['containers'].append(container)
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_configmaps(self, mock_client, mock_launcher):
        # GIVEN
        from airflow.utils.state import State

        configmap = 'test-configmap'
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            configmaps=[configmap],
        )
        # THEN
        mock_launcher.return_value = (State.SUCCESS, None)
        k.execute(None)
        self.assertEqual(
            mock_launcher.call_args[0][0].spec.containers[0].env_from,
            [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                name=configmap
            ))]
        )

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_secrets(self, mock_client, launcher_mock):
        # GIVEN
        from airflow.utils.state import State
        secret_ref = 'secret_name'
        secrets = [Secret('env', None, secret_ref)]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            secrets=secrets,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN
        launcher_mock.return_value = (State.SUCCESS, None)
        k.execute(None)
        self.assertEqual(
            launcher_mock.call_args[0][0].spec.containers[0].env_from,
            [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                name=secret_ref
            ))]
        )
Esempio n. 5
0
 def serialize_pod(pod):
     """
     Converts a k8s.V1Pod into a jsonified object
     """
     api_client = ApiClient()
     return api_client.sanitize_for_serialization(pod)
class TestKubernetesWorkerConfiguration(unittest.TestCase):
    """
    Tests that if dags_volume_subpath/logs_volume_subpath configuration
    options are passed to worker pod config
    """

    affinity_config = {
        'podAntiAffinity': {
            'requiredDuringSchedulingIgnoredDuringExecution': [
                {
                    'topologyKey': 'kubernetes.io/hostname',
                    'labelSelector': {
                        'matchExpressions': [
                            {
                                'key': 'app',
                                'operator': 'In',
                                'values': ['airflow']
                            }
                        ]
                    }
                }
            ]
        }
    }

    tolerations_config = [
        {
            'key': 'dedicated',
            'operator': 'Equal',
            'value': 'airflow'
        },
        {
            'key': 'prod',
            'operator': 'Exists'
        }
    ]

    worker_annotations_config = {
        'iam.amazonaws.com/role': 'role-arn',
        'other/annotation': 'value'
    }

    def setUp(self):
        if AirflowKubernetesScheduler is None:
            self.skipTest("kubernetes python package is not installed")

        self.kube_config = mock.MagicMock()
        self.kube_config.airflow_home = '/'
        self.kube_config.airflow_dags = 'dags'
        self.kube_config.airflow_logs = 'logs'
        self.kube_config.dags_volume_subpath = None
        self.kube_config.dags_volume_mount_point = None
        self.kube_config.logs_volume_subpath = None
        self.kube_config.dags_in_image = False
        self.kube_config.dags_folder = None
        self.kube_config.git_dags_folder_mount_point = None
        self.kube_config.kube_labels = {'dag_id': 'original_dag_id', 'my_label': 'label_id'}
        self.kube_config.pod_template_file = ''
        self.kube_config.restart_policy = ''
        self.kube_config.image_pull_policy = ''
        self.api_client = ApiClient()

    def tearDown(self) -> None:
        self.kube_config = None

    def test_worker_configuration_no_subpaths(self):
        self.kube_config.dags_volume_claim = 'airflow-dags'
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()
        for volume_or_mount in volumes + volume_mounts:
            if volume_or_mount.name not in ['airflow-config', 'airflow-local-settings']:
                self.assertNotIn(
                    'subPath', self.api_client.sanitize_for_serialization(volume_or_mount),
                    "subPath shouldn't be defined"
                )

    @conf_vars({
        ('kubernetes', 'git_ssh_known_hosts_configmap_name'): 'airflow-configmap',
        ('kubernetes', 'git_ssh_key_secret_name'): 'airflow-secrets',
        ('kubernetes', 'git_user'): 'some-user',
        ('kubernetes', 'git_password'): 'some-password',
        ('kubernetes', 'git_repo'): '[email protected]:apache/airflow.git',
        ('kubernetes', 'git_branch'): 'master',
        ('kubernetes', 'git_dags_folder_mount_point'): '/usr/local/airflow/dags',
        ('kubernetes', 'delete_worker_pods'): 'True',
        ('kubernetes', 'kube_client_request_args'): '{"_request_timeout" : [60,360]}',
    })
    def test_worker_configuration_auth_both_ssh_and_user(self):
        with self.assertRaisesRegex(AirflowConfigException,
                                    'either `git_user` and `git_password`.*'
                                    'or `git_ssh_key_secret_name`.*'
                                    'but not both$'):
            KubeConfig()

    @parameterized.expand([
        ('{"grace_period_seconds": 10}', {"grace_period_seconds": 10}),
        ("", {})
    ])
    def test_delete_option_kwargs_config(self, config, expected_value):
        with conf_vars({
            ('kubernetes', 'delete_option_kwargs'): config,
        }):
            self.assertEqual(KubeConfig().delete_option_kwargs, expected_value)

    def test_worker_with_subpaths(self):
        self.kube_config.dags_volume_subpath = 'dags'
        self.kube_config.logs_volume_subpath = 'logs'
        self.kube_config.dags_volume_claim = 'dags'
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        for volume in volumes:
            self.assertNotIn(
                'subPath', self.api_client.sanitize_for_serialization(volume),
                "subPath isn't valid configuration for a volume"
            )

        for volume_mount in volume_mounts:
            if volume_mount.name != 'airflow-config':
                self.assertIn(
                    'subPath', self.api_client.sanitize_for_serialization(volume_mount),
                    "subPath should've been passed to volumeMount configuration"
                )

    def test_worker_generate_dag_volume_mount_path(self):
        self.kube_config.git_dags_folder_mount_point = '/root/airflow/git/dags'
        self.kube_config.dags_folder = '/root/airflow/dags'
        worker_config = WorkerConfiguration(self.kube_config)

        self.kube_config.dags_volume_claim = 'airflow-dags'
        self.kube_config.dags_volume_host = ''
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path, self.kube_config.dags_folder)

        self.kube_config.dags_volume_mount_point = '/root/airflow/package'
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path, '/root/airflow/package')
        self.kube_config.dags_volume_mount_point = ''

        self.kube_config.dags_volume_claim = ''
        self.kube_config.dags_volume_host = '/host/airflow/dags'
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path, self.kube_config.dags_folder)

        self.kube_config.dags_volume_claim = ''
        self.kube_config.dags_volume_host = ''
        dag_volume_mount_path = worker_config.generate_dag_volume_mount_path()
        self.assertEqual(dag_volume_mount_path,
                         self.kube_config.git_dags_folder_mount_point)

    def test_worker_environment_no_dags_folder(self):
        self.kube_config.airflow_configmap = ''
        self.kube_config.git_dags_folder_mount_point = ''
        self.kube_config.dags_folder = ''
        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()

        self.assertNotIn('AIRFLOW__CORE__DAGS_FOLDER', env)

    def test_worker_environment_when_dags_folder_specified(self):
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_dags_folder_mount_point = ''
        dags_folder = '/workers/path/to/dags'
        self.kube_config.dags_folder = dags_folder

        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()

        self.assertEqual(dags_folder, env['AIRFLOW__CORE__DAGS_FOLDER'])

    def test_worker_environment_dags_folder_using_git_sync(self):
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'dags'
        self.kube_config.git_dags_folder_mount_point = '/workers/path/to/dags'

        dags_folder = '{}/{}/{}'.format(self.kube_config.git_dags_folder_mount_point,
                                        self.kube_config.git_sync_dest,
                                        self.kube_config.git_subpath)

        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()

        self.assertEqual(dags_folder, env['AIRFLOW__CORE__DAGS_FOLDER'])

    def test_init_environment_using_git_sync_ssh_without_known_hosts(self):
        # Tests the init environment created with git-sync SSH authentication option is correct
        # without known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_secret_name = 'airflow-secrets'
        self.kube_config.git_ssh_known_hosts_configmap_name = None
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertIn(k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='false'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_ssh_with_known_hosts(self):
        # Tests the init environment created with git-sync SSH authentication option is correct
        # with known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_key_secret_name = 'airflow-secrets'
        self.kube_config.git_ssh_known_hosts_configmap_name = 'airflow-configmap'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertIn(k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='true'), env)
        self.assertIn(k8s.V1EnvVar(
            name='GIT_SSH_KNOWN_HOSTS_FILE',
            value='/etc/git-secret/known_hosts'
        ), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_user_without_known_hosts(self):
        # Tests the init environment created with git-sync User authentication option is correct
        # without known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_user = '******'
        self.kube_config.git_password = '******'
        self.kube_config.git_ssh_known_hosts_configmap_name = None
        self.kube_config.git_ssh_key_secret_name = None
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertNotIn(k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_USERNAME', value='git_user'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_PASSWORD', value='git_password'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='false'), env)
        self.assertNotIn(k8s.V1EnvVar(
            name='GIT_SSH_KNOWN_HOSTS_FILE',
            value='/etc/git-secret/known_hosts'
        ), env)
        self.assertNotIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_user_with_known_hosts(self):
        # Tests the init environment created with git-sync User authentication option is correct
        # with known hosts file
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_user = '******'
        self.kube_config.git_password = '******'
        self.kube_config.git_ssh_known_hosts_configmap_name = 'airflow-configmap'
        self.kube_config.git_ssh_key_secret_name = None
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()

        self.assertTrue(init_containers)  # check not empty
        env = init_containers[0].env

        self.assertNotIn(k8s.V1EnvVar(name='GIT_SSH_KEY_FILE', value='/etc/git-secret/ssh'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_USERNAME', value='git_user'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_SYNC_PASSWORD', value='git_password'), env)
        self.assertIn(k8s.V1EnvVar(name='GIT_KNOWN_HOSTS', value='true'), env)
        self.assertIn(k8s.V1EnvVar(
            name='GIT_SSH_KNOWN_HOSTS_FILE',
            value='/etc/git-secret/known_hosts'
        ), env)
        self.assertNotIn(k8s.V1EnvVar(name='GIT_SYNC_SSH', value='true'), env)

    def test_init_environment_using_git_sync_run_as_user_empty(self):
        # Tests if git_syn_run_as_user is none, then no securityContext created in init container

        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.git_sync_run_as_user = ''

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()
        self.assertTrue(init_containers)  # check not empty

        self.assertIsNone(init_containers[0].security_context)

    def test_init_environment_using_git_sync_run_as_user_root(self):
        # Tests if git_syn_run_as_user is '0', securityContext is created with
        # the right uid

        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.git_sync_run_as_user = 0

        worker_config = WorkerConfiguration(self.kube_config)
        init_containers = worker_config._get_init_containers()
        self.assertTrue(init_containers)  # check not empty

        self.assertEqual(0, init_containers[0].security_context.run_as_user)

    def test_make_pod_run_as_user_0(self):
        # Tests the pod created with run-as-user 0 actually gets that in it's config
        self.kube_config.worker_run_as_user = 0
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.as_pod()

        self.assertEqual(0, pod.spec.security_context.run_as_user)

    def test_make_pod_assert_labels(self):
        # Tests the pod created has all the expected labels set
        self.kube_config.dags_folder = 'dags'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = PodGenerator.construct_pod(
            "test_dag_id",
            "test_task_id",
            "test_pod_id",
            1,
            "2019-11-21 11:08:22.920875",
            ["bash -c 'ls /'"],
            None,
            worker_config.as_pod(),
            "default",
            "sample-uuid",
        )
        expected_labels = {
            'airflow-worker': 'sample-uuid',
            'airflow_version': airflow_version.replace('+', '-'),
            'dag_id': 'test_dag_id',
            'execution_date': '2019-11-21 11:08:22.920875',
            'kubernetes_executor': 'True',
            'task_id': 'test_task_id',
            'try_number': '1'
        }
        self.assertEqual(pod.metadata.labels, expected_labels)

    def test_make_pod_git_sync_ssh_without_known_hosts(self):
        # Tests the pod created with git-sync SSH authentication option is correct without known hosts
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_key_secret_name = 'airflow-secrets'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)

        pod = worker_config.as_pod()

        init_containers = worker_config._get_init_containers()
        git_ssh_key_file = next((x.value for x in init_containers[0].env
                                if x.name == 'GIT_SSH_KEY_FILE'), None)
        volume_mount_ssh_key = next((x.mount_path for x in init_containers[0].volume_mounts
                                    if x.name == worker_config.git_sync_ssh_secret_volume_name),
                                    None)
        self.assertTrue(git_ssh_key_file)
        self.assertTrue(volume_mount_ssh_key)
        self.assertEqual(65533, pod.spec.security_context.fs_group)
        self.assertEqual(git_ssh_key_file,
                         volume_mount_ssh_key,
                         'The location where the git ssh secret is mounted'
                         ' needs to be the same as the GIT_SSH_KEY_FILE path')

    def test_make_pod_git_sync_credentials_secret(self):
        # Tests the pod created with git_sync_credentials_secret will get into the init container
        self.kube_config.git_sync_credentials_secret = 'airflow-git-creds-secret'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)

        pod = worker_config.as_pod()

        username_env = k8s.V1EnvVar(
            name='GIT_SYNC_USERNAME',
            value_from=k8s.V1EnvVarSource(
                secret_key_ref=k8s.V1SecretKeySelector(
                    name=self.kube_config.git_sync_credentials_secret,
                    key='GIT_SYNC_USERNAME')
            )
        )
        password_env = k8s.V1EnvVar(
            name='GIT_SYNC_PASSWORD',
            value_from=k8s.V1EnvVarSource(
                secret_key_ref=k8s.V1SecretKeySelector(
                    name=self.kube_config.git_sync_credentials_secret,
                    key='GIT_SYNC_PASSWORD')
            )
        )

        self.assertIn(username_env, pod.spec.init_containers[0].env,
                      'The username env for git credentials did not get into the init container')

        self.assertIn(password_env, pod.spec.init_containers[0].env,
                      'The password env for git credentials did not get into the init container')

    def test_make_pod_git_sync_rev(self):
        # Tests the pod created with git_sync_credentials_secret will get into the init container
        self.kube_config.git_sync_rev = 'sampletag'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.worker_fs_group = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'

        worker_config = WorkerConfiguration(self.kube_config)

        pod = worker_config.as_pod()

        rev_env = k8s.V1EnvVar(
            name='GIT_SYNC_REV',
            value=self.kube_config.git_sync_rev,
        )

        self.assertIn(rev_env, pod.spec.init_containers[0].env,
                      'The git_sync_rev env did not get into the init container')

    def test_make_pod_git_sync_ssh_with_known_hosts(self):
        # Tests the pod created with git-sync SSH authentication option is correct with known hosts
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.git_ssh_secret_name = 'airflow-secrets'
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None

        worker_config = WorkerConfiguration(self.kube_config)

        init_containers = worker_config._get_init_containers()
        git_ssh_known_hosts_file = next((x.value for x in init_containers[0].env
                                         if x.name == 'GIT_SSH_KNOWN_HOSTS_FILE'), None)

        volume_mount_ssh_known_hosts_file = next(
            (x.mount_path for x in init_containers[0].volume_mounts
             if x.name == worker_config.git_sync_ssh_known_hosts_volume_name),
            None)
        self.assertTrue(git_ssh_known_hosts_file)
        self.assertTrue(volume_mount_ssh_known_hosts_file)
        self.assertEqual(git_ssh_known_hosts_file,
                         volume_mount_ssh_known_hosts_file,
                         'The location where the git known hosts file is mounted'
                         ' needs to be the same as the GIT_SSH_KNOWN_HOSTS_FILE path')

    def test_make_pod_with_empty_executor_config(self):
        self.kube_config.kube_affinity = self.affinity_config
        self.kube_config.kube_tolerations = self.tolerations_config
        self.kube_config.kube_annotations = self.worker_annotations_config
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.as_pod()

        self.assertTrue(pod.spec.affinity['podAntiAffinity'] is not None)
        self.assertEqual('app',
                         pod.spec.affinity['podAntiAffinity']
                         ['requiredDuringSchedulingIgnoredDuringExecution'][0]
                         ['labelSelector']
                         ['matchExpressions'][0]
                         ['key'])

        self.assertEqual(2, len(pod.spec.tolerations))
        self.assertEqual('prod', pod.spec.tolerations[1]['key'])
        self.assertEqual('role-arn', pod.metadata.annotations['iam.amazonaws.com/role'])
        self.assertEqual('value', pod.metadata.annotations['other/annotation'])

    def test_make_pod_with_executor_config(self):
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        config_pod = PodGenerator(
            image='',
            affinity=self.affinity_config,
            tolerations=self.tolerations_config,
        ).gen_pod()

        pod = worker_config.as_pod()

        result = PodGenerator.reconcile_pods(pod, config_pod)

        self.assertTrue(result.spec.affinity['podAntiAffinity'] is not None)
        self.assertEqual('app',
                         result.spec.affinity['podAntiAffinity']
                         ['requiredDuringSchedulingIgnoredDuringExecution'][0]
                         ['labelSelector']
                         ['matchExpressions'][0]
                         ['key'])

        self.assertEqual(2, len(result.spec.tolerations))
        self.assertEqual('prod', result.spec.tolerations[1]['key'])

    def test_worker_pvc_dags(self):
        # Tests persistence volume config created when `dags_volume_claim` is set
        self.kube_config.dags_volume_claim = 'airflow-dags'
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        init_containers = worker_config._get_init_containers()

        dag_volume = [volume for volume in volumes if volume.name == 'airflow-dags']
        dag_volume_mount = [mount for mount in volume_mounts if mount.name == 'airflow-dags']

        self.assertEqual('airflow-dags', dag_volume[0].persistent_volume_claim.claim_name)
        self.assertEqual(1, len(dag_volume_mount))
        self.assertTrue(dag_volume_mount[0].read_only)
        self.assertEqual(0, len(init_containers))

    def test_worker_git_dags(self):
        # Tests persistence volume config created when `git_repo` is set
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_folder = '/usr/local/airflow/dags'
        self.kube_config.worker_dags_folder = '/usr/local/airflow/dags'

        self.kube_config.git_sync_container_repository = 'gcr.io/google-containers/git-sync-amd64'
        self.kube_config.git_sync_container_tag = 'v2.0.5'
        self.kube_config.git_sync_container = 'gcr.io/google-containers/git-sync-amd64:v2.0.5'
        self.kube_config.git_sync_init_container_name = 'git-sync-clone'
        self.kube_config.git_subpath = 'dags_folder'
        self.kube_config.git_sync_root = '/git'
        self.kube_config.git_sync_run_as_user = 65533
        self.kube_config.git_dags_folder_mount_point = '/usr/local/airflow/dags/repo/dags_folder'

        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        dag_volume = [volume for volume in volumes if volume.name == 'airflow-dags']
        dag_volume_mount = [mount for mount in volume_mounts if mount.name == 'airflow-dags']

        self.assertIsNotNone(dag_volume[0].empty_dir)
        self.assertEqual(self.kube_config.git_dags_folder_mount_point, dag_volume_mount[0].mount_path)
        self.assertTrue(dag_volume_mount[0].read_only)

        init_container = worker_config._get_init_containers()[0]
        init_container_volume_mount = [mount for mount in init_container.volume_mounts
                                       if mount.name == 'airflow-dags']

        self.assertEqual('git-sync-clone', init_container.name)
        self.assertEqual('gcr.io/google-containers/git-sync-amd64:v2.0.5', init_container.image)
        self.assertEqual(1, len(init_container_volume_mount))
        self.assertFalse(init_container_volume_mount[0].read_only)
        self.assertEqual(65533, init_container.security_context.run_as_user)

    def test_worker_container_dags(self):
        # Tests that the 'airflow-dags' persistence volume is NOT created when `dags_in_image` is set
        self.kube_config.dags_in_image = True
        self.kube_config.dags_folder = 'dags'
        worker_config = WorkerConfiguration(self.kube_config)
        volumes = worker_config._get_volumes()
        volume_mounts = worker_config._get_volume_mounts()

        dag_volume = [volume for volume in volumes if volume.name == 'airflow-dags']
        dag_volume_mount = [mount for mount in volume_mounts if mount.name == 'airflow-dags']

        init_containers = worker_config._get_init_containers()

        self.assertEqual(0, len(dag_volume))
        self.assertEqual(0, len(dag_volume_mount))
        self.assertEqual(0, len(init_containers))

    def test_set_airflow_config_configmap(self):
        """
        Test that airflow.cfg can be set via configmap by
        checking volume & volume-mounts are set correctly.
        """
        self.kube_config.airflow_home = '/usr/local/airflow'
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.airflow_local_settings_configmap = None
        self.kube_config.dags_folder = '/workers/path/to/dags'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.as_pod()

        pod_spec_dict = pod.spec.to_dict()

        airflow_config_volume = [
            volume for volume in pod_spec_dict['volumes'] if volume["name"] == 'airflow-config'
        ]
        # Test that volume_name is found
        self.assertEqual(1, len(airflow_config_volume))

        # Test that config map exists
        self.assertEqual(
            {'default_mode': None, 'items': None, 'name': 'airflow-configmap', 'optional': None},
            airflow_config_volume[0]['config_map']
        )

        # Test that only 1 Volume Mounts exists with 'airflow-config' name
        # One for airflow.cfg
        volume_mounts = [
            volume_mount for volume_mount in pod_spec_dict['containers'][0]['volume_mounts']
            if volume_mount['name'] == 'airflow-config'
        ]

        self.assertEqual(
            [
                {
                    'mount_path': '/usr/local/airflow/airflow.cfg',
                    'mount_propagation': None,
                    'name': 'airflow-config',
                    'read_only': True,
                    'sub_path': 'airflow.cfg',
                    'sub_path_expr': None
                }
            ],
            volume_mounts
        )

    def test_set_airflow_local_settings_configmap(self):
        """
        Test that airflow_local_settings.py can be set via configmap by
        checking volume & volume-mounts are set correctly.
        """
        self.kube_config.airflow_home = '/usr/local/airflow'
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.airflow_local_settings_configmap = 'airflow-configmap'
        self.kube_config.dags_folder = '/workers/path/to/dags'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.as_pod()

        pod_spec_dict = pod.spec.to_dict()

        airflow_config_volume = [
            volume for volume in pod_spec_dict['volumes'] if volume["name"] == 'airflow-config'
        ]
        # Test that volume_name is found
        self.assertEqual(1, len(airflow_config_volume))

        # Test that config map exists
        self.assertEqual(
            {'default_mode': None, 'items': None, 'name': 'airflow-configmap', 'optional': None},
            airflow_config_volume[0]['config_map']
        )

        # Test that 2 Volume Mounts exists and has 2 different mount-paths
        # One for airflow.cfg
        # Second for airflow_local_settings.py
        volume_mounts = [
            volume_mount for volume_mount in pod_spec_dict['containers'][0]['volume_mounts']
            if volume_mount['name'] == 'airflow-config'
        ]
        self.assertEqual(2, len(volume_mounts))

        self.assertEqual(
            [
                {
                    'mount_path': '/usr/local/airflow/airflow.cfg',
                    'mount_propagation': None,
                    'name': 'airflow-config',
                    'read_only': True,
                    'sub_path': 'airflow.cfg',
                    'sub_path_expr': None
                },
                {
                    'mount_path': '/usr/local/airflow/config/airflow_local_settings.py',
                    'mount_propagation': None,
                    'name': 'airflow-config',
                    'read_only': True,
                    'sub_path': 'airflow_local_settings.py',
                    'sub_path_expr': None
                }
            ],
            volume_mounts
        )

    def test_set_airflow_configmap_different_for_local_setting(self):
        """
        Test that airflow_local_settings.py can be set via configmap by
        checking volume & volume-mounts are set correctly when using a different
        configmap than airflow_configmap (airflow.cfg)
        """
        self.kube_config.airflow_home = '/usr/local/airflow'
        self.kube_config.airflow_configmap = 'airflow-configmap'
        self.kube_config.airflow_local_settings_configmap = 'airflow-ls-configmap'
        self.kube_config.dags_folder = '/workers/path/to/dags'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.as_pod()

        pod_spec_dict = pod.spec.to_dict()

        airflow_local_settings_volume = [
            volume for volume in pod_spec_dict['volumes'] if volume["name"] == 'airflow-local-settings'
        ]
        # Test that volume_name is found
        self.assertEqual(1, len(airflow_local_settings_volume))

        # Test that config map exists
        self.assertEqual(
            {'default_mode': None, 'items': None, 'name': 'airflow-ls-configmap', 'optional': None},
            airflow_local_settings_volume[0]['config_map']
        )

        # Test that 2 Volume Mounts exists and has 2 different mount-paths
        # One for airflow.cfg
        # Second for airflow_local_settings.py
        airflow_cfg_volume_mount = [
            volume_mount for volume_mount in pod_spec_dict['containers'][0]['volume_mounts']
            if volume_mount['name'] == 'airflow-config'
        ]

        local_setting_volume_mount = [
            volume_mount for volume_mount in pod_spec_dict['containers'][0]['volume_mounts']
            if volume_mount['name'] == 'airflow-local-settings'
        ]
        self.assertEqual(1, len(airflow_cfg_volume_mount))
        self.assertEqual(1, len(local_setting_volume_mount))

        self.assertEqual(
            [
                {
                    'mount_path': '/usr/local/airflow/config/airflow_local_settings.py',
                    'mount_propagation': None,
                    'name': 'airflow-local-settings',
                    'read_only': True,
                    'sub_path': 'airflow_local_settings.py',
                    'sub_path_expr': None
                }
            ],
            local_setting_volume_mount
        )

        self.assertEqual(
            [
                {
                    'mount_path': '/usr/local/airflow/airflow.cfg',
                    'mount_propagation': None,
                    'name': 'airflow-config',
                    'read_only': True,
                    'sub_path': 'airflow.cfg',
                    'sub_path_expr': None
                }
            ],
            airflow_cfg_volume_mount
        )

    def test_kubernetes_environment_variables(self):
        # Tests the kubernetes environment variables get copied into the worker pods
        input_environment = {
            'ENVIRONMENT': 'prod',
            'LOG_LEVEL': 'warning'
        }
        self.kube_config.kube_env_vars = input_environment
        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()
        for key in input_environment:
            self.assertIn(key, env)
            self.assertIn(input_environment[key], env.values())

        core_executor = 'AIRFLOW__CORE__EXECUTOR'
        input_environment = {
            core_executor: 'NotLocalExecutor'
        }
        self.kube_config.kube_env_vars = input_environment
        worker_config = WorkerConfiguration(self.kube_config)
        env = worker_config._get_environment()
        self.assertEqual(env[core_executor], 'LocalExecutor')

    def test_get_secrets(self):
        # Test when secretRef is None and kube_secrets is not empty
        self.kube_config.kube_secrets = {
            'AWS_SECRET_KEY': 'airflow-secret=aws_secret_key',
            'POSTGRES_PASSWORD': '******'
        }
        self.kube_config.env_from_secret_ref = None
        worker_config = WorkerConfiguration(self.kube_config)
        secrets = worker_config._get_secrets()
        secrets.sort(key=lambda secret: secret.deploy_target)
        expected = [
            Secret('env', 'AWS_SECRET_KEY', 'airflow-secret', 'aws_secret_key'),
            Secret('env', 'POSTGRES_PASSWORD', 'airflow-secret', 'postgres_credentials')
        ]
        self.assertListEqual(expected, secrets)

        # Test when secret is not empty and kube_secrets is empty dict
        self.kube_config.kube_secrets = {}
        self.kube_config.env_from_secret_ref = 'secret_a,secret_b'
        worker_config = WorkerConfiguration(self.kube_config)
        secrets = worker_config._get_secrets()
        expected = [Secret('env', None, 'secret_a'), Secret('env', None, 'secret_b')]
        self.assertListEqual(expected, secrets)

    def test_get_env_from(self):
        # Test when configmap is empty
        self.kube_config.env_from_configmap_ref = ''
        worker_config = WorkerConfiguration(self.kube_config)
        configmaps = worker_config._get_env_from()
        self.assertListEqual([], configmaps)

        # test when configmap is not empty
        self.kube_config.env_from_configmap_ref = 'configmap_a,configmap_b'
        self.kube_config.env_from_secret_ref = 'secretref_a,secretref_b'
        worker_config = WorkerConfiguration(self.kube_config)
        configmaps = worker_config._get_env_from()
        self.assertListEqual([
            k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name='configmap_a')),
            k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name='configmap_b')),
            k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secretref_a')),
            k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secretref_b'))
        ], configmaps)

    def test_pod_template_file(self):
        fixture = 'tests/kubernetes/pod.yaml'
        self.kube_config.pod_template_file = fixture
        worker_config = WorkerConfiguration(self.kube_config)
        result = worker_config.as_pod()
        expected = PodGenerator.deserialize_model_file(fixture)
        expected.metadata.name = ANY
        self.assertEqual(expected, result)

    def test_get_labels(self):
        worker_config = WorkerConfiguration(self.kube_config)
        labels = worker_config._get_labels({'my_kube_executor_label': 'kubernetes'}, {
            'dag_id': 'override_dag_id',
        })
        self.assertEqual({
            'my_label': 'label_id',
            'dag_id': 'override_dag_id',
            'my_kube_executor_label': 'kubernetes',
        }, labels)

    def test_make_pod_with_image_pull_secrets(self):
        # Tests the pod created with image_pull_secrets actually gets that in it's config
        self.kube_config.dags_volume_claim = None
        self.kube_config.dags_volume_host = None
        self.kube_config.dags_in_image = None
        self.kube_config.git_dags_folder_mount_point = 'dags'
        self.kube_config.git_sync_dest = 'repo'
        self.kube_config.git_subpath = 'path'
        self.kube_config.image_pull_secrets = 'image_pull_secret1,image_pull_secret2'

        worker_config = WorkerConfiguration(self.kube_config)
        pod = worker_config.as_pod()

        self.assertEqual(2, len(pod.spec.image_pull_secrets))

    def test_get_resources(self):
        self.kube_config.worker_resources = {'limit_cpu': 0.25, 'limit_memory': '64Mi', 'request_cpu': '250m',
                                             'request_memory': '64Mi'}

        worker_config = WorkerConfiguration(self.kube_config)
        resources = worker_config._get_resources()
        self.assertEqual(resources.limits["cpu"], 0.25)
        self.assertEqual(resources.limits["memory"], "64Mi")
        self.assertEqual(resources.requests["cpu"], "250m")
        self.assertEqual(resources.requests["memory"], "64Mi")
class TestKubernetesPodOperatorSystem(unittest.TestCase):
    def get_current_task_name(self):
        # reverse test name to make pod name unique (it has limited length)
        return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1]

    def setUp(self):
        self.maxDiff = None
        self.api_client = ApiClient()
        self.expected_pod = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'namespace': 'default',
                'name': mock.ANY,
                'annotations': {},
                'labels': {
                    'foo': 'bar',
                    'kubernetes_pod_operator': 'True',
                    'airflow_version': airflow_version.replace('+', '-'),
                    'execution_date': '2016-01-01T0100000100-a2f50a31f',
                    'dag_id': 'dag',
                    'task_id': 'task',
                    'try_number': '1',
                },
            },
            'spec': {
                'affinity': {},
                'containers': [
                    {
                        'image': 'ubuntu:16.04',
                        'args': ["echo 10"],
                        'command': ["bash", "-cx"],
                        'env': [],
                        'envFrom': [],
                        'resources': {},
                        'name': 'base',
                        'ports': [],
                        'volumeMounts': [],
                    }
                ],
                'hostNetwork': False,
                'imagePullSecrets': [],
                'initContainers': [],
                'nodeSelector': {},
                'restartPolicy': 'Never',
                'securityContext': {},
                'tolerations': [],
                'volumes': [],
            },
        }

    def tearDown(self):
        client = kube_client.get_kube_client(in_cluster=False)
        client.delete_collection_namespaced_pod(namespace="default")

    @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
    @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_image_pull_secrets_correctly_set(self, mock_client, await_pod_completion_mock, create_mock):
        fake_pull_secrets = "fakeSecret"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            image_pull_secrets=fake_pull_secrets,
            cluster_context='default',
        )
        mock_pod = MagicMock()
        mock_pod.status.phase = 'Succeeded'
        await_pod_completion_mock.return_value = mock_pod
        context = create_context(k)
        k.execute(context=context)
        assert create_mock.call_args[1]['pod'].spec.image_pull_secrets == [
            k8s.V1LocalObjectReference(name=fake_pull_secrets)
        ]

    def test_working_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        assert self.expected_pod['spec'] == actual_pod['spec']
        assert self.expected_pod['metadata']['labels'] == actual_pod['metadata']['labels']

    def test_pod_node_selectors(self):
        node_selectors = {'beta.kubernetes.io/os': 'linux'}
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            node_selectors=node_selectors,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['nodeSelector'] = node_selectors
        assert self.expected_pod == actual_pod

    def test_pod_resources(self):
        resources = {
            'limit_cpu': 0.25,
            'limit_memory': '64Mi',
            'limit_ephemeral_storage': '2Gi',
            'request_cpu': '250m',
            'request_memory': '64Mi',
            'request_ephemeral_storage': '1Gi',
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            resources=resources,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['resources'] = {
            'requests': {'memory': '64Mi', 'cpu': '250m', 'ephemeral-storage': '1Gi'},
            'limits': {'memory': '64Mi', 'cpu': 0.25, 'ephemeral-storage': '2Gi'},
        }
        assert self.expected_pod == actual_pod

    def test_pod_affinity(self):
        affinity = {
            'nodeAffinity': {
                'requiredDuringSchedulingIgnoredDuringExecution': {
                    'nodeSelectorTerms': [
                        {
                            'matchExpressions': [
                                {'key': 'beta.kubernetes.io/os', 'operator': 'In', 'values': ['linux']}
                            ]
                        }
                    ]
                }
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            affinity=affinity,
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['affinity'] = affinity
        assert self.expected_pod == actual_pod

    def test_port(self):
        port = Port('http', 80)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            ports=[port],
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['ports'] = [{'name': 'http', 'containerPort': 80}]
        assert self.expected_pod == actual_pod

    def test_volume_mount(self):
        with patch.object(PodManager, 'log') as mock_logger:
            volume_mount = VolumeMount(
                'test-volume', mount_path='/tmp/test_volume', sub_path=None, read_only=False
            )

            volume_config = {'persistentVolumeClaim': {'claimName': 'test-volume'}}
            volume = Volume(name='test-volume', configs=volume_config)
            args = [
                "echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
                "&& cat /tmp/test_volume/test.txt"
            ]
            k = KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=args,
                labels={"foo": "bar"},
                volume_mounts=[volume_mount],
                volumes=[volume],
                is_delete_operator_pod=False,
                name="test",
                task_id="task",
                in_cluster=False,
                do_xcom_push=False,
            )
            context = create_context(k)
            k.execute(context=context)
            mock_logger.info.assert_any_call('retrieved from mount')
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = args
            self.expected_pod['spec']['containers'][0]['volumeMounts'] = [
                {'name': 'test-volume', 'mountPath': '/tmp/test_volume', 'readOnly': False}
            ]
            self.expected_pod['spec']['volumes'] = [
                {'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
            ]
            assert self.expected_pod == actual_pod

    def test_run_as_user_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 0,
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        assert self.expected_pod == actual_pod

    def test_run_as_user_non_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        assert self.expected_pod == actual_pod

    def test_fs_group(self):
        security_context = {
            'securityContext': {
                'fsGroup': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        assert self.expected_pod == actual_pod

    def test_faulty_service_account(self):
        """pod creation should fail when service account does not exist"""
        service_account = "foobar"
        namespace = "default"
        k = KubernetesPodOperator(
            namespace=namespace,
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
            service_account_name=service_account,
        )
        context = create_context(k)
        pod = k.build_pod_request_obj(context)
        with pytest.raises(
            ApiException, match=f"error looking up service account {namespace}/{service_account}"
        ):
            k.get_or_create_pod(pod, context)

    def test_pod_failure(self):
        """
        Tests that the task fails when a pod reports a failure
        """
        bad_internal_command = ["foobar 10 "]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=bad_internal_command,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        with pytest.raises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = bad_internal_command
            assert self.expected_pod == actual_pod

    def test_xcom_push(self):
        return_value = '{"foo": "bar"\n, "buzz": 2}'
        args = [f'echo \'{return_value}\' > /airflow/xcom/return.json']
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=args,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=True,
        )
        context = create_context(k)
        assert k.execute(context) == json.loads(return_value)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
        volume_mount = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT)
        container = self.api_client.sanitize_for_serialization(PodDefaults.SIDECAR_CONTAINER)
        self.expected_pod['spec']['containers'][0]['args'] = args
        self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(0, volume_mount)
        self.expected_pod['spec']['volumes'].insert(0, volume)
        self.expected_pod['spec']['containers'].append(container)
        assert self.expected_pod == actual_pod

    @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
    @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
        # GIVEN
        configmap = 'test-configmap'
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            configmaps=[configmap],
        )
        # THEN
        mock_pod = MagicMock()
        mock_pod.status.phase = 'Succeeded'
        mock_monitor.return_value = mock_pod
        context = create_context(k)
        k.execute(context)
        assert mock_start.call_args[1]['pod'].spec.containers[0].env_from == [
            k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap))
        ]

    @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
    @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_secrets(self, mock_client, await_pod_completion_mock, create_mock):
        # GIVEN
        secret_ref = 'secret_name'
        secrets = [Secret('env', None, secret_ref)]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            secrets=secrets,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN

        mock_pod = MagicMock()
        mock_pod.status.phase = 'Succeeded'
        await_pod_completion_mock.return_value = mock_pod
        context = create_context(k)
        k.execute(context)
        assert create_mock.call_args[1]['pod'].spec.containers[0].env_from == [
            k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref))
        ]

    def test_env_vars(self):
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            env_vars={
                "ENV1": "val1",
                "ENV2": "val2",
            },
            pod_runtime_info_envs=[PodRuntimeInfoEnv("ENV3", "status.podIP")],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )

        context = create_context(k)
        k.execute(context)

        # THEN
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['env'] = [
            {'name': 'ENV1', 'value': 'val1'},
            {'name': 'ENV2', 'value': 'val2'},
            {'name': 'ENV3', 'valueFrom': {'fieldRef': {'fieldPath': 'status.podIP'}}},
        ]
        assert self.expected_pod == actual_pod

    def test_pod_template_file_with_overrides_system(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            labels={"foo": "bar", "fizz": "buzz"},
            env_vars={"env_name": "value"},
            in_cluster=False,
            pod_template_file=fixture,
            do_xcom_push=True,
        )

        context = create_context(k)
        result = k.execute(context)
        assert result is not None
        assert k.pod.metadata.labels == {
            'fizz': 'buzz',
            'foo': 'bar',
            'airflow_version': mock.ANY,
            'dag_id': 'dag',
            'execution_date': mock.ANY,
            'kubernetes_pod_operator': 'True',
            'task_id': mock.ANY,
            'try_number': '1',
        }
        assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
        assert result == {"hello": "world"}

    def test_init_container(self):
        # GIVEN
        volume_mounts = [
            k8s.V1VolumeMount(mount_path='/etc/foo', name='test-volume', sub_path=None, read_only=True)
        ]

        init_environments = [
            k8s.V1EnvVar(name='key1', value='value1'),
            k8s.V1EnvVar(name='key2', value='value2'),
        ]

        init_container = k8s.V1Container(
            name="init-container",
            image="ubuntu:16.04",
            env=init_environments,
            volume_mounts=volume_mounts,
            command=["bash", "-cx"],
            args=["echo 10"],
        )

        volume_config = {'persistentVolumeClaim': {'claimName': 'test-volume'}}
        volume = Volume(name='test-volume', configs=volume_config)

        expected_init_container = {
            'name': 'init-container',
            'image': 'ubuntu:16.04',
            'command': ['bash', '-cx'],
            'args': ['echo 10'],
            'env': [{'name': 'key1', 'value': 'value1'}, {'name': 'key2', 'value': 'value2'}],
            'volumeMounts': [{'mountPath': '/etc/foo', 'name': 'test-volume', 'readOnly': True}],
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            volumes=[volume],
            init_containers=[init_container],
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['initContainers'] = [expected_init_container]
        self.expected_pod['spec']['volumes'] = [
            {'name': 'test-volume', 'persistentVolumeClaim': {'claimName': 'test-volume'}}
        ]
        assert self.expected_pod == actual_pod
Esempio n. 8
0
class TestKubernetesPodOperatorSystem(unittest.TestCase):
    def get_current_task_name(self):
        # reverse test name to make pod name unique (it has limited length)
        return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1]

    def setUp(self):
        self.maxDiff = None  # pylint: disable=invalid-name
        self.api_client = ApiClient()
        self.expected_pod = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'namespace': 'default',
                'name': mock.ANY,
                'annotations': {},
                'labels': {
                    'foo': 'bar',
                    'kubernetes_pod_operator': 'True',
                    'airflow_version': airflow_version.replace('+', '-'),
                    'execution_date': '2016-01-01T0100000100-a2f50a31f',
                    'dag_id': 'dag',
                    'task_id': 'task',
                    'try_number': '1',
                },
            },
            'spec': {
                'affinity': {},
                'containers': [{
                    'image': 'ubuntu:16.04',
                    'imagePullPolicy': 'IfNotPresent',
                    'args': ["echo 10"],
                    'command': ["bash", "-cx"],
                    'env': [],
                    'envFrom': [],
                    'resources': {},
                    'name': 'base',
                    'ports': [],
                    'volumeMounts': [],
                }],
                'hostNetwork':
                False,
                'imagePullSecrets': [],
                'initContainers': [],
                'restartPolicy':
                'Never',
                'securityContext': {},
                'serviceAccountName':
                'default',
                'tolerations': [],
                'volumes': [],
            },
        }

    def tearDown(self):
        client = kube_client.get_kube_client(in_cluster=False)
        client.delete_collection_namespaced_pod(namespace="default")

    def create_context(self, task):
        dag = DAG(dag_id="dag")
        tzinfo = pendulum.timezone("Europe/Amsterdam")
        execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
        task_instance = TaskInstance(task=task, execution_date=execution_date)
        return {
            "dag": dag,
            "ts": execution_date.isoformat(),
            "task": task,
            "ti": task_instance,
        }

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock,
                                              start_mock):
        fake_pull_secrets = "fakeSecret"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            image_pull_secrets=fake_pull_secrets,
            cluster_context='default',
        )
        monitor_mock.return_value = (State.SUCCESS, None)
        context = self.create_context(k)
        k.execute(context=context)
        self.assertEqual(
            start_mock.call_args[0][0].spec.image_pull_secrets,
            [k8s.V1LocalObjectReference(name=fake_pull_secrets)],
        )

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.delete_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_delete_even_on_launcher_error(self, mock_client,
                                               delete_pod_mock,
                                               monitor_pod_mock,
                                               start_pod_mock):  # pylint: disable=unused-argument
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            cluster_context='default',
            is_delete_operator_pod=True,
        )
        monitor_pod_mock.side_effect = AirflowException('fake failure')
        with self.assertRaises(AirflowException):
            context = self.create_context(k)
            k.execute(context=context)
        assert delete_pod_mock.called

    def test_working_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
        self.assertEqual(self.expected_pod['metadata']['labels'],
                         actual_pod['metadata']['labels'])

    def test_pod_node_selectors(self):
        node_selectors = {'beta.kubernetes.io/os': 'linux'}
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            node_selectors=node_selectors,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['nodeSelector'] = node_selectors
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_resources(self):
        resources = {
            'limit_cpu': 0.25,
            'limit_memory': '64Mi',
            'limit_ephemeral_storage': '2Gi',
            'request_cpu': '250m',
            'request_memory': '64Mi',
            'request_ephemeral_storage': '1Gi',
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            resources=resources,
        )
        context = self.create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['resources'] = {
            'requests': {
                'memory': '64Mi',
                'cpu': '250m',
                'ephemeral-storage': '1Gi'
            },
            'limits': {
                'memory': '64Mi',
                'cpu': 0.25,
                'ephemeral-storage': '2Gi'
            },
        }
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_affinity(self):
        affinity = {
            'nodeAffinity': {
                'requiredDuringSchedulingIgnoredDuringExecution': {
                    'nodeSelectorTerms': [{
                        'matchExpressions': [{
                            'key': 'beta.kubernetes.io/os',
                            'operator': 'In',
                            'values': ['linux']
                        }]
                    }]
                }
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            affinity=affinity,
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['affinity'] = affinity
        self.assertEqual(self.expected_pod, actual_pod)

    def test_port(self):
        port = Port('http', 80)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            ports=[port],
        )
        context = self.create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['ports'] = [{
            'name':
            'http',
            'containerPort':
            80
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    def test_volume_mount(self):
        with patch.object(PodLauncher, 'log') as mock_logger:
            volume_mount = VolumeMount('test-volume',
                                       mount_path='/tmp/test_volume',
                                       sub_path=None,
                                       read_only=False)

            volume_config = {
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }
            volume = Volume(name='test-volume', configs=volume_config)
            args = [
                "echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
                "&& cat /tmp/test_volume/test.txt"
            ]
            k = KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=args,
                labels={"foo": "bar"},
                volume_mounts=[volume_mount],
                volumes=[volume],
                is_delete_operator_pod=False,
                name="test",
                task_id="task",
                in_cluster=False,
                do_xcom_push=False,
            )
            context = create_context(k)
            k.execute(context=context)
            mock_logger.info.assert_any_call('retrieved from mount')
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = args
            self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{
                'name':
                'test-volume',
                'mountPath':
                '/tmp/test_volume',
                'readOnly':
                False
            }]
            self.expected_pod['spec']['volumes'] = [{
                'name': 'test-volume',
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }]
            self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 0,
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_run_as_user_non_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_fs_group(self):
        security_context = {
            'securityContext': {
                'fsGroup': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        self.assertEqual(self.expected_pod, actual_pod)

    def test_faulty_service_account(self):
        bad_service_account_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
            service_account_name=bad_service_account_name,
        )
        with self.assertRaises(ApiException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec'][
                'serviceAccountName'] = bad_service_account_name
            self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_failure(self):
        """
        Tests that the task fails when a pod reports a failure
        """
        bad_internal_command = ["foobar 10 "]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=bad_internal_command,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        with self.assertRaises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'args'] = bad_internal_command
            self.assertEqual(self.expected_pod, actual_pod)

    def test_xcom_push(self):
        return_value = '{"foo": "bar"\n, "buzz": 2}'
        args = [f'echo \'{return_value}\' > /airflow/xcom/return.json']
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=args,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=True,
        )
        context = create_context(k)
        self.assertEqual(k.execute(context), json.loads(return_value))
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
        volume_mount = self.api_client.sanitize_for_serialization(
            PodDefaults.VOLUME_MOUNT)
        container = self.api_client.sanitize_for_serialization(
            PodDefaults.SIDECAR_CONTAINER)
        self.expected_pod['spec']['containers'][0]['args'] = args
        self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(
            0, volume_mount)  # noqa
        self.expected_pod['spec']['volumes'].insert(0, volume)
        self.expected_pod['spec']['containers'].append(container)
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
        # GIVEN
        configmap = 'test-configmap'
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            configmaps=[configmap],
        )
        # THEN
        mock_monitor.return_value = (State.SUCCESS, None)
        context = self.create_context(k)
        k.execute(context)
        self.assertEqual(
            mock_start.call_args[0][0].spec.containers[0].env_from,
            [
                k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                    name=configmap))
            ],
        )

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
        # GIVEN
        secret_ref = 'secret_name'
        secrets = [Secret('env', None, secret_ref)]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            secrets=secrets,
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN
        monitor_mock.return_value = (State.SUCCESS, None)
        context = self.create_context(k)
        k.execute(context)
        self.assertEqual(
            start_mock.call_args[0][0].spec.containers[0].env_from,
            [
                k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                    name=secret_ref))
            ],
        )

    def test_env_vars(self):
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            env_vars={
                "ENV1": "val1",
                "ENV2": "val2",
            },
            pod_runtime_info_envs=[PodRuntimeInfoEnv("ENV3", "status.podIP")],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
        )

        context = create_context(k)
        k.execute(context)

        # THEN
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['env'] = [
            {
                'name': 'ENV1',
                'value': 'val1'
            },
            {
                'name': 'ENV2',
                'value': 'val2'
            },
            {
                'name': 'ENV3',
                'valueFrom': {
                    'fieldRef': {
                        'fieldPath': 'status.podIP'
                    }
                }
            },
        ]
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_template_file_with_overrides_system(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            labels={
                "foo": "bar",
                "fizz": "buzz"
            },
            env_vars={"env_name": "value"},
            in_cluster=False,
            pod_template_file=fixture,
            do_xcom_push=True,
        )

        context = create_context(k)
        result = k.execute(context)
        self.assertIsNotNone(result)
        self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'})
        self.assertEqual(k.pod.spec.containers[0].env,
                         [k8s.V1EnvVar(name="env_name", value="value")])
        self.assertDictEqual(result, {"hello": "world"})

    def test_init_container(self):
        # GIVEN
        volume_mounts = [
            k8s.V1VolumeMount(mount_path='/etc/foo',
                              name='test-volume',
                              sub_path=None,
                              read_only=True)
        ]

        init_environments = [
            k8s.V1EnvVar(name='key1', value='value1'),
            k8s.V1EnvVar(name='key2', value='value2'),
        ]

        init_container = k8s.V1Container(
            name="init-container",
            image="ubuntu:16.04",
            env=init_environments,
            volume_mounts=volume_mounts,
            command=["bash", "-cx"],
            args=["echo 10"],
        )

        volume_config = {'persistentVolumeClaim': {'claimName': 'test-volume'}}
        volume = Volume(name='test-volume', configs=volume_config)

        expected_init_container = {
            'name':
            'init-container',
            'image':
            'ubuntu:16.04',
            'command': ['bash', '-cx'],
            'args': ['echo 10'],
            'env': [{
                'name': 'key1',
                'value': 'value1'
            }, {
                'name': 'key2',
                'value': 'value2'
            }],
            'volumeMounts': [{
                'mountPath': '/etc/foo',
                'name': 'test-volume',
                'readOnly': True
            }],
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            volumes=[volume],
            init_containers=[init_container],
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['initContainers'] = [expected_init_container]
        self.expected_pod['spec']['volumes'] = [{
            'name': 'test-volume',
            'persistentVolumeClaim': {
                'claimName': 'test-volume'
            }
        }]
        self.assertEqual(self.expected_pod, actual_pod)

    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_priority_class_name(self, mock_client, monitor_mock,
                                     start_mock):  # pylint: disable=unused-argument
        """Test ability to assign priorityClassName to pod"""
        priority_class_name = "medium-test"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test",
            task_id="task",
            in_cluster=False,
            do_xcom_push=False,
            priority_class_name=priority_class_name,
        )

        monitor_mock.return_value = (State.SUCCESS, None)
        context = self.create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['priorityClassName'] = priority_class_name
        self.assertEqual(self.expected_pod, actual_pod)

    def test_pod_name(self):
        pod_name_too_long = "a" * 221
        with self.assertRaises(AirflowException):
            KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=["echo 10"],
                labels={"foo": "bar"},
                name=pod_name_too_long,
                task_id="task",
                in_cluster=False,
                do_xcom_push=False,
            )
Esempio n. 9
0
class TestKubernetesPodOperatorSystem(unittest.TestCase):
    def get_current_task_name(self):
        # reverse test name to make pod name unique (it has limited length)
        return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1]

    def setUp(self):
        self.maxDiff = None
        self.api_client = ApiClient()
        self.expected_pod = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'namespace': 'default',
                'name': ANY,
                'annotations': {},
                'labels': {
                    'foo': 'bar',
                    'kubernetes_pod_operator': 'True',
                    'airflow_version': airflow_version.replace('+', '-'),
                    'execution_date': '2016-01-01T0100000100-a2f50a31f',
                    'dag_id': 'dag',
                    'task_id': ANY,
                    'try_number': '1',
                },
            },
            'spec': {
                'affinity': {},
                'containers': [{
                    'image': 'ubuntu:16.04',
                    'args': ["echo 10"],
                    'command': ["bash", "-cx"],
                    'env': [],
                    'envFrom': [],
                    'resources': {},
                    'name': 'base',
                    'ports': [],
                    'volumeMounts': [],
                }],
                'hostNetwork':
                False,
                'imagePullSecrets': [],
                'initContainers': [],
                'nodeSelector': {},
                'restartPolicy':
                'Never',
                'securityContext': {},
                'tolerations': [],
                'volumes': [],
            },
        }

    def tearDown(self) -> None:
        client = kube_client.get_kube_client(in_cluster=False)
        client.delete_collection_namespaced_pod(namespace="default")
        import time

        time.sleep(1)

    def test_do_xcom_push_defaults_false(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = get_kubeconfig_path()
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            config_file=new_config_path,
        )
        assert not k.do_xcom_push

    def test_config_path_move(self):
        new_config_path = '/tmp/kube_config'
        old_config_path = get_kubeconfig_path()
        shutil.copy(old_config_path, new_config_path)

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test1",
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            is_delete_operator_pod=False,
            config_file=new_config_path,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        assert self.expected_pod == actual_pod

    def test_working_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        assert self.expected_pod['spec'] == actual_pod['spec']
        assert self.expected_pod['metadata']['labels'] == actual_pod[
            'metadata']['labels']

    def test_delete_operator_pod(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            is_delete_operator_pod=True,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        assert self.expected_pod['spec'] == actual_pod['spec']
        assert self.expected_pod['metadata']['labels'] == actual_pod[
            'metadata']['labels']

    def test_pod_hostnetwork(self):
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        assert self.expected_pod['spec'] == actual_pod['spec']
        assert self.expected_pod['metadata']['labels'] == actual_pod[
            'metadata']['labels']

    def test_pod_dnspolicy(self):
        dns_policy = "ClusterFirstWithHostNet"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            hostnetwork=True,
            dnspolicy=dns_policy,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['hostNetwork'] = True
        self.expected_pod['spec']['dnsPolicy'] = dns_policy
        assert self.expected_pod['spec'] == actual_pod['spec']
        assert self.expected_pod['metadata']['labels'] == actual_pod[
            'metadata']['labels']

    def test_pod_schedulername(self):
        scheduler_name = "default-scheduler"
        k = KubernetesPodOperator(
            namespace="default",
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            schedulername=scheduler_name,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['schedulerName'] = scheduler_name
        assert self.expected_pod == actual_pod

    def test_pod_node_selectors(self):
        node_selectors = {'beta.kubernetes.io/os': 'linux'}
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            node_selectors=node_selectors,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['nodeSelector'] = node_selectors
        assert self.expected_pod == actual_pod

    def test_pod_resources(self):
        resources = k8s.V1ResourceRequirements(
            requests={
                'memory': '64Mi',
                'cpu': '250m',
                'ephemeral-storage': '1Gi'
            },
            limits={
                'memory': '64Mi',
                'cpu': 0.25,
                'nvidia.com/gpu': None,
                'ephemeral-storage': '2Gi'
            },
        )
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            resources=resources,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['resources'] = {
            'requests': {
                'memory': '64Mi',
                'cpu': '250m',
                'ephemeral-storage': '1Gi'
            },
            'limits': {
                'memory': '64Mi',
                'cpu': 0.25,
                'nvidia.com/gpu': None,
                'ephemeral-storage': '2Gi'
            },
        }
        assert self.expected_pod == actual_pod

    def test_pod_affinity(self):
        affinity = {
            'nodeAffinity': {
                'requiredDuringSchedulingIgnoredDuringExecution': {
                    'nodeSelectorTerms': [{
                        'matchExpressions': [{
                            'key': 'beta.kubernetes.io/os',
                            'operator': 'In',
                            'values': ['linux']
                        }]
                    }]
                }
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            affinity=affinity,
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['affinity'] = affinity
        assert self.expected_pod == actual_pod

    def test_port(self):
        port = k8s.V1ContainerPort(
            name='http',
            container_port=80,
        )

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            ports=[port],
        )
        context = create_context(k)
        k.execute(context=context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['containers'][0]['ports'] = [{
            'name':
            'http',
            'containerPort':
            80
        }]
        assert self.expected_pod == actual_pod

    def test_volume_mount(self):
        with mock.patch.object(PodManager, 'log') as mock_logger:
            volume_mount = k8s.V1VolumeMount(name='test-volume',
                                             mount_path='/tmp/test_volume',
                                             sub_path=None,
                                             read_only=False)

            volume = k8s.V1Volume(
                name='test-volume',
                persistent_volume_claim=k8s.
                V1PersistentVolumeClaimVolumeSource(claim_name='test-volume'),
            )

            args = [
                "echo \"retrieved from mount\" > /tmp/test_volume/test.txt "
                "&& cat /tmp/test_volume/test.txt"
            ]
            k = KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=args,
                labels={"foo": "bar"},
                volume_mounts=[volume_mount],
                volumes=[volume],
                name="test-" + str(random.randint(0, 1000000)),
                task_id="task" + self.get_current_task_name(),
                in_cluster=False,
                do_xcom_push=False,
            )
            context = create_context(k)
            k.execute(context=context)
            mock_logger.info.assert_any_call('retrieved from mount')
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0]['args'] = args
            self.expected_pod['spec']['containers'][0]['volumeMounts'] = [{
                'name':
                'test-volume',
                'mountPath':
                '/tmp/test_volume',
                'readOnly':
                False
            }]
            self.expected_pod['spec']['volumes'] = [{
                'name': 'test-volume',
                'persistentVolumeClaim': {
                    'claimName': 'test-volume'
                }
            }]
            assert self.expected_pod == actual_pod

    def test_run_as_user_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 0,
            }
        }
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        assert self.expected_pod == actual_pod

    def test_run_as_user_non_root(self):
        security_context = {
            'securityContext': {
                'runAsUser': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        assert self.expected_pod == actual_pod

    def test_fs_group(self):
        security_context = {
            'securityContext': {
                'fsGroup': 1000,
            }
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-fs-group",
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            security_context=security_context,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['securityContext'] = security_context
        assert self.expected_pod == actual_pod

    def test_faulty_image(self):
        bad_image_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image=bad_image_name,
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
        )
        with pytest.raises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'image'] = bad_image_name
            assert self.expected_pod == actual_pod

    def test_faulty_service_account(self):
        bad_service_account_name = "foobar"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            startup_timeout_seconds=5,
            service_account_name=bad_service_account_name,
        )
        context = create_context(k)
        pod = k.build_pod_request_obj(context)
        with pytest.raises(
                ApiException,
                match="error looking up service account default/foobar"):
            k.get_or_create_pod(pod, context)

    def test_pod_failure(self):
        """
        Tests that the task fails when a pod reports a failure
        """
        bad_internal_command = ["foobar 10 "]
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=bad_internal_command,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        with pytest.raises(AirflowException):
            context = create_context(k)
            k.execute(context)
            actual_pod = self.api_client.sanitize_for_serialization(k.pod)
            self.expected_pod['spec']['containers'][0][
                'args'] = bad_internal_command
            assert self.expected_pod == actual_pod

    @mock.patch("airflow.models.taskinstance.TaskInstance.xcom_push")
    def test_xcom_push(self, xcom_push):
        return_value = '{"foo": "bar"\n, "buzz": 2}'
        args = [f'echo \'{return_value}\' > /airflow/xcom/return.json']
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=args,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=True,
        )
        context = create_context(k)
        k.execute(context)
        assert xcom_push.called_once_with(key=XCOM_RETURN_KEY,
                                          value=json.loads(return_value))
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
        volume_mount = self.api_client.sanitize_for_serialization(
            PodDefaults.VOLUME_MOUNT)
        container = self.api_client.sanitize_for_serialization(
            PodDefaults.SIDECAR_CONTAINER)
        self.expected_pod['spec']['containers'][0]['args'] = args
        self.expected_pod['spec']['containers'][0]['volumeMounts'].insert(
            0, volume_mount)
        self.expected_pod['spec']['volumes'].insert(0, volume)
        self.expected_pod['spec']['containers'].append(container)
        assert self.expected_pod == actual_pod

    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
    )
    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion"
    )
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_envs_from_secrets(self, mock_client, await_pod_completion_mock,
                               create_pod):
        # GIVEN

        secret_ref = 'secret_name'
        secrets = [Secret('env', None, secret_ref)]
        # WHEN
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            secrets=secrets,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN
        await_pod_completion_mock.return_value = None
        context = create_context(k)
        with pytest.raises(AirflowException):
            k.execute(context)
        assert create_pod.call_args[1]['pod'].spec.containers[0].env_from == [
            k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                name=secret_ref))
        ]

    def test_env_vars(self):
        # WHEN
        env_vars = [
            k8s.V1EnvVar(name="ENV1", value="val1"),
            k8s.V1EnvVar(name="ENV2", value="val2"),
            k8s.V1EnvVar(
                name="ENV3",
                value_from=k8s.V1EnvVarSource(
                    field_ref=k8s.V1ObjectFieldSelector(
                        field_path="status.podIP")),
            ),
        ]

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            env_vars=env_vars,
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
        )
        # THEN
        context = create_context(k)
        actual_pod = self.api_client.sanitize_for_serialization(
            k.build_pod_request_obj(context))
        self.expected_pod['spec']['containers'][0]['env'] = [
            {
                'name': 'ENV1',
                'value': 'val1'
            },
            {
                'name': 'ENV2',
                'value': 'val2'
            },
            {
                'name': 'ENV3',
                'valueFrom': {
                    'fieldRef': {
                        'fieldPath': 'status.podIP'
                    }
                }
            },
        ]
        assert self.expected_pod == actual_pod

    def test_pod_template_file_system(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            pod_template_file=fixture,
            do_xcom_push=True,
        )

        context = create_context(k)
        result = k.execute(context)
        assert result is not None
        assert result == {"hello": "world"}

    def test_pod_template_file_with_overrides_system(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            labels={
                "foo": "bar",
                "fizz": "buzz"
            },
            env_vars=[k8s.V1EnvVar(name="env_name", value="value")],
            in_cluster=False,
            pod_template_file=fixture,
            do_xcom_push=True,
        )

        context = create_context(k)
        result = k.execute(context)
        assert result is not None
        assert k.pod.metadata.labels == {
            'fizz': 'buzz',
            'foo': 'bar',
            'airflow_version': mock.ANY,
            'dag_id': 'dag',
            'execution_date': mock.ANY,
            'kubernetes_pod_operator': 'True',
            'task_id': mock.ANY,
            'try_number': '1',
        }
        assert k.pod.spec.containers[0].env == [
            k8s.V1EnvVar(name="env_name", value="value")
        ]
        assert result == {"hello": "world"}

    def test_pod_template_file_with_full_pod_spec(self):
        fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
        pod_spec = k8s.V1Pod(
            metadata=k8s.V1ObjectMeta(labels={
                "foo": "bar",
                "fizz": "buzz"
            }, ),
            spec=k8s.V1PodSpec(containers=[
                k8s.V1Container(
                    name="base",
                    env=[k8s.V1EnvVar(name="env_name", value="value")],
                )
            ]),
        )
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            pod_template_file=fixture,
            full_pod_spec=pod_spec,
            do_xcom_push=True,
        )

        context = create_context(k)
        result = k.execute(context)
        assert result is not None
        assert k.pod.metadata.labels == {
            'fizz': 'buzz',
            'foo': 'bar',
            'airflow_version': mock.ANY,
            'dag_id': 'dag',
            'execution_date': mock.ANY,
            'kubernetes_pod_operator': 'True',
            'task_id': mock.ANY,
            'try_number': '1',
        }
        assert k.pod.spec.containers[0].env == [
            k8s.V1EnvVar(name="env_name", value="value")
        ]
        assert result == {"hello": "world"}

    def test_full_pod_spec(self):
        pod_spec = k8s.V1Pod(
            metadata=k8s.V1ObjectMeta(labels={
                "foo": "bar",
                "fizz": "buzz"
            },
                                      namespace="default",
                                      name="test-pod"),
            spec=k8s.V1PodSpec(
                containers=[
                    k8s.V1Container(
                        name="base",
                        image="perl",
                        command=["/bin/bash"],
                        args=[
                            "-c",
                            'echo {\\"hello\\" : \\"world\\"} | cat > /airflow/xcom/return.json'
                        ],
                        env=[k8s.V1EnvVar(name="env_name", value="value")],
                    )
                ],
                restart_policy="Never",
            ),
        )
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            full_pod_spec=pod_spec,
            do_xcom_push=True,
            is_delete_operator_pod=False,
        )

        context = create_context(k)
        result = k.execute(context)
        assert result is not None
        assert k.pod.metadata.labels == {
            'fizz': 'buzz',
            'foo': 'bar',
            'airflow_version': mock.ANY,
            'dag_id': 'dag',
            'execution_date': mock.ANY,
            'kubernetes_pod_operator': 'True',
            'task_id': mock.ANY,
            'try_number': '1',
        }
        assert k.pod.spec.containers[0].env == [
            k8s.V1EnvVar(name="env_name", value="value")
        ]
        assert result == {"hello": "world"}

    def test_init_container(self):
        # GIVEN
        volume_mounts = [
            k8s.V1VolumeMount(mount_path='/etc/foo',
                              name='test-volume',
                              sub_path=None,
                              read_only=True)
        ]

        init_environments = [
            k8s.V1EnvVar(name='key1', value='value1'),
            k8s.V1EnvVar(name='key2', value='value2'),
        ]

        init_container = k8s.V1Container(
            name="init-container",
            image="ubuntu:16.04",
            env=init_environments,
            volume_mounts=volume_mounts,
            command=["bash", "-cx"],
            args=["echo 10"],
        )

        volume = k8s.V1Volume(
            name='test-volume',
            persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
                claim_name='test-volume'),
        )
        expected_init_container = {
            'name':
            'init-container',
            'image':
            'ubuntu:16.04',
            'command': ['bash', '-cx'],
            'args': ['echo 10'],
            'env': [{
                'name': 'key1',
                'value': 'value1'
            }, {
                'name': 'key2',
                'value': 'value2'
            }],
            'volumeMounts': [{
                'mountPath': '/etc/foo',
                'name': 'test-volume',
                'readOnly': True
            }],
        }

        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            volumes=[volume],
            init_containers=[init_container],
            in_cluster=False,
            do_xcom_push=False,
        )
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['initContainers'] = [expected_init_container]
        self.expected_pod['spec']['volumes'] = [{
            'name': 'test-volume',
            'persistentVolumeClaim': {
                'claimName': 'test-volume'
            }
        }]
        assert self.expected_pod == actual_pod

    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom"
    )
    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
    )
    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion"
    )
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_template_file(self, mock_client, await_pod_completion_mock,
                               create_mock, extract_xcom_mock):
        extract_xcom_mock.return_value = '{}'
        path = sys.path[0] + '/tests/kubernetes/pod.yaml'
        k = KubernetesPodOperator(
            task_id="task" + self.get_current_task_name(),
            random_name_suffix=False,
            pod_template_file=path,
            do_xcom_push=True,
        )
        pod_mock = MagicMock()
        pod_mock.status.phase = 'Succeeded'
        await_pod_completion_mock.return_value = pod_mock
        context = create_context(k)
        with self.assertLogs(k.log, level=logging.DEBUG) as cm:
            k.execute(context)
            expected_line = textwrap.dedent("""\
            DEBUG:airflow.task.operators:Starting pod:
            api_version: v1
            kind: Pod
            metadata:
              annotations: {}
              cluster_name: null
              creation_timestamp: null
              deletion_grace_period_seconds: null\
            """).strip()
            assert any(line.startswith(expected_line) for line in cm.output)

        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        expected_dict = {
            'apiVersion': 'v1',
            'kind': 'Pod',
            'metadata': {
                'annotations': {},
                'labels': {
                    'dag_id': 'dag',
                    'execution_date': mock.ANY,
                    'kubernetes_pod_operator': 'True',
                    'task_id': mock.ANY,
                    'try_number': '1',
                },
                'name': 'memory-demo',
                'namespace': 'mem-example',
            },
            'spec': {
                'affinity': {},
                'containers': [
                    {
                        'args':
                        ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'],
                        'command': ['stress'],
                        'env': [],
                        'envFrom': [],
                        'image':
                        'ghcr.io/apache/airflow-stress:1.0.4-2021.07.04',
                        'name':
                        'base',
                        'ports': [],
                        'resources': {
                            'limits': {
                                'memory': '200Mi'
                            },
                            'requests': {
                                'memory': '100Mi'
                            }
                        },
                        'volumeMounts': [{
                            'mountPath': '/airflow/xcom',
                            'name': 'xcom'
                        }],
                    },
                    {
                        'command': [
                            'sh', '-c',
                            'trap "exit 0" INT; while true; do sleep 1; done;'
                        ],
                        'image':
                        'alpine',
                        'name':
                        'airflow-xcom-sidecar',
                        'resources': {
                            'requests': {
                                'cpu': '1m'
                            }
                        },
                        'volumeMounts': [{
                            'mountPath': '/airflow/xcom',
                            'name': 'xcom'
                        }],
                    },
                ],
                'hostNetwork':
                False,
                'imagePullSecrets': [],
                'initContainers': [],
                'nodeSelector': {},
                'restartPolicy':
                'Never',
                'securityContext': {},
                'tolerations': [],
                'volumes': [{
                    'emptyDir': {},
                    'name': 'xcom'
                }],
            },
        }
        version = actual_pod['metadata']['labels']['airflow_version']
        assert version.startswith(airflow_version)
        del actual_pod['metadata']['labels']['airflow_version']
        assert expected_dict == actual_pod

    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
    )
    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion"
    )
    @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
    def test_pod_priority_class_name(self, mock_client,
                                     await_pod_completion_mock, create_mock):
        """Test ability to assign priorityClassName to pod"""

        priority_class_name = "medium-test"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["echo 10"],
            labels={"foo": "bar"},
            name="test-" + str(random.randint(0, 1000000)),
            task_id="task" + self.get_current_task_name(),
            in_cluster=False,
            do_xcom_push=False,
            priority_class_name=priority_class_name,
        )

        pod_mock = MagicMock()
        pod_mock.status.phase = 'Succeeded'
        await_pod_completion_mock.return_value = pod_mock
        context = create_context(k)
        k.execute(context)
        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
        self.expected_pod['spec']['priorityClassName'] = priority_class_name
        assert self.expected_pod == actual_pod

    def test_pod_name(self):
        pod_name_too_long = "a" * 221
        with pytest.raises(AirflowException):
            KubernetesPodOperator(
                namespace='default',
                image="ubuntu:16.04",
                cmds=["bash", "-cx"],
                arguments=["echo 10"],
                labels={"foo": "bar"},
                name=pod_name_too_long,
                task_id="task" + self.get_current_task_name(),
                in_cluster=False,
                do_xcom_push=False,
            )

    @mock.patch(
        "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion"
    )
    def test_on_kill(self, await_pod_completion_mock):

        client = kube_client.get_kube_client(in_cluster=False)
        name = "test"
        namespace = "default"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["sleep 1000"],
            labels={"foo": "bar"},
            name="test",
            task_id=name,
            in_cluster=False,
            do_xcom_push=False,
            get_logs=False,
            termination_grace_period=0,
        )
        context = create_context(k)
        with pytest.raises(AirflowException):
            k.execute(context)
        name = k.pod.metadata.name
        pod = client.read_namespaced_pod(name=name, namespace=namespace)
        assert pod.status.phase == "Running"
        k.on_kill()
        with pytest.raises(ApiException,
                           match=r'pods \\"test.[a-z0-9]+\\" not found'):
            client.read_namespaced_pod(name=name, namespace=namespace)

    def test_reattach_failing_pod_once(self):
        client = kube_client.get_kube_client(in_cluster=False)
        name = "test"
        namespace = "default"
        k = KubernetesPodOperator(
            namespace='default',
            image="ubuntu:16.04",
            cmds=["bash", "-cx"],
            arguments=["exit 1"],
            labels={"foo": "bar"},
            name="test",
            task_id=name,
            in_cluster=False,
            do_xcom_push=False,
            is_delete_operator_pod=False,
            termination_grace_period=0,
        )

        context = create_context(k)

        # launch pod
        with mock.patch(
                "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion"
        ) as await_pod_completion_mock:
            pod_mock = MagicMock()

            # we don't want failure because we don't want the pod to be patched as "already_checked"
            pod_mock.status.phase = 'Succeeded'
            await_pod_completion_mock.return_value = pod_mock
            k.execute(context)
            name = k.pod.metadata.name
            pod = client.read_namespaced_pod(name=name, namespace=namespace)
            while pod.status.phase != "Failed":
                pod = client.read_namespaced_pod(name=name,
                                                 namespace=namespace)
            assert 'already_checked' not in pod.metadata.labels

        # should not call `create_pod`, because there's a pod there it should find
        # should use the found pod and patch as "already_checked" (in failure block)
        with mock.patch(
                "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
        ) as create_mock:
            with pytest.raises(AirflowException):
                k.execute(context)
            pod = client.read_namespaced_pod(name=name, namespace=namespace)
            assert pod.metadata.labels["already_checked"] == "True"
            create_mock.assert_not_called()

        # `create_pod` should be called because though there's still a pod to be found,
        # it will be `already_checked`
        with mock.patch(
                "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
        ) as create_mock:
            with pytest.raises(AirflowException):
                k.execute(context)
            create_mock.assert_called_once()