Esempio n. 1
0
    def execute(self, context):
        hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
        try:
            queue = hook.create_queue(
                location=self.location,
                task_queue=self.task_queue,
                project_id=self.project_id,
                queue_name=self.queue_name,
                retry=self.retry,
                timeout=self.timeout,
                metadata=self.metadata,
            )
        except AlreadyExists:
            queue = hook.get_queue(
                location=self.location,
                project_id=self.project_id,
                queue_name=self.queue_name,
                retry=self.retry,
                timeout=self.timeout,
                metadata=self.metadata,
            )

        return MessageToDict(queue)
Esempio n. 2
0
 def setUp(self):
     with mock.patch(
         "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
         new=mock_base_gcp_hook_no_default_project_id,
     ):
         self.hook = CloudTasksHook(gcp_conn_id="test")
Esempio n. 3
0
class TestCloudTasksHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(
            "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
            new=mock_base_gcp_hook_no_default_project_id,
        ):
            self.hook = CloudTasksHook(gcp_conn_id="test")

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.client_info",
        new_callable=mock.PropertyMock,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.tasks.CloudTasksHook._get_credentials")
    @mock.patch("airflow.providers.google.cloud.hooks.tasks.CloudTasksClient")
    def test_cloud_tasks_client_creation(self, mock_client, mock_get_creds, mock_client_info):
        result = self.hook.get_conn()
        mock_client.assert_called_once_with(
            credentials=mock_get_creds.return_value, client_info=mock_client_info.return_value
        )
        self.assertEqual(mock_client.return_value, result)
        self.assertEqual(self.hook._client, result)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.create_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_create_queue(self, get_conn):
        result = self.hook.create_queue(
            location=LOCATION,
            task_queue=Queue(),
            queue_name=QUEUE_ID,
            project_id=PROJECT_ID,
        )

        self.assertIs(result, API_RESPONSE)

        get_conn.return_value.create_queue.assert_called_once_with(
            parent=FULL_LOCATION_PATH,
            queue=Queue(name=FULL_QUEUE_PATH),
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.update_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_update_queue(self, get_conn):
        result = self.hook.update_queue(
            task_queue=Queue(state=3),
            location=LOCATION,
            queue_name=QUEUE_ID,
            project_id=PROJECT_ID,
        )

        self.assertIs(result, API_RESPONSE)

        get_conn.return_value.update_queue.assert_called_once_with(
            queue=Queue(name=FULL_QUEUE_PATH, state=3),
            update_mask=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.get_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_get_queue(self, get_conn):
        result = self.hook.get_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)

        self.assertIs(result, API_RESPONSE)

        get_conn.return_value.get_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.list_queues.return_value": API_RESPONSE},  # type: ignore
    )
    def test_list_queues(self, get_conn):
        result = self.hook.list_queues(location=LOCATION, project_id=PROJECT_ID)

        self.assertEqual(result, list(API_RESPONSE))

        get_conn.return_value.list_queues.assert_called_once_with(
            parent=FULL_LOCATION_PATH,
            filter_=None,
            page_size=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.delete_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_delete_queue(self, get_conn):
        result = self.hook.delete_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)

        self.assertEqual(result, None)

        get_conn.return_value.delete_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.purge_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_purge_queue(self, get_conn):
        result = self.hook.purge_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.purge_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.pause_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_pause_queue(self, get_conn):
        result = self.hook.pause_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.pause_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.resume_queue.return_value": API_RESPONSE},  # type: ignore
    )
    def test_resume_queue(self, get_conn):
        result = self.hook.resume_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.resume_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.create_task.return_value": API_RESPONSE},  # type: ignore
    )
    def test_create_task(self, get_conn):
        result = self.hook.create_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task=Task(),
            project_id=PROJECT_ID,
            task_name=TASK_NAME,
        )

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.create_task.assert_called_once_with(
            parent=FULL_QUEUE_PATH,
            task=Task(name=FULL_TASK_PATH),
            response_view=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.get_task.return_value": API_RESPONSE},  # type: ignore
    )
    def test_get_task(self, get_conn):
        result = self.hook.get_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task_name=TASK_NAME,
            project_id=PROJECT_ID,
        )

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.get_task.assert_called_once_with(
            name=FULL_TASK_PATH,
            response_view=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.list_tasks.return_value": API_RESPONSE},  # type: ignore
    )
    def test_list_tasks(self, get_conn):
        result = self.hook.list_tasks(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)

        self.assertEqual(result, list(API_RESPONSE))

        get_conn.return_value.list_tasks.assert_called_once_with(
            parent=FULL_QUEUE_PATH,
            response_view=None,
            page_size=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.delete_task.return_value": API_RESPONSE},  # type: ignore
    )
    def test_delete_task(self, get_conn):
        result = self.hook.delete_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task_name=TASK_NAME,
            project_id=PROJECT_ID,
        )

        self.assertEqual(result, None)

        get_conn.return_value.delete_task.assert_called_once_with(
            name=FULL_TASK_PATH, retry=None, timeout=None, metadata=None
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.run_task.return_value": API_RESPONSE},  # type: ignore
    )
    def test_run_task(self, get_conn):
        result = self.hook.run_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task_name=TASK_NAME,
            project_id=PROJECT_ID,
        )

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.run_task.assert_called_once_with(
            name=FULL_TASK_PATH,
            response_view=None,
            retry=None,
            timeout=None,
            metadata=None,
        )