def test_custom_job_remote_runner_on_payload_deserializes_correctly( self, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response = mock.Mock() job_client.get_custom_job.return_value = get_custom_job_response get_custom_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED mock_path_exists.return_value = False custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources) expected_parent = f"projects/{self._project}/locations/{self._location}" expected_job_spec = json.loads(self._payload, strict=False) job_client.create_custom_job.assert_called_once_with( parent=expected_parent, custom_job=expected_job_spec)
def test_custom_job_remote_runner_cancel(self, mock_execution_context, mock_post_requests, _, mock_auth, mock_job_service_client): creds = mock.Mock() creds.token = "fake_token" mock_auth.return_value = [creds, "project"] job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response = mock.Mock() job_client.get_custom_job.return_value = get_custom_job_response get_custom_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED mock_execution_context.return_value = None custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources) # Call cancellation handler mock_execution_context.call_args[1]["on_cancel"]() mock_post_requests.assert_called_once_with( url=f"{self._custom_job_uri_prefix}{self._custom_job_name}:cancel", data="", headers={ "Content-type": "application/json", "Authorization": "Bearer fake_token", })
def test_custom_job_remote_runner_raises_exception_empty_URI_in_gcp_resources( self, mock_time_sleep, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response_success = mock.Mock() get_custom_job_response_success.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED job_client.get_custom_job.side_effect = [ get_custom_job_response_success ] # Write the job proto to output custom_job_resources = GcpResources() custom_job_resource_1 = custom_job_resources.resources.add() custom_job_resource_1.resource_type = "CustomJob" custom_job_resource_1.resource_uri = "" with open(self._gcp_resources, "w") as f: f.write(json_format.MessageToJson(custom_job_resources)) with self.assertRaisesRegex( ValueError, "Job Name in gcp_resource is not formatted correctly or is empty." ): custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources)
def test_custom_job_remote_runner_returns_gcp_resources( self, mock_time_sleep, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response_success = mock.Mock() get_custom_job_response_success.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED job_client.get_custom_job.side_effect = [ get_custom_job_response_success ] mock_path_exists.return_value = False custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources) with open(self._gcp_resources) as f: serialized_gcp_resources = f.read() # Instantiate GCPResources Proto custom_job_resources = json_format.Parse(serialized_gcp_resources, GcpResources()) self.assertEqual(len(custom_job_resources.resources), 1) custom_job_name = custom_job_resources.resources[0].resource_uri[ len(self._custom_job_uri_prefix):] self.assertEqual(custom_job_name, self._custom_job_name)
def test_custom_job_remote_runner_retries_to_get_status_on_non_completed_job( self, mock_time_sleep, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response_success = mock.Mock() get_custom_job_response_success.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED get_custom_job_response_running = mock.Mock() get_custom_job_response_running.state = gca_job_state.JobState.JOB_STATE_RUNNING job_client.get_custom_job.side_effect = [ get_custom_job_response_running, get_custom_job_response_success ] mock_path_exists.return_value = False custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources) mock_time_sleep.assert_called_once_with( job_remote_runner._POLLING_INTERVAL_IN_SECONDS) self.assertEqual(job_client.get_custom_job.call_count, 2)
def test_custom_job_remote_runner_raises_exception_on_internal_error( self, mock_path_exists, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response = mock.Mock() job_client.get_custom_job.return_value = get_custom_job_response get_custom_job_response.state = gca_job_state.JobState.JOB_STATE_FAILED mock_path_exists.return_value = False with self.assertRaises(SystemExit): custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources)
def test_custom_job_remote_runner_on_region_is_set_correctly_in_client_options( self, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response = mock.Mock() job_client.get_custom_job.return_value = get_custom_job_response get_custom_job_response.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources) mock_job_service_client.assert_called_once_with(client_options={ "api_endpoint": "test_region-aiplatform.googleapis.com" }, client_info=mock.ANY)
def test_custom_job_remote_runner_raises_exception_with_more_than_one_resources_in_gcp_resources( self, mock_time_sleep, mock_job_service_client): job_client = mock.Mock() mock_job_service_client.return_value = job_client create_custom_job_response = mock.Mock() job_client.create_custom_job.return_value = create_custom_job_response create_custom_job_response.name = self._custom_job_name get_custom_job_response_success = mock.Mock() get_custom_job_response_success.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED job_client.get_custom_job.side_effect = [ get_custom_job_response_success ] # Write the job proto to output custom_job_resources = GcpResources() custom_job_resource_1 = custom_job_resources.resources.add() custom_job_resource_1.resource_type = "CustomJob" custom_job_resource_1.resource_uri = f"{self._custom_job_uri_prefix}{self._custom_job_name}" custom_job_resource_2 = custom_job_resources.resources.add() custom_job_resource_2.resource_type = "CustomJob" custom_job_resource_2.resource_uri = f"{self._custom_job_uri_prefix}{self._custom_job_name}" with open(self._gcp_resources, "w") as f: f.write(json_format.MessageToJson(custom_job_resources)) with self.assertRaisesRegex( ValueError, "gcp_resources should contain one resource, found 2"): custom_job_remote_runner.create_custom_job(self._type, self._project, self._location, self._payload, self._gcp_resources)