Exemple #1
0
 def test_roundtrip(self):
     for dtype in (torch.float, torch.double):
         train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
         train_Y1 = train_X.sum(dim=-1)
         train_Y2 = train_X[:, 0] - train_X[:, 1]
         train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
         # SingleTaskGP
         batch_gp = SingleTaskGP(train_X, train_Y)
         list_gp = batched_to_model_list(batch_gp)
         batch_gp_recov = model_list_to_batched(list_gp)
         sd_orig = batch_gp.state_dict()
         sd_recov = batch_gp_recov.state_dict()
         self.assertTrue(set(sd_orig) == set(sd_recov))
         self.assertTrue(all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig))
         # FixedNoiseGP
         batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
         list_gp = batched_to_model_list(batch_gp)
         batch_gp_recov = model_list_to_batched(list_gp)
         sd_orig = batch_gp.state_dict()
         sd_recov = batch_gp_recov.state_dict()
         self.assertTrue(set(sd_orig) == set(sd_recov))
         self.assertTrue(all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig))
         # SingleTaskMultiFidelityGP
         for lin_trunc in (False, True):
             batch_gp = SingleTaskMultiFidelityGP(
                 train_X, train_Y, iteration_fidelity=1, linear_truncated=lin_trunc
             )
             list_gp = batched_to_model_list(batch_gp)
             batch_gp_recov = model_list_to_batched(list_gp)
             sd_orig = batch_gp.state_dict()
             sd_recov = batch_gp_recov.state_dict()
             self.assertTrue(set(sd_orig) == set(sd_recov))
             self.assertTrue(
                 all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig)
             )
Exemple #2
0
 def test_batched_to_model_list(self):
     for dtype in (torch.float, torch.double):
         # test SingleTaskGP
         train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
         train_Y1 = train_X.sum(dim=-1)
         train_Y2 = train_X[:, 0] - train_X[:, 1]
         train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
         batch_gp = SingleTaskGP(train_X, train_Y)
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test FixedNoiseGP
         batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test SingleTaskMultiFidelityGP
         for lin_trunc in (False, True):
             batch_gp = SingleTaskMultiFidelityGP(
                 train_X, train_Y, iteration_fidelity=1, linear_truncated=lin_trunc
             )
             list_gp = batched_to_model_list(batch_gp)
             self.assertIsInstance(list_gp, ModelListGP)
         # test HeteroskedasticSingleTaskGP
         batch_gp = HeteroskedasticSingleTaskGP(
             train_X, train_Y, torch.rand_like(train_Y)
         )
         with self.assertRaises(NotImplementedError):
             batched_to_model_list(batch_gp)
 def test_roundtrip(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         train_X = torch.rand(10, 2, device=device, dtype=dtype)
         train_Y1 = train_X.sum(dim=-1)
         train_Y2 = train_X[:, 0] - train_X[:, 1]
         train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
         # SingleTaskGP
         batch_gp = SingleTaskGP(train_X, train_Y)
         list_gp = batched_to_model_list(batch_gp)
         batch_gp_recov = model_list_to_batched(list_gp)
         sd_orig = batch_gp.state_dict()
         sd_recov = batch_gp_recov.state_dict()
         self.assertTrue(set(sd_orig) == set(sd_recov))
         self.assertTrue(
             all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig))
         # FixedNoiseGP
         batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
         list_gp = batched_to_model_list(batch_gp)
         batch_gp_recov = model_list_to_batched(list_gp)
         sd_orig = batch_gp.state_dict()
         sd_recov = batch_gp_recov.state_dict()
         self.assertTrue(set(sd_orig) == set(sd_recov))
         self.assertTrue(
             all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig))
