예제 #1
0
    def test_job_id_parameters_override(self, mock_load_json,
                                        mock_get_current_time):
        mock_load_json.return_value = _load_test_data('pipeline_job.json')
        mock_get_current_time.return_value = datetime.date(2020, 10, 28)

        api_client = client.AIPlatformClient(
            project_id='test-project', region='us-central1')
        with mock.patch.object(
                api_client, '_submit_job', autospec=True) as mock_submit:
            api_client.create_run_from_job_spec(
                job_spec_path='path/to/pipeline_job.json',
                job_id='my-new-id',
                pipeline_root='gs://bucket/new-blob',
                parameter_values={
                    'text': 'Hello test!',
                    'list': [1, 2, 3],
                })

            golden = _load_test_data('pipeline_job.json')
            golden['name'] = ('projects/test-project/locations/us-central1/'
                              'pipelineJobs/my-new-id')
            golden['displayName'] = 'my-new-id'
            golden['runtimeConfig'][
                'gcsOutputDirectory'] = 'gs://bucket/new-blob'
            golden['runtimeConfig']['parameters']['text'] = {
                'stringValue': 'Hello test!'
            }
            golden['runtimeConfig']['parameters']['list'] = {
                'stringValue': '[1, 2, 3]'
            }
            mock_submit.assert_called_once_with(
                job_spec=golden, job_id='my-new-id')
예제 #2
0
    def test_client_init_with_defaults(self):
        api_client = client.AIPlatformClient(
            project_id='test-project', region='us-central1')

        self.assertEqual(api_client._project_id, 'test-project')
        self.assertEqual(api_client._region, 'us-central1')
        self.assertEqual(api_client._endpoint,
                         'us-central1-aiplatform.googleapis.com')
예제 #3
0
    def test_get_job_success(self, mock_build_client, mock_get_gcp_credential,
                             mock_get_current_time):
        mock_get_current_time.return_value = datetime.date(2020, 10, 28)
        mock_build_client.return_value = _MockClient()
        api_client = client.AIPlatformClient(
            project_id='test-project', region='us-central1')

        self.assertEqual(_EXPECTED_GET_RESPONSE,
                         api_client.get_job(job_id='test-job'))
        mock_get_gcp_credential.assert_called_once()
예제 #4
0
    def test_create_run_from_pipeline_job(self, mock_load_json,
                                          mock_get_current_time):
        mock_load_json.return_value = _load_test_data('pipeline_job.json')
        mock_get_current_time.return_value = datetime.date(2020, 10, 28)

        api_client = client.AIPlatformClient(
            project_id='test-project', region='us-central1')
        with mock.patch.object(
                api_client, '_submit_job', autospec=True) as mock_submit:
            api_client.create_run_from_job_spec(
                job_spec_path='path/to/pipeline_job.json')

            golden = _load_test_data('pipeline_job.json')
            mock_submit.assert_called_once_with(
                job_spec=golden, job_id='sample-test-pipeline-20201028000000')
예제 #5
0
    def test_disable_caching(self, mock_load_json, mock_get_current_time):
        mock_load_json.return_value = _load_test_data('pipeline_job.json')
        mock_get_current_time.return_value = datetime.date(2020, 10, 28)

        api_client = client.AIPlatformClient(
            project_id='test-project', region='us-central1')
        with mock.patch.object(
                api_client, '_submit_job', autospec=True) as mock_submit:
            api_client.create_run_from_job_spec(
                job_spec_path='path/to/pipeline_job.json', enable_caching=False)

            golden = _load_test_data('pipeline_job.json')
            golden = json.loads(
                json.dumps(golden).replace('"enableCache": true',
                                           '"enableCache": false'))
            mock_submit.assert_called_once_with(
                job_spec=golden, job_id='sample-test-pipeline-20201028000000')
예제 #6
0
    def test_advanced_settings(self, mock_load_json, mock_get_current_time):
        mock_load_json.return_value = _load_test_data('pipeline_job.json')
        mock_get_current_time.return_value = datetime.date(2020, 10, 28)

        api_client = client.AIPlatformClient(
            project_id='test-project', region='us-central1')
        with mock.patch.object(
                api_client, '_submit_job', autospec=True) as mock_submit:
            api_client.create_run_from_job_spec(
                job_spec_path='path/to/pipeline_job.json',
                cmek='custom-key',
                service_account='custom-sa',
                network='custom-network')

            golden = _load_test_data('pipeline_job.json')
            golden['encryptionSpec'] = {'kmsKeyName': 'custom-key'}
            golden['serviceAccount'] = 'custom-sa'
            golden['network'] = 'custom-network'
            mock_submit.assert_called_once_with(
                job_spec=golden, job_id='sample-test-pipeline-20201028000000')