class TPESampler(Sampler): def __init__(self, optimize_mode='minimize'): self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) self.cur_sample = None self.index = None self.total_parameters = {} def update_sample_space(self, sample_space): search_space = {} for i, each in enumerate(sample_space): search_space[str(i)] = {'_type': 'choice', '_value': each} self.tpe_tuner.update_search_space(search_space) def generate_samples(self, model_id): self.cur_sample = self.tpe_tuner.generate_parameters(model_id) self.total_parameters[model_id] = self.cur_sample self.index = 0 def receive_result(self, model_id, result): self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result) def choice(self, candidates, mutator, model, index): chosen = self.cur_sample[str(self.index)] self.index += 1 return chosen
class TPESampler(Sampler): def __init__(self, optimize_mode='minimize'): # Move import here to eliminate some warning messages about dill. from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) self.cur_sample: Optional[dict] = None self.index: Optional[int] = None self.total_parameters = {} def update_sample_space(self, sample_space): search_space = {} for i, each in enumerate(sample_space): search_space[str(i)] = {'_type': 'choice', '_value': each} self.tpe_tuner.update_search_space(search_space) def generate_samples(self, model_id): self.cur_sample = self.tpe_tuner.generate_parameters(model_id) self.total_parameters[model_id] = self.cur_sample self.index = 0 def receive_result(self, model_id, result): self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result) def choice(self, candidates, mutator, model, index): assert isinstance(self.index, int) and isinstance( self.cur_sample, dict) chosen = self.cur_sample[str(self.index)] self.index += 1 return chosen