Exemple #4
0
 def test_get_gp_samples(self):
     # test multi-task model
     X = torch.stack([torch.rand(3), torch.tensor([1.0, 0.0, 1.0])], dim=-1)
     Y = torch.rand(3, 1)
     with self.assertRaises(NotImplementedError):
         gp_samples = get_gp_samples(
             model=MultiTaskGP(X, Y, task_feature=1),
             num_outputs=1,
             n_samples=20,
             num_rff_features=500,
         )
     tkwargs = {"device": self.device}
     for dtype, m in product((torch.float, torch.double), (1, 2)):
         tkwargs["dtype"] = dtype
         for mtype in range(2):
             model, X, Y = _get_model(**tkwargs, multi_output=m == 2)
             use_batch_model = mtype == 0 and m == 2
             gp_samples = get_gp_samples(
                 model=batched_to_model_list(model)
                 if use_batch_model else model,
                 num_outputs=m,
                 n_samples=20,
                 num_rff_features=500,
             )
             self.assertEqual(len(gp_samples), 20)
             self.assertIsInstance(gp_samples[0], DeterministicModel)
             Y_hat_rff = torch.stack(
                 [gp_sample(X) for gp_sample in gp_samples],
                 dim=0).mean(dim=0)
             with torch.no_grad():
                 Y_hat = model.posterior(X).mean
             self.assertTrue(torch.allclose(Y_hat_rff, Y_hat, atol=2e-1))
 def test_batched_to_model_list(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         # test SingleTaskGP
         train_X = torch.rand(10, 2, device=device, dtype=dtype)
         train_Y1 = train_X.sum(dim=-1)
         train_Y2 = train_X[:, 0] - train_X[:, 1]
         train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
         batch_gp = SingleTaskGP(train_X, train_Y)
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test FixedNoiseGP
         batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test HeteroskedasticSingleTaskGP
         batch_gp = HeteroskedasticSingleTaskGP(train_X, train_Y,
                                                torch.rand_like(train_Y))
         with self.assertRaises(NotImplementedError):
             batched_to_model_list(batch_gp)
Exemple #6
0
def get_gp_samples(
    model: Model, num_outputs: int, n_samples: int, num_rff_features: int = 500
) -> List[GenericDeterministicModel]:
    r"""Sample functions from GP posterior using RFF.

    Args:
        model: the model
        num_outputs: the number of outputs
        n_samples: the number of sampled functions to draw
        num_rff_features: the number of random fourier features

    Returns:
        A list of sampled functions.
    """
    if num_outputs > 1:
        if not isinstance(model, ModelListGP):
            models = batched_to_model_list(model).models
        else:
            models = model.models
    else:
        models = [model]
    if isinstance(models[0], MultiTaskGP):
        raise NotImplementedError

    weights = []
    bases = []
    for m in range(num_outputs):
        train_X = models[m].train_inputs[0]
        # get random fourier features
        basis = RandomFourierFeatures(
            kernel=models[m].covar_module,
            input_dim=train_X.shape[-1],
            num_rff_features=num_rff_features,
        )
        bases.append(basis)
        phi_X = basis(train_X)
        # sample weights from bayesian linear model
        mvn = get_weights_posterior(
            X=phi_X,
            y=models[m].train_targets,
            sigma_sq=models[m].likelihood.noise.mean().item(),
        )
        weights.append(mvn.sample(torch.Size([n_samples])))
        # construct a determinisitic, multi-output model for each sample
    models = [
        get_deterministic_model(
            weights=[weights[m][i] for m in range(num_outputs)],
            bases=bases,
        )
        for i in range(n_samples)
    ]
    return models
Exemple #7
0
 def test_batched_to_model_list(self):
     for dtype in (torch.float, torch.double):
         # test SingleTaskGP
         train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
         train_Y1 = train_X.sum(dim=-1)
         train_Y2 = train_X[:, 0] - train_X[:, 1]
         train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
         batch_gp = SingleTaskGP(train_X, train_Y)
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test FixedNoiseGP
         batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test SingleTaskMultiFidelityGP
         for lin_trunc in (False, True):
             batch_gp = SingleTaskMultiFidelityGP(
                 train_X,
                 train_Y,
                 iteration_fidelity=1,
                 linear_truncated=lin_trunc)
             list_gp = batched_to_model_list(batch_gp)
             self.assertIsInstance(list_gp, ModelListGP)
         # test HeteroskedasticSingleTaskGP
         batch_gp = HeteroskedasticSingleTaskGP(train_X, train_Y,
                                                torch.rand_like(train_Y))
         with self.assertRaises(NotImplementedError):
             batched_to_model_list(batch_gp)
         # test with transforms
         input_tf = Normalize(
             d=2,
             bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]],
                                 device=self.device,
                                 dtype=dtype),
         )
         octf = Standardize(m=2)
         batch_gp = SingleTaskGP(train_X,
                                 train_Y,
                                 outcome_transform=octf,
                                 input_transform=input_tf)
         list_gp = batched_to_model_list(batch_gp)
         for i, m in enumerate(list_gp.models):
             self.assertIsInstance(m.input_transform, Normalize)
             self.assertTrue(
                 torch.equal(m.input_transform.bounds, input_tf.bounds))
             self.assertIsInstance(m.outcome_transform, Standardize)
             self.assertEqual(m.outcome_transform._m, 1)
             expected_octf = octf.subset_output(idcs=[i])
             for attr_name in ["means", "stdvs", "_stdvs_sq"]:
                 self.assertTrue(
                     torch.equal(
                         m.outcome_transform.__getattr__(attr_name),
                         expected_octf.__getattr__(attr_name),
                     ))
    def test_get_gp_samples(self):
        # test multi-task model
        with torch.random.fork_rng():
            torch.manual_seed(0)
            X = torch.stack(
                [torch.rand(3), torch.tensor([1.0, 0.0, 1.0])], dim=-1)
            Y = torch.rand(3, 1)

        with self.assertRaises(NotImplementedError):
            gp_samples = get_gp_samples(
                model=MultiTaskGP(X, Y, task_feature=1),
                num_outputs=1,
                n_samples=20,
                num_rff_features=500,
            )
        tkwargs = {"device": self.device}
        for dtype, m in product((torch.float, torch.double), (1, 2)):
            tkwargs["dtype"] = dtype
            for mtype in [True, False]:
                model, X, Y = _get_model(**tkwargs, multi_output=m == 2)
                use_batch_model = mtype and m == 2
                with torch.random.fork_rng():
                    torch.manual_seed(0)
                    gp_samples = get_gp_samples(
                        model=batched_to_model_list(model)
                        if use_batch_model else model,
                        num_outputs=m,
                        n_samples=20,
                        num_rff_features=500,
                    )
                self.assertEqual(len(gp_samples(X)), 20)
                self.assertIsInstance(gp_samples, DeterministicModel)
                Y_hat_rff = gp_samples(X).mean(dim=0)
                with torch.no_grad():
                    Y_hat = model.posterior(X).mean
                self.assertTrue(torch.allclose(Y_hat_rff, Y_hat, atol=2e-1))

                # test batched evaluation
                Y_batched = gp_samples(
                    torch.randn(13, 20, 3, X.shape[-1], **tkwargs))
                self.assertEqual(Y_batched.shape, torch.Size([13, 20, 3, m]))

        # test incorrect batch shape check
        with self.assertRaises(ValueError):
            gp_samples(torch.randn(13, 23, 3, X.shape[-1], **tkwargs))
