def test_lineage_send(self, atlas_mock):
        td = mock.MagicMock()
        en = mock.MagicMock()
        atlas_mock.return_value = mock.Mock(typedefs=td, entity_post=en)

        dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE)

        f1 = File("/tmp/does_not_exist_1")
        f2 = File("/tmp/does_not_exist_2")

        inlets_d = [
            f1,
        ]
        outlets_d = [
            f2,
        ]

        with dag:
            op1 = DummyOperator(task_id='leave1',
                                inlets={"datasets": inlets_d},
                                outlets={"datasets": outlets_d})

        ctx = {"ti": TI(task=op1, execution_date=DEFAULT_DATE)}

        self.atlas.send_lineage(operator=op1,
                                inlets=inlets_d,
                                outlets=outlets_d,
                                context=ctx)

        self.assertEqual(td.create.call_count, 1)
        self.assertTrue(en.create.called)
        self.assertEqual(len(en.mock_calls), 3)
Beispiel #2
0
 def _setup_connections(get_connections, uri):
     gcp_connection = mock.MagicMock()
     gcp_connection.extra_dejson = mock.MagicMock()
     gcp_connection.extra_dejson.get.return_value = 'empty_project'
     cloudsql_connection = Connection()
     cloudsql_connection.parse_from_uri(uri)
     cloudsql_connection2 = Connection()
     cloudsql_connection2.parse_from_uri(uri)
     get_connections.side_effect = [[gcp_connection], [cloudsql_connection],
                                    [cloudsql_connection2]]
 def setUp(self):
     self.gunicorn_master_proc = mock.Mock(pid=2137)
     self.children = mock.MagicMock()
     self.child = mock.MagicMock()
     self.process = mock.MagicMock()
     self.monitor = cli.GunicornMonitor(
         gunicorn_master_pid=1,
         num_workers_expected=4,
         master_timeout=60,
         worker_refresh_interval=60,
         worker_refresh_batch_size=2,
         reload_on_plugin_change=True,
     )
    def setUp(self):
        super().setUp()

        self.extra_dejson = mock.MagicMock()
        self.extra_dejson.get.return_value = None
        self.conn = mock.MagicMock()
        self.conn.extra_dejson = self.extra_dejson
        conn = self.conn

        class SubPigCliHook(PigCliHook):
            def get_connection(self, id):
                return conn

        self.pig_hook = SubPigCliHook
Beispiel #5
0
    def setUp(self):
        super(TestOracleHook, self).setUp()

        self.cur = mock.MagicMock()
        self.conn = mock.MagicMock()
        self.conn.cursor.return_value = self.cur
        conn = self.conn

        class UnitTestOracleHook(OracleHook):
            conn_name_attr = 'test_conn_id'

            def get_conn(self):
                return conn

        self.db_hook = UnitTestOracleHook()
Beispiel #6
0
    def test_flush(self):
        logger = mock.MagicMock()
        logger.log = mock.MagicMock()

        log = StreamLogWriter(logger, 1)

        msg = "test_message"

        log.write(msg)
        self.assertEqual(log._buffer, msg)

        log.flush()
        logger.log.assert_called_once_with(1, msg)

        self.assertEqual(log._buffer, "")
Beispiel #7
0
    def test_get_conn_uri_engine_version_1(self, mock_hvac):
        mock_client = mock.MagicMock()
        mock_hvac.Client.return_value = mock_client
        mock_client.secrets.kv.v1.read_secret.return_value = {
            'request_id': '182d0673-618c-9889-4cba-4e1f4cfe4b4b',
            'lease_id': '',
            'renewable': False,
            'lease_duration': 2764800,
            'data': {
                'conn_uri': 'postgresql://*****:*****@host:5432/airflow'
            },
            'wrap_info': None,
            'warnings': None,
            'auth': None
        }

        kwargs = {
            "connections_path": "connections",
            "mount_point": "airflow",
            "auth_type": "token",
            "url": "http://127.0.0.1:8200",
            "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS",
            "kv_engine_version": 1
        }

        test_client = VaultBackend(**kwargs)
        returned_uri = test_client.get_conn_uri(conn_id="test_postgres")
        mock_client.secrets.kv.v1.read_secret.assert_called_once_with(
            mount_point='airflow', path='connections/test_postgres')
        self.assertEqual('postgresql://*****:*****@host:5432/airflow',
                         returned_uri)
