# Alternative syntax, which assumes cv receives 'groups' by default, and that a # method-based API is provided on meta-estimators: # lr = LogisticRegressionCV( # cv=GroupKFold(), # scoring='accuracy', # ).add_prop_route(scoring='sample_weight') cross_validate( lr, X, y, cv=GroupKFold(), props={ 'sample_weight': my_weights, 'groups': my_groups }, scoring='accuracy', prop_routing={ 'estimator': '*', # pass all props 'cv': ['groups'], 'scoring': ['sample_weight'], }) # Error handling: if props={'sample_eight': my_weights, ...} was passed # instead, LogisticRegressionCV would have to identify that a key was passed # that could not be routed nor used, in order to raise an error. # %% # Case B: weighted scoring and unweighted fitting
def fit(self, X, y): return super().fit(X, y, sample_weight=MY_WEIGHTS.loc[X.index]) acc_scorer = get_scorer('accuracy') def wrapped_weighted_acc(est, X, y, sample_weight=None): return acc_scorer(est, X, y, sample_weight=MY_WEIGHTS.loc[X.index]) lr = WeightedLogisticRegressionCV( cv=wrapped_group_cv, scoring=wrapped_weighted_acc, ).set_props_request(['sample_weight']) cross_validate(lr, X, y, cv=wrapped_group_cv, scoring=wrapped_weighted_acc) # %% # Case B: weighted scoring and unweighted fitting lr = LogisticRegressionCV( cv=wrapped_group_cv, scoring=wrapped_weighted_acc, ).set_props_request(['sample_weight']) cross_validate(lr, X, y, cv=wrapped_group_cv, scoring=wrapped_weighted_acc) # %% # Case C: unweighted feature selection lr = WeightedLogisticRegressionCV( cv=wrapped_group_cv,
# Here we presume that GroupKFold requests `groups` by default. # We need to explicitly request weights in make_scorer and for # LogisticRegressionCV. Both of these consumers understand the meaning # of the key "sample_weight". weighted_acc = make_scorer(accuracy, request_props=['sample_weight']) lr = LogisticRegressionCV( cv=group_cv, scoring=weighted_acc, ).request_sample_weight(fit=['sample_weight']) cross_validate(lr, X, y, cv=group_cv, props={ 'sample_weight': my_weights, 'groups': my_groups }, scoring=weighted_acc) # Error handling: if props={'sample_eight': my_weights, ...} was passed, # cross_validate would raise an error, since 'sample_eight' was not requested # by any of its children. # %% # Case B: weighted scoring and unweighted fitting # Since LogisticRegressionCV requires that weights explicitly be requested, # removing that request means the fitting is unweighted.
acc_scorer = get_scorer('accuracy') def wrapped_weighted_acc(est, X, y, sample_weight=None): return acc_scorer(est, unwrap_X(X), y, sample_weight=X[:, WEIGHT_IDX]) lr = WrappedLogisticRegressionCV( cv=wrapped_group_cv, scoring=wrapped_weighted_acc, ).set_props_request(['sample_weight']) cross_validate(lr, np.hstack([X, my_weights, my_groups]), y, cv=wrapped_group_cv, scoring=wrapped_weighted_acc) # %% # Case B: weighted scoring and unweighted fitting class UnweightedWrappedLogisticRegressionCV(LogisticRegressionCV): def fit(self, X, y): return super().fit(unwrap_X(X), y) lr = UnweightedWrappedLogisticRegressionCV( cv=wrapped_group_cv, scoring=wrapped_weighted_acc,
from defs import (accuracy_score, GroupKFold, make_scorer, SelectKBest, LogisticRegressionCV, cross_validate, make_pipeline, X, y, my_groups, my_weights, my_other_weights) # %% # Case A: weighted scoring and fitting lr = LogisticRegressionCV( cv=GroupKFold(), scoring='accuracy', ) cross_validate(lr, X, y, cv=GroupKFold(), props={ 'sample_weight': my_weights, 'groups': my_groups }, scoring='accuracy') # Error handling: if props={'sample_eight': my_weights, ...} was passed # instead, the estimator would fit and score without weight, silently failing. # %% # Case B: weighted scoring and unweighted fitting class MyLogisticRegressionCV(LogisticRegressionCV): def fit(self, X, y, props=None): props = props.copy()
my_weights, my_other_weights) # %% # Case A: weighted scoring and fitting lr = LogisticRegressionCV( cv=GroupKFold(), scoring='accuracy', ) props = {'cv__groups': my_groups, 'estimator__cv__groups': my_groups, 'estimator__sample_weight': my_weights, 'scoring__sample_weight': my_weights, 'estimator__scoring__sample_weight': my_weights} cross_validate(lr, X, y, cv=GroupKFold(), props=props, scoring='accuracy') # error handling: if props={'estimator__sample_eight': my_weights, ...} was # passed instead, the estimator would raise an error. # %% # Case B: weighted scoring and unweighted fitting lr = LogisticRegressionCV( cv=GroupKFold(), scoring='accuracy', ) props = {'cv__groups': my_groups, 'estimator__cv__groups': my_groups, 'scoring__sample_weight': my_weights,
# %% # Case A: weighted scoring and fitting lr = LogisticRegressionCV( cv=group_cv, scoring='accuracy', ) props = { 'cv__groups': my_groups, 'estimator__cv__groups': my_groups, 'estimator__sample_weight': my_weights, 'scoring__sample_weight': my_weights, 'estimator__scoring__sample_weight': my_weights } cross_validate(lr, X, y, cv=group_cv, props=props, scoring='accuracy') # error handling: if props={'estimator__sample_eight': my_weights, ...} was # passed instead, the estimator would raise an error. # %% # Case B: weighted scoring and unweighted fitting lr = LogisticRegressionCV( cv=group_cv, scoring='accuracy', ) props = { 'cv__groups': my_groups, 'estimator__cv__groups': my_groups, 'scoring__sample_weight': my_weights,