Exemple #9
0
 def test_batched_to_model_list(self):
     for dtype in (torch.float, torch.double):
         # test SingleTaskGP
         train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
         train_Y1 = train_X.sum(dim=-1)
         train_Y2 = train_X[:, 0] - train_X[:, 1]
         train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
         batch_gp = SingleTaskGP(train_X, train_Y)
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test FixedNoiseGP
         batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
         list_gp = batched_to_model_list(batch_gp)
         self.assertIsInstance(list_gp, ModelListGP)
         # test SingleTaskMultiFidelityGP
         for lin_trunc in (False, True):
             batch_gp = SingleTaskMultiFidelityGP(
                 train_X,
                 train_Y,
                 iteration_fidelity=1,
                 linear_truncated=lin_trunc)
             list_gp = batched_to_model_list(batch_gp)
             self.assertIsInstance(list_gp, ModelListGP)
         # test HeteroskedasticSingleTaskGP
         batch_gp = HeteroskedasticSingleTaskGP(train_X, train_Y,
                                                torch.rand_like(train_Y))
         with self.assertRaises(NotImplementedError):
             batched_to_model_list(batch_gp)
         # test input transform
         input_tf = Normalize(
             d=2,
             bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]],
                                 device=self.device,
                                 dtype=dtype),
         )
         batch_gp = SingleTaskGP(train_X, train_Y, input_transform=input_tf)
         list_gp = batched_to_model_list(batch_gp)
         for m in list_gp.models:
             self.assertIsInstance(m.input_transform, Normalize)
             self.assertTrue(
                 torch.equal(m.input_transform.bounds, input_tf.bounds))
