def _check_consolidated_weights(weights, param_names): """Check the rank condition on the linear weights.""" n_constraints, n_params = weights.shape msg_too_many = ( "Too many linear constraints. There can be at most as many linear constraints" "as involved parameters with non-zero weights.\n" ) msg_rank = "The weights for linear constraints must be linearly independent.\n" msg_general = ( "The error occurred for constraints on the following parameters:\n{}\n with " "weighting matrix:\n{}\nIt is possible that you did not specify those " "constraints as linear constraints but as bounds, fixes, increasing or " "decreasing constraints." ) relevant_names = [param_names[i] for i in weights.columns] if n_constraints > n_params: raise InvalidConstraintError( msg_too_many + msg_general.format(relevant_names, weights) ) if np.linalg.matrix_rank(weights) < n_constraints: raise InvalidConstraintError( msg_rank + msg_general.format(relevant_names, weights) )
def check_fixes_and_bounds(constr_info, transformations, parnames): """Check fixes. Warn the user if he fixes a parameter to a value even though that parameter has a different non-nan value in params check that fixes are compatible with other constraints. Args: constr_info (dict): Dict of 1d numpy arrays with info about constraints. transformations (list): Processed transforming constraints. parnames (list): List of parameter names. """ df = pd.DataFrame(constr_info, index=parnames) # Check fixes and bounds are compatible with other constraints prob_msg = ("{} constraints are incompatible with fixes or bounds. " "This is violated for:\n{}") cov_msg = ( "{} constraints are incompatible with fixes or bounds except for the first " "parameter. This is violated for:\n{}") for constr in transformations: if constr["type"] in ["covariance", "sdcorr"]: subset = df.iloc[constr["index"][1:]] if subset["is_fixed_to_value"].any(): problematic = subset[subset["is_fixed_to_value"]].index raise InvalidConstraintError( cov_msg.format(constr["type"], problematic)) if np.isfinite(subset[["lower_bounds", "upper_bounds"]]).any(axis=None): problematic = (subset.replace([-np.inf, np.inf], np.nan).dropna(how="all").index) raise InvalidConstraintError( cov_msg.format(constr["type"], problematic)) elif constr["type"] == "probability": subset = df.iloc[constr["index"]] if subset["is_fixed_to_value"].any(): problematic = subset[subset["is_fixed_to_value"]].index raise InvalidConstraintError( prob_msg.format(constr["type"], problematic)) if np.isfinite(subset[["lower_bounds", "upper_bounds"]]).any(axis=None): problematic = (subset.replace([-np.inf, np.inf], np.nan).dropna(how="all").index) raise InvalidConstraintError( prob_msg.format(constr["type"], problematic)) invalid = df.query("lower_bounds >= upper_bounds")[[ "lower_bounds", "upper_bounds" ]] msg = ("lower_bound must be strictly smaller than upper_bound. " + f"This is violated for:\n{invalid}") if len(invalid) > 0: raise InvalidConstraintError(msg)
def _fail_if_selections_are_incompatible(selected, constraint): if len(selected) <= 1: msg = ( "pairwise equality constraints require mutliple sets of selected " "parameters but there is just one in the following constraint:\n" f"{constraint}") raise InvalidConstraintError(msg) lengths = [len(sel) for sel in selected] if len(set(lengths)) != 1: msg = ( "All sets of selected parameters for pairwise equality constraints need " f"to have the same length. You have lengths {lengths} in constraint:\n" f"{constraint}") raise InvalidConstraintError(msg)
def check_for_incompatible_overlaps(transformations, parnames): """Check that there are no overlaps between constraints that transform parameters. Since the constraints are already consolidated such that only those that transform a parameter are left and all equality constraints are already plugged in, this boils down to checking that no parameter appears more than once. Args: constr_info (dict): Dict of 1d numpy arrays with info about constraints. transformations (list): Processed transforming constraints. parnames (list): List of parameter names. """ all_indices = [] for constr in transformations: all_indices += constr["index"] msg = ( "Transforming constraints such as 'covariance', 'sdcorr', 'probability' " "and 'linear' cannot overlap. This includes overlaps induced by equality " "constraints. This was violated for the following parameters:\n{}") if len(set(all_indices)) < len(all_indices): unique, counts = np.unique(all_indices, return_counts=True) invalid_indices = unique[counts >= 2] invalid_names = [parnames[i] for i in invalid_indices] raise InvalidConstraintError(msg.format(invalid_names))
def check_types(constraints): """Check that no invalid constraint types are requested. Args: constraints (list): List of constraints. Raises: TypeError if invalid constraint types are encountered """ valid_types = { "covariance", "sdcorr", "linear", "probability", "increasing", "decreasing", "equality", "pairwise_equality", "fixed", } for constr in constraints: if constr["type"] not in valid_types: raise InvalidConstraintError( "Invalid constraint_type: {}".format(constr["type"]), )
def _fail_if_duplicates(selected, constraint, param_names): duplicates = _find_duplicates(selected) if duplicates: names = [param_names[i] for i in duplicates] msg = ( "Error while processing constraints. There are duplicates in selected " "parameters. The parameters that were selected more than once are " f"{names}. The problematic constraint is:\n{constraint}") raise InvalidConstraintError(msg)
def _get_selection_field(constraint, selector_case, params_case): """Get the relevant selection field of a constraint.""" selector_case = _get_selector_case(constraint) valid = { "multiple selectors": { "dataframe": {"locs", "queries", "selectors"}, "numpy array": {"locs", "selectors"}, "pytree": {"selectors"}, "series": {"locs", "selectors"}, }, "one selector": { "dataframe": {"loc", "query", "selector"}, "numpy array": {"loc", "selector"}, "pytree": {"selector"}, "series": {"loc", "selector"}, }, } valid = valid[selector_case][params_case] present = set(constraint).intersection(valid) if not present: msg = ( "No valid parameter selection field in constraint. Valid selection fields " f"are {valid}. The constraint is:\n{constraint}") raise InvalidConstraintError(msg) elif len(present) > 1: msg = ( f"Too many parameter selection fields in constraint: {present}. " "Constraints must have exactly one parameter selection field. The " f"constraint was:\n{constraint}") raise InvalidConstraintError(msg) field = list(present)[0] return field
def _check_validity_and_return_evaluation(c, params, skip_checks): """Check that nonlinear constraints are valid. Returns: constaint_eval: Evaluation of constraint at params, if skip_checks if False, else None. """ # ================================================================================== # check functions # ================================================================================== if "func" not in c: raise InvalidConstraintError( "Constraint needs to have entry 'fun', representing the constraint " "function.") if not callable(c["func"]): raise InvalidConstraintError( "Entry 'fun' in nonlinear constraints has be callable.") if "derivative" in c and not callable(c["derivative"]): raise InvalidConstraintError( "Entry 'jac' in nonlinear constraints has be callable.") # ================================================================================== # check bounds # ================================================================================== is_equality_constraint = "value" in c if is_equality_constraint: if "lower_bounds" in c or "upper_bounds" in c: raise InvalidConstraintError( "Only one of 'value' or ('lower_bounds', 'upper_bounds') can be " "passed to a nonlinear constraint.") if not is_equality_constraint: if "lower_bounds" not in c and "upper_bounds" not in c: raise InvalidConstraintError( "For inequality constraint at least one of ('lower_bounds', " "'upper_bounds') has to be passed to the nonlinear constraint." ) if "lower_bounds" in c and "upper_bounds" in c: if not np.all( np.array(c["lower_bounds"]) <= np.array(c["upper_bounds"])): raise InvalidConstraintError( "If lower bounds need to less than or equal to upper bounds.") # ================================================================================== # check selector # ================================================================================== if "selector" in c: if not callable(c["selector"]): raise InvalidConstraintError( f"'selector' entry needs to be callable in constraint {c}.") else: try: c["selector"](params) except Exception: raise InvalidFunctionError( "Error when calling 'selector' function on params in constraint " f" {c}") elif "loc" in c: if not isinstance(params, (pd.Series, pd.DataFrame)): raise InvalidConstraintError( "params needs to be pd.Series or pd.DataFrame to use 'loc' selector in " f"in consrtaint {c}.") try: params.loc[c["loc"]] except (KeyError, IndexError): raise InvalidConstraintError("'loc' string is invalid.") elif "query" in c: if not isinstance(params, pd.DataFrame): raise InvalidConstraintError( "params needs to be pd.DataFrame to use 'query' selector in " f"constraints {c}.") try: params.query(c["query"]) except Exception: raise InvalidConstraintError( f"'query' string is invalid in constraint {c}.") # ================================================================================== # check that constraints can be evaluated # ================================================================================== constraint_eval = None if not skip_checks: selector = _process_selector(c) try: constraint_eval = c["func"](selector(params)) except Exception: raise InvalidFunctionError( f"Error when evaluating function of constraint {c}.") return constraint_eval
def process_selectors(constraints, params, tree_converter, param_names): """Process and harmonize the selector fields of constraints. By selector fields we mean loc, locs, query, queries, selector and selectors entries in constraints. The processed selector fields are called "index" and are integer numpy arrays with positions of parameters in a flattened parameter vector. Args: constraints (list): User provided constraints. params (pytree): User provided params. tree_converter (TreeConverter): NamedTuple with methods to convert between flattend and unflattend parameters. param_names (list): Names of flattened parameters. Used for error messages. Returns: list: List of constraints with additional "index" entry. """ # fast path if constraints in (None, []): return [] if isinstance(constraints, dict): constraints = [constraints] registry = get_registry(extended=True) n_params = len(tree_converter.params_flatten(params)) helper = tree_converter.params_unflatten(np.arange(n_params)) params_case = _get_params_case(params) flat_constraints = [] for constr in constraints: selector_case = _get_selector_case(constr) field = _get_selection_field( constraint=constr, selector_case=selector_case, params_case=params_case, ) evaluator = _get_selection_evaluator( field=field, constraint=constr, params_case=params_case, registry=registry, ) try: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) selected = evaluator(helper) except (KeyboardInterrupt, SystemExit): raise except Exception as e: msg = ( "An error occurred when trying to select parameters for the following " "constraint:\n{constr}") raise InvalidConstraintError(msg) from e if selector_case == "one selector": if np.isscalar(selected): selected = [selected] _fail_if_duplicates(selected, constr, param_names) selected = np.array(selected).astype(int) else: selected = [[sel] if np.isscalar(sel) else sel for sel in selected] _fail_if_selections_are_incompatible(selected, constr) for sel in selected: _fail_if_duplicates(sel, constr, param_names) selected = [np.array(sel).astype(int) for sel in selected] new_constr = constr.copy() if selector_case == "one selector": new_constr["index"] = selected else: new_constr["indices"] = selected if selector_case == "one selector": if len(new_constr["index"]) > 0: flat_constraints.append(new_constr) else: if len(new_constr["indices"][0]) > 0: flat_constraints.append(new_constr) return flat_constraints