Beispiel #8
0
    def test_get_variable_value(self, mock_hvac):
        mock_client = mock.MagicMock()
        mock_hvac.Client.return_value = mock_client
        mock_client.secrets.kv.v2.read_secret_version.return_value = {
            'request_id': '2d48a2ad-6bcb-e5b6-429d-da35fdf31f56',
            'lease_id': '',
            'renewable': False,
            'lease_duration': 0,
            'data': {
                'data': {
                    'value': 'world'
                },
                'metadata': {
                    'created_time': '2020-03-28T02:10:54.301784Z',
                    'deletion_time': '',
                    'destroyed': False,
                    'version': 1
                }
            },
            'wrap_info': None,
            'warnings': None,
            'auth': None
        }

        kwargs = {
            "variables_path": "variables",
            "mount_point": "airflow",
            "auth_type": "token",
            "url": "http://127.0.0.1:8200",
            "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS"
        }

        test_client = VaultBackend(**kwargs)
        returned_uri = test_client.get_variable("hello")
        self.assertEqual('world', returned_uri)
Beispiel #9
0
    def test_should_copy_single_file(self, mock_named_temporary_file,
                                     mock_gdrive, mock_gcs_hook):
        type(mock_named_temporary_file.return_value.__enter__.return_value
             ).name = mock.PropertyMock(side_effect=["TMP1"])
        task = GcsToGDriveOperator(
            task_id="copy_single_file",
            source_bucket="data",
            source_object="sales/sales-2017/january.avro",
            destination_object="copied_sales/2017/january-backup.avro",
        )

        task.execute(mock.MagicMock())

        mock_gcs_hook.assert_has_calls([
            mock.call(delegate_to=None,
                      google_cloud_storage_conn_id="google_cloud_default"),
            mock.call().download(bucket="data",
                                 filename="TMP1",
                                 object="sales/sales-2017/january.avro"),
        ])

        mock_gdrive.assert_has_calls([
            mock.call(delegate_to=None, gcp_conn_id="google_cloud_default"),
            mock.call().upload_file(
                local_location="TMP1",
                remote_location="copied_sales/2017/january-backup.avro"),
        ])
Beispiel #10
0
 def test_launch_wait(self):
     with client.test_mode as t:
         t.register_json('/workflow_job_templates/1/launch/', {'id': 1},
                         method='POST')
         with mock.patch.object(self.res, 'wait', mock.MagicMock()) as m:
             self.res.launch(workflow_job_template=1, wait=True)
             assert m.called
Beispiel #11
0
    def test_run_next_exception(self, mock_get_kube_client,
                                mock_kubernetes_job_watcher):

        # When a quota is exceeded this is the ApiException we get
        response = HTTPResponse(
            body=
            '{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", '
            '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, '
            'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", '
            '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}'
        )
        response.status = 403
        response.reason = "Forbidden"

        # A mock kube_client that throws errors when making a pod
        mock_kube_client = mock.patch('kubernetes.client.CoreV1Api',
                                      autospec=True)
        mock_kube_client.create_namespaced_pod = mock.MagicMock(
            side_effect=ApiException(http_resp=response))
        mock_get_kube_client.return_value = mock_kube_client
        mock_api_client = mock.MagicMock()
        mock_api_client.sanitize_for_serialization.return_value = {}
        mock_kube_client.api_client = mock_api_client

        kubernetes_executor = KubernetesExecutor()
        kubernetes_executor.start()

        # Execute a task while the Api Throws errors
        try_number = 1
        kubernetes_executor.execute_async(key=('dag', 'task',
                                               datetime.utcnow(), try_number),
                                          queue=None,
                                          command='command',
                                          executor_config={})
        kubernetes_executor.sync()
        kubernetes_executor.sync()

        assert mock_kube_client.create_namespaced_pod.called
        self.assertFalse(kubernetes_executor.task_queue.empty())

        # Disable the ApiException
        mock_kube_client.create_namespaced_pod.side_effect = None

        # Execute the task without errors should empty the queue
        kubernetes_executor.sync()
        assert mock_kube_client.create_namespaced_pod.called
        self.assertTrue(kubernetes_executor.task_queue.empty())
Beispiel #12
0
 def setUp(self) -> None:
     with mock.patch(
             "airflow.gcp.hooks.bigquery_dts.GoogleCloudBaseHook.__init__",
             new=mock_base_gcp_hook_no_default_project_id,
     ):
         self.hook = BiqQueryDataTransferServiceHook()
         self.hook._get_credentials = mock.MagicMock(  # type: ignore
             return_value=CREDENTIALS)
