Esempio n. 1
0
    def testDoWithTuneArgs(self):
        executor = ai_platform_tuner_executor.Executor()
        self._exec_properties['tune_args'] = json_format.MessageToJson(
            message=tuner_pb2.TuneArgs(num_parallel_trials=3),
            preserving_proto_field_name=True)

        executor.Do(self._inputs, self._outputs,
                    self._serialize_custom_config_under_test())

        self.mock_runner.start_aip_training.assert_called_with(
            self._inputs, self._outputs,
            self._serialize_custom_config_under_test(),
            self._executor_class_path, {
                'project': self._project_id,
                'jobDir': self._job_dir,
                'scaleTier': 'CUSTOM',
                'masterType': 'standard',
                'workerType': 'standard',
                'workerCount': 2,
            }, mock.ANY)
Esempio n. 2
0
    def testDoWithTuneArgsAndTrainingInputOverride(self):
        executor = ai_platform_tuner_executor.Executor()
        self._exec_properties['tune_args'] = proto_utils.proto_to_json(
            tuner_pb2.TuneArgs(num_parallel_trials=6))

        self._exec_properties['custom_config'][
            ai_platform_tuner_executor.TUNING_ARGS_KEY].update({
                'scaleTier':
                'CUSTOM',
                'masterType':
                'n1-highmem-16',
                'workerType':
                'n1-highmem-16',
                'workerCount':
                2,
            })

        executor.Do(self._inputs, self._outputs,
                    self._serialize_custom_config_under_test())

        self.mock_runner.start_cloud_training.assert_called_with(
            self._inputs,
            self._outputs,
            self._serialize_custom_config_under_test(),
            self._executor_class_path,
            {
                'project': self._project_id,
                'jobDir': self._job_dir,
                # Confirm scale tier and machine types are not overritten.
                'scaleTier': 'CUSTOM',
                'masterType': 'n1-highmem-16',
                'workerType': 'n1-highmem-16',
                # Confirm workerCount has been adjusted to num_parallel_trials.
                'workerCount': 5,
            },
            self._job_id,
            False,
            None)
Esempio n. 3
0
 def testDo(self):
     executor = ai_platform_tuner_executor.Executor()
     executor.Do(self._inputs, self._outputs,
                 self._serialize_custom_config_under_test())
Esempio n. 4
0
 def testDoWithoutCustomCaipTuneArgs(self):
     executor = ai_platform_tuner_executor.Executor()
     self._exec_properties = {'custom_config': {}}
     with self.assertRaises(ValueError):
         executor.Do(self._inputs, self._outputs,
                     self._serialize_custom_config_under_test())