コード例 #1
0
    def test_sample_all_priors(self, cuda=False):
        device = torch.device("cuda" if cuda else "cpu")
        for dtype in (torch.float, torch.double):
            train_X = torch.rand(3, 5, device=device, dtype=dtype)
            train_Y = torch.rand(3, 1, device=device, dtype=dtype)
            model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
            mll.to(device=device, dtype=dtype)
            original_state_dict = dict(deepcopy(mll.model.state_dict()))
            sample_all_priors(model)

            # make sure one of the hyperparameters changed
            self.assertTrue(
                dict(model.state_dict())["likelihood.noise_covar.raw_noise"] !=
                original_state_dict["likelihood.noise_covar.raw_noise"])
            # check that lengthscales are all different
            ls = model.covar_module.base_kernel.raw_lengthscale.view(
                -1).tolist()
            self.assertTrue(all(ls[0] != ls[i]) for i in range(1, len(ls)))

            # change one of the priors to SmoothedBoxPrior
            model.covar_module = ScaleKernel(
                MaternKernel(
                    nu=2.5,
                    ard_num_dims=model.train_inputs[0].shape[-1],
                    batch_shape=model._aug_batch_shape,
                    lengthscale_prior=SmoothedBoxPrior(3.0, 6.0),
                ),
                batch_shape=model._aug_batch_shape,
                outputscale_prior=GammaPrior(2.0, 0.15),
            )
            original_state_dict = dict(deepcopy(mll.model.state_dict()))
            with warnings.catch_warnings(
                    record=True) as ws, settings.debug(True):
                sample_all_priors(model)
                self.assertEqual(len(ws), 1)
                self.assertTrue("rsample" in str(ws[0].message))

            # the lengthscale should not have changed because sampling is
            # not implemented for SmoothedBoxPrior
            self.assertTrue(
                torch.equal(
                    dict(model.state_dict())
                    ["covar_module.base_kernel.raw_lengthscale"],
                    original_state_dict[
                        "covar_module.base_kernel.raw_lengthscale"],
                ))

            # set setting_closure to None and make sure RuntimeError is raised
            prior_tuple = model.likelihood.noise_covar._priors["noise_prior"]
            model.likelihood.noise_covar._priors["noise_prior"] = (
                prior_tuple[0],
                prior_tuple[1],
                None,
            )
            with self.assertRaises(RuntimeError):
                sample_all_priors(model)
