def check_all_parameters_present( self, parameterization: TParameterization, raise_error: bool = False, ) -> bool: """Whether a given parameterization contains all the parameters in the search space. Args: parameterization: Dict from parameter name to value to validate. raise_error: If true parameterization does not belong, raises an error with detailed explanation of why. Returns: Whether the parameterization is contained in the search space. """ parameterization_params = set(parameterization.keys()) ss_params = set(self._parameters.keys()) if parameterization_params != ss_params: if raise_error: raise ValueError( f"Parameterization has parameters: {parameterization_params}, " f"but search space has parameters: {ss_params}.") return False return True
def check_membership( self, parameterization: TParameterization, raise_error: bool = False, check_all_parameters_present: bool = True, ) -> bool: """Whether the given parameterization belongs in the search space. Checks that the given parameter values have the same name/type as search space parameters, are contained in the search space domain, and satisfy the parameter constraints. Args: parameterization: Dict from parameter name to value to validate. raise_error: If true parameterization does not belong, raises an error with detailed explanation of why. check_all_parameters_present: Ensure that parameterization specifies values for all parameters as expected by the search space and its hierarchical structure. Returns: Whether the parameterization is contained in the search space. """ super().check_membership( parameterization=parameterization, raise_error=raise_error, check_all_parameters_present=False, ) # Check that each arm "belongs" in the hierarchical # search space; ensure that it only has the parameters that make sense # with each other (and does not contain dependent parameters if the # parameter they depend on does not have the correct value). cast_to_hss_params = set( self._cast_parameterization( parameters=parameterization, check_all_parameters_present=check_all_parameters_present, ).keys() ) parameterization_params = set(parameterization.keys()) if cast_to_hss_params != parameterization_params: if raise_error: raise ValueError( "Parameterization violates the hierarchical structure of the search" f"space; cast version would have parameters: {cast_to_hss_params}," f" but full version contains parameters: {parameterization_params}." ) return False return True
def _get_sum(parameterization: TParameterization) -> float: param_names = list(parameterization.keys()) if any(param_name not in param_names for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") x1, x2 = parameterization["x1"], parameterization["x2"] return x1 + x2