コード例 #1
0
ファイル: utils.py プロジェクト: sakuranew/skorch
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
    """Checks whether the net is initialized.

    Note: This calls ``sklearn.utils.validation.check_is_fitted``
    under the hood, using exactly the same arguments and logic. The
    only difference is that this function has an adapted error message
    and raises a ``skorch.exception.NotInitializedError`` instead of
    an ``sklearn.exceptions.NotFittedError``.

    """
    if msg is None:
        msg = ("This %(name)s instance is not initialized yet. Call "
               "'initialize' or 'fit' with appropriate arguments "
               "before using this method.")

    if not isinstance(attributes, (list, tuple)):
        attributes = [attributes]

    if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
        raise NotInitializedError(msg % {'name': type(estimator).__name__})
コード例 #2
0
ファイル: net.py プロジェクト: rain1024/skorch
    def load_params(self, f):
        """Load only the module's parameters, not the whole object.

        To save and load the whole object, use pickle.

        Parameters
        ----------
        f : file-like object or str
          See ``torch.load`` documentation.

        Example
        -------
        >>> before = NeuralNetClassifier(mymodule)
        >>> before.save_params('path/to/file')
        >>> after = NeuralNetClassifier(mymodule).initialize()
        >>> after.load_params('path/to/file')

        """
        if not hasattr(self, 'module_'):
            raise NotInitializedError(
                "Cannot load parameters of an un-initialized model. "
                "Please initialize first by calling .initialize() "
                "or by fitting the model with .fit(...).")

        cuda_req_not_met = (self.use_cuda and not torch.cuda.is_available())
        if not self.use_cuda or cuda_req_not_met:
            # Eiher we want to load the model to the CPU in which case
            # we are loading in a way where it doesn't matter if the data
            # was on the GPU or not or the model was on the GPU but there
            # is no CUDA device available.
            if cuda_req_not_met:
                warnings.warn(
                    "Model configured to use CUDA but no CUDA devices "
                    "available. Loading on CPU instead.",
                    ResourceWarning)
                self.use_cuda = False
            model = torch.load(f, lambda storage, loc: storage)
        else:
            model = torch.load(f)

        self.module_.load_state_dict(model)
コード例 #3
0
ファイル: utils.py プロジェクト: yonashub/skorch
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
    """Checks whether the net is initialized.

    Note: This calls ``sklearn.utils.validation.check_is_fitted``
    under the hood, using exactly the same arguments and logic. The
    only difference is that this function has an adapted error message
    and raises a ``skorch.exception.NotInitializedError`` instead of
    an ``sklearn.exceptions.NotFittedError``.

    """
    if msg is None:
        msg = ("This %(name)s instance is not initialized yet. Call "
               "'initialize' or 'fit' with appropriate arguments "
               "before using this method.")
    try:
        sklearn_check_is_fitted(
            estimator=estimator,
            attributes=attributes,
            msg=msg,
            all_or_any=all_or_any,
        )
    except NotFittedError as e:
        raise NotInitializedError(str(e))
コード例 #4
0
ファイル: net.py プロジェクト: sriharsha0806/skorch
    def save_params(self, f):
        """Save only the module's parameters, not the whole object.

        To save the whole object, use pickle.

        Parameters
        ----------
        f : file-like object or str
          See ``torch.save`` documentation.

        Example
        -------
        >>> before = NeuralNetClassifier(mymodule)
        >>> before.save_params('path/to/file')
        >>> after = NeuralNetClassifier(mymodule).initialize()
        >>> after.load_params('path/to/file')

        """
        if not hasattr(self, 'module_'):
            raise NotInitializedError(
                "Cannot save parameters of an un-initialized model. "
                "Please initialize first by calling .initialize() "
                "or by fitting the model with .fit(...).")
        torch.save(self.module_.state_dict(), f)