def test_is_target(): target = Target() assert is_target(target) not_a_target = 1 assert not is_target(not_a_target)
def fit(self, X, y=None, **fit_params): """ A wrapper around the fitting function. Parameters ---------- X : xarray DataArray, Dataset other other array-like The training input samples. y : xarray DataArray, Dataset other other array-like The target values. Returns ------- Returns self. """ if self.estimator is None: raise ValueError('You must specify an estimator instance to wrap.') if is_target(y): y = y(X) if is_dataarray(X): self.type_ = 'DataArray' self.estimator_ = self._fit(X, y, **fit_params) elif is_dataset(X): self.type_ = 'Dataset' self.estimator_dict_ = { v: self._fit(X[v], y, **fit_params) for v in X.data_vars } else: self.type_ = 'other' if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) return self
def score(self, X, y, sample_weight=None): """ Returns the score of the prediction. Parameters ---------- X : xarray Dataset or Dataset The training set. y : xarray Dataset or Dataset The target values. sample_weight : array-like, shape = [n_samples], optional Sample weights. Returns ------- score : float Score of self.predict(X) wrt. y. """ if self.type_ == 'DataArray': if not is_dataarray(X): raise ValueError( 'This wrapper was fitted for DataArray inputs, but the ' 'provided X does not seem to be a DataArray.') check_is_fitted(self, ['estimator_']) if is_target(y): y = y(X) return self.estimator_.score(X, y, sample_weight) elif self.type_ == 'Dataset': if not is_dataset(X): raise ValueError( 'This wrapper was fitted for Dataset inputs, but the ' 'provided X does not seem to be a Dataset.') check_is_fitted(self, ['estimator_dict_']) # TODO: this probably has to be done for each data_var individually if is_target(y): y = y(X) score_list = [ e.score(X[v], y, sample_weight) for v, e in six.iteritems(self.estimator_dict_) ] return np.mean(score_list) elif self.type_ == 'other': check_is_fitted(self, ['estimator_']) return self.estimator_.score(X, y, sample_weight) else: raise ValueError('Unexpected type_.')
def fit_transform(self, X, y=None, **fit_params): """ A wrapper around the fit_transform function. Parameters ---------- X : xarray DataArray, Dataset or other array-like The input samples. y : xarray DataArray, Dataset or other array-like The target values. Returns ------- Xt : xarray DataArray, Dataset or other array-like The transformed output. """ if self.estimator is None: raise ValueError('You must specify an estimator instance to wrap.') if is_target(y): y = y(X) if is_dataarray(X): self.type_ = 'DataArray' self.estimator_ = clone(self.estimator) if self.reshapes is not None: data, dims = self._fit_transform(self.estimator_, X, y, **fit_params) coords = self._update_coords(X) return xr.DataArray(data, coords=coords, dims=dims) else: return xr.DataArray(self.estimator_.fit_transform( X.data, y, **fit_params), coords=X.coords, dims=X.dims) elif is_dataset(X): self.type_ = 'Dataset' self.estimator_dict_ = {v: clone(self.estimator) for v in X.data_vars} if self.reshapes is not None: data_vars = dict() for v, e in six.iteritems(self.estimator_dict_): yp_v, dims = self._fit_transform(e, X[v], y, **fit_params) data_vars[v] = (dims, yp_v) coords = self._update_coords(X) return xr.Dataset(data_vars, coords=coords) else: data_vars = { v: (X[v].dims, e.fit_transform(X[v].data, y, **fit_params)) for v, e in six.iteritems(self.estimator_dict_) } return xr.Dataset(data_vars, coords=X.coords) else: self.type_ = 'other' if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = clone(self.estimator) Xt = self.estimator_.fit_transform(X, y, **fit_params) for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) return Xt
def fit(self, X, y=None, **fit_params): """ A wrapper around the fitting function. Parameters ---------- X : xarray DataArray, Dataset other other array-like The training input samples. y : xarray DataArray, Dataset other other array-like The target values. Returns ------- Returns self. """ if self.estimator is None: raise ValueError("You must specify an estimator instance to wrap.") self._reset() if is_target(y): y = y(X) if is_dataarray(X): self.type_ = "DataArray" self.estimator_ = self._fit(X, y, **fit_params) # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith("_") and not v.startswith("_"): setattr(self, v, getattr(self.estimator_, v)) elif is_dataset(X): self.type_ = "Dataset" self.estimator_dict_ = { v: self._fit(X[v], y, **fit_params) for v in X.data_vars } # TODO: check if this needs to be removed for compat wrappers for e_name, e in six.iteritems(self.estimator_dict_): for v in vars(e): if v.endswith("_") and not v.startswith("_"): if hasattr(self, v): getattr(self, v).update({e_name: getattr(e, v)}) else: setattr(self, v, {e_name: getattr(e, v)}) else: self.type_ = "other" if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = self._make_estimator().fit(X, y, **fit_params) # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith("_") and not v.startswith("_"): setattr(self, v, getattr(self.estimator_, v)) return self
def partial_fit(self, X, y=None, **fit_params): """ A wrapper around the partial_fit function. Parameters ---------- X : xarray DataArray, Dataset or other array-like The input samples. y : xarray DataArray, Dataset or other array-like The target values. """ if self.estimator is None: raise ValueError('You must specify an estimator instance to wrap.') if is_target(y): y = y(X) if is_dataarray(X): if not hasattr(self, 'type_'): self.type_ = 'DataArray' self.estimator_ = self._fit(X, y, **fit_params) elif self.type_ == 'DataArray': self.estimator_ = self._partial_fit(self.estimator_, X, y, **fit_params) else: raise ValueError( 'This wrapper was not fitted for DataArray inputs.') # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) elif is_dataset(X): if not hasattr(self, 'type_'): self.type_ = 'Dataset' self.estimator_dict_ = { v: self._fit(X[v], y, **fit_params) for v in X.data_vars } elif self.type_ == 'Dataset': self.estimator_dict_ = { v: self._partial_fit(self.estimator_dict_[v], X[v], y, **fit_params) for v in X.data_vars } else: raise ValueError('This wrapper was not fitted for Dataset inputs.') # TODO: check if this needs to be removed for compat wrappers for e_name, e in six.iteritems(self.estimator_dict_): for v in vars(e): if v.endswith('_') and not v.startswith('_'): if hasattr(self, v): getattr(self, v).update({e_name: getattr(e, v)}) else: setattr(self, v, {e_name: getattr(e, v)}) else: if not hasattr(self, 'type_'): self.type_ = 'other' if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) elif self.type_ == 'other': self.estimator_ = self.estimator_.partial_fit(X, y, **fit_params) else: raise ValueError('This wrapper was not fitted for other inputs.') # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) return self