def test_batch_prediction_job_remote_runner_retries_to_get_status_on_non_completed_job( self, mock_time_sleep, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_batch_prediction_job_response = mock.Mock() job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response create_batch_prediction_job_response.name = self._batch_prediction_job_name get_batch_prediction_job_response_success = mock.Mock() get_batch_prediction_job_response_success.name = 'job1' get_batch_prediction_job_response_success.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED get_batch_prediction_job_response_success.output_info.bigquery_output_table = 'bigquery_output_table' get_batch_prediction_job_response_success.output_info.bigquery_output_dataset = 'bigquery_output_dataset' get_batch_prediction_job_response_success.output_info.gcs_output_directory = 'gcs_output_directory' get_batch_prediction_job_response_running = mock.Mock() get_batch_prediction_job_response_running.state = gca_job_state.JobState.JOB_STATE_RUNNING job_client.get_batch_prediction_job.side_effect = [ get_batch_prediction_job_response_running, get_batch_prediction_job_response_success ] mock_path_exists.return_value = False batch_prediction_job_remote_runner.create_batch_prediction_job( self._job_type, self._project, self._location, self._payload, self._gcp_resources, self._executor_input) mock_time_sleep.assert_called_once_with( job_remote_runner._POLLING_INTERVAL_IN_SECONDS) self.assertEqual(job_client.get_batch_prediction_job.call_count, 2)
def test_batch_prediction_job_remote_runner_succeeded_output_bq_table( self, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_batch_prediction_job_response = mock.Mock() job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response create_batch_prediction_job_response.name = self._batch_prediction_job_name get_batch_prediction_job_response = mock.Mock() job_client.get_batch_prediction_job.return_value = get_batch_prediction_job_response get_batch_prediction_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED get_batch_prediction_job_response.name = 'job1' get_batch_prediction_job_response.output_info.bigquery_output_table = 'bigquery_output_table' get_batch_prediction_job_response.output_info.bigquery_output_dataset = 'bq://bq_project.bigquery_output_dataset' get_batch_prediction_job_response.output_info.gcs_output_directory = '' mock_path_exists.return_value = False batch_prediction_job_remote_runner.create_batch_prediction_job( self._job_type, self._project, self._location, self._payload, self._gcp_resources, self._executor_input) mock_job_service_client.assert_called_once_with(client_options={ 'api_endpoint': 'test_region-aiplatform.googleapis.com' }, client_info=mock.ANY) expected_parent = f'projects/{self._project}/locations/{self._location}' expected_job_spec = json.loads(self._payload, strict=False) job_client.create_batch_prediction_job.assert_called_once_with( parent=expected_parent, batch_prediction_job=expected_job_spec) with open(self._gcp_resources) as f: serialized_gcp_resources = f.read() # Instantiate GCPResources Proto batch_prediction_job_resources = json_format.Parse( serialized_gcp_resources, GcpResources()) self.assertEqual(len(batch_prediction_job_resources.resources), 1) batch_prediction_job_name = batch_prediction_job_resources.resources[ 0].resource_uri[len(self._batch_prediction_job_uri_prefix):] self.assertEqual(batch_prediction_job_name, self._batch_prediction_job_name) with open(self._output_file_path) as f: executor_output = json.load(f, strict=False) self.assertEqual( executor_output, json.loads('{"artifacts": {\ "batchpredictionjob": {"artifacts": [{"metadata": {"resourceName": "job1", "bigqueryOutputDataset": "bq://bq_project.bigquery_output_dataset","bigqueryOutputTable": "bigquery_output_table","gcsOutputDirectory": ""}, "name": "foobar", "type": {"schemaTitle": "google.VertexBatchPredictionJob"}, "uri": "https://test_region-aiplatform.googleapis.com/v1/job1"}]},\ "bigquery_output_table": {"artifacts": [{"metadata": {"projectId": "bq_project", "datasetId": "bigquery_output_dataset", "tableId": "bigquery_output_table"}, "name": "bq_table", "type": {"schemaTitle": "google.BQTable"}, "uri": "https://www.googleapis.com/bigquery/v2/projects/bq_project/datasets/bigquery_output_dataset/tables/bigquery_output_table"}]}}}' ))
def test_batch_prediction_job_remote_runner_raises_exception_on_error( self, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_batch_prediction_job_response = mock.Mock() job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response create_batch_prediction_job_response.name = self._batch_prediction_job_name get_batch_prediction_job_response = mock.Mock() job_client.get_batch_prediction_job.return_value = get_batch_prediction_job_response get_batch_prediction_job_response.state = gca_job_state.JobState.JOB_STATE_FAILED mock_path_exists.return_value = False with self.assertRaises(RuntimeError): batch_prediction_job_remote_runner.create_batch_prediction_job( self._job_type, self._project, self._location, self._payload, self._gcp_resources, self._executor_input)
def test_batch_prediction_job_remote_runner_cancel( self, mock_execution_context, mock_post_requests, _, mock_auth, mock_path_exists, mock_job_service_client): creds = mock.Mock() creds.token = 'fake_token' mock_auth.return_value = [creds, "project"] job_client = mock.Mock() mock_job_service_client.return_value = job_client create_batch_prediction_job_response = mock.Mock() job_client.create_batch_prediction_job.return_value = create_batch_prediction_job_response create_batch_prediction_job_response.name = self._batch_prediction_job_name get_batch_prediction_job_response = mock.Mock() job_client.get_batch_prediction_job.return_value = get_batch_prediction_job_response get_batch_prediction_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED get_batch_prediction_job_response.name = 'job1' get_batch_prediction_job_response.output_info.bigquery_output_table = 'bigquery_output_table' get_batch_prediction_job_response.output_info.bigquery_output_dataset = 'bigquery_output_dataset' get_batch_prediction_job_response.output_info.gcs_output_directory = 'gcs_output_directory' mock_path_exists.return_value = False mock_execution_context.return_value = None batch_prediction_job_remote_runner.create_batch_prediction_job( self._job_type, self._project, self._location, self._payload, self._gcp_resources, self._executor_input) # Call cancellation handler mock_execution_context.call_args[1]['on_cancel']() mock_post_requests.assert_called_once_with( url= f'{self._batch_prediction_job_uri_prefix}{self._batch_prediction_job_name}:cancel', data='', headers={ 'Content-type': 'application/json', 'Authorization': 'Bearer fake_token', })