Beispiel #13
0
 def test_wait_for_job(self, mock_get_job):
     mock_get_job.side_effect = [
         mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)),
         mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
     ]
     with self.assertRaises(AirflowException):
         self.hook.wait_for_job(
             job_id=JOB_ID,
             location=GCP_LOCATION,
             project_id=GCP_PROJECT,
             wait_time=0,
         )
     calls = [
         mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT),
         mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT),
     ]
     mock_get_job.has_calls(calls)
Beispiel #14
0
def create_post_side_effect(exception, status_code=500):
    if exception != requests_exceptions.HTTPError:
        return exception()
    else:
        response = mock.MagicMock()
        response.status_code = status_code
        response.raise_for_status.side_effect = exception(response=response)
        return response
Beispiel #15
0
 def setUp(self) -> None:
     with mock.patch(
             "airflow.gcp.hooks.automl.GoogleCloudBaseHook.__init__",
             new=mock_base_gcp_hook_no_default_project_id,
     ):
         self.hook = CloudAutoMLHook()
         self.hook._get_credentials = mock.MagicMock(  # type: ignore
             return_value=CREDENTIALS)
Beispiel #16
0
 def test_should_raise_exception_on_multiple_wildcard(
         self, mock_named_temporary_file, mock_gdrive, mock_gcs_hook):
     task = GcsToGDriveOperator(task_id="move_files",
                                source_bucket="data",
                                source_object="sales/*/*.avro",
                                move_object=True)
     with six.assertRaisesRegex(self, AirflowException,
                                "Only one wildcard"):
         task.execute(mock.MagicMock())
Beispiel #17
0
    def test_get_variable_non_existent_key(self, mock_client_callable, mock_get_creds):
        mock_get_creds.return_value = CREDENTIALS, PROJECT_ID
        mock_client = mock.MagicMock()
        mock_client_callable.return_value = mock_client
        # The requested secret id or secret version does not exist
        mock_client.access_secret_version.side_effect = NotFound('test-msg')

        secrets_manager_backend = CloudSecretsManagerBackend(variables_prefix=VARIABLES_PREFIX)
        self.assertIsNone(secrets_manager_backend.get_variable(VAR_KEY))
Beispiel #18
0
 def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, mock_auth_default):
     mock_credentials = mock.MagicMock()
     mock_auth_default.return_value = (mock_credentials, "PROJECT_ID")
     self.instance.extras = {}
     self.instance.delegate_to = "USER"
     result = self.instance._get_credentials_and_project_id()
     mock_auth_default.assert_called_once_with(scopes=self.instance.scopes)
     mock_credentials.with_subject.assert_called_once_with("USER")
     self.assertEqual((mock_credentials.with_subject.return_value, "PROJECT_ID"), result)
    def test_run_cli_success(self, popen_mock):
        proc_mock = mock.MagicMock()
        proc_mock.returncode = 0
        proc_mock.stdout.readline.return_value = b''
        popen_mock.return_value = proc_mock

        hook = self.pig_hook()
        stdout = hook.run_cli("")

        self.assertEqual(stdout, "")
 def test_correct_low_level_api_calls(self):
     api_mock = mock.MagicMock()
     uploader = FakeThreadedConcurrentUploader(api_mock, 'vault_name')
     uploader.upload('foofile')
     # The threads call the upload_part, so we're just verifying the
     # initiate/complete multipart API calls.
     api_mock.initiate_multipart_upload.assert_called_with(
         'vault_name', 4 * 1024 * 1024, None)
     api_mock.complete_multipart_upload.assert_called_with(
         'vault_name', mock.ANY, mock.ANY, 8 * 1024 * 1024)
Beispiel #21
0
 def test_work_queue_is_correctly_populated(self):
     uploader = FakeThreadedConcurrentUploader(mock.MagicMock(),
                                               'vault_name')
     uploader.upload('foofile')
     q = uploader.worker_queue
     items = [q.get() for i in range(q.qsize())]
     self.assertEqual(items[0], (0, 4 * 1024 * 1024))
     self.assertEqual(items[1], (1, 4 * 1024 * 1024))
     # 2 for the parts, 10 for the end sentinels (10 threads).
     self.assertEqual(len(items), 12)
    def test_run_cli_fail(self, popen_mock):
        proc_mock = mock.MagicMock()
        proc_mock.returncode = 1
        proc_mock.stdout.readline.return_value = b''
        popen_mock.return_value = proc_mock

        hook = self.pig_hook()

        from airflow.exceptions import AirflowException
        self.assertRaises(AirflowException, hook.run_cli, "")
