def test_request_dict_without_user_args(self, MockDiscovery): self.setup(MockDiscovery) job_name = deploy.deploy_job( self.region, self.docker_img, self.chief_config, self.worker_count, self.worker_config, None, self.stream_logs, ) build_ret_val = MockDiscovery.build.return_value proj_ret_val = build_ret_val.projects.return_value jobs_ret_val = proj_ret_val.jobs.return_value del self.expected_request_dict["trainingInput"]["args"] # Verify job creation args _, kwargs = jobs_ret_val.create.call_args self.assertDictEqual( kwargs, { "parent": "projects/" + self.mock_project_name, "body": self.expected_request_dict, }, )
def test_deploy_job(self, MockDiscovery, MockStdOut): self.setup(MockDiscovery) job_name = deploy.deploy_job( self.region, self.docker_img, self.chief_config, self.worker_count, self.worker_config, self.entry_point_args, self.stream_logs, ) self.assertEqual(job_name, self.mock_job_id) # Verify discovery API is invoked as expected. self.assertEqual(MockDiscovery.build.call_count, 1) args, _ = MockDiscovery.build.call_args self.assertListEqual(list(args), ["ml", "v1"]) # Verify job is created as expected build_ret_val = MockDiscovery.build.return_value self.assertEqual(build_ret_val.projects.call_count, 1) proj_ret_val = build_ret_val.projects.return_value self.assertEqual(proj_ret_val.jobs.call_count, 1) jobs_ret_val = proj_ret_val.jobs.return_value self.assertEqual(jobs_ret_val.create.call_count, 1) # Verify job creation args _, kwargs = jobs_ret_val.create.call_args self.assertDictEqual( kwargs, { "parent": "projects/" + self.mock_project_name, "body": self.expected_request_dict, }, ) # Verify print statement self.assertEqual( MockStdOut.getvalue(), "Job submitted successfully.\nYour job ID is: {}\nPlease access " "your job logs at the following URL:\nhttps://" "console.cloud.google.com/mlengine/jobs/{}?project={}\n".format( self.mock_job_id, self.mock_job_id, self.mock_project_name), )
def test_request_dict_with_TPU_worker(self, MockDiscovery): self.setup(MockDiscovery) chief_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"] worker_config = machine_config.COMMON_MACHINE_CONFIGS["TPU"] worker_count = 1 job_name = deploy.deploy_job( self.region, self.docker_img, chief_config, worker_count, worker_config, self.entry_point_args, self.stream_logs, ) build_ret_val = MockDiscovery.build.return_value proj_ret_val = build_ret_val.projects.return_value jobs_ret_val = proj_ret_val.jobs.return_value self.expected_request_dict["trainingInput"]["workerCount"] = "1" self.expected_request_dict["trainingInput"]["workerType"] = "cloud_tpu" self.expected_request_dict["trainingInput"][ "masterType"] = "n1-standard-4" self.expected_request_dict["trainingInput"]["workerConfig"][ "acceleratorConfig"]["type"] = "TPU_V3" self.expected_request_dict["trainingInput"]["workerConfig"][ "acceleratorConfig"]["count"] = "8" self.expected_request_dict["trainingInput"]["workerConfig"][ "tpuTfVersion"] = "2.1" self.expected_request_dict["trainingInput"]["masterConfig"][ "acceleratorConfig"]["type"] = "ACCELERATOR_TYPE_UNSPECIFIED" self.expected_request_dict["trainingInput"]["masterConfig"][ "acceleratorConfig"]["count"] = "0" # Verify job creation args _, kwargs = jobs_ret_val.create.call_args self.assertDictEqual( kwargs, { "parent": "projects/" + self.mock_project_name, "body": self.expected_request_dict, }, )
def test_request_dict_without_workers(self, MockDiscovery): self.setup(MockDiscovery) worker_count = 0 job_name = deploy.deploy_job(self.region, self.docker_img, self.chief_config, worker_count, None, self.entry_point_args, self.stream_logs) build_ret_val = MockDiscovery.build.return_value proj_ret_val = build_ret_val.projects.return_value jobs_ret_val = proj_ret_val.jobs.return_value self.expected_request_dict['trainingInput']['workerCount'] = str( worker_count) del self.expected_request_dict['trainingInput']['workerType'] del self.expected_request_dict['trainingInput']['workerConfig'] # Verify job creation args _, kwargs = jobs_ret_val.create.call_args self.assertDictEqual( kwargs, { 'parent': 'projects/' + self.mock_project_name, 'body': self.expected_request_dict })