def test_batch_prediction_job_remote_runner_retries_to_get_status_on_non_completed_job(
            self, mock_time_sleep, mock_path_exists, mock_job_service_client):
        job_client = mock.Mock()
        mock_job_service_client.return_value = job_client

        create_batch_prediction_job_response = mock.Mock()
        job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response
        create_batch_prediction_job_response.name = self._batch_prediction_job_name

        get_batch_prediction_job_response_success = mock.Mock()
        get_batch_prediction_job_response_success.name = 'job1'
        get_batch_prediction_job_response_success.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED
        get_batch_prediction_job_response_success.output_info.bigquery_output_table = 'bigquery_output_table'
        get_batch_prediction_job_response_success.output_info.bigquery_output_dataset = 'bigquery_output_dataset'
        get_batch_prediction_job_response_success.output_info.gcs_output_directory = 'gcs_output_directory'

        get_batch_prediction_job_response_running = mock.Mock()
        get_batch_prediction_job_response_running.state = gca_job_state.JobState.JOB_STATE_RUNNING

        job_client.get_batch_prediction_job.side_effect = [
            get_batch_prediction_job_response_running,
            get_batch_prediction_job_response_success
        ]

        mock_path_exists.return_value = False

        batch_prediction_job_remote_runner.create_batch_prediction_job(
            self._job_type, self._project, self._location, self._payload,
            self._gcp_resources, self._executor_input)
        mock_time_sleep.assert_called_once_with(
            job_remote_runner._POLLING_INTERVAL_IN_SECONDS)
        self.assertEqual(job_client.get_batch_prediction_job.call_count, 2)
    def test_batch_prediction_job_remote_runner_succeeded_output_bq_table(
            self, mock_path_exists, mock_job_service_client):
        job_client = mock.Mock()
        mock_job_service_client.return_value = job_client

        create_batch_prediction_job_response = mock.Mock()
        job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response
        create_batch_prediction_job_response.name = self._batch_prediction_job_name

        get_batch_prediction_job_response = mock.Mock()
        job_client.get_batch_prediction_job.return_value = get_batch_prediction_job_response
        get_batch_prediction_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED
        get_batch_prediction_job_response.name = 'job1'
        get_batch_prediction_job_response.output_info.bigquery_output_table = 'bigquery_output_table'
        get_batch_prediction_job_response.output_info.bigquery_output_dataset = 'bq://bq_project.bigquery_output_dataset'
        get_batch_prediction_job_response.output_info.gcs_output_directory = ''

        mock_path_exists.return_value = False

        batch_prediction_job_remote_runner.create_batch_prediction_job(
            self._job_type, self._project, self._location, self._payload,
            self._gcp_resources, self._executor_input)

        mock_job_service_client.assert_called_once_with(client_options={
            'api_endpoint':
            'test_region-aiplatform.googleapis.com'
        },
                                                        client_info=mock.ANY)

        expected_parent = f'projects/{self._project}/locations/{self._location}'
        expected_job_spec = json.loads(self._payload, strict=False)

        job_client.create_batch_prediction_job.assert_called_once_with(
            parent=expected_parent, batch_prediction_job=expected_job_spec)

        with open(self._gcp_resources) as f:
            serialized_gcp_resources = f.read()

            # Instantiate GCPResources Proto
            batch_prediction_job_resources = json_format.Parse(
                serialized_gcp_resources, GcpResources())

            self.assertEqual(len(batch_prediction_job_resources.resources), 1)
            batch_prediction_job_name = batch_prediction_job_resources.resources[
                0].resource_uri[len(self._batch_prediction_job_uri_prefix):]
            self.assertEqual(batch_prediction_job_name,
                             self._batch_prediction_job_name)

        with open(self._output_file_path) as f:
            executor_output = json.load(f, strict=False)
            self.assertEqual(
                executor_output,
                json.loads('{"artifacts": {\
              "batchpredictionjob": {"artifacts": [{"metadata": {"resourceName": "job1", "bigqueryOutputDataset": "bq://bq_project.bigquery_output_dataset","bigqueryOutputTable": "bigquery_output_table","gcsOutputDirectory": ""}, "name": "foobar", "type": {"schemaTitle": "google.VertexBatchPredictionJob"}, "uri": "https://test_region-aiplatform.googleapis.com/v1/job1"}]},\
              "bigquery_output_table": {"artifacts": [{"metadata": {"projectId": "bq_project", "datasetId": "bigquery_output_dataset", "tableId": "bigquery_output_table"}, "name": "bq_table", "type": {"schemaTitle": "google.BQTable"}, "uri": "https://www.googleapis.com/bigquery/v2/projects/bq_project/datasets/bigquery_output_dataset/tables/bigquery_output_table"}]}}}'
                           ))
    def test_batch_prediction_job_remote_runner_raises_exception_on_error(
            self, mock_path_exists, mock_job_service_client):

        job_client = mock.Mock()
        mock_job_service_client.return_value = job_client

        create_batch_prediction_job_response = mock.Mock()
        job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response
        create_batch_prediction_job_response.name = self._batch_prediction_job_name

        get_batch_prediction_job_response = mock.Mock()
        job_client.get_batch_prediction_job.return_value = get_batch_prediction_job_response
        get_batch_prediction_job_response.state = gca_job_state.JobState.JOB_STATE_FAILED

        mock_path_exists.return_value = False

        with self.assertRaises(RuntimeError):
            batch_prediction_job_remote_runner.create_batch_prediction_job(
                self._job_type, self._project, self._location, self._payload,
                self._gcp_resources, self._executor_input)
    def test_batch_prediction_job_remote_runner_cancel(
            self, mock_execution_context, mock_post_requests, _, mock_auth,
            mock_path_exists, mock_job_service_client):
        creds = mock.Mock()
        creds.token = 'fake_token'
        mock_auth.return_value = [creds, "project"]

        job_client = mock.Mock()
        mock_job_service_client.return_value = job_client

        create_batch_prediction_job_response = mock.Mock()
        job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response
        create_batch_prediction_job_response.name = self._batch_prediction_job_name

        get_batch_prediction_job_response = mock.Mock()
        job_client.get_batch_prediction_job.return_value = get_batch_prediction_job_response
        get_batch_prediction_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED
        get_batch_prediction_job_response.name = 'job1'
        get_batch_prediction_job_response.output_info.bigquery_output_table = 'bigquery_output_table'
        get_batch_prediction_job_response.output_info.bigquery_output_dataset = 'bigquery_output_dataset'
        get_batch_prediction_job_response.output_info.gcs_output_directory = 'gcs_output_directory'

        mock_path_exists.return_value = False
        mock_execution_context.return_value = None

        batch_prediction_job_remote_runner.create_batch_prediction_job(
            self._job_type, self._project, self._location, self._payload,
            self._gcp_resources, self._executor_input)

        # Call cancellation handler
        mock_execution_context.call_args[1]['on_cancel']()
        mock_post_requests.assert_called_once_with(
            url=
            f'{self._batch_prediction_job_uri_prefix}{self._batch_prediction_job_name}:cancel',
            data='',
            headers={
                'Content-type': 'application/json',
                'Authorization': 'Bearer fake_token',
            })