def GetSuggestions(self, request, context): """ Main function to provide suggestion. """ algorithm_name, config = OptimizerConfiguration.convertAlgorithmSpec( request.experiment.spec.algorithm) if algorithm_name != "bayesianoptimization": raise Exception( "Failed to create the algorithm: {}".format(algorithm_name)) if self.is_first_run: search_space = HyperParameterSearchSpace.convert( request.experiment) self.base_service = BaseSkoptService( base_estimator=config.base_estimator, n_initial_points=config.n_initial_points, acq_func=config.acq_func, acq_optimizer=config.acq_optimizer, random_state=config.random_state, search_space=search_space) self.is_first_run = False trials = Trial.convert(request.trials) new_trials = self.base_service.getSuggestions(trials, request.request_number) return api_pb2.GetSuggestionsReply( parameter_assignments=Assignment.generate(new_trials))
class SkoptService(api_pb2_grpc.SuggestionServicer, HealthServicer): def __init__(self): super(SkoptService, self).__init__() self.base_service = None self.is_first_run = True def GetSuggestions(self, request, context): """ Main function to provide suggestion. """ algorithm_name, config = OptimizerConfiguration.convert_algorithm_spec( request.experiment.spec.algorithm) if self.is_first_run: search_space = HyperParameterSearchSpace.convert( request.experiment) self.base_service = BaseSkoptService( base_estimator=config.base_estimator, n_initial_points=config.n_initial_points, acq_func=config.acq_func, acq_optimizer=config.acq_optimizer, random_state=config.random_state, search_space=search_space) self.is_first_run = False trials = Trial.convert(request.trials) new_trials = self.base_service.getSuggestions(trials, request.request_number) return api_pb2.GetSuggestionsReply( parameter_assignments=Assignment.generate(new_trials)) def ValidateAlgorithmSettings(self, request, context): is_valid, message = OptimizerConfiguration.validate_algorithm_spec( request.experiment.spec.algorithm) if not is_valid: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details(message) logger.error(message) return api_pb2.ValidateAlgorithmSettingsReply()