Example #1
0
    def test_request_dict_without_user_args(self, MockDiscovery):
        self.setup(MockDiscovery)
        job_name = deploy.deploy_job(
            self.region,
            self.docker_img,
            self.chief_config,
            self.worker_count,
            self.worker_config,
            None,
            self.stream_logs,
        )
        build_ret_val = MockDiscovery.build.return_value
        proj_ret_val = build_ret_val.projects.return_value
        jobs_ret_val = proj_ret_val.jobs.return_value

        del self.expected_request_dict["trainingInput"]["args"]

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )
Example #2
0
    def test_request_dict_without_workers(self):
        worker_count = 0

        _ = deploy.deploy_job(self.docker_img,
                              self.chief_config,
                              worker_count,
                              None,
                              self.entry_point_args,
                              self.stream_logs,
                              service_account=self.service_account)
        build_ret_val = self._mock_discovery_build.return_value
        proj_ret_val = build_ret_val.projects.return_value
        jobs_ret_val = proj_ret_val.jobs.return_value

        self.expected_request_dict["trainingInput"]["workerCount"] = str(
            worker_count)
        del self.expected_request_dict["trainingInput"]["workerType"]
        del self.expected_request_dict["trainingInput"]["workerConfig"]

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )
Example #3
0
    def test_deploy_job_with_default_service_account_has_no_serviceaccount_key(
            self):
        # If user does not provide a service account (i.e. service_account=None,
        # the service account key should not be included in the request dict as
        # AI Platform will treat None as the name of the service account.
        _ = deploy.deploy_job(
            self.docker_img,
            self.chief_config,
            self.worker_count,
            self.worker_config,
            self.entry_point_args,
            self.stream_logs,
        )
        build_ret_val = self._mock_discovery_build.return_value
        proj_ret_val = build_ret_val.projects.return_value
        jobs_ret_val = proj_ret_val.jobs.return_value

        del self.expected_request_dict["trainingInput"]["serviceAccount"]

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )
Example #4
0
    def test_logs_streaming_error(self, mock_subprocess_popen):
        chief_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"]
        worker_config = machine_config.COMMON_MACHINE_CONFIGS["TPU"]
        worker_count = 1

        mock_subprocess_popen.side_effect = ValueError("error")
        self.stream_logs = True

        with self.assertRaises(ValueError):
            deploy.deploy_job(
                self.docker_img,
                chief_config,
                worker_count,
                worker_config,
                self.entry_point_args,
                self.stream_logs,
            )
Example #5
0
    def test_deploy_job_error(self):
        chief_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"]
        worker_config = machine_config.COMMON_MACHINE_CONFIGS["TPU"]
        worker_count = 1

        build_ret_val = self._mock_discovery_build.return_value
        build_ret_val.projects.side_effect = errors.HttpError(
            mock.Mock(status=404), b"not found")

        with self.assertRaises(errors.HttpError):
            deploy.deploy_job(
                self.docker_img,
                chief_config,
                worker_count,
                worker_config,
                self.entry_point_args,
                self.stream_logs,
            )
Example #6
0
    def test_deploy_job(self, mock_stdout):
        job_name = deploy.deploy_job(
            self.docker_img,
            self.chief_config,
            self.worker_count,
            self.worker_config,
            self.entry_point_args,
            self.stream_logs,
        )

        self.assertEqual(job_name, self.mock_job_id)

        # Verify discovery API is invoked as expected.
        self.assertEqual(self._mock_discovery_build.call_count, 1)
        args, kwargs = self._mock_discovery_build.call_args
        self.assertListEqual(list(args), ["ml", "v1"])
        self.assertDictEqual(
            kwargs,
            {
                "cache_discovery": False,
                "requestBuilder": google_api_client.TFCloudHttpRequest,
            },
        )

        # Verify job is created as expected
        build_ret_val = self._mock_discovery_build.return_value
        self.assertEqual(build_ret_val.projects.call_count, 1)

        proj_ret_val = build_ret_val.projects.return_value
        self.assertEqual(proj_ret_val.jobs.call_count, 1)

        jobs_ret_val = proj_ret_val.jobs.return_value
        self.assertEqual(jobs_ret_val.create.call_count, 1)

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )

        # Verify print statement
        self.assertEqual(
            mock_stdout.getvalue(),
            "\nJob submitted successfully."
            "\nYour job ID is:  {}\n"
            "\nPlease access your training job information here:\nhttps://"
            "console.cloud.google.com/mlengine/jobs/{}?project={}\n"
            "\nPlease access your training job logs here: "
            "https://console.cloud.google.com/logs/viewer?resource=ml_job%2F"
            "job_id%2F{}&interval=NO_LIMIT&project={}\n\n".format(
                self.mock_job_id, self.mock_job_id, self.mock_project_name,
                self.mock_job_id, self.mock_project_name),
        )
