Ejemplo n.º 1
0
    def test_get_extra_mll_args(self):
        train_X = torch.rand(3, 5)
        train_Y = torch.rand(3)
        model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
        # test ExactMarginalLogLikelihood
        exact_mll = ExactMarginalLogLikelihood(model.likelihood, model)
        exact_extra_args = _get_extra_mll_args(mll=exact_mll)
        self.assertEqual(len(exact_extra_args), 1)
        self.assertTrue(torch.equal(exact_extra_args[0], train_X))

        # test VariationalELBO
        elbo = VariationalELBO(model.likelihood, model, num_data=train_X.shape[0])
        elbo_extra_args = _get_extra_mll_args(mll=elbo)
        self.assertEqual(len(elbo_extra_args), 0)

        # test SumMarginalLogLikelihood
        model2 = ModelListGP(gp_models=[model])
        sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2)
        sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll)
        self.assertEqual(len(sum_mll_extra_args), 1)
        self.assertEqual(len(sum_mll_extra_args[0]), 1)
        self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X))

        # test unsupported MarginalLogLikelihood type
        unsupported_mll = MarginalLogLikelihood(model.likelihood, model)
        with self.assertRaises(ValueError):
            _get_extra_mll_args(mll=unsupported_mll)
Ejemplo n.º 2
0
def _scipy_objective_and_grad(
        x: np.ndarray, mll: MarginalLogLikelihood,
        property_dict: Dict[str, TorchAttr]) -> Tuple[float, np.ndarray]:
    r"""Get objective and gradient in format that scipy expects.

    Args:
        x: The (flattened) input parameters.
        mll: The MarginalLogLikelihood module to evaluate.
        property_dict: The property dictionary required to "unflatten" the input
            parameter vector, as generated by `module_to_array`.

    Returns:
        2-element tuple containing

        - The objective value.
        - The gradient of the objective.
    """
    mll = set_params_with_array(mll, x, property_dict)
    train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
    mll.zero_grad()
    output = mll.model(*train_inputs)
    args = [output, train_targets] + _get_extra_mll_args(mll)
    loss = -mll(*args).sum()
    loss.backward()
    param_dict = OrderedDict(mll.named_parameters())
    grad = []
    for p_name in property_dict:
        t = param_dict[p_name].grad
        if t is None:
            # this deals with parameters that do not affect the loss
            grad.append(np.zeros(property_dict[p_name].shape.numel()))
        else:
            grad.append(t.detach().view(-1).cpu().double().clone().numpy())
    mll.zero_grad()
    return loss.item(), np.concatenate(grad)
Ejemplo n.º 3
0
def _scipy_objective_and_grad(
    x: np.ndarray, mll: MarginalLogLikelihood, property_dict: Dict[str, TorchAttr]
) -> Tuple[float, np.ndarray]:
    r"""Get objective and gradient in format that scipy expects.

    Args:
        x: The (flattened) input parameters.
        mll: The MarginalLogLikelihood module to evaluate.
        property_dict: The property dictionary required to "unflatten" the input
            parameter vector, as generated by `module_to_array`.

    Returns:
        2-element tuple containing

        - The objective value.
        - The gradient of the objective.
    """
    mll = set_params_with_array(mll, x, property_dict)
    train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
    mll.zero_grad()
    output = mll.model(*train_inputs)
    args = [output, train_targets] + _get_extra_mll_args(mll)
    loss = -mll(*args).sum()
    loss.backward()
    param_dict = OrderedDict(mll.named_parameters())
    grad = []
    for p_name in property_dict:
        t = param_dict[p_name].grad
        if t is None:
            # this deals with parameters that do not affect the loss
            grad.append(np.zeros(property_dict[p_name].shape.numel()))
        else:
            grad.append(t.detach().view(-1).cpu().double().clone().numpy())
    mll.zero_grad()
    return loss.item(), np.concatenate(grad)
Ejemplo n.º 4
0
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a gpytorch model.

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function.

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    mll.train()
    mll, _ = optimizer(mll, track_iterations=False, **kwargs)
    mll.eval()
    return mll
Ejemplo n.º 5
0
def fit_gpytorch_model(
    mll: MarginalLogLikelihood, optimizer: Callable = fit_gpytorch_scipy, **kwargs: Any
) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a gpytorch model.

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function.

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    mll.train()
    mll, _ = optimizer(mll, track_iterations=False, **kwargs)
    mll.eval()
    return mll