コード例 #2
0
ファイル: fit.py プロジェクト: pytorch/botorch
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a GPyTorch model.

    On optimizer failures, a new initial condition is sampled from the
    hyperparameter priors and optimization is retried. The maximum number of
    retries can be passed in as a `max_retries` kwarg (default is 5).

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function, including
            `max_retries` and `sequential` (controls the fitting of `ModelListGP`
            and `BatchedMultiOutputGPyTorchModel` models) or `approx_mll`
            (whether to use gpytorch's approximate MLL computation).

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    sequential = kwargs.pop("sequential", True)
    max_retries = kwargs.pop("max_retries", 5)
    if isinstance(mll, SumMarginalLogLikelihood) and sequential:
        for mll_ in mll.mlls:
            fit_gpytorch_model(mll=mll_,
                               optimizer=optimizer,
                               max_retries=max_retries,
                               **kwargs)
        return mll
    elif (isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
          and mll.model._num_outputs > 1 and sequential):
        tf = None
        try:  # check if backwards-conversion is possible
            # remove the outcome transform since the training targets are already
            # transformed and the outcome transform cannot currently be split.
            # TODO: support splitting outcome transforms.
            if hasattr(mll.model, "outcome_transform"):
                tf = mll.model.outcome_transform
                mll.model.outcome_transform = None
            model_list = batched_to_model_list(mll.model)
            mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
            fit_gpytorch_model(
                mll=mll_,
                optimizer=optimizer,
                sequential=True,
                max_retries=max_retries,
                **kwargs,
            )
            model_ = model_list_to_batched(mll_.model)
            mll.model.load_state_dict(model_.state_dict())
            # setting the transformed inputs is necessary because gpytorch
            # stores the raw training inputs on the ExactGP in the
            # ExactGP.__init__ call. At evaluation time, the test inputs will
            # already be in the transformed space if some transforms have
            # transform_on_eval set to False. ExactGP.__call__ will
            # concatenate the test points with the training inputs. Therefore,
            # it is important to set the ExactGP's train_inputs to also be
            # transformed data using all transforms (including the transforms
            # with transform_on_train set to True).
            mll.train()
            if tf is not None:
                mll.model.outcome_transform = tf
            return mll.eval()
        # NotImplementedError is omitted since it derives from RuntimeError
        except (UnsupportedError, RuntimeError, AttributeError):
            warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
            if tf is not None:
                mll.model.outcome_transform = tf
            return fit_gpytorch_model(mll=mll,
                                      optimizer=optimizer,
                                      sequential=False,
                                      max_retries=max_retries)
    # retry with random samples from the priors upon failure
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    while retry < max_retries:
        with warnings.catch_warnings(record=True) as ws:
            if retry > 0:  # use normal initial conditions on first try
                mll.model.load_state_dict(original_state_dict)
                sample_all_priors(mll.model)
            try:
                mll, _ = optimizer(mll, track_iterations=False, **kwargs)
            except NotPSDError:
                retry += 1
                logging.log(
                    logging.DEBUG,
                    f"Fitting failed on try {retry} due to a NotPSDError.",
                )
                continue
        has_optwarning = False
        for w in ws:
            # Do not count reaching `maxiter` as an optimization failure.
            if "ITERATIONS REACHED LIMIT" in str(w.message):
                logging.log(
                    logging.DEBUG,
                    "Fitting ended early due to reaching the iteration limit.",
                )
                continue
            has_optwarning |= issubclass(w.category, OptimizationWarning)
            warnings.warn(w.message, w.category)
        if not has_optwarning:
            mll.eval()
            return mll
        retry += 1
        logging.log(logging.DEBUG, f"Fitting failed on try {retry}.")

    warnings.warn("Fitting failed on all retries.", OptimizationWarning)
    return mll.eval()
コード例 #3
0
ファイル: fit.py プロジェクト: wangsd01/botorch
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a GPyTorch model.

    On optimizer failures, a new initial condition is sampled from the
    hyperparameter priors and optimization is retried. The maximum number of
    retries can be passed in as a `max_retries` kwarg (default is 5).

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function, including
            `max_retries` and `sequential` (controls the fitting of `ModelListGP`
            and `BatchedMultiOutputGPyTorchModel` models) or `approx_mll`
            (whether to use gpytorch's approximate MLL computation).

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    sequential = kwargs.pop("sequential", True)
    max_retries = kwargs.pop("max_retries", 5)
    if isinstance(mll, SumMarginalLogLikelihood) and sequential:
        for mll_ in mll.mlls:
            fit_gpytorch_model(mll=mll_,
                               optimizer=optimizer,
                               max_retries=max_retries,
                               **kwargs)
        return mll
    elif (isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
          and mll.model._num_outputs > 1 and sequential):
        try:  # check if backwards-conversion is possible
            model_list = batched_to_model_list(mll.model)
            model_ = model_list_to_batched(model_list)
            mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
            fit_gpytorch_model(
                mll=mll_,
                optimizer=optimizer,
                sequential=True,
                max_retries=max_retries,
                **kwargs,
            )
            model_ = model_list_to_batched(mll_.model)
            mll.model.load_state_dict(model_.state_dict())
            return mll.eval()
        # NotImplentedError is omitted since it derives from RuntimeError
        except (UnsupportedError, RuntimeError, AttributeError):
            warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
            return fit_gpytorch_model(mll=mll,
                                      optimizer=optimizer,
                                      sequential=False,
                                      max_retries=max_retries)
    # retry with random samples from the priors upon failure
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    while retry < max_retries:
        with warnings.catch_warnings(record=True) as ws:
            if retry > 0:  # use normal initial conditions on first try
                mll.model.load_state_dict(original_state_dict)
                sample_all_priors(mll.model)
            mll, _ = optimizer(mll, track_iterations=False, **kwargs)
            if not any(
                    issubclass(w.category, OptimizationWarning) for w in ws):
                mll.eval()
                return mll
            retry += 1
            logging.log(logging.DEBUG, f"Fitting failed on try {retry}.")

    warnings.warn("Fitting failed on all retries.", OptimizationWarning)
    return mll.eval()