def test_is_dataarray(): X_da = xr.DataArray(np.random.random((100, 10))) assert is_dataarray(X_da) X_not_a_da = np.random.random((100, 10)) assert not is_dataarray(X_not_a_da)
def test_is_dataset(): X_ds = xr.Dataset({'var_1': 1}) assert is_dataset(X_ds) X_not_a_ds = np.random.random((100, 10)) assert not is_dataarray(X_not_a_ds)
def _call_fitted(self, method, X): """ Call a method of a fitted estimator (predict, transform, ...). """ check_is_fitted(self, ["type_"]) if self.type_ == "DataArray": if not is_dataarray(X): raise ValueError( "This wrapper was fitted for DataArray inputs, but the " "provided X does not seem to be a DataArray.") check_is_fitted(self, ["estimator_"]) if self.reshapes is not None: data, dims = self._call_array_method(self.estimator_, method, X) coords = self._update_coords(X) return xr.DataArray(data, coords=coords, dims=dims) else: return xr.DataArray( getattr(self.estimator_, method)(X.data), coords=X.coords, dims=X.dims, ) elif self.type_ == "Dataset": if not is_dataset(X): raise ValueError( "This wrapper was fitted for Dataset inputs, but the " "provided X does not seem to be a Dataset.") check_is_fitted(self, ["estimator_dict_"]) if self.reshapes is not None: data_vars = dict() for v, e in self.estimator_dict_.items(): yp_v, dims = self._call_array_method(e, method, X[v]) data_vars[v] = (dims, yp_v) coords = self._update_coords(X) return xr.Dataset(data_vars, coords=coords) else: data_vars = { v: (X[v].dims, getattr(e, method)(X[v].data)) for v, e in self.estimator_dict_.items() } return xr.Dataset(data_vars, coords=X.coords) elif self.type_ == "other": check_is_fitted(self, ["estimator_"]) return getattr(self.estimator_, method)(X) else: raise ValueError("Unexpected type_.")
def fit(self, X, y=None, **fit_params): """ A wrapper around the fitting function. Parameters ---------- X : xarray DataArray, Dataset other other array-like The training input samples. y : xarray DataArray, Dataset other other array-like The target values. Returns ------- Returns self. """ if self.estimator is None: raise ValueError('You must specify an estimator instance to wrap.') if is_target(y): y = y(X) if is_dataarray(X): self.type_ = 'DataArray' self.estimator_ = self._fit(X, y, **fit_params) elif is_dataset(X): self.type_ = 'Dataset' self.estimator_dict_ = { v: self._fit(X[v], y, **fit_params) for v in X.data_vars } else: self.type_ = 'other' if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) return self
def score(self, X, y, sample_weight=None): """ Returns the score of the prediction. Parameters ---------- X : xarray Dataset or Dataset The training set. y : xarray Dataset or Dataset The target values. sample_weight : array-like, shape = [n_samples], optional Sample weights. Returns ------- score : float Score of self.predict(X) wrt. y. """ if self.type_ == 'DataArray': if not is_dataarray(X): raise ValueError( 'This wrapper was fitted for DataArray inputs, but the ' 'provided X does not seem to be a DataArray.') check_is_fitted(self, ['estimator_']) if is_target(y): y = y(X) return self.estimator_.score(X, y, sample_weight) elif self.type_ == 'Dataset': if not is_dataset(X): raise ValueError( 'This wrapper was fitted for Dataset inputs, but the ' 'provided X does not seem to be a Dataset.') check_is_fitted(self, ['estimator_dict_']) # TODO: this probably has to be done for each data_var individually if is_target(y): y = y(X) score_list = [ e.score(X[v], y, sample_weight) for v, e in six.iteritems(self.estimator_dict_) ] return np.mean(score_list) elif self.type_ == 'other': check_is_fitted(self, ['estimator_']) return self.estimator_.score(X, y, sample_weight) else: raise ValueError('Unexpected type_.')
def fit_transform(self, X, y=None, **fit_params): """ A wrapper around the fit_transform function. Parameters ---------- X : xarray DataArray, Dataset or other array-like The input samples. y : xarray DataArray, Dataset or other array-like The target values. Returns ------- Xt : xarray DataArray, Dataset or other array-like The transformed output. """ if self.estimator is None: raise ValueError('You must specify an estimator instance to wrap.') if is_target(y): y = y(X) if is_dataarray(X): self.type_ = 'DataArray' self.estimator_ = clone(self.estimator) if self.reshapes is not None: data, dims = self._fit_transform(self.estimator_, X, y, **fit_params) coords = self._update_coords(X) return xr.DataArray(data, coords=coords, dims=dims) else: return xr.DataArray(self.estimator_.fit_transform( X.data, y, **fit_params), coords=X.coords, dims=X.dims) elif is_dataset(X): self.type_ = 'Dataset' self.estimator_dict_ = {v: clone(self.estimator) for v in X.data_vars} if self.reshapes is not None: data_vars = dict() for v, e in six.iteritems(self.estimator_dict_): yp_v, dims = self._fit_transform(e, X[v], y, **fit_params) data_vars[v] = (dims, yp_v) coords = self._update_coords(X) return xr.Dataset(data_vars, coords=coords) else: data_vars = { v: (X[v].dims, e.fit_transform(X[v].data, y, **fit_params)) for v, e in six.iteritems(self.estimator_dict_) } return xr.Dataset(data_vars, coords=X.coords) else: self.type_ = 'other' if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = clone(self.estimator) Xt = self.estimator_.fit_transform(X, y, **fit_params) for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) return Xt
def fit(self, X, y=None, **fit_params): """ A wrapper around the fitting function. Parameters ---------- X : xarray DataArray, Dataset other other array-like The training input samples. y : xarray DataArray, Dataset other other array-like The target values. Returns ------- Returns self. """ if self.estimator is None: raise ValueError("You must specify an estimator instance to wrap.") self._reset() if is_target(y): y = y(X) if is_dataarray(X): self.type_ = "DataArray" self.estimator_ = self._fit(X, y, **fit_params) # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith("_") and not v.startswith("_"): setattr(self, v, getattr(self.estimator_, v)) elif is_dataset(X): self.type_ = "Dataset" self.estimator_dict_ = { v: self._fit(X[v], y, **fit_params) for v in X.data_vars } # TODO: check if this needs to be removed for compat wrappers for e_name, e in six.iteritems(self.estimator_dict_): for v in vars(e): if v.endswith("_") and not v.startswith("_"): if hasattr(self, v): getattr(self, v).update({e_name: getattr(e, v)}) else: setattr(self, v, {e_name: getattr(e, v)}) else: self.type_ = "other" if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = self._make_estimator().fit(X, y, **fit_params) # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith("_") and not v.startswith("_"): setattr(self, v, getattr(self.estimator_, v)) return self
def partial_fit(self, X, y=None, **fit_params): """ A wrapper around the partial_fit function. Parameters ---------- X : xarray DataArray, Dataset or other array-like The input samples. y : xarray DataArray, Dataset or other array-like The target values. """ if self.estimator is None: raise ValueError('You must specify an estimator instance to wrap.') if is_target(y): y = y(X) if is_dataarray(X): if not hasattr(self, 'type_'): self.type_ = 'DataArray' self.estimator_ = self._fit(X, y, **fit_params) elif self.type_ == 'DataArray': self.estimator_ = self._partial_fit(self.estimator_, X, y, **fit_params) else: raise ValueError( 'This wrapper was not fitted for DataArray inputs.') # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) elif is_dataset(X): if not hasattr(self, 'type_'): self.type_ = 'Dataset' self.estimator_dict_ = { v: self._fit(X[v], y, **fit_params) for v in X.data_vars } elif self.type_ == 'Dataset': self.estimator_dict_ = { v: self._partial_fit(self.estimator_dict_[v], X[v], y, **fit_params) for v in X.data_vars } else: raise ValueError('This wrapper was not fitted for Dataset inputs.') # TODO: check if this needs to be removed for compat wrappers for e_name, e in six.iteritems(self.estimator_dict_): for v in vars(e): if v.endswith('_') and not v.startswith('_'): if hasattr(self, v): getattr(self, v).update({e_name: getattr(e, v)}) else: setattr(self, v, {e_name: getattr(e, v)}) else: if not hasattr(self, 'type_'): self.type_ = 'other' if y is None: X = check_array(X) else: X, y = check_X_y(X, y) self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) elif self.type_ == 'other': self.estimator_ = self.estimator_.partial_fit(X, y, **fit_params) else: raise ValueError('This wrapper was not fitted for other inputs.') # TODO: check if this needs to be removed for compat wrappers for v in vars(self.estimator_): if v.endswith('_') and not v.startswith('_'): setattr(self, v, getattr(self.estimator_, v)) return self