Ejemplo n.º 6
0
    def test_get_extra_mll_args(self):
        train_X = torch.rand(3, 5)
        train_Y = torch.rand(3, 1)
        model = SingleTaskGP(train_X=train_X, train_Y=train_Y)

        # test ExactMarginalLogLikelihood
        exact_mll = ExactMarginalLogLikelihood(model.likelihood, model)
        exact_extra_args = _get_extra_mll_args(mll=exact_mll)
        self.assertEqual(len(exact_extra_args), 1)
        self.assertTrue(torch.equal(exact_extra_args[0], train_X))

        # test SumMarginalLogLikelihood
        model2 = ModelListGP(model)
        sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2)
        sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll)
        self.assertEqual(len(sum_mll_extra_args), 1)
        self.assertEqual(len(sum_mll_extra_args[0]), 1)
        self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X))

        # test unsupported MarginalLogLikelihood type
        unsupported_mll = MarginalLogLikelihood(model.likelihood, model)
        unsupported_mll_extra_args = _get_extra_mll_args(mll=unsupported_mll)
        self.assertEqual(unsupported_mll_extra_args, [])
Ejemplo n.º 7
0
def fit_gpytorch_torch(
    mll: MarginalLogLikelihood,
    bounds: Optional[ParameterBounds] = None,
    optimizer_cls: Optimizer = Adam,
    options: Optional[Dict[str, Any]] = None,
    track_iterations: bool = True,
    approx_mll: bool = True,
) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationIteration]]]]:
    r"""Fit a gpytorch model by maximizing MLL with a torch optimizer.

    The model and likelihood in mll must already be in train mode.
    Note: this method requires that the model has `train_inputs` and `train_targets`.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        bounds: A ParameterBounds dictionary mapping parameter names to tuples
            of lower and upper bounds. Bounds specified here take precedence
            over bounds on the same parameters specified in the constraints
            registered with the module.
        optimizer_cls: Torch optimizer to use. Must not require a closure.
        options: options for model fitting. Relevant options will be passed to
            the `optimizer_cls`. Additionally, options can include: "disp"
            to specify whether to display model fitting diagnostics and "maxiter"
            to specify the maximum number of iterations.
        track_iterations: Track the function values and wall time for each
            iteration.
        approx_mll: If True, use gpytorch's approximate MLL computation (
            according to the gpytorch defaults based on the training at size).
            Unlike for the deterministic algorithms used in fit_gpytorch_scipy,
            this is not an issue for stochastic optimizers.

    Returns:
        2-element tuple containing
        - mll with parameters optimized in-place.
        - Dictionary with the following key/values:
        "fopt": Best mll value.
        "wall_time": Wall time of fitting.
        "iterations": List of OptimizationIteration objects with information on each
        iteration. If track_iterations is False, will be empty.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> mll.train()
        >>> fit_gpytorch_torch(mll)
        >>> mll.eval()
    """
    optim_options = {"maxiter": 100, "disp": True, "lr": 0.05}
    optim_options.update(options or {})
    exclude = optim_options.pop("exclude", None)
    if exclude is not None:
        mll_params = [
            t for p_name, t in mll.named_parameters() if p_name not in exclude
        ]
    else:
        mll_params = list(mll.parameters())
    optimizer = optimizer_cls(
        params=[{"params": mll_params}],
        **_filter_kwargs(optimizer_cls, **optim_options),
    )

    # get bounds specified in model (if any)
    bounds_: ParameterBounds = {}
    if hasattr(mll, "named_parameters_and_constraints"):
        for param_name, _, constraint in mll.named_parameters_and_constraints():
            if constraint is not None and not constraint.enforced:
                bounds_[param_name] = constraint.lower_bound, constraint.upper_bound

    # update with user-supplied bounds (overwrites if already exists)
    if bounds is not None:
        bounds_.update(bounds)

    iterations = []
    t1 = time.time()

    param_trajectory: Dict[str, List[Tensor]] = {
        name: [] for name, param in mll.named_parameters()
    }
    loss_trajectory: List[float] = []
    i = 0
    converged = False
    convergence_criterion = ConvergenceCriterion(
        **_filter_kwargs(ConvergenceCriterion, **optim_options)
    )
    train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
    while not converged:
        optimizer.zero_grad()
        with gpt_settings.fast_computations(log_prob=approx_mll):
            output = mll.model(*train_inputs)
            # we sum here to support batch mode
            args = [output, train_targets] + _get_extra_mll_args(mll)
            loss = -mll(*args).sum()
            loss.backward()
        loss_trajectory.append(loss.item())
        for name, param in mll.named_parameters():
            param_trajectory[name].append(param.detach().clone())
        if optim_options["disp"] and (
            (i + 1) % 10 == 0 or i == (optim_options["maxiter"] - 1)
        ):
            print(f"Iter {i + 1}/{optim_options['maxiter']}: {loss.item()}")
        if track_iterations:
            iterations.append(OptimizationIteration(i, loss.item(), time.time() - t1))
        optimizer.step()
        # project onto bounds:
        if bounds_:
            for pname, param in mll.named_parameters():
                if pname in bounds_:
                    param.data = param.data.clamp(*bounds_[pname])
        i += 1
        converged = convergence_criterion.evaluate(fvals=loss.detach())
    info_dict = {
        "fopt": loss_trajectory[-1],
        "wall_time": time.time() - t1,
        "iterations": iterations,
    }
    return mll, info_dict
