def __init__(self,
              min_records,
              trials,
              num_halving_rounds=float('inf'),
              eta=0,
              bracket_id=None,
              goal_maximize=False,
              pending_trials=None,
              fix_quantiles=False):
     self._min_records = min_records
     self._max_records = min_records * eta**num_halving_rounds
     self._num_halving_rounds = num_halving_rounds
     self._eta = eta
     self._bracket_id = bracket_id
     self._goal_maximize = goal_maximize
     # Total budget
     self._total_records = 0
     self._ladder = [[]]
     self._fix_quantiles = fix_quantiles
     self._pending_trials = {}
     for t in pending_trials:
         self._pending_trials[(_get_metadata(t, 'hyperband_id'),
                               _get_metadata(t,
                                             'termination_record'))] = True
     self._populate_bracket(trials)
 def GetNewSuggestions(self, num_suggestions_hint, completed_trials,
                       pending_trials):
     # brackets reconstructed each time, probably can modify with each new result
     brackets, next_hyperband_id = self._construct_brackets(
         completed_trials, pending_trials)
     trials = []
     for _ in range(num_suggestions_hint):
         bracket_budget = [(idx,
                            numpy.floor(bracket.num_fully_trained *
                                        (bracket.num_halving_rounds + 1) /
                                        (self._max_halving_rounds + 1)))
                           for idx, bracket in enumerate(brackets)]
         bracket_id = sorted(bracket_budget, key=lambda x: x[1] * 10 - x[0])
         print(bracket_id)
         bracket_id = bracket_id[0][0]
         trial = brackets[bracket_id].get_trial()
         if _get_metadata(trial, 'hyperband_id') < 0:
             # new trial
             parameters = self._delegate_policy.GetNewSuggestions(
                 1, completed_trials, pending_trials)[0].parameters
             trial.parameters = parameters
             _set_metadata(trial, 'hyperband_id', next_hyperband_id)
             next_hyperband_id += 1
         trials.append(trial)
     return trials
    def bracket_string(self):
        brk_str = ''

        def smean(S):
            if len(S) == 0:
                return float('NaN')
            else:
                return numpy.mean(S)

        brk_str += 'r_i\tn_i\tbest objective_value\n---------------------\n'
        for idx, rung in enumerate(self._ladder):
            vals = []
            if rung and self._goal_maximize:
                vals = sorted([
                    -max(x.objective_value for x in trial.measurements)
                    for trial in rung if trial.measurements
                ])
                vals = ['%.2f' % -x for x in vals]
            elif rung:
                vals = sorted([
                    min(x.objective_value for x in trial.measurements)
                    for trial in rung if trial.measurements
                ])
                vals = ['%.2f' % x for x in vals]
            brk_str += '%.0f/%.0f, %d, %s\n' % (
                round(self._min_records * self._eta**idx),
                smean([
                    trial.measurements[-1].steps
                    for trial in rung if trial.measurements
                ]), len(vals), str(vals))
        if len(self._ladder[-1]):
            brk_str += 'best trial is: %d\n' % (_get_metadata(
                self._ladder[-1][0], 'hyperband_id'))
        return brk_str
    def get_trial(self):
        # identify the best rung to advance a trial from
        trial_to_advance = None
        #bracket_depth = min(len(self._ladder), self._num_halving_rounds)
        bracket_depth = len(self._ladder)
        trials_already_advanced = {}
        if len(self._ladder) > self._num_halving_rounds:
            trials_already_advanced.update({
                _get_metadata(trial, 'hyperband_id'): True
                for trial in self._ladder[-1]
            })
        for rung in reversed(range(bracket_depth)):
            candidate_trials = self._ladder[rung]
            if len(candidate_trials) >= self._eta:
                #each rung already sorted when populating brackets
                candidate_trials = candidate_trials[
                    0:int(len(self._ladder[rung]) / self._eta)]
                #print([_get_metadata(t,'hyperband_id') for t in candidate_trials])
                while candidate_trials:
                    trial = candidate_trials.pop(0)
                    hyperband_id = _get_metadata(trial, 'hyperband_id')
                    if not trials_already_advanced.get(
                            hyperband_id, False) and (
                                hyperband_id,
                                _get_metadata(trial, 'termination_record') *
                                self._eta) not in self._pending_trials:
                        trial_to_advance = trial
                        break
            if trial_to_advance is not None:
                break
            trials_already_advanced.update({
                _get_metadata(trial, 'hyperband_id'): True
                for trial in self._ladder[rung]
            })

        # advance an existing trial or initialize a new one
        if trial_to_advance:
            parameters = trial_to_advance.parameters
            measurements = trial_to_advance.measurements
            bracket_id = _get_metadata(trial_to_advance, 'bracket_id')
            hyperband_id = _get_metadata(trial_to_advance, 'hyperband_id')
            termination_record = self._eta * _get_metadata(
                trial_to_advance, 'termination_record')
        else:
            parameters = []
            measurements = []
            bracket_id = self._bracket_id
            hyperband_id = -1
            termination_record = self._min_records
        trial = Trial.Trial()
        trial.status = Trial.REQUESTED
        trial.parameters = parameters
        trial.measurements.extend(measurements)
        _set_metadata(trial, 'bracket_id', bracket_id)
        _set_metadata(trial, 'hyperband_id', hyperband_id)
        _set_metadata(trial, 'termination_record', termination_record)
        self._total_records += termination_record
        return trial
 def _construct_brackets(self, completed_trials, pending_trials):
     all_trials = set(completed_trials).union(set(pending_trials))
     num_brackets = self._max_halving_rounds + 1
     next_hyperband_id = -1
     stratified_trials = [[] for _ in range(num_brackets)]
     for trial in all_trials:
         stratified_trials[_get_metadata(trial, 'bracket_id')].append(trial)
         next_hyperband_id = max(next_hyperband_id,
                                 _get_metadata(trial, 'hyperband_id'))
     next_hyperband_id += 1
     brackets = []
     for s in range(num_brackets):
         bracket = Bracket(
             self._min_records * self._eta**(self._max_halving_rounds - s),
             stratified_trials[s], s, self._eta, s, self._goal_maximize, [
                 trial for trial in pending_trials
                 if _get_metadata(trial, 'bracket_id') == s
             ])
         brackets.append(bracket)
     return brackets, next_hyperband_id
    def _construct_brackets(self, completed_trials, pending_trials):
        all_trials = set(completed_trials).union(set(pending_trials))
        next_hyperband_id = -1
        if len(all_trials) > 0:
            next_hyperband_id = max(
                [_get_metadata(t, 'hyperband_id') for t in all_trials])
        next_hyperband_id += 1
        s = self._halving_rounds
        bracket = Bracket(self._min_records, all_trials, s, self._eta, s,
                          self._goal_maximize, pending_trials)

        return bracket, next_hyperband_id
 def _populate_bracket(self, trials):
     for trial in trials:
         if trial.measurements:
             records = trial.measurements[-1].steps
             self._total_records += records
             rung = int(
                 round(log(records / self._min_records) / log(self._eta)))
             while len(self._ladder) <= rung:
                 self._ladder.append([])
             self._ladder[rung].append(trial)
     for rung in range(len(self._ladder)):
         candidate_trials = [
             trial for trial in self._ladder[rung] if
             trial.status == Trial.COMPLETE and not trial.trial_infeasible
         ]
         if self._goal_maximize:
             candidate_trials = sorted(
                 candidate_trials,
                 key=lambda x: -max(m.objective_value
                                    for m in x.measurements))
         else:
             candidate_trials = sorted(
                 candidate_trials,
                 key=lambda x: min(m.objective_value
                                   for m in x.measurements))
         self._ladder[rung] = candidate_trials
         if self._fix_quantiles and rung >= 1:
             trials_below = self._ladder[
                 rung - 1][0:int(len(self._ladder[rung - 1]) / self._eta)]
             ids_to_advance = [
                 _get_metadata(trial, 'hyperband_id')
                 for trial in trials_below
             ]
             for t in reversed(range(len(self._ladder[rung]))):
                 trial = self._ladder[rung][t]
                 if _get_metadata(trial,
                                  'hyperband_id') not in ids_to_advance:
                     #print("Deleting configuration %d from rung %d" % (_get_metadata(trial,'hyperband_id'),rung))
                     #print(len(self._ladder[rung]))
                     del self._ladder[rung][t]
 def GetNewSuggestions(self, num_suggestions_hint, completed_trials,
                       pending_trials):
     # brackets reconstructed each time, probably can modify with each new result
     bracket, next_hyperband_id = self._construct_brackets(
         completed_trials, pending_trials)
     trials = []
     for _ in range(num_suggestions_hint):
         trial = bracket.get_trial()
         if _get_metadata(trial, 'hyperband_id') < 0:
             # new trial
             parameters = self._delegate_policy.GetNewSuggestions(
                 1, completed_trials, pending_trials)[0].parameters
             trial.parameters = parameters
             _set_metadata(trial, 'hyperband_id', next_hyperband_id)
             #print(next_hyperband_id, _get_metadata(trial,'termination_record'))
             next_hyperband_id += 1
         trials.append(trial)
     return trials