Example #7
0
    def test_deploy_job(self, mock_discovery, mock_stdout):
        self.setup(mock_discovery)

        job_name = deploy.deploy_job(
            self.region,
            self.docker_img,
            self.chief_config,
            self.worker_count,
            self.worker_config,
            self.entry_point_args,
            self.stream_logs,
        )

        self.assertEqual(job_name, self.mock_job_id)

        # Verify discovery API is invoked as expected.
        self.assertEqual(mock_discovery.build.call_count, 1)
        args, kwargs = mock_discovery.build.call_args
        self.assertListEqual(list(args), ["ml", "v1"])
        self.assertDictEqual(
            kwargs,
            {
                "cache_discovery": False,
                "requestBuilder": google_api_client.TFCloudHttpRequest,
            },
        )

        # Verify job is created as expected
        build_ret_val = mock_discovery.build.return_value
        self.assertEqual(build_ret_val.projects.call_count, 1)
        proj_ret_val = build_ret_val.projects.return_value
        self.assertEqual(proj_ret_val.jobs.call_count, 1)
        jobs_ret_val = proj_ret_val.jobs.return_value
        self.assertEqual(jobs_ret_val.create.call_count, 1)

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )

        # Verify print statement
        self.assertEqual(
            mock_stdout.getvalue(),
            "Job submitted successfully.\nYour job ID is:  {}\nPlease access "
            "your job logs at the following URL:\nhttps://"
            "console.cloud.google.com/mlengine/jobs/{}?project={}\n".format(
                self.mock_job_id, self.mock_job_id, self.mock_project_name),
        )
Example #8
0
    def DISABLED_test_request_dict_with_TPU_worker(self, MockDiscovery):
        # TODO(psv): Fix broken test.
        self.setup(MockDiscovery)
        chief_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"]
        worker_config = machine_config.COMMON_MACHINE_CONFIGS["TPU"]
        worker_count = 1

        job_name = deploy.deploy_job(
            self.region,
            self.docker_img,
            chief_config,
            worker_count,
            worker_config,
            self.entry_point_args,
            self.stream_logs,
        )
        build_ret_val = MockDiscovery.build.return_value
        proj_ret_val = build_ret_val.projects.return_value
        jobs_ret_val = proj_ret_val.jobs.return_value

        self.expected_request_dict["trainingInput"]["workerCount"] = "1"
        self.expected_request_dict["trainingInput"]["workerType"] = "cloud_tpu"
        self.expected_request_dict["trainingInput"]["masterType"] = "n1-standard-4"
        self.expected_request_dict["trainingInput"]["workerConfig"][
            "acceleratorConfig"
        ]["type"] = "TPU_V3"
        self.expected_request_dict["trainingInput"]["workerConfig"][
            "acceleratorConfig"
        ]["count"] = "8"
        self.expected_request_dict["trainingInput"]["workerConfig"][
            "tpuTfVersion"
        ] = "2.1"
        self.expected_request_dict["trainingInput"]["masterConfig"][
            "acceleratorConfig"
        ]["type"] = "ACCELERATOR_TYPE_UNSPECIFIED"
        self.expected_request_dict["trainingInput"]["masterConfig"][
            "acceleratorConfig"
        ]["count"] = "0"

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )
Example #9
0
    def test_request_dict_with_tpu_worker(self):
        chief_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"]
        worker_config = machine_config.COMMON_MACHINE_CONFIGS["TPU"]
        worker_count = 1

        _ = deploy.deploy_job(
            self.region,
            self.docker_img,
            chief_config,
            worker_count,
            worker_config,
            self.entry_point_args,
            self.stream_logs,
        )
        build_ret_val = self._mock_discovery_build.return_value
        proj_ret_val = build_ret_val.projects.return_value
        jobs_ret_val = proj_ret_val.jobs.return_value

        self.expected_request_dict["trainingInput"]["workerCount"] = "1"
        self.expected_request_dict["trainingInput"]["workerType"] = "cloud_tpu"
        self.expected_request_dict["trainingInput"]["masterType"] = (
            "n1-standard-4")
        self.expected_request_dict["trainingInput"]["workerConfig"][
            "acceleratorConfig"]["type"] = "TPU_V3"
        self.expected_request_dict["trainingInput"]["workerConfig"][
            "acceleratorConfig"]["count"] = "8"
        v = deploy.VERSION.split(".")
        self.expected_request_dict["trainingInput"]["workerConfig"][
            "tpuTfVersion"] = v[0] + "." + v[1]
        self.expected_request_dict["trainingInput"]["masterConfig"][
            "acceleratorConfig"]["type"] = "ACCELERATOR_TYPE_UNSPECIFIED"
        self.expected_request_dict["trainingInput"]["masterConfig"][
            "acceleratorConfig"]["count"] = "0"

        # Verify job creation args
        _, kwargs = jobs_ret_val.create.call_args
        self.assertDictEqual(
            kwargs,
            {
                "parent": "projects/" + self.mock_project_name,
                "body": self.expected_request_dict,
            },
        )