def testDoWithTuneArgs(self): executor = ai_platform_tuner_executor.Executor() self._exec_properties['tune_args'] = json_format.MessageToJson( message=tuner_pb2.TuneArgs(num_parallel_trials=3), preserving_proto_field_name=True) executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test()) self.mock_runner.start_aip_training.assert_called_with( self._inputs, self._outputs, self._serialize_custom_config_under_test(), self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, 'scaleTier': 'CUSTOM', 'masterType': 'standard', 'workerType': 'standard', 'workerCount': 2, }, mock.ANY)
def testDoWithTuneArgsAndTrainingInputOverride(self): executor = ai_platform_tuner_executor.Executor() self._exec_properties['tune_args'] = proto_utils.proto_to_json( tuner_pb2.TuneArgs(num_parallel_trials=6)) self._exec_properties['custom_config'][ ai_platform_tuner_executor.TUNING_ARGS_KEY].update({ 'scaleTier': 'CUSTOM', 'masterType': 'n1-highmem-16', 'workerType': 'n1-highmem-16', 'workerCount': 2, }) executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test()) self.mock_runner.start_cloud_training.assert_called_with( self._inputs, self._outputs, self._serialize_custom_config_under_test(), self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, # Confirm scale tier and machine types are not overritten. 'scaleTier': 'CUSTOM', 'masterType': 'n1-highmem-16', 'workerType': 'n1-highmem-16', # Confirm workerCount has been adjusted to num_parallel_trials. 'workerCount': 5, }, self._job_id, False, None)
def testDo(self): executor = ai_platform_tuner_executor.Executor() executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test())
def testDoWithoutCustomCaipTuneArgs(self): executor = ai_platform_tuner_executor.Executor() self._exec_properties = {'custom_config': {}} with self.assertRaises(ValueError): executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test())