def test_poke_raise_exception_on_terminal_state(self, mock_hook):
        mock_get_job = mock_hook.return_value.get_job
        mock_fetch_job_autoscaling_events_by_id = mock_hook.return_value.fetch_job_autoscaling_events_by_id
        callback = mock.MagicMock()

        task = DataflowJobAutoScalingEventsSensor(
            task_id=TEST_TASK_ID,
            job_id=TEST_JOB_ID,
            callback=callback,
            fail_on_terminal_state=True,
            location=TEST_LOCATION,
            project_id=TEST_PROJECT_ID,
            gcp_conn_id=TEST_GCP_CONN_ID,
            delegate_to=TEST_DELEGATE_TO,
            impersonation_chain=TEST_IMPERSONATION_CHAIN,
        )
        mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": DataflowJobStatus.JOB_STATE_DONE}

        with self.assertRaisesRegex(
            AirflowException,
            f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
            f"{DataflowJobStatus.JOB_STATE_DONE}",
        ):
            task.poke(mock.MagicMock())

        mock_hook.assert_called_once_with(
            gcp_conn_id=TEST_GCP_CONN_ID,
            delegate_to=TEST_DELEGATE_TO,
            impersonation_chain=TEST_IMPERSONATION_CHAIN,
        )
        mock_fetch_job_autoscaling_events_by_id.assert_not_called()
        callback.assert_not_called()
    def test_poke(self, job_current_state, fail_on_terminal_state, mock_hook):
        mock_get_job = mock_hook.return_value.get_job
        mock_fetch_job_autoscaling_events_by_id = mock_hook.return_value.fetch_job_autoscaling_events_by_id
        callback = mock.MagicMock()

        task = DataflowJobAutoScalingEventsSensor(
            task_id=TEST_TASK_ID,
            job_id=TEST_JOB_ID,
            callback=callback,
            fail_on_terminal_state=fail_on_terminal_state,
            location=TEST_LOCATION,
            project_id=TEST_PROJECT_ID,
            gcp_conn_id=TEST_GCP_CONN_ID,
            delegate_to=TEST_DELEGATE_TO,
            impersonation_chain=TEST_IMPERSONATION_CHAIN,
        )
        mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState": job_current_state}

        results = task.poke(mock.MagicMock())

        self.assertEqual(callback.return_value, results)

        mock_hook.assert_called_once_with(
            gcp_conn_id=TEST_GCP_CONN_ID,
            delegate_to=TEST_DELEGATE_TO,
            impersonation_chain=TEST_IMPERSONATION_CHAIN,
        )
        mock_fetch_job_autoscaling_events_by_id.assert_called_once_with(
            job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, location=TEST_LOCATION
        )
        callback.assert_called_once_with(mock_fetch_job_autoscaling_events_by_id.return_value)
Exemple #3
0
    # [END howto_sensor_wait_for_job_message]

    # [START howto_sensor_wait_for_job_autoscaling_event]
    def check_autoscaling_event(autoscaling_events: List[dict]) -> bool:
        """Check autoscaling event"""
        for autoscaling_event in autoscaling_events:
            if "Worker pool started." in autoscaling_event.get(
                    "description", {}).get("messageText", ""):
                return True
        return False

    wait_for_python_job_async_autoscaling_event = DataflowJobAutoScalingEventsSensor(
        task_id="wait-for-python-job-async-autoscaling-event",
        job_id=
        "{{task_instance.xcom_pull('start-python-job-async')['job_id']}}",
        location='europe-west3',
        callback=check_autoscaling_event,
    )
    # [END howto_sensor_wait_for_job_autoscaling_event]

    start_python_job_async >> wait_for_python_job_async_done
    start_python_job_async >> wait_for_python_job_async_metric
    start_python_job_async >> wait_for_python_job_async_message
    start_python_job_async >> wait_for_python_job_async_autoscaling_event

with models.DAG(
        "example_gcp_dataflow_template",
        default_args=default_args,
        start_date=days_ago(1),
        schedule_interval=None,  # Override to match your needs