Exemple #10
0
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a GPyTorch model.

    On optimizer failures, a new initial condition is sampled from the
    hyperparameter priors and optimization is retried. The maximum number of
    retries can be passed in as a `max_retries` kwarg (default is 5).

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function, including
            `max_retries` and `sequential` (controls the fitting of `ModelListGP`
            and `BatchedMultiOutputGPyTorchModel` models) or `approx_mll`
            (whether to use gpytorch's approximate MLL computation).

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    sequential = kwargs.pop("sequential", True)
    max_retries = kwargs.pop("max_retries", 5)
    if isinstance(mll, SumMarginalLogLikelihood) and sequential:
        for mll_ in mll.mlls:
            fit_gpytorch_model(mll=mll_,
                               optimizer=optimizer,
                               max_retries=max_retries,
                               **kwargs)
        return mll
    elif (isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
          and mll.model._num_outputs > 1 and sequential):
        tf = None
        try:  # check if backwards-conversion is possible
            # remove the outcome transform since the training targets are already
            # transformed and the outcome transform cannot currently be split.
            # TODO: support splitting outcome transforms.
            if hasattr(mll.model, "outcome_transform"):
                tf = mll.model.outcome_transform
                mll.model.outcome_transform = None
            model_list = batched_to_model_list(mll.model)
            mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
            fit_gpytorch_model(
                mll=mll_,
                optimizer=optimizer,
                sequential=True,
                max_retries=max_retries,
                **kwargs,
            )
            model_ = model_list_to_batched(mll_.model)
            mll.model.load_state_dict(model_.state_dict())
            # setting the transformed inputs is necessary because gpytorch
            # stores the raw training inputs on the ExactGP in the
            # ExactGP.__init__ call. At evaluation time, the test inputs will
            # already be in the transformed space if some transforms have
            # transform_on_eval set to False. ExactGP.__call__ will
            # concatenate the test points with the training inputs. Therefore,
            # it is important to set the ExactGP's train_inputs to also be
            # transformed data using all transforms (including the transforms
            # with transform_on_train set to True).
            mll.train()
            if tf is not None:
                mll.model.outcome_transform = tf
            return mll.eval()
        # NotImplementedError is omitted since it derives from RuntimeError
        except (UnsupportedError, RuntimeError, AttributeError):
            warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
            if tf is not None:
                mll.model.outcome_transform = tf
            return fit_gpytorch_model(mll=mll,
                                      optimizer=optimizer,
                                      sequential=False,
                                      max_retries=max_retries)
    # retry with random samples from the priors upon failure
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    while retry < max_retries:
        with warnings.catch_warnings(record=True) as ws:
            if retry > 0:  # use normal initial conditions on first try
                mll.model.load_state_dict(original_state_dict)
                sample_all_priors(mll.model)
            try:
                mll, _ = optimizer(mll, track_iterations=False, **kwargs)
            except NotPSDError:
                retry += 1
                logging.log(
                    logging.DEBUG,
                    f"Fitting failed on try {retry} due to a NotPSDError.",
                )
                continue
        has_optwarning = False
        for w in ws:
            # Do not count reaching `maxiter` as an optimization failure.
            if "ITERATIONS REACHED LIMIT" in str(w.message):
                logging.log(
                    logging.DEBUG,
                    "Fitting ended early due to reaching the iteration limit.",
                )
                continue
            has_optwarning |= issubclass(w.category, OptimizationWarning)
            warnings.warn(w.message, w.category)
        if not has_optwarning:
            mll.eval()
            return mll
        retry += 1
        logging.log(logging.DEBUG, f"Fitting failed on try {retry}.")

    warnings.warn("Fitting failed on all retries.", OptimizationWarning)
    return mll.eval()
Exemple #11
0
def get_gp_samples(
    model: Model, num_outputs: int, n_samples: int, num_rff_features: int = 500
) -> GenericDeterministicModel:
    r"""Sample functions from GP posterior using RFFs. The returned
    `GenericDeterministicModel` effectively wraps `num_outputs` models,
    each of which has a batch shape of `n_samples`. Refer
    `get_deterministic_model_multi_samples` for more details.

    Args:
        model: The model.
        num_outputs: The number of outputs.
        n_samples: The number of functions to be sampled IID.
        num_rff_features: The number of random Fourier features.

    Returns:
        A batched `GenericDeterministicModel` that batch evaluates `n_samples`
        sampled functions.
    """
    if num_outputs > 1:
        if not isinstance(model, ModelListGP):
            models = batched_to_model_list(model).models
        else:
            models = model.models
    else:
        models = [model]
    if isinstance(models[0], MultiTaskGP):
        raise NotImplementedError

    weights = []
    bases = []
    for m in range(num_outputs):
        train_X = models[m].train_inputs[0]
        train_targets = models[m].train_targets
        # get random fourier features
        # sample_shape controls the number of iid functions.
        basis = RandomFourierFeatures(
            kernel=models[m].covar_module,
            input_dim=train_X.shape[-1],
            num_rff_features=num_rff_features,
            sample_shape=torch.Size([n_samples]),
        )
        bases.append(basis)
        # TODO: when batched kernels are supported in RandomFourierFeatures,
        # the following code can be uncommented.
        # if train_X.ndim > 2:
        #    batch_shape_train_X = train_X.shape[:-2]
        #    dataset_shape = train_X.shape[-2:]
        #    train_X = train_X.unsqueeze(-3).expand(
        #        *batch_shape_train_X, n_samples, *dataset_shape
        #    )
        #    train_targets = train_targets.unsqueeze(-2).expand(
        #        *batch_shape_train_X, n_samples, dataset_shape[0]
        #    )
        phi_X = basis(train_X)
        # Sample weights from bayesian linear model
        # 1. When inputs are not batched, train_X.shape == (n, d)
        # weights.sample().shape == (n_samples, num_rff_features)
        # 2. When inputs are batched, train_X.shape == (batch_shape_input, n, d)
        # This is expanded to (batch_shape_input, n_samples, n, d)
        # to maintain compatibility with RFF forward semantics
        # weights.sample().shape == (batch_shape_input, n_samples, num_rff_features)
        mvn = get_weights_posterior(
            X=phi_X,
            y=train_targets,
            sigma_sq=models[m].likelihood.noise.mean().item(),
        )
        weights.append(mvn.sample())

    # TODO: Ideally support RFFs for multi-outputs instead of having to
    # generate a basis for each output serially.
    return get_deterministic_model_multi_samples(weights=weights, bases=bases)
Exemple #12
0
def fit_gpytorch_model(mll: MarginalLogLikelihood,
                       optimizer: Callable = fit_gpytorch_scipy,
                       **kwargs: Any) -> MarginalLogLikelihood:
    r"""Fit hyperparameters of a GPyTorch model.

    On optimizer failures, a new initial condition is sampled from the
    hyperparameter priors and optimization is retried. The maximum number of
    retries can be passed in as a `max_retries` kwarg (default is 5).

    Optimizer functions are in botorch.optim.fit.

    Args:
        mll: MarginalLogLikelihood to be maximized.
        optimizer: The optimizer function.
        kwargs: Arguments passed along to the optimizer function, including
            `max_retries` and `sequential` (controls the fitting of `ModelListGP`
            and `BatchedMultiOutputGPyTorchModel` models) or `approx_mll`
            (whether to use gpytorch's approximate MLL computation).

    Returns:
        MarginalLogLikelihood with optimized parameters.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        >>> fit_gpytorch_model(mll)
    """
    sequential = kwargs.pop("sequential", True)
    max_retries = kwargs.pop("max_retries", 5)
    if isinstance(mll, SumMarginalLogLikelihood) and sequential:
        for mll_ in mll.mlls:
            fit_gpytorch_model(mll=mll_,
                               optimizer=optimizer,
                               max_retries=max_retries,
                               **kwargs)
        return mll
    elif (isinstance(mll.model, BatchedMultiOutputGPyTorchModel)
          and mll.model._num_outputs > 1 and sequential):
        try:  # check if backwards-conversion is possible
            model_list = batched_to_model_list(mll.model)
            model_ = model_list_to_batched(model_list)
            mll_ = SumMarginalLogLikelihood(model_list.likelihood, model_list)
            fit_gpytorch_model(
                mll=mll_,
                optimizer=optimizer,
                sequential=True,
                max_retries=max_retries,
                **kwargs,
            )
            model_ = model_list_to_batched(mll_.model)
            mll.model.load_state_dict(model_.state_dict())
            return mll.eval()
        # NotImplentedError is omitted since it derives from RuntimeError
        except (UnsupportedError, RuntimeError, AttributeError):
            warnings.warn(FAILED_CONVERSION_MSG, BotorchWarning)
            return fit_gpytorch_model(mll=mll,
                                      optimizer=optimizer,
                                      sequential=False,
                                      max_retries=max_retries)
    # retry with random samples from the priors upon failure
    mll.train()
    original_state_dict = deepcopy(mll.model.state_dict())
    retry = 0
    while retry < max_retries:
        with warnings.catch_warnings(record=True) as ws:
            if retry > 0:  # use normal initial conditions on first try
                mll.model.load_state_dict(original_state_dict)
                sample_all_priors(mll.model)
            mll, _ = optimizer(mll, track_iterations=False, **kwargs)
            if not any(
                    issubclass(w.category, OptimizationWarning) for w in ws):
                mll.eval()
                return mll
            retry += 1
            logging.log(logging.DEBUG, f"Fitting failed on try {retry}.")

    warnings.warn("Fitting failed on all retries.", OptimizationWarning)
    return mll.eval()