def fit_predict_ours(data: dict, random_seed: int, optimization_config: OptimizationConfig, test_intermediates: Optional[dict] = None) -> dict: # Create surrogate model num_dims = len(data['ss_limits']) _gpmodel = GaussianProcessRegression( kernel=Matern52(num_dims, ARD=True), mean=ZeroMeanFunction(), # Instead of ScalarMeanFunction optimization_config=optimization_config, random_seed=random_seed, test_intermediates=test_intermediates) model = GPMXNetModel(data['state'], DEFAULT_METRIC, random_seed, _gpmodel, fit_parameters=True, num_fantasy_samples=20) model_params = model.get_params() print('Hyperparameters: {}'.format(model_params)) # Prediction means, stddevs = model.predict(data['test_inputs'])[0] return {'means': means, 'stddevs': stddevs}
class GPMXNetPendingCandidateStateTransformer(PendingCandidateStateTransformer): """ This class maintains the TuningJobState along an asynchronous GP-based HPO experiment, and manages the reaction to changes of this state. In particular, it provides a GPMXNetModel on demand, which encapsulates the GP posterior. Note: The GPMXNetModel can be accessed only once the state has at least one labeled case, since otherwise no posterior can be computed. skip_optimization is a predicate depending on TuningJobState, determining what is done at the next recent GPMXNetModel computation. If False, the GP hyperparameters are optimized. Otherwise, the current ones are not changed. Safeguard against multiple GP hyperparameter optimization while labeled data does not change: The posterior has to be recomputed every time the state changes, even if this only concerns pending evaluations. The expensive part of this is refitting the GP hyperparameters, which makes sense only when the labeled data in the state changes. We put a safeguard in place to avoid refitting when the labeled data is unchanged. """ def __init__( self, gpmodel: GPModel, init_state: TuningJobState, model_args: GPMXNetModelArgs, skip_optimization: SkipOptimizationPredicate = None, profiler: GPMXNetSimpleProfiler = None, debug_log: Optional[DebugLogPrinter] = None): self._gpmodel = gpmodel self._state = copy.copy(init_state) self._model_args = model_args if skip_optimization is None: self.skip_optimization = NeverSkipPredicate() else: self.skip_optimization = skip_optimization self._profiler = profiler self._debug_log = debug_log # GPMXNetModel computed on demand self._model: GPMXNetModel = None self._candidate_evaluations = None # _model_params is returned by get_params. Careful: This is not just # self._gpmodel.get_params(), since the current GPMXNetModel may append # additional parameters self._model_params = gpmodel.get_params() @property def state(self) -> TuningJobState: return self._state def model(self, **kwargs) -> GPMXNetModel: """ If skip_optimization is given, it overrides the self.skip_optimization predicate. :return: GPMXNetModel for current state """ if self._model is None: skip_optimization = kwargs.get('skip_optimization') self._compute_model(skip_optimization=skip_optimization) return self._model def get_params(self): return self._model_params def set_params(self, param_dict): self._gpmodel.set_params(param_dict) self._model_params = self._gpmodel.get_params() def append_candidate(self, candidate: Candidate): """ Appends new pending candidate to the state. :param candidate: New pending candidate """ self._model = None # Invalidate self._state.pending_evaluations.append(PendingEvaluation(candidate)) @staticmethod def _find_candidate(candidate: Candidate, lst: List): try: pos = next( i for i, x in enumerate(lst) if x.candidate == candidate) except StopIteration: pos = -1 return pos def drop_candidate(self, candidate: Candidate): """ Drop candidate (labeled or pending) from state. :param candidate: Candidate to be dropped """ # Candidate may be labeled or pending. First, try labeled pos = self._find_candidate( candidate, self._state.candidate_evaluations) if pos != -1: self._model = None # Invalidate self._state.candidate_evaluations.pop(pos) if self._debug_log is not None: deb_msg = "[GPMXNetAsyncPendingCandidateStateTransformer.drop_candidate]\n" deb_msg += ("- len(candidate_evaluations) afterwards = {}".format( len(self.state.candidate_evaluations))) logger.info(deb_msg) else: # Try pending pos = self._find_candidate( candidate, self._state.pending_evaluations) assert pos != -1, \ "Candidate {} not registered (neither labeled, nor pending)".format( candidate) self._model = None # Invalidate self._state.pending_evaluations.pop(pos) if self._debug_log is not None: deb_msg = "[GPMXNetAsyncPendingCandidateStateTransformer.drop_candidate]\n" deb_msg += ("- len(pending_evaluations) afterwards = {}\n".format( len(self.state.pending_evaluations))) logger.info(deb_msg) def label_candidate(self, data: CandidateEvaluation): """ Adds a labeled candidate. If it was pending before, it is removed as pending candidate. :param data: New labeled candidate """ pos = self._find_candidate( data.candidate, self._state.pending_evaluations) if pos != -1: self._state.pending_evaluations.pop(pos) self._state.candidate_evaluations.append(data) self._model = None # Invalidate def filter_pending_evaluations( self, filter_pred: Callable[[PendingEvaluation], bool]): """ Filters state.pending_evaluations with filter_pred. :param filter_pred Filtering predicate """ new_pending_evaluations = list(filter( filter_pred, self._state.pending_evaluations)) if len(new_pending_evaluations) != len(self._state.pending_evaluations): if self._debug_log is not None: deb_msg = "[GPMXNetAsyncPendingCandidateStateTransformer.filter_pending_evaluations]\n" deb_msg += ("- from len {} to {}".format( len(self.state.pending_evaluations), len(new_pending_evaluations))) logger.info(deb_msg) self._model = None # Invalidate del self._state.pending_evaluations[:] self._state.pending_evaluations.extend(new_pending_evaluations) def mark_candidate_failed(self, candidate: Candidate): self._state.failed_candidates.append(candidate) def _compute_model(self, skip_optimization: bool = None): args = self._model_args if skip_optimization is None: skip_optimization = self.skip_optimization(self._state) fit_parameters = not skip_optimization if fit_parameters and self._candidate_evaluations: # Did the labeled data really change since the last recent refit? # If not, skip the refitting if self._state.candidate_evaluations == self._candidate_evaluations: fit_parameters = False logger.warning( "Skipping the refitting of GP hyperparameters, since the " "labeled data did not change since the last recent fit") self._model = GPMXNetModel( state=self._state, active_metric=args.active_metric, random_seed=args.random_seed, gpmodel=self._gpmodel, fit_parameters=fit_parameters, num_fantasy_samples=args.num_fantasy_samples, normalize_targets=args.normalize_targets, profiler=self._profiler, debug_log=self._debug_log) # Note: This may be different than self._gpmodel.get_params(), since # the GPMXNetModel may append additional info self._model_params = self._model.get_params() if fit_parameters: # Keep copy of labeled data in order to avoid unnecessary # refitting self._candidate_evaluations = copy.copy( self._state.candidate_evaluations)