Пример #1
0
def test_is_target():

    target = Target()

    assert is_target(target)

    not_a_target = 1

    assert not is_target(not_a_target)
Пример #2
0
    def fit(self, X, y=None, **fit_params):
        """ A wrapper around the fitting function.

        Parameters
        ----------
        X : xarray DataArray, Dataset other other array-like
            The training input samples.

        y : xarray DataArray, Dataset other other array-like
            The target values.

        Returns
        -------
        Returns self.
        """

        if self.estimator is None:
            raise ValueError('You must specify an estimator instance to wrap.')

        if is_target(y):
            y = y(X)

        if is_dataarray(X):

            self.type_ = 'DataArray'
            self.estimator_ = self._fit(X, y, **fit_params)

        elif is_dataset(X):

            self.type_ = 'Dataset'
            self.estimator_dict_ = {
                v: self._fit(X[v], y, **fit_params)
                for v in X.data_vars
            }

        else:

            self.type_ = 'other'
            if y is None:
                X = check_array(X)
            else:
                X, y = check_X_y(X, y)

            self.estimator_ = clone(self.estimator).fit(X, y, **fit_params)

            for v in vars(self.estimator_):
                if v.endswith('_') and not v.startswith('_'):
                    setattr(self, v, getattr(self.estimator_, v))

        return self
Пример #3
0
def score(self, X, y, sample_weight=None):
    """ Returns the score of the prediction.

    Parameters
    ----------
    X : xarray Dataset or Dataset
        The training set.

    y : xarray Dataset or Dataset
        The target values.

    sample_weight : array-like, shape = [n_samples], optional
        Sample weights.

    Returns
    -------
    score : float
        Score of self.predict(X) wrt. y.
    """

    if self.type_ == 'DataArray':

        if not is_dataarray(X):
            raise ValueError(
                'This wrapper was fitted for DataArray inputs, but the '
                'provided X does not seem to be a DataArray.')

        check_is_fitted(self, ['estimator_'])

        if is_target(y):
            y = y(X)

        return self.estimator_.score(X, y, sample_weight)

    elif self.type_ == 'Dataset':

        if not is_dataset(X):
            raise ValueError(
                'This wrapper was fitted for Dataset inputs, but the '
                'provided X does not seem to be a Dataset.')

        check_is_fitted(self, ['estimator_dict_'])

        # TODO: this probably has to be done for each data_var individually
        if is_target(y):
            y = y(X)

        score_list = [
            e.score(X[v], y, sample_weight)
            for v, e in six.iteritems(self.estimator_dict_)
        ]

        return np.mean(score_list)

    elif self.type_ == 'other':

        check_is_fitted(self, ['estimator_'])

        return self.estimator_.score(X, y, sample_weight)

    else:
        raise ValueError('Unexpected type_.')
Пример #4
0
def fit_transform(self, X, y=None, **fit_params):
    """ A wrapper around the fit_transform function.

    Parameters
    ----------
    X : xarray DataArray, Dataset or other array-like
        The input samples.

    y : xarray DataArray, Dataset or other array-like
        The target values.

    Returns
    -------
    Xt : xarray DataArray, Dataset or other array-like
        The transformed output.
    """

    if self.estimator is None:
        raise ValueError('You must specify an estimator instance to wrap.')

    if is_target(y):
        y = y(X)

    if is_dataarray(X):

        self.type_ = 'DataArray'
        self.estimator_ = clone(self.estimator)

        if self.reshapes is not None:
            data, dims = self._fit_transform(self.estimator_, X, y,
                                             **fit_params)
            coords = self._update_coords(X)
            return xr.DataArray(data, coords=coords, dims=dims)
        else:
            return xr.DataArray(self.estimator_.fit_transform(
                X.data, y, **fit_params),
                                coords=X.coords,
                                dims=X.dims)

    elif is_dataset(X):

        self.type_ = 'Dataset'
        self.estimator_dict_ = {v: clone(self.estimator) for v in X.data_vars}

        if self.reshapes is not None:
            data_vars = dict()
            for v, e in six.iteritems(self.estimator_dict_):
                yp_v, dims = self._fit_transform(e, X[v], y, **fit_params)
                data_vars[v] = (dims, yp_v)
            coords = self._update_coords(X)
            return xr.Dataset(data_vars, coords=coords)
        else:
            data_vars = {
                v: (X[v].dims, e.fit_transform(X[v].data, y, **fit_params))
                for v, e in six.iteritems(self.estimator_dict_)
            }
            return xr.Dataset(data_vars, coords=X.coords)

    else:

        self.type_ = 'other'
        if y is None:
            X = check_array(X)
        else:
            X, y = check_X_y(X, y)

        self.estimator_ = clone(self.estimator)
        Xt = self.estimator_.fit_transform(X, y, **fit_params)

        for v in vars(self.estimator_):
            if v.endswith('_') and not v.startswith('_'):
                setattr(self, v, getattr(self.estimator_, v))

    return Xt
