def GetSuggestions(self, request, context): """ Main function to provide suggestion. """ if self.is_first_run: search_space = HyperParameterSearchSpace.convert( request.experiment) self.base_service = BaseChocolateService( algorithm_name=request.experiment.spec.algorithm. algorithm_name, search_space=search_space) self.is_first_run = False trials = Trial.convert(request.trials) new_assignments = self.base_service.getSuggestions( trials, request.request_number) return api_pb2.GetSuggestionsReply( parameter_assignments=Assignment.generate(new_assignments))
class ChocolateService(api_pb2_grpc.SuggestionServicer, HealthServicer): def __init__(self): super(ChocolateService, self).__init__() self.base_service = None self.is_first_run = True def ValidateAlgorithmSettings(self, request, context): algorithm_name = request.experiment.spec.algorithm.algorithm_name if algorithm_name == "grid": search_space = HyperParameterSearchSpace.convert( request.experiment) for param in search_space.params: if param.type == DOUBLE: if param.step == "" or param.step is None: return self._set_validate_context_error( context, "param {} step is nil".format(param.name)) return api_pb2.ValidateAlgorithmSettingsReply() def GetSuggestions(self, request, context): """ Main function to provide suggestion. """ if self.is_first_run: search_space = HyperParameterSearchSpace.convert( request.experiment) self.base_service = BaseChocolateService( algorithm_name=request.experiment.spec.algorithm. algorithm_name, search_space=search_space) self.is_first_run = False trials = Trial.convert(request.trials) new_assignments = self.base_service.getSuggestions( trials, request.request_number) return api_pb2.GetSuggestionsReply( parameter_assignments=Assignment.generate(new_assignments)) def _set_validate_context_error(self, context, error_message): context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details(error_message) logger.info(error_message) return api_pb2.ValidateAlgorithmSettingsReply()
class ChocolateService(api_pb2_grpc.SuggestionServicer, HealthServicer): def __init__(self): super(ChocolateService, self).__init__() self.base_service = None self.is_first_run = True def ValidateAlgorithmSettings(self, request, context): algorithm_name = request.experiment.spec.algorithm.algorithm_name if algorithm_name == "grid": search_space = HyperParameterSearchSpace.convert( request.experiment) available_space = {} for param in search_space.params: if param.type == INTEGER: available_space[param.name] = range( int(param.min), int(param.max) + 1, int(param.step)) elif param.type == DOUBLE: if param.step == "" or param.step is None: return self._set_validate_context_error( context, "Param: {} step is nil".format(param.name)) double_list = np.arange( float(param.min), float(param.max) + float(param.step), float(param.step)) if double_list[-1] > float(param.max): double_list = double_list[:-1] available_space[param.name] = double_list elif param.type == CATEGORICAL or param.type == DISCRETE: available_space[param.name] = param.list num_combinations = len( list(itertools.product(*available_space.values()))) max_trial_count = request.experiment.spec.max_trial_count if max_trial_count > num_combinations: return self._set_validate_context_error( context, "Max Trial Count: {} > all possible search space combinations: {}" .format(max_trial_count, num_combinations)) return api_pb2.ValidateAlgorithmSettingsReply() def GetSuggestions(self, request, context): """ Main function to provide suggestion. """ if self.is_first_run: search_space = HyperParameterSearchSpace.convert( request.experiment) self.base_service = BaseChocolateService( algorithm_name=request.experiment.spec.algorithm. algorithm_name, search_space=search_space) self.is_first_run = False trials = Trial.convert(request.trials) new_assignments = self.base_service.getSuggestions( trials, request.current_request_number, request.total_request_number) return api_pb2.GetSuggestionsReply( parameter_assignments=Assignment.generate(new_assignments)) def _set_validate_context_error(self, context, error_message): context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details(error_message) logger.info(error_message) return api_pb2.ValidateAlgorithmSettingsReply()