def test_lineage_send(self, atlas_mock): td = mock.MagicMock() en = mock.MagicMock() atlas_mock.return_value = mock.Mock(typedefs=td, entity_post=en) dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE) f1 = File("/tmp/does_not_exist_1") f2 = File("/tmp/does_not_exist_2") inlets_d = [ f1, ] outlets_d = [ f2, ] with dag: op1 = DummyOperator(task_id='leave1', inlets={"datasets": inlets_d}, outlets={"datasets": outlets_d}) ctx = {"ti": TI(task=op1, execution_date=DEFAULT_DATE)} self.atlas.send_lineage(operator=op1, inlets=inlets_d, outlets=outlets_d, context=ctx) self.assertEqual(td.create.call_count, 1) self.assertTrue(en.create.called) self.assertEqual(len(en.mock_calls), 3)
def _setup_connections(get_connections, uri): gcp_connection = mock.MagicMock() gcp_connection.extra_dejson = mock.MagicMock() gcp_connection.extra_dejson.get.return_value = 'empty_project' cloudsql_connection = Connection() cloudsql_connection.parse_from_uri(uri) cloudsql_connection2 = Connection() cloudsql_connection2.parse_from_uri(uri) get_connections.side_effect = [[gcp_connection], [cloudsql_connection], [cloudsql_connection2]]
def setUp(self): self.gunicorn_master_proc = mock.Mock(pid=2137) self.children = mock.MagicMock() self.child = mock.MagicMock() self.process = mock.MagicMock() self.monitor = cli.GunicornMonitor( gunicorn_master_pid=1, num_workers_expected=4, master_timeout=60, worker_refresh_interval=60, worker_refresh_batch_size=2, reload_on_plugin_change=True, )
def setUp(self): super().setUp() self.extra_dejson = mock.MagicMock() self.extra_dejson.get.return_value = None self.conn = mock.MagicMock() self.conn.extra_dejson = self.extra_dejson conn = self.conn class SubPigCliHook(PigCliHook): def get_connection(self, id): return conn self.pig_hook = SubPigCliHook
def setUp(self): super(TestOracleHook, self).setUp() self.cur = mock.MagicMock() self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur conn = self.conn class UnitTestOracleHook(OracleHook): conn_name_attr = 'test_conn_id' def get_conn(self): return conn self.db_hook = UnitTestOracleHook()
def test_flush(self): logger = mock.MagicMock() logger.log = mock.MagicMock() log = StreamLogWriter(logger, 1) msg = "test_message" log.write(msg) self.assertEqual(log._buffer, msg) log.flush() logger.log.assert_called_once_with(1, msg) self.assertEqual(log._buffer, "")
def test_get_conn_uri_engine_version_1(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_client.secrets.kv.v1.read_secret.return_value = { 'request_id': '182d0673-618c-9889-4cba-4e1f4cfe4b4b', 'lease_id': '', 'renewable': False, 'lease_duration': 2764800, 'data': { 'conn_uri': 'postgresql://*****:*****@host:5432/airflow' }, 'wrap_info': None, 'warnings': None, 'auth': None } kwargs = { "connections_path": "connections", "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS", "kv_engine_version": 1 } test_client = VaultBackend(**kwargs) returned_uri = test_client.get_conn_uri(conn_id="test_postgres") mock_client.secrets.kv.v1.read_secret.assert_called_once_with( mount_point='airflow', path='connections/test_postgres') self.assertEqual('postgresql://*****:*****@host:5432/airflow', returned_uri)
def test_get_variable_value(self, mock_hvac): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_client.secrets.kv.v2.read_secret_version.return_value = { 'request_id': '2d48a2ad-6bcb-e5b6-429d-da35fdf31f56', 'lease_id': '', 'renewable': False, 'lease_duration': 0, 'data': { 'data': { 'value': 'world' }, 'metadata': { 'created_time': '2020-03-28T02:10:54.301784Z', 'deletion_time': '', 'destroyed': False, 'version': 1 } }, 'wrap_info': None, 'warnings': None, 'auth': None } kwargs = { "variables_path": "variables", "mount_point": "airflow", "auth_type": "token", "url": "http://127.0.0.1:8200", "token": "s.7AU0I51yv1Q1lxOIg1F3ZRAS" } test_client = VaultBackend(**kwargs) returned_uri = test_client.get_variable("hello") self.assertEqual('world', returned_uri)
def test_should_copy_single_file(self, mock_named_temporary_file, mock_gdrive, mock_gcs_hook): type(mock_named_temporary_file.return_value.__enter__.return_value ).name = mock.PropertyMock(side_effect=["TMP1"]) task = GcsToGDriveOperator( task_id="copy_single_file", source_bucket="data", source_object="sales/sales-2017/january.avro", destination_object="copied_sales/2017/january-backup.avro", ) task.execute(mock.MagicMock()) mock_gcs_hook.assert_has_calls([ mock.call(delegate_to=None, google_cloud_storage_conn_id="google_cloud_default"), mock.call().download(bucket="data", filename="TMP1", object="sales/sales-2017/january.avro"), ]) mock_gdrive.assert_has_calls([ mock.call(delegate_to=None, gcp_conn_id="google_cloud_default"), mock.call().upload_file( local_location="TMP1", remote_location="copied_sales/2017/january-backup.avro"), ])
def test_launch_wait(self): with client.test_mode as t: t.register_json('/workflow_job_templates/1/launch/', {'id': 1}, method='POST') with mock.patch.object(self.res, 'wait', mock.MagicMock()) as m: self.res.launch(workflow_job_template=1, wait=True) assert m.called
def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watcher): # When a quota is exceeded this is the ApiException we get response = HTTPResponse( body= '{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", ' '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, ' 'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", ' '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}' ) response.status = 403 response.reason = "Forbidden" # A mock kube_client that throws errors when making a pod mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True) mock_kube_client.create_namespaced_pod = mock.MagicMock( side_effect=ApiException(http_resp=response)) mock_get_kube_client.return_value = mock_kube_client mock_api_client = mock.MagicMock() mock_api_client.sanitize_for_serialization.return_value = {} mock_kube_client.api_client = mock_api_client kubernetes_executor = KubernetesExecutor() kubernetes_executor.start() # Execute a task while the Api Throws errors try_number = 1 kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), queue=None, command='command', executor_config={}) kubernetes_executor.sync() kubernetes_executor.sync() assert mock_kube_client.create_namespaced_pod.called self.assertFalse(kubernetes_executor.task_queue.empty()) # Disable the ApiException mock_kube_client.create_namespaced_pod.side_effect = None # Execute the task without errors should empty the queue kubernetes_executor.sync() assert mock_kube_client.create_namespaced_pod.called self.assertTrue(kubernetes_executor.task_queue.empty())
def setUp(self) -> None: with mock.patch( "airflow.gcp.hooks.bigquery_dts.GoogleCloudBaseHook.__init__", new=mock_base_gcp_hook_no_default_project_id, ): self.hook = BiqQueryDataTransferServiceHook() self.hook._get_credentials = mock.MagicMock( # type: ignore return_value=CREDENTIALS)
def test_wait_for_job(self, mock_get_job): mock_get_job.side_effect = [ mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)), mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)), ] with self.assertRaises(AirflowException): self.hook.wait_for_job( job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0, ) calls = [ mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), ] mock_get_job.has_calls(calls)
def create_post_side_effect(exception, status_code=500): if exception != requests_exceptions.HTTPError: return exception() else: response = mock.MagicMock() response.status_code = status_code response.raise_for_status.side_effect = exception(response=response) return response
def setUp(self) -> None: with mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.__init__", new=mock_base_gcp_hook_no_default_project_id, ): self.hook = CloudAutoMLHook() self.hook._get_credentials = mock.MagicMock( # type: ignore return_value=CREDENTIALS)
def test_should_raise_exception_on_multiple_wildcard( self, mock_named_temporary_file, mock_gdrive, mock_gcs_hook): task = GcsToGDriveOperator(task_id="move_files", source_bucket="data", source_object="sales/*/*.avro", move_object=True) with six.assertRaisesRegex(self, AirflowException, "Only one wildcard"): task.execute(mock.MagicMock())
def test_get_variable_non_existent_key(self, mock_client_callable, mock_get_creds): mock_get_creds.return_value = CREDENTIALS, PROJECT_ID mock_client = mock.MagicMock() mock_client_callable.return_value = mock_client # The requested secret id or secret version does not exist mock_client.access_secret_version.side_effect = NotFound('test-msg') secrets_manager_backend = CloudSecretsManagerBackend(variables_prefix=VARIABLES_PREFIX) self.assertIsNone(secrets_manager_backend.get_variable(VAR_KEY))
def test_get_credentials_and_project_id_with_default_auth_and_delegate(self, mock_auth_default): mock_credentials = mock.MagicMock() mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") self.instance.extras = {} self.instance.delegate_to = "USER" result = self.instance._get_credentials_and_project_id() mock_auth_default.assert_called_once_with(scopes=self.instance.scopes) mock_credentials.with_subject.assert_called_once_with("USER") self.assertEqual((mock_credentials.with_subject.return_value, "PROJECT_ID"), result)
def test_run_cli_success(self, popen_mock): proc_mock = mock.MagicMock() proc_mock.returncode = 0 proc_mock.stdout.readline.return_value = b'' popen_mock.return_value = proc_mock hook = self.pig_hook() stdout = hook.run_cli("") self.assertEqual(stdout, "")
def test_correct_low_level_api_calls(self): api_mock = mock.MagicMock() uploader = FakeThreadedConcurrentUploader(api_mock, 'vault_name') uploader.upload('foofile') # The threads call the upload_part, so we're just verifying the # initiate/complete multipart API calls. api_mock.initiate_multipart_upload.assert_called_with( 'vault_name', 4 * 1024 * 1024, None) api_mock.complete_multipart_upload.assert_called_with( 'vault_name', mock.ANY, mock.ANY, 8 * 1024 * 1024)
def test_work_queue_is_correctly_populated(self): uploader = FakeThreadedConcurrentUploader(mock.MagicMock(), 'vault_name') uploader.upload('foofile') q = uploader.worker_queue items = [q.get() for i in range(q.qsize())] self.assertEqual(items[0], (0, 4 * 1024 * 1024)) self.assertEqual(items[1], (1, 4 * 1024 * 1024)) # 2 for the parts, 10 for the end sentinels (10 threads). self.assertEqual(len(items), 12)
def test_run_cli_fail(self, popen_mock): proc_mock = mock.MagicMock() proc_mock.returncode = 1 proc_mock.stdout.readline.return_value = b'' popen_mock.return_value = proc_mock hook = self.pig_hook() from airflow.exceptions import AirflowException self.assertRaises(AirflowException, hook.run_cli, "")
def test_downloader_work_queue_is_correctly_populated(self): job = mock.MagicMock() job.archive_size = 8 * 1024 * 1024 downloader = FakeThreadedConcurrentDownloader(job) downloader.download('foofile') q = downloader.worker_queue items = [q.get() for i in range(q.qsize())] self.assertEqual(items[0], (0, 4 * 1024 * 1024)) self.assertEqual(items[1], (1, 4 * 1024 * 1024)) # 2 for the parts, 10 for the end sentinels (10 threads). self.assertEqual(len(items), 12)
def test_support_project_id_parameter(self): mock_instance = mock.MagicMock() class FixtureFallback: @_fallback_to_project_id_from_variables def test_fn(self, *args, **kwargs): mock_instance(*args, **kwargs) FixtureFallback().test_fn(project_id="TEST") mock_instance.assert_called_once_with(project_id="TEST")
def test_set_context(self): handler1 = mock.MagicMock() handler2 = mock.MagicMock() parent = mock.MagicMock() parent.propagate = False parent.handlers = [ handler1, ] log = mock.MagicMock() log.handlers = [ handler2, ] log.parent = parent log.propagate = True value = "test" set_context(log, value) handler1.set_context.assert_called_with(value) handler2.set_context.assert_called_with(value)
def test_get_conn_uri_non_existent_key(self, mock_client_callable, mock_get_creds): mock_get_creds.return_value = CREDENTIALS, PROJECT_ID mock_client = mock.MagicMock() mock_client_callable.return_value = mock_client # The requested secret id or secret version does not exist mock_client.access_secret_version.side_effect = NotFound('test-msg') secrets_manager_backend = CloudSecretsManagerBackend(connections_prefix=CONNECTIONS_PREFIX) self.assertIsNone(secrets_manager_backend.get_conn_uri(conn_id=CONN_ID)) self.assertEqual([], secrets_manager_backend.get_connections(conn_id=CONN_ID))
def test_get_queryset_with_kwargs(self): """Establish that our `get_queryset` method filters in the way we expect if we have unknown keyword arguments (which typically come from parent viewsets). """ mvs = ModelViewSet(kwargs={'foo__pk': 42}) with mock.patch.object(ModelViewSet.mro()[1], 'get_queryset') as m: m.return_value = mock.MagicMock() qs = mvs.get_queryset() self.assertEqual(m.return_value.mock_calls, [mock.call.filter(foo__pk=42)])
def test_raise_exception_on_positional_argument(self): mock_instance = mock.MagicMock() class FixutureFallback: @_fallback_to_project_id_from_variables def test_fn(self, *args, **kwargs): mock_instance(*args, **kwargs) with self.assertRaisesRegex( AirflowException, "You must use keyword arguments in this methods rather than positional" ): FixutureFallback().test_fn({'project': "TEST"}, "TEST2")
def test_run_cli_verbose(self, popen_mock): test_stdout_lines = [b"one", b"two", b""] test_stdout_strings = [s.decode('utf-8') for s in test_stdout_lines] proc_mock = mock.MagicMock() proc_mock.returncode = 0 proc_mock.stdout.readline = mock.Mock(side_effect=test_stdout_lines) popen_mock.return_value = proc_mock hook = self.pig_hook() stdout = hook.run_cli("", verbose=True) self.assertEqual(stdout, "".join(test_stdout_strings))
def test_raise_exception_on_conflict(self): mock_instance = mock.MagicMock() class FixtureFallback: @_fallback_to_project_id_from_variables def test_fn(self, *args, **kwargs): mock_instance(*args, **kwargs) with self.assertRaisesRegex( AirflowException, "The mutually exclusive parameter `project_id` and `project` key in `variables` parameters are " "both present\\. Please remove one\\."): FixtureFallback().test_fn(variables={'project': "TEST"}, project_id="TEST2")