Ejemplo n.º 8
0
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a GPyTorch model.

    On optimizer failures, a new initial condition is sampled from the
    hyperparameter priors and optimization is retried. The maximum number of
    retries can be passed in as a `max_retries` kwarg (default is 5).

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function, including
            `max_retries` and `sequential` (controls the fitting of `ModelListGP`
            and `BatchedMultiOutputGPyTorchModel` models) or `approx_mll`
            (whether to use gpytorch's approximate MLL computation).

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    sequential = kwargs.pop("sequential", True)
    max_retries = kwargs.pop("max_retries", 5)
    if isinstance(mll, SumMarginalLogLikelihood) and sequential:
        for mll_ in mll.mlls:
            fit_gpytorch_model(mll=mll_,
                               optimizer=optimizer,
                               max_retries=max_retries,
                               **kwargs)
        return mll
    elif (isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
          and mll.model._num_outputs > 1 and sequential):
        tf = None
        try:  # check if backwards-conversion is possible
            # remove the outcome transform since the training targets are already
            # transformed and the outcome transform cannot currently be split.
            # TODO: support splitting outcome transforms.
            if hasattr(mll.model, "outcome_transform"):
                tf = mll.model.outcome_transform
                mll.model.outcome_transform = None
            model_list = batched_to_model_list(mll.model)
            mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
            fit_gpytorch_model(
                mll=mll_,
                optimizer=optimizer,
                sequential=True,
                max_retries=max_retries,
                **kwargs,
            )
            model_ = model_list_to_batched(mll_.model)
            mll.model.load_state_dict(model_.state_dict())
            # setting the transformed inputs is necessary because gpytorch
            # stores the raw training inputs on the ExactGP in the
            # ExactGP.__init__ call. At evaluation time, the test inputs will
            # already be in the transformed space if some transforms have
            # transform_on_eval set to False. ExactGP.__call__ will
            # concatenate the test points with the training inputs. Therefore,
            # it is important to set the ExactGP's train_inputs to also be
            # transformed data using all transforms (including the transforms
            # with transform_on_train set to True).
            mll.train()
            if tf is not None:
                mll.model.outcome_transform = tf
            return mll.eval()
        # NotImplementedError is omitted since it derives from RuntimeError
        except (UnsupportedError, RuntimeError, AttributeError):
            warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
            if tf is not None:
                mll.model.outcome_transform = tf
            return fit_gpytorch_model(mll=mll,
                                      optimizer=optimizer,
                                      sequential=False,
                                      max_retries=max_retries)
    # retry with random samples from the priors upon failure
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    while retry < max_retries:
        with warnings.catch_warnings(record=True) as ws:
            if retry > 0:  # use normal initial conditions on first try
                mll.model.load_state_dict(original_state_dict)
                sample_all_priors(mll.model)
            try:
                mll, _ = optimizer(mll, track_iterations=False, **kwargs)
            except NotPSDError:
                retry += 1
                logging.log(
                    logging.DEBUG,
                    f"Fitting failed on try {retry} due to a NotPSDError.",
                )
                continue
        has_optwarning = False
        for w in ws:
            # Do not count reaching `maxiter` as an optimization failure.
            if "ITERATIONS REACHED LIMIT" in str(w.message):
                logging.log(
                    logging.DEBUG,
                    "Fitting ended early due to reaching the iteration limit.",
                )
                continue
            has_optwarning |= issubclass(w.category, OptimizationWarning)
            warnings.warn(w.message, w.category)
        if not has_optwarning:
            mll.eval()
            return mll
        retry += 1
        logging.log(logging.DEBUG, f"Fitting failed on try {retry}.")

    warnings.warn("Fitting failed on all retries.", OptimizationWarning)
    return mll.eval()
Ejemplo n.º 9
0
def fit_gpytorch_torch(
    mll: MarginalLogLikelihood,
    bounds: Optional[ParameterBounds] = None,
    optimizer_cls: Optimizer = Adam,
    options: Optional[Dict[str, Any]] = None,
    track_iterations: bool = True,
) -> Tuple[MarginalLogLikelihood, List[OptimizationIteration]]:
    r"""Fit a gpytorch model by maximizing MLL with a torch optimizer.

    The model and likelihood in mll must already be in train mode.
    Note: this method requires that the model has `train_inputs` and `train_targets`.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        bounds: A ParameterBounds dictionary mapping parameter names to tuples
            of lower and upper bounds. Bounds specified here take precedence
            over bounds on the same parameters specified in the constraints
            registered with the module.
        optimizer_cls: Torch optimizer to use. Must not require a closure.
        options: options for model fitting. Relevant options will be passed to
            the `optimizer_cls`. Additionally, options can include: "disp"
            to specify whether to display model fitting diagnostics and "maxiter"
            to specify the maximum number of iterations.
        track_iterations: Track the function values and wall time for each
            iteration.

    Returns:
        2-element tuple containing

        - mll with parameters optimized in-place.
        - List of OptimizationIteration objects with information on each
          iteration. If track_iterations is False, this will be an empty list.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> mll.train()
        >>> fit_gpytorch_torch(mll)
        >>> mll.eval()
    """
    optim_options = {"maxiter": 100, "disp": True, "lr": 0.05}
    optim_options.update(options or {})
    optimizer = optimizer_cls(
        params=[{"params": mll.parameters()}],
        **_filter_kwargs(optimizer_cls, **optim_options),
    )

    # get bounds specified in model (if any)
    bounds_: ParameterBounds = {}
    if hasattr(mll, "named_parameters_and_constraints"):
        for param_name, _, constraint in mll.named_parameters_and_constraints():
            if constraint is not None and not constraint.enforced:
                bounds_[param_name] = constraint.lower_bound, constraint.upper_bound

    # update with user-supplied bounds (overwrites if already exists)
    if bounds is not None:
        bounds_.update(bounds)

    iterations = []
    t1 = time.time()

    param_trajectory: Dict[str, List[Tensor]] = {
        name: [] for name, param in mll.named_parameters()
    }
    loss_trajectory: List[float] = []
    i = 0
    converged = False
    train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
    while not converged:
        optimizer.zero_grad()
        output = mll.model(*train_inputs)
        # we sum here to support batch mode
        args = [output, train_targets] + _get_extra_mll_args(mll)
        loss = -mll(*args).sum()
        loss.backward()
        loss_trajectory.append(loss.item())
        for name, param in mll.named_parameters():
            param_trajectory[name].append(param.detach().clone())
        if optim_options["disp"] and (
            (i + 1) % 10 == 0 or i == (optim_options["maxiter"] - 1)
        ):
            print(f"Iter {i + 1}/{optim_options['maxiter']}: {loss.item()}")
        if track_iterations:
            iterations.append(OptimizationIteration(i, loss.item(), time.time() - t1))
        optimizer.step()
        # project onto bounds:
        if bounds_:
            for pname, param in mll.named_parameters():
                if pname in bounds_:
                    param.data = param.data.clamp(*bounds_[pname])
        i += 1
        converged = check_convergence(
            loss_trajectory=loss_trajectory,
            param_trajectory=param_trajectory,
            options={"maxiter": optim_options["maxiter"]},
        )
    return mll, iterations
Ejemplo n.º 10
0
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a GPyTorch model.

    On optimizer failures, a new initial condition is sampled from the
    hyperparameter priors and optimization is retried. The maximum number of
    retries can be passed in as a `max_retries` kwarg (default is 5).

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function, including
            `max_retries` and `sequential` (controls the fitting of `ModelListGP`
            and `BatchedMultiOutputGPyTorchModel` models).

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    sequential = kwargs.pop("sequential", True)
    max_retries = kwargs.pop("max_retries", 5)
    if isinstance(mll, SumMarginalLogLikelihood) and sequential:
        for mll_ in mll.mlls:
            fit_gpytorch_model(mll=mll_,
                               optimizer=optimizer,
                               max_retries=max_retries,
                               **kwargs)
        return mll
    elif (isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
          and mll.model._num_outputs > 1 and sequential):
        try:  # check if backwards-conversion is possible
            model_list = batched_to_model_list(mll.model)
            model_ = model_list_to_batched(model_list)
            mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
            fit_gpytorch_model(
                mll=mll_,
                optimizer=optimizer,
                sequential=True,
                max_retries=max_retries,
                **kwargs,
            )
            model_ = model_list_to_batched(mll_.model)
            mll.model.load_state_dict(model_.state_dict())
            return mll.eval()
        except (NotImplementedError, UnsupportedError, RuntimeError,
                AttributeError):
            warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
            return fit_gpytorch_model(mll=mll,
                                      optimizer=optimizer,
                                      sequential=False,
                                      max_retries=max_retries)
    # retry with random samples from the priors upon failure
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    while retry < max_retries:
        with warnings.catch_warnings(record=True) as ws:
            if retry > 0:  # use normal initial conditions on first try
                mll.model.load_state_dict(original_state_dict)
                sample_all_priors(mll.model)
            mll, _ = optimizer(mll, track_iterations=False, **kwargs)
            if not any(
                    issubclass(w.category, OptimizationWarning) for w in ws):
                mll.eval()
                return mll
            retry += 1
            logging.log(logging.DEBUG, f"Fitting failed on try {retry}.")

    warnings.warn("Fitting failed on all retries.", OptimizationWarning)
    return mll.eval()
Ejemplo n.º 11
0
def fit_gpytorch_torch(
    mll: MarginalLogLikelihood,
    bounds: Optional[ParameterBounds] = None,
    optimizer_cls: Optimizer = Adam,
    options: Optional[Dict[str, Any]] = None,
    track_iterations: bool = True,
) -> Tuple[MarginalLogLikelihood, List[OptimizationIteration]]:
    r"""Fit a gpytorch model by maximizing MLL with a torch optimizer.

    The model and likelihood in mll must already be in train mode.
    Note: this method requires that the model has `train_inputs` and `train_targets`.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        bounds: A ParameterBounds dictionary mapping parameter names to tuples
            of lower and upper bounds. Bounds specified here take precedence
            over bounds on the same parameters specified in the constraints
            registered with the module.
        optimizer_cls: Torch optimizer to use. Must not require a closure.
        options: options for model fitting. Relevant options will be passed to
            the `optimizer_cls`. Additionally, options can include: "disp"
            to specify whether to display model fitting diagnostics and "maxiter"
            to specify the maximum number of iterations.
        track_iterations: Track the function values and wall time for each
            iteration.

    Returns:
        2-element tuple containing

        - mll with parameters optimized in-place.
        - List of OptimizationIteration objects with information on each
          iteration. If track_iterations is False, this will be an empty list.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> mll.train()
        >>> fit_gpytorch_torch(mll)
        >>> mll.eval()
    """
    optim_options = {"maxiter": 100, "disp": True, "lr": 0.05}
    optim_options.update(options or {})
    optimizer = optimizer_cls(
        params=[{
            "params": mll.parameters()
        }],
        **_filter_kwargs(optimizer_cls, **optim_options),
    )

    # get bounds specified in model (if any)
    bounds_: ParameterBounds = {}
    if hasattr(mll, "named_parameters_and_constraints"):
        for param_name, _, constraint in mll.named_parameters_and_constraints(
        ):
            if constraint is not None and not constraint.enforced:
                bounds_[
                    param_name] = constraint.lower_bound, constraint.upper_bound

    # update with user-supplied bounds (overwrites if already exists)
    if bounds is not None:
        bounds_.update(bounds)

    iterations = []
    t1 = time.time()

    param_trajectory: Dict[str, List[Tensor]] = {
        name: []
        for name, param in mll.named_parameters()
    }
    loss_trajectory: List[float] = []
    i = 0
    converged = False
    train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
    while not converged:
        optimizer.zero_grad()
        output = mll.model(*train_inputs)
        # we sum here to support batch mode
        args = [output, train_targets] + _get_extra_mll_args(mll)
        loss = -mll(*args).sum()
        loss.backward()
        loss_trajectory.append(loss.item())
        for name, param in mll.named_parameters():
            param_trajectory[name].append(param.detach().clone())
        if optim_options["disp"] and ((i + 1) % 10 == 0
                                      or i == (optim_options["maxiter"] - 1)):
            print(f"Iter {i + 1}/{optim_options['maxiter']}: {loss.item()}")
        if track_iterations:
            iterations.append(
                OptimizationIteration(i, loss.item(),
                                      time.time() - t1))
        optimizer.step()
        # project onto bounds:
        if bounds_:
            for pname, param in mll.named_parameters():
                if pname in bounds_:
                    param.data = param.data.clamp(*bounds_[pname])
        i += 1
        converged = check_convergence(
            loss_trajectory=loss_trajectory,
            param_trajectory=param_trajectory,
            options={"maxiter": optim_options["maxiter"]},
        )
    return mll, iterations
Ejemplo n.º 12
0
def fit_gpytorch_manifold(
    mll: MarginalLogLikelihood,
    bounds: Optional[ParameterBounds] = None,
    solver: Solver = pyman_solvers.ConjugateGradient(maxiter=500),
    nb_init_candidates: int = 200,
    last_x_as_candidate_prob: float = 0.9,
    options: Optional[Dict[str, Any]] = None,
    track_iterations: bool = True,
    approx_mll: bool = False,
    module_to_array_func: TModToArray = module_to_list_of_array,
    module_from_array_func: TArrayToMod = set_params_with_list_of_array,
) -> Tuple[MarginalLogLikelihood, Dict[str, Union[
        float, List[OptimizationIteration]]]]:
    """
    This function fits a gpytorch model by maximizing MLL with a pymanopt optimizer.

    The model and likelihood in mll must already be in train mode.
    This method requires that the model has `train_inputs` and `train_targets`.

    Parameters
    ----------
    :param mll: MarginalLogLikelihood to be maximized.

    Optional parameters
    -------------------
    :param nb_init_candidates: number of random initial candidates for the GP parameters
    :param last_x_as_candidate_prob: probability that the last set of parameter is among the initial candidates
    :param bounds: A dictionary mapping parameter names to tuples of lower and upper bounds.
    :param solver: Pymanopt solver.
    :param options: Dictionary of solver options, passed along to scipy.minimize.
    :param track_iterations: Track the function values and wall time for each iteration.
    :param approx_mll: If True, use gpytorch's approximate MLL computation. This is disabled by default since the
        stochasticity is an issue for determistic optimizers). Enabling this is only recommended when working with
        large training data sets (n>2000).

    Returns
    -------
    :return: 2-element tuple containing
        - MarginalLogLikelihood with parameters optimized in-place.
        - Dictionary with the following key/values:
            "fopt": Best mll value.
            "wall_time": Wall time of fitting.
            "iterations": List of OptimizationIteration objects with information on each iteration.
                If track_iterations is False, will be empty.

    Example:
    gp = SingleTaskGP(train_X, train_Y)
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    mll.train()
    fit_gpytorch_scipy(mll)
    mll.eval()
    """
    options = options or {}
    # Current parameters
    x0, property_dict, bounds = module_to_array_func(module=mll,
                                                     bounds=bounds,
                                                     exclude=options.pop(
                                                         "exclude", None))
    x0 = [x0i.astype(np.float64) for x0i in x0]
    if bounds is not None:
        warnings.warn(
            'Bounds handling not supported yet in fit_gpytorch_manifold')
        # bounds = Bounds(lb=bounds[0], ub=bounds[1], keep_feasible=True)

    t1 = time.time()

    # Define cost function
    def cost(x):
        param_dict = OrderedDict(mll.named_parameters())
        idx = 0
        for p_name, attrs in property_dict.items():
            # Construct the new tensor
            if len(attrs.shape) == 0:  # deal with scalar tensors
                # new_data = torch.tensor(x[0], dtype=attrs.dtype, device=attrs.device)
                new_data = torch.tensor(x[idx][0],
                                        dtype=attrs.dtype,
                                        device=attrs.device)
            else:
                # new_data = torch.tensor(x, dtype=attrs.dtype, device=attrs.device).view(*attrs.shape)
                new_data = torch.tensor(x[idx],
                                        dtype=attrs.dtype,
                                        device=attrs.device).view(*attrs.shape)
            param_dict[p_name].data = new_data
            idx += 1
        # mllx = set_params_with_array(mll, x, property_dict)
        train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets
        mll.zero_grad()
        output = mll.model(*train_inputs)
        args = [output, train_targets] + _get_extra_mll_args(mll)
        loss = -mll(*args).sum()
        return loss

    def egrad(x):
        loss = cost(x)
        loss.backward()
        param_dict = OrderedDict(mll.named_parameters())
        grad = []
        for p_name in property_dict:
            t = param_dict[p_name].grad
            if t is None:
                # this deals with parameters that do not affect the loss
                if len(property_dict[p_name].shape
                       ) > 1 and property_dict[p_name].shape[0] > 1:
                    # if the variable is a matrix, keep its shape
                    grad.append(np.zeros(property_dict[p_name].shape))
                else:
                    grad.append(np.zeros(property_dict[p_name].shape))
            else:
                if t.ndim > 1 and t.shape[
                        0] > 1:  # if the variable is a matrix, keep its shape
                    grad.append(t.detach().cpu().double().clone().numpy())
                else:  # Vector case
                    grad.append(
                        t.detach().view(-1).cpu().double().clone().numpy())
        return grad

    # Define the manifold (product of manifolds)
    manifolds_list = []
    for p_name, t in mll.named_parameters():
        try:
            # If a manifold is given add it
            manifolds_list.append(attrgetter(p_name + "_manifold")(mll))
        except AttributeError:
            # Otherwise, default: Euclidean
            manifolds_list.append(
                Euclidean(int(np.prod(property_dict[p_name].shape))))
    # Product of manifolds
    manifold = Product(manifolds_list)

    # Instanciate the problem on the manifold
    if track_iterations:
        verbosity = 2
    else:
        verbosity = 0

    problem = Problem(manifold=manifold,
                      cost=cost,
                      egrad=egrad,
                      verbosity=verbosity,
                      arg=torch.Tensor())  #, precon=precon)

    # For cases where the Hessian is hard/long to compute, we approximate it with finite differences of the gradient.
    # Typical cases: the Hessian can be hard to compute due to the 2nd derivative of the eigenvalue decomposition,
    # e.g. in the SPD affine-invariant distance.
    problem._hess = types.MethodType(get_hessianfd, problem)

    # Choose initial parameters
    # Do not always consider x0, to encourage variations of the parameters.
    if np.random.rand() < last_x_as_candidate_prob:
        x0_candidates = [x0]
        x0_candidates += [
            manifold.rand() for i in range(nb_init_candidates - 1)
        ]
    else:
        x0_candidates = []
        x0_candidates += [manifold.rand() for i in range(nb_init_candidates)]
    for i in range(int(3 * nb_init_candidates / 4)):
        x0_candidates[i][0:4] = x0[0:4]  #TODO remove hard-coding
    y0_candidates = [cost(x0_candidates[i]) for i in range(nb_init_candidates)]

    y_init, x_init_idx = torch.Tensor(y0_candidates).min(0)
    x_init = x0_candidates[x_init_idx]

    with gpt_settings.fast_computations(log_prob=approx_mll):
        # Logverbosity of the solver to 1
        solver._logverbosity = 1
        # Solve
        opt_x, opt_log = solver.solve(problem, x=x_init)

    # Construct info dict
    info_dict = {
        "fopt": float(cost(opt_x).detach().numpy()),
        "wall_time": time.time() - t1,
        "opt_log": opt_log,
    }
    # if not res.success:  # TODO update
    #     try:
    #         # Some res.message are bytes
    #         msg = res.message.decode("ascii")
    #     except AttributeError:
    #         # Others are str
    #         msg = res.message
    #     warnings.warn(
    #         f"Fitting failed with the optimizer reporting '{msg}'", OptimizationWarning
    #     )
    # Set to optimum
    mll = module_from_array_func(mll, opt_x, property_dict)
    return mll, info_dict