Exemplo n.º 1
0
 def execute(self, context):
     hook = CloudTasksHook(gcp_conn_id=self.gcp_conn_id)
     return hook.pause_queue(
         location=self.location,
         queue_name=self.queue_name,
         project_id=self.project_id,
         retry=self.retry,
         timeout=self.timeout,
         metadata=self.metadata,
     )
Exemplo n.º 2
0
class TestCloudTasksHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(
                "airflow.gcp.hooks.base.CloudBaseHook.__init__",
                new=mock_base_gcp_hook_no_default_project_id,
        ):
            self.hook = CloudTasksHook(gcp_conn_id="test")

    @mock.patch("airflow.gcp.hooks.tasks.CloudTasksHook.client_info",
                new_callable=mock.PropertyMock)
    @mock.patch("airflow.gcp.hooks.tasks.CloudTasksHook._get_credentials")
    @mock.patch("airflow.gcp.hooks.tasks.CloudTasksClient")
    def test_cloud_tasks_client_creation(self, mock_client, mock_get_creds,
                                         mock_client_info):
        result = self.hook.get_conn()
        mock_client.assert_called_once_with(
            credentials=mock_get_creds.return_value,
            client_info=mock_client_info.return_value)
        self.assertEqual(mock_client.return_value, result)
        self.assertEqual(self.hook._client, result)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.create_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_create_queue(self, get_conn):
        result = self.hook.create_queue(
            location=LOCATION,
            task_queue=Queue(),
            queue_name=QUEUE_ID,
            project_id=PROJECT_ID,
        )

        self.assertIs(result, API_RESPONSE)

        get_conn.return_value.create_queue.assert_called_once_with(
            parent=FULL_LOCATION_PATH,
            queue=Queue(name=FULL_QUEUE_PATH),
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.update_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_update_queue(self, get_conn):
        result = self.hook.update_queue(
            task_queue=Queue(state=3),
            location=LOCATION,
            queue_name=QUEUE_ID,
            project_id=PROJECT_ID,
        )

        self.assertIs(result, API_RESPONSE)

        get_conn.return_value.update_queue.assert_called_once_with(
            queue=Queue(name=FULL_QUEUE_PATH, state=3),
            update_mask=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.get_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_get_queue(self, get_conn):
        result = self.hook.get_queue(location=LOCATION,
                                     queue_name=QUEUE_ID,
                                     project_id=PROJECT_ID)

        self.assertIs(result, API_RESPONSE)

        get_conn.return_value.get_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.list_queues.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_list_queues(self, get_conn):
        result = self.hook.list_queues(location=LOCATION,
                                       project_id=PROJECT_ID)

        self.assertEqual(result, list(API_RESPONSE))

        get_conn.return_value.list_queues.assert_called_once_with(
            parent=FULL_LOCATION_PATH,
            filter_=None,
            page_size=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.delete_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_delete_queue(self, get_conn):
        result = self.hook.delete_queue(location=LOCATION,
                                        queue_name=QUEUE_ID,
                                        project_id=PROJECT_ID)

        self.assertEqual(result, None)

        get_conn.return_value.delete_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.purge_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_purge_queue(self, get_conn):
        result = self.hook.purge_queue(location=LOCATION,
                                       queue_name=QUEUE_ID,
                                       project_id=PROJECT_ID)

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.purge_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.pause_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_pause_queue(self, get_conn):
        result = self.hook.pause_queue(location=LOCATION,
                                       queue_name=QUEUE_ID,
                                       project_id=PROJECT_ID)

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.pause_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.resume_queue.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_resume_queue(self, get_conn):
        result = self.hook.resume_queue(location=LOCATION,
                                        queue_name=QUEUE_ID,
                                        project_id=PROJECT_ID)

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.resume_queue.assert_called_once_with(
            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.create_task.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_create_task(self, get_conn):
        result = self.hook.create_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task=Task(),
            project_id=PROJECT_ID,
            task_name=TASK_NAME,
        )

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.create_task.assert_called_once_with(
            parent=FULL_QUEUE_PATH,
            task=Task(name=FULL_TASK_PATH),
            response_view=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.get_task.return_value": API_RESPONSE},  # type: ignore
    )
    def test_get_task(self, get_conn):
        result = self.hook.get_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task_name=TASK_NAME,
            project_id=PROJECT_ID,
        )

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.get_task.assert_called_once_with(
            name=FULL_TASK_PATH,
            response_view=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.list_tasks.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_list_tasks(self, get_conn):
        result = self.hook.list_tasks(location=LOCATION,
                                      queue_name=QUEUE_ID,
                                      project_id=PROJECT_ID)

        self.assertEqual(result, list(API_RESPONSE))

        get_conn.return_value.list_tasks.assert_called_once_with(
            parent=FULL_QUEUE_PATH,
            response_view=None,
            page_size=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.delete_task.return_value":
           API_RESPONSE},  # type: ignore
    )
    def test_delete_task(self, get_conn):
        result = self.hook.delete_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task_name=TASK_NAME,
            project_id=PROJECT_ID,
        )

        self.assertEqual(result, None)

        get_conn.return_value.delete_task.assert_called_once_with(
            name=FULL_TASK_PATH, retry=None, timeout=None, metadata=None)

    @mock.patch(  # type: ignore
        "airflow.gcp.hooks.tasks.CloudTasksHook.get_conn",
        **{"return_value.run_task.return_value": API_RESPONSE},  # type: ignore
    )
    def test_run_task(self, get_conn):
        result = self.hook.run_task(
            location=LOCATION,
            queue_name=QUEUE_ID,
            task_name=TASK_NAME,
            project_id=PROJECT_ID,
        )

        self.assertEqual(result, API_RESPONSE)

        get_conn.return_value.run_task.assert_called_once_with(
            name=FULL_TASK_PATH,
            response_view=None,
            retry=None,
            timeout=None,
            metadata=None,
        )