Пример #5
0
    def fit(self, X, y=None, **fit_params):
        """ A wrapper around the fitting function.

        Parameters
        ----------
        X : xarray DataArray, Dataset other other array-like
            The training input samples.

        y : xarray DataArray, Dataset other other array-like
            The target values.

        Returns
        -------
        Returns self.
        """

        if self.estimator is None:
            raise ValueError("You must specify an estimator instance to wrap.")

        self._reset()

        if is_target(y):
            y = y(X)

        if is_dataarray(X):

            self.type_ = "DataArray"
            self.estimator_ = self._fit(X, y, **fit_params)

            # TODO: check if this needs to be removed for compat wrappers
            for v in vars(self.estimator_):
                if v.endswith("_") and not v.startswith("_"):
                    setattr(self, v, getattr(self.estimator_, v))

        elif is_dataset(X):

            self.type_ = "Dataset"
            self.estimator_dict_ = {
                v: self._fit(X[v], y, **fit_params)
                for v in X.data_vars
            }

            # TODO: check if this needs to be removed for compat wrappers
            for e_name, e in six.iteritems(self.estimator_dict_):
                for v in vars(e):
                    if v.endswith("_") and not v.startswith("_"):
                        if hasattr(self, v):
                            getattr(self, v).update({e_name: getattr(e, v)})
                        else:
                            setattr(self, v, {e_name: getattr(e, v)})

        else:

            self.type_ = "other"
            if y is None:
                X = check_array(X)
            else:
                X, y = check_X_y(X, y)

            self.estimator_ = self._make_estimator().fit(X, y, **fit_params)

            # TODO: check if this needs to be removed for compat wrappers
            for v in vars(self.estimator_):
                if v.endswith("_") and not v.startswith("_"):
                    setattr(self, v, getattr(self.estimator_, v))

        return self
Пример #6
0
def partial_fit(self, X, y=None, **fit_params):
    """ A wrapper around the partial_fit function.

    Parameters
    ----------
    X : xarray DataArray, Dataset or other array-like
        The input samples.

    y : xarray DataArray, Dataset or other array-like
        The target values.
    """

    if self.estimator is None:
        raise ValueError('You must specify an estimator instance to wrap.')

    if is_target(y):
        y = y(X)

    if is_dataarray(X):

        if not hasattr(self, 'type_'):
            self.type_ = 'DataArray'
            self.estimator_ = self._fit(X, y, **fit_params)
        elif self.type_ == 'DataArray':
            self.estimator_ = self._partial_fit(self.estimator_, X, y,
                                                **fit_params)
        else:
            raise ValueError(
                'This wrapper was not fitted for DataArray inputs.')

        # TODO: check if this needs to be removed for compat wrappers
        for v in vars(self.estimator_):
            if v.endswith('_') and not v.startswith('_'):
                setattr(self, v, getattr(self.estimator_, v))

    elif is_dataset(X):

        if not hasattr(self, 'type_'):
            self.type_ = 'Dataset'
            self.estimator_dict_ = {
                v: self._fit(X[v], y, **fit_params)
                for v in X.data_vars
            }
        elif self.type_ == 'Dataset':
            self.estimator_dict_ = {
                v: self._partial_fit(self.estimator_dict_[v], X[v], y,
                                     **fit_params)
                for v in X.data_vars
            }
        else:
            raise ValueError('This wrapper was not fitted for Dataset inputs.')

        # TODO: check if this needs to be removed for compat wrappers
        for e_name, e in six.iteritems(self.estimator_dict_):
            for v in vars(e):
                if v.endswith('_') and not v.startswith('_'):
                    if hasattr(self, v):
                        getattr(self, v).update({e_name: getattr(e, v)})
                    else:
                        setattr(self, v, {e_name: getattr(e, v)})

    else:

        if not hasattr(self, 'type_'):
            self.type_ = 'other'
            if y is None:
                X = check_array(X)
            else:
                X, y = check_X_y(X, y)
            self.estimator_ = clone(self.estimator).fit(X, y, **fit_params)
        elif self.type_ == 'other':
            self.estimator_ = self.estimator_.partial_fit(X, y, **fit_params)
        else:
            raise ValueError('This wrapper was not fitted for other inputs.')

        # TODO: check if this needs to be removed for compat wrappers
        for v in vars(self.estimator_):
            if v.endswith('_') and not v.startswith('_'):
                setattr(self, v, getattr(self.estimator_, v))

    return self