Beispiel #23
0
 def test_downloader_work_queue_is_correctly_populated(self):
     job = mock.MagicMock()
     job.archive_size = 8 * 1024 * 1024
     downloader = FakeThreadedConcurrentDownloader(job)
     downloader.download('foofile')
     q = downloader.worker_queue
     items = [q.get() for i in range(q.qsize())]
     self.assertEqual(items[0], (0, 4 * 1024 * 1024))
     self.assertEqual(items[1], (1, 4 * 1024 * 1024))
     # 2 for the parts, 10 for the end sentinels (10 threads).
     self.assertEqual(len(items), 12)
Beispiel #24
0
    def test_support_project_id_parameter(self):
        mock_instance = mock.MagicMock()

        class FixtureFallback:
            @_fallback_to_project_id_from_variables
            def test_fn(self, *args, **kwargs):
                mock_instance(*args, **kwargs)

        FixtureFallback().test_fn(project_id="TEST")

        mock_instance.assert_called_once_with(project_id="TEST")
Beispiel #25
0
    def test_set_context(self):
        handler1 = mock.MagicMock()
        handler2 = mock.MagicMock()
        parent = mock.MagicMock()
        parent.propagate = False
        parent.handlers = [
            handler1,
        ]
        log = mock.MagicMock()
        log.handlers = [
            handler2,
        ]
        log.parent = parent
        log.propagate = True

        value = "test"
        set_context(log, value)

        handler1.set_context.assert_called_with(value)
        handler2.set_context.assert_called_with(value)
Beispiel #26
0
    def test_get_conn_uri_non_existent_key(self, mock_client_callable, mock_get_creds):
        mock_get_creds.return_value = CREDENTIALS, PROJECT_ID
        mock_client = mock.MagicMock()
        mock_client_callable.return_value = mock_client
        # The requested secret id or secret version does not exist
        mock_client.access_secret_version.side_effect = NotFound('test-msg')

        secrets_manager_backend = CloudSecretsManagerBackend(connections_prefix=CONNECTIONS_PREFIX)

        self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=CONN_ID))
        self.assertEqual([], secrets_manager_backend.get_connections(conn_id=CONN_ID))
Beispiel #27
0
 def test_get_queryset_with_kwargs(self):
     """Establish that our `get_queryset` method filters in the way
     we expect if we have unknown keyword arguments (which typically come
     from parent viewsets).
     """
     mvs = ModelViewSet(kwargs={'foo__pk': 42})
     with mock.patch.object(ModelViewSet.mro()[1], 'get_queryset') as m:
         m.return_value = mock.MagicMock()
         qs = mvs.get_queryset()
         self.assertEqual(m.return_value.mock_calls,
                          [mock.call.filter(foo__pk=42)])
Beispiel #28
0
    def test_raise_exception_on_positional_argument(self):
        mock_instance = mock.MagicMock()

        class FixutureFallback:
            @_fallback_to_project_id_from_variables
            def test_fn(self, *args, **kwargs):
                mock_instance(*args, **kwargs)

        with self.assertRaisesRegex(
                AirflowException,
                "You must use keyword arguments in this methods rather than positional"
        ):
            FixutureFallback().test_fn({'project': "TEST"}, "TEST2")
    def test_run_cli_verbose(self, popen_mock):
        test_stdout_lines = [b"one", b"two", b""]
        test_stdout_strings = [s.decode('utf-8') for s in test_stdout_lines]

        proc_mock = mock.MagicMock()
        proc_mock.returncode = 0
        proc_mock.stdout.readline = mock.Mock(side_effect=test_stdout_lines)
        popen_mock.return_value = proc_mock

        hook = self.pig_hook()
        stdout = hook.run_cli("", verbose=True)

        self.assertEqual(stdout, "".join(test_stdout_strings))
Beispiel #30
0
    def test_raise_exception_on_conflict(self):
        mock_instance = mock.MagicMock()

        class FixtureFallback:
            @_fallback_to_project_id_from_variables
            def test_fn(self, *args, **kwargs):
                mock_instance(*args, **kwargs)

        with self.assertRaisesRegex(
                AirflowException,
                "The mutually exclusive parameter `project_id` and `project` key in `variables` parameters are "
                "both present\\. Please remove one\\."):
            FixtureFallback().test_fn(variables={'project': "TEST"},
                                      project_id="TEST2")