Exemplo n.º 1
0
    def get_and_fit_model(
        self,
        Xs: List[Tensor],
        Ys: List[Tensor],
        Yvars: List[Tensor],
        task_features: List[int],
        fidelity_features: List[int],
        metric_names: List[str],
        state_dict: Optional[Dict[str, Tensor]] = None,
        fidelity_model_id: Optional[int] = None,
        **kwargs: Any,
    ) -> ModelListGP:
        """Get a fitted multi-task contextual GP model for each outcome.
        Args:
            Xs: List of X data, one tensor per outcome.
            Ys: List of Y data, one tensor per outcome.
            Yvars:List of Noise variance of Yvar data, one tensor per outcome.
        Returns: Fitted multi-task contextual GP model.
        """

        models = []
        for i, X in enumerate(Xs):
            # validate input Yvars
            Yvar = Yvars[i].clamp_min_(MIN_OBSERVED_NOISE_LEVEL)
            is_nan = torch.isnan(Yvar)
            all_nan_Yvar = torch.all(is_nan)
            if all_nan_Yvar:
                gp_m = LCEMGP(
                    train_X=X,
                    train_Y=Ys[i],
                    task_feature=task_features[i],
                    context_cat_feature=self.context_cat_feature,
                    context_emb_feature=self.context_emb_feature,
                    embs_dim_list=self.embs_dim_list,
                )
            else:
                gp_m = FixedNoiseLCEMGP(
                    train_X=X,
                    train_Y=Ys[i],
                    train_Yvar=Yvar,
                    task_feature=task_features[i],
                    context_cat_feature=self.context_cat_feature,
                    context_emb_feature=self.context_emb_feature,
                    embs_dim_list=self.embs_dim_list,
                )
            models.append(gp_m)
        # Use a ModelListGP
        model = ModelListGP(*models)
        model.to(Xs[0])
        mll = SumMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)
        return model
Exemplo n.º 2
0
def _get_single_task_gpytorch_model(
    Xs: List[Tensor],
    Ys: List[Tensor],
    Yvars: List[Tensor],
    task_features: List[int],
    fidelity_features: List[int],
    state_dict: Optional[Dict[str, Tensor]] = None,
    num_samples: int = 512,
    thinning: int = 16,
    use_input_warping: bool = False,
    gp_kernel: str = "matern",
    **kwargs: Any,
) -> ModelListGP:
    r"""Instantiates a batched GPyTorchModel(ModelListGP) based on the given data.
    The model fitting is based on MCMC and is run separately using pyro. The MCMC
    samples will be loaded into the model instantiated here afterwards.

    Returns:
        A ModelListGP.
    """
    if len(task_features) > 0:
        raise NotImplementedError(
            "Currently do not support MT-GP models with MCMC!")
    if len(fidelity_features) > 0:
        raise NotImplementedError(
            "Fidelity MF-GP models are not currently supported with MCMC!")

    num_mcmc_samples = num_samples // thinning
    covar_modules = [
        _get_rbf_kernel(num_samples=num_mcmc_samples, dim=Xs[0].shape[-1])
        if gp_kernel == "rbf" else None for _ in range(len(Xs))
    ]

    models = [
        _get_model(
            X=X.unsqueeze(0).expand(num_mcmc_samples, X.shape[0], -1),
            Y=Y.unsqueeze(0).expand(num_mcmc_samples, Y.shape[0], -1),
            Yvar=Yvar.unsqueeze(0).expand(num_mcmc_samples, Yvar.shape[0], -1),
            fidelity_features=fidelity_features,
            use_input_warping=use_input_warping,
            covar_module=covar_module,
            **kwargs,
        ) for X, Y, Yvar, covar_module in zip(Xs, Ys, Yvars, covar_modules)
    ]
    model = ModelListGP(*models)
    model.to(Xs[0])
    return model
Exemplo n.º 3
0
    def get_and_fit_model(
        self,
        Xs: List[Tensor],
        Ys: List[Tensor],
        Yvars: List[Tensor],
        task_features: List[int],
        fidelity_features: List[int],
        metric_names: List[str],
        state_dict: Optional[Dict[str, Tensor]] = None,
        fidelity_model_id: Optional[int] = None,
        **kwargs: Any,
    ) -> GPyTorchModel:
        """Get a fitted LCEAGP model for each outcome.
        Args:
            Xs: X for each outcome.
            Ys: Y for each outcome.
            Yvars: Noise variance of Y for each outcome.
        Returns: Fitted LCEAGP model.
        """
        # generate model space decomposition dict
        decomp_index = generate_model_space_decomposition(
            decomposition=self.decomposition, feature_names=self.feature_names
        )

        models = []
        for i, X in enumerate(Xs):
            Yvar = Yvars[i].clamp_min_(MIN_OBSERVED_NOISE_LEVEL)
            gp_m, _ = get_map_model(
                train_X=X,
                train_Y=Ys[i],
                train_Yvar=Yvar,
                decomposition=decomp_index,
                train_embedding=self.train_embedding,
                cat_feature_dict=self.cat_feature_dict,
                embs_feature_dict=self.embs_feature_dict,
                embs_dim_list=self.embs_dim_list,
                context_weight_dict=self.context_weight_dict,
            )
            models.append(gp_m)

        if len(models) == 1:
            model = models[0]
        else:
            model = ModelListGP(*models)
        model.to(Xs[0])
        return model
Exemplo n.º 4
0
    def construct(self, training_data: List[TrainingData],
                  **kwargs: Any) -> None:
        """Constructs the underlying BoTorch ``Model`` using the training data.

        Args:
            training_data: List of ``TrainingData`` for the submodels of
                ``ModelListGP``. Each training data is for one outcome,
                and the order of outcomes should match the order of metrics
                in ``metric_names`` argument.
            **kwargs: Keyword arguments, accepts:
                - ``metric_names`` (required): Names of metrics, in the same order
                as training data (so if training data is ``[tr_A, tr_B]``, the
                metrics are ``["A" and "B"]``). These are used to match training data
                with correct submodels of ``ModelListGP``,
                - ``fidelity_features``: Indices of columns in X that represent
                fidelity,
                - ``task_features``: Indices of columns in X that represent tasks.
        """
        metric_names = kwargs.get(Keys.METRIC_NAMES)
        fidelity_features = kwargs.get(Keys.FIDELITY_FEATURES, [])
        task_features = kwargs.get(Keys.TASK_FEATURES, [])
        if metric_names is None:
            raise ValueError("Metric names are required.")

        self._training_data_per_outcome = {
            metric_name: tr
            for metric_name, tr in zip(metric_names, training_data)
        }

        submodels = []
        for m in metric_names:
            model_cls = self.botorch_submodel_class_per_outcome.get(
                m, self.botorch_submodel_class)
            if not model_cls:
                raise ValueError(f"No model class specified for outcome {m}.")

            if m not in self.training_data_per_outcome:  # pragma: no cover
                logger.info(f"Metric {m} not in training data.")
                continue

            # NOTE: here we do a shallow copy of `self.submodel_options`, to
            # protect from accidental modification of shared options. As it is
            # a shallow copy, it does not protect the objects in the dictionary,
            # just the dictionary itself.
            submodel_options = {
                **self.submodel_options,
                **self.submodel_options_per_outcome.get(m, {}),
            }

            formatted_model_inputs = model_cls.construct_inputs(
                training_data=self.training_data_per_outcome[m],
                fidelity_features=fidelity_features,
                task_features=task_features,
                **submodel_options,
            )
            # pyre-ignore[45]: Py raises informative error if model is abstract.
            submodels.append(model_cls(**formatted_model_inputs))
        self._model = ModelListGP(*submodels)
Exemplo n.º 5
0
def batched_to_model_list(
        batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
    """Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.

    Args:
        model_list: The `BatchedMultiOutputGPyTorchModel` to be converted to a
            `ModelListGP`.

    Returns:
        The model converted into a `ModelListGP`.

    Example:
        >>> train_X = torch.rand(5, 2)
        >>> train_Y = torch.rand(5, 2)
        >>> batch_gp = SingleTaskGP(train_X, train_Y)
        >>> list_gp = batched_to_model_list(batch_gp)
    """
    # TODO: Add support for HeteroskedasticSingleTaskGP
    if isinstance(batch_model, HeteroskedasticSingleTaskGP):
        raise NotImplementedError(
            "Conversion of HeteroskedasticSingleTaskGP currently not supported."
        )
    batch_sd = batch_model.state_dict()

    tensors = {n for n, p in batch_sd.items() if len(p.shape) > 0}
    scalars = set(batch_sd) - tensors
    input_bdims = len(batch_model._input_batch_shape)

    models = []

    for i in range(batch_model._num_outputs):
        scalar_sd = {s: batch_sd[s].clone() for s in scalars}
        tensor_sd = {
            t: (batch_sd[t].select(input_bdims, i).clone()
                if "active_dims" not in t else batch_sd[t].clone())
            for t in tensors
        }
        sd = {**scalar_sd, **tensor_sd}
        kwargs = {
            "train_X":
            batch_model.train_inputs[0].select(input_bdims, i).clone(),
            "train_Y":
            batch_model.train_targets.select(input_bdims,
                                             i).clone().unsqueeze(-1),
        }
        if isinstance(batch_model, FixedNoiseGP):
            noise_covar = batch_model.likelihood.noise_covar
            kwargs["train_Yvar"] = (noise_covar.noise.select(
                input_bdims, i).clone().unsqueeze(-1))
        if isinstance(batch_model, SingleTaskMultiFidelityGP):
            kwargs.update(batch_model._init_args)
        model = batch_model.__class__(**kwargs)
        model.load_state_dict(sd)
        models.append(model)

    return ModelListGP(*models)
Exemplo n.º 6
0
    def get_and_fit_model(
        self,
        Xs: List[Tensor],
        Ys: List[Tensor],
        Yvars: List[Tensor],
        state_dicts: Optional[List[MutableMapping[str, Tensor]]] = None,
    ) -> GPyTorchModel:
        """Get a fitted ALEBO model for each outcome.

        Args:
            Xs: X for each outcome, already projected down.
            Ys: Y for each outcome.
            Yvars: Noise variance of Y for each outcome.
            state_dicts: State dicts to initialize model fitting.

        Returns: Fitted ALEBO model.
        """
        if state_dicts is None:
            state_dicts = [None] * len(Xs)
            fit_restarts = self.fit_restarts
        else:
            fit_restarts = 1  # Warm-started
        Yvars = [Yvar.clamp_min_(1e-7) for Yvar in Yvars]
        models = [
            get_fitted_model(
                B=self.B,
                train_X=X,
                train_Y=Ys[i],
                train_Yvar=Yvars[i],
                restarts=fit_restarts,
                nsamp=self.laplace_nsamp,
                # pyre-fixme[6]: Expected `Optional[Dict[str, Tensor]]` for 7th
                #  param but got `Optional[MutableMapping[str, Tensor]]`.
                init_state_dict=state_dicts[i],
            )
            for i, X in enumerate(Xs)
        ]
        if len(models) == 1:
            model = models[0]
        else:
            model = ModelListGP(*models)
        model.to(Xs[0])
        return model
Exemplo n.º 7
0
    def get_and_fit_model(
        self,
        Xs: List[Tensor],
        Ys: List[Tensor],
        Yvars: List[Tensor],
        task_features: List[int],
        fidelity_features: List[int],
        metric_names: List[str],
        state_dict: Optional[Dict[str, Tensor]] = None,
        fidelity_model_id: Optional[int] = None,
        **kwargs: Any,
    ) -> GPyTorchModel:
        """Get a fitted StructuralAdditiveContextualGP model for each outcome.
        Args:
            Xs: X for each outcome.
            Ys: Y for each outcome.
            Yvars: Noise variance of Y for each outcome.
        Returns: Fitted StructuralAdditiveContextualGP model.
        """
        # generate model space decomposition dict
        decomp_index = generate_model_space_decomposition(
            decomposition=self.decomposition, feature_names=self.feature_names)

        models = []
        for i, X in enumerate(Xs):
            Yvar = Yvars[i].clamp_min_(MIN_OBSERVED_NOISE_LEVEL)
            gp_m = SACGP(X, Ys[i], Yvar, decomp_index)
            mll = ExactMarginalLogLikelihood(gp_m.likelihood, gp_m)
            fit_gpytorch_model(mll)
            models.append(gp_m)

        if len(models) == 1:
            model = models[0]
        else:
            model = ModelListGP(*models)
        model.to(Xs[0])
        return model
Exemplo n.º 8
0
    def construct(self, training_data: List[TrainingData],
                  **kwargs: Any) -> None:
        """Constructs the underlying BoTorch `Model` using the training data.

        Args:
            training_data: List of `TrainingData` for the submodels of `ModelListGP`.
                Each training data is for one outcome, and the order of outcomes
                should match the order of metrics in `metric_names` argument.
            **kwargs: Keyword arguments, accepts:
                - `metric_names` (required): Names of metrics, in the same order
                as training data (so if training data is `[tr_A, tr_B]`, the metrics
                would be `["A" and "B"]`). These are used to match training data
                with correct submodels of `ModelListGP`,
                - `fidelity_features`: Indices of columns in X that represent
                fidelity,
                - `task_features`: Indices of columns in X that represent tasks.
        """
        metric_names = kwargs.get(Keys.METRIC_NAMES)
        fidelity_features = kwargs.get(Keys.FIDELITY_FEATURES, [])
        task_features = kwargs.get(Keys.TASK_FEATURES, [])
        if metric_names is None:
            raise ValueError("Metric names are required.")

        self._training_data_per_outcome = {
            metric_name: tr
            for metric_name, tr in zip(metric_names, training_data)
        }
        submodel_options = self.submodel_options_per_outcome or {}
        submodels = []

        for metric_name, model_cls in self.botorch_model_class_per_outcome.items(
        ):
            if metric_name not in self.training_data_per_outcome:
                continue  # pragma: no cover
            tr = self.training_data_per_outcome[metric_name]
            formatted_model_inputs = model_cls.construct_inputs(
                training_data=tr,
                fidelity_features=fidelity_features,
                task_features=task_features,
            )
            kwargs = submodel_options.get(metric_name, {})
            # pyre-ignore[45]: Py raises informative msg if `model_cls` abstract.
            submodels.append(model_cls(**formatted_model_inputs, **kwargs))
        self._model = ModelListGP(*submodels)
Exemplo n.º 9
0
def get_and_fit_model(
    Xs: List[Tensor],
    Ys: List[Tensor],
    Yvars: List[Tensor],
    task_features: List[int],
    fidelity_features: List[int],
    state_dict: Optional[Dict[str, Tensor]] = None,
    refit_model: bool = True,
    **kwargs: Any,
) -> GPyTorchModel:
    r"""Instantiates and fits a botorch ModelListGP using the given data.

    Args:
        Xs: List of X data, one tensor per outcome
        Ys: List of Y data, one tensor per outcome
        Yvars: List of observed variance of Ys.
        task_features: List of columns of X that are tasks.
        fidelity_features: List of columns of X that are fidelity parameters.
        state_dict: If provided, will set model parameters to this state
            dictionary. Otherwise, will fit the model.
        refit_model: Flag for refitting model.

    Returns:
        A fitted GPyTorchModel.
    """
    if len(fidelity_features) > 0 and len(task_features) > 0:
        raise NotImplementedError(
            "Currently do not support MF-GP models with task_features!"
        )
    if len(fidelity_features) > 1:
        raise NotImplementedError(
            "Fidelity MF-GP models currently support only a single fidelity parameter!"
        )
    if len(task_features) > 1:
        raise NotImplementedError(
            f"This model only supports 1 task feature (got {task_features})"
        )
    elif len(task_features) == 1:
        task_feature = task_features[0]
    else:
        task_feature = None
    model = None
    if task_feature is None:
        if len(Xs) == 1:
            # Use single output, single task GP
            model = _get_model(
                X=Xs[0],
                Y=Ys[0],
                Yvar=Yvars[0],
                task_feature=task_feature,
                fidelity_features=fidelity_features,
                **kwargs,
            )
        elif all(torch.equal(Xs[0], X) for X in Xs[1:]):
            # Use batched multioutput, single task GP
            Y = torch.cat(Ys, dim=-1)
            Yvar = torch.cat(Yvars, dim=-1)
            model = _get_model(
                X=Xs[0],
                Y=Y,
                Yvar=Yvar,
                task_feature=task_feature,
                fidelity_features=fidelity_features,
                **kwargs,
            )
    if model is None:
        # Use a ModelListGP
        models = [
            _get_model(X=X, Y=Y, Yvar=Yvar, task_feature=task_feature, **kwargs)
            for X, Y, Yvar in zip(Xs, Ys, Yvars)
        ]
        model = ModelListGP(*models)
    model.to(Xs[0])
    if state_dict is not None:
        model.load_state_dict(state_dict)
    if state_dict is None or refit_model:
        # TODO: Add bounds for optimization stability - requires revamp upstream
        bounds = {}
        if isinstance(model, ModelListGP):
            mll = SumMarginalLogLikelihood(model.likelihood, model)
        else:
            # pyre-ignore: [16]
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
        mll = fit_gpytorch_model(mll, bounds=bounds)
    return model
Exemplo n.º 10
0
    def testALEBOGP(self):
        # First non-batch
        B = torch.tensor(
            [[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 3.0, 4.0, 5.0, 6.0]],
            dtype=torch.double)
        train_X = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                               dtype=torch.double)
        train_Y = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.double)
        train_Yvar = 0.1 * torch.ones(3, 1, dtype=torch.double)

        mll = get_map_model(
            B=B,
            train_X=train_X,
            train_Y=train_Y,
            train_Yvar=train_Yvar,
            restarts=1,
            init_state_dict=None,
        )
        m = mll.model
        m.eval()
        self.assertIsInstance(m, ALEBOGP)
        self.assertIsInstance(m.covar_module.base_kernel, ALEBOKernel)

        X = torch.tensor([[2.0, 2.0], [3.0, 3.0], [4.0, 4.0]],
                         dtype=torch.double)
        f = m(X)
        self.assertEqual(f.mean.shape, torch.Size([3]))
        self.assertEqual(f.variance.shape, torch.Size([3]))
        self.assertEqual(f.covariance_matrix.shape, torch.Size([3, 3]))

        # Batch
        Uvec_b = m.covar_module.base_kernel.Uvec.repeat(5, 1)
        mean_b = m.mean_module.constant.repeat(5, 1)
        output_scale_b = m.covar_module.raw_outputscale.repeat(5)
        m_b = get_batch_model(
            B=B,
            train_X=train_X,
            train_Y=train_Y,
            train_Yvar=train_Yvar,
            Uvec_batch=Uvec_b,
            mean_constant_batch=mean_b,
            output_scale_batch=output_scale_b,
        )

        self.assertEqual(m_b._aug_batch_shape, torch.Size([5]))
        f = m_b(X)
        self.assertEqual(f.mean.shape, torch.Size([3]))
        self.assertEqual(f.variance.shape, torch.Size([3]))
        self.assertEqual(f.covariance_matrix.shape, torch.Size([3, 3]))
        self.assertEqual(
            m_b.posterior(X).mvn.covariance_matrix.shape, torch.Size([3, 3]))

        # The whole process in get_fitted_model
        init_state_dict = m.state_dict()
        m_b2 = get_fitted_model(
            B=B,
            train_X=train_X,
            train_Y=train_Y,
            train_Yvar=train_Yvar,
            restarts=1,
            nsamp=5,
            init_state_dict=init_state_dict,
        )
        self.assertEqual(m_b2._aug_batch_shape, torch.Size([5]))

        # Test extract_map_statedict
        map_sds = extract_map_statedict(m_b=m_b, num_outputs=1)
        self.assertEqual(len(map_sds), 1)
        self.assertEqual(len(map_sds[0]), 3)
        self.assertEqual(
            set(map_sds[0]),
            {
                "covar_module.base_kernel.Uvec",
                "covar_module.raw_outputscale",
                "mean_module.constant",
            },
        )
        self.assertEqual(map_sds[0]["covar_module.base_kernel.Uvec"].shape,
                         torch.Size([3]))

        ml = ModelListGP(m_b, m_b2)
        map_sds = extract_map_statedict(m_b=ml, num_outputs=2)
        self.assertEqual(len(map_sds), 2)
        for i in range(2):
            self.assertEqual(len(map_sds[i]), 3)
            self.assertEqual(
                set(map_sds[i]),
                {
                    "covar_module.base_kernel.Uvec",
                    "covar_module.raw_outputscale",
                    "mean_module.constant",
                },
            )
            self.assertEqual(map_sds[i]["covar_module.base_kernel.Uvec"].shape,
                             torch.Size([3]))
Exemplo n.º 11
0
def get_and_fit_model(
    Xs: List[Tensor],
    Ys: List[Tensor],
    Yvars: List[Tensor],
    task_features: List[int],
    fidelity_features: List[int],
    metric_names: List[str],
    state_dict: Optional[Dict[str, Tensor]] = None,
    refit_model: bool = True,
    **kwargs: Any,
) -> GPyTorchModel:
    r"""Instantiates and fits a botorch GPyTorchModel using the given data.
    N.B. Currently, the logic for choosing ModelListGP vs other models is handled
    using if-else statements in lines 96-137. In the future, this logic should be
    taken care of by modular botorch.

    Args:
        Xs: List of X data, one tensor per outcome.
        Ys: List of Y data, one tensor per outcome.
        Yvars: List of observed variance of Ys.
        task_features: List of columns of X that are tasks.
        fidelity_features: List of columns of X that are fidelity parameters.
        metric_names: Names of each outcome Y in Ys.
        state_dict: If provided, will set model parameters to this state
            dictionary. Otherwise, will fit the model.
        refit_model: Flag for refitting model.

    Returns:
        A fitted GPyTorchModel.
    """

    if len(fidelity_features) > 0 and len(task_features) > 0:
        raise NotImplementedError(
            "Currently do not support MF-GP models with task_features!")
    if len(fidelity_features) > 1:
        raise NotImplementedError(
            "Fidelity MF-GP models currently support only a single fidelity parameter!"
        )
    if len(task_features) > 1:
        raise NotImplementedError(
            f"This model only supports 1 task feature (got {task_features})")
    elif len(task_features) == 1:
        task_feature = task_features[0]
    else:
        task_feature = None
    model = None

    # TODO: Better logic for deciding when to use a ModelListGP. Currently the
    # logic is unclear. The two cases in which ModelListGP is used are
    # (i) the training inputs (Xs) are not the same for the different outcomes, and
    # (ii) a multi-task model is used

    if task_feature is None:
        if len(Xs) == 1:
            # Use single output, single task GP
            model = _get_model(
                X=Xs[0],
                Y=Ys[0],
                Yvar=Yvars[0],
                task_feature=task_feature,
                fidelity_features=fidelity_features,
                **kwargs,
            )
        elif all(torch.equal(Xs[0], X) for X in Xs[1:]):
            # Use batched multioutput, single task GP
            Y = torch.cat(Ys, dim=-1)
            Yvar = torch.cat(Yvars, dim=-1)
            model = _get_model(
                X=Xs[0],
                Y=Y,
                Yvar=Yvar,
                task_feature=task_feature,
                fidelity_features=fidelity_features,
                **kwargs,
            )
    # TODO: Is this equivalent an "else:" here?

    if model is None:  # use multi-task GP
        mtgp_rank_dict = kwargs.pop("multitask_gp_ranks", {})
        # assembles list of ranks associated with each metric
        if len({len(Xs), len(Ys), len(Yvars), len(metric_names)}) > 1:
            raise ValueError(
                "Lengths of Xs, Ys, Yvars, and metric_names must match. Your "
                f"inputs have lengths {len(Xs)}, {len(Ys)}, {len(Yvars)}, and "
                f"{len(metric_names)}, respectively.")
        mtgp_rank_list = [
            mtgp_rank_dict.get(metric, None) for metric in metric_names
        ]
        models = [
            _get_model(X=X,
                       Y=Y,
                       Yvar=Yvar,
                       task_feature=task_feature,
                       rank=mtgp_rank,
                       **kwargs)
            for X, Y, Yvar, mtgp_rank in zip(Xs, Ys, Yvars, mtgp_rank_list)
        ]
        model = ModelListGP(*models)
    model.to(Xs[0])
    if state_dict is not None:
        model.load_state_dict(state_dict)
    if state_dict is None or refit_model:
        # TODO: Add bounds for optimization stability - requires revamp upstream
        bounds = {}
        if isinstance(model, ModelListGP):
            mll = SumMarginalLogLikelihood(model.likelihood, model)
        else:
            # pyre-ignore: [16]
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
        mll = fit_gpytorch_model(mll, bounds=bounds)
    return model
Exemplo n.º 12
0
    def test_sample_cached_cholesky(self):
        torch.manual_seed(0)
        tkwargs = {"device": self.device}
        for dtype in (torch.float, torch.double):
            tkwargs["dtype"] = dtype
            train_X = torch.rand(10, 2, **tkwargs)
            train_Y = torch.randn(10, 2, **tkwargs)
            for m in (1, 2):
                model_list_values = (True, False) if m == 2 else (False, )
                for use_model_list in model_list_values:
                    if use_model_list:
                        model = ModelListGP(
                            SingleTaskGP(
                                train_X,
                                train_Y[..., :1],
                            ),
                            SingleTaskGP(
                                train_X,
                                train_Y[..., 1:],
                            ),
                        )
                    else:
                        model = SingleTaskGP(
                            train_X,
                            train_Y[:, :m],
                        )
                    sampler = IIDNormalSampler(3)
                    base_sampler = IIDNormalSampler(3)
                    for q in (1, 3, 9):
                        # test batched baseline_L
                        for train_batch_shape in (
                                torch.Size([]),
                                torch.Size([3]),
                                torch.Size([3, 2]),
                        ):
                            # test batched test points
                            for test_batch_shape in (
                                    torch.Size([]),
                                    torch.Size([4]),
                                    torch.Size([4, 2]),
                            ):

                                if len(train_batch_shape) > 0:
                                    train_X_ex = train_X.unsqueeze(0).expand(
                                        train_batch_shape + train_X.shape)
                                else:
                                    train_X_ex = train_X
                                if len(test_batch_shape) > 0:
                                    test_X = train_X_ex.unsqueeze(0).expand(
                                        test_batch_shape + train_X_ex.shape)
                                else:
                                    test_X = train_X_ex
                                with torch.no_grad():
                                    base_posterior = model.posterior(
                                        train_X_ex[..., :-q, :])
                                    mvn = base_posterior.mvn
                                    lazy_covar = mvn.lazy_covariance_matrix
                                    if m == 2:
                                        lazy_covar = lazy_covar.base_lazy_tensor
                                    baseline_L = lazy_covar.root_decomposition(
                                    )
                                    baseline_L = baseline_L.root.evaluate()
                                test_X = test_X.clone().requires_grad_(True)
                                new_posterior = model.posterior(test_X)
                                samples = sampler(new_posterior)
                                samples[..., -q:, :].sum().backward()
                                test_X2 = test_X.detach().clone(
                                ).requires_grad_(True)
                                new_posterior2 = model.posterior(test_X2)
                                q_samples = sample_cached_cholesky(
                                    posterior=new_posterior2,
                                    baseline_L=baseline_L,
                                    q=q,
                                    base_samples=sampler.base_samples.detach().
                                    clone(),
                                    sample_shape=sampler.sample_shape,
                                )
                                q_samples.sum().backward()
                                all_close_kwargs = ({
                                    "atol": 1e-4,
                                    "rtol": 1e-2,
                                } if dtype == torch.float else {})
                                self.assertTrue(
                                    torch.allclose(
                                        q_samples.detach(),
                                        samples[..., -q:, :].detach(),
                                        **all_close_kwargs,
                                    ))
                                self.assertTrue(
                                    torch.allclose(
                                        test_X2.grad[..., -q:, :],
                                        test_X.grad[..., -q:, :],
                                        **all_close_kwargs,
                                    ))
                                # Test that adding a new point and base_sample
                                # did not change posterior samples for previous points.
                                # This tests that we properly account for not
                                # interleaving.
                                base_sampler.base_samples = (
                                    sampler.base_samples[
                                        ..., :-q, :].detach().clone())

                                baseline_samples = base_sampler(base_posterior)
                                new_batch_shape = samples.shape[
                                    1:-baseline_samples.ndim + 1]
                                expanded_baseline_samples = baseline_samples.view(
                                    baseline_samples.shape[0],
                                    *[1] * len(new_batch_shape),
                                    *baseline_samples.shape[1:],
                                ).expand(
                                    baseline_samples.shape[0],
                                    *new_batch_shape,
                                    *baseline_samples.shape[1:],
                                )
                                self.assertTrue(
                                    torch.allclose(
                                        expanded_baseline_samples,
                                        samples[..., :-q, :],
                                        **all_close_kwargs,
                                    ))
                            # test nans
                            with torch.no_grad():
                                test_posterior = model.posterior(test_X2)
                            test_posterior.mvn.loc = torch.full_like(
                                test_posterior.mvn.loc, float("nan"))
                            with self.assertRaises(NanError):
                                sample_cached_cholesky(
                                    posterior=test_posterior,
                                    baseline_L=baseline_L,
                                    q=q,
                                    base_samples=sampler.base_samples.detach().
                                    clone(),
                                    sample_shape=sampler.sample_shape,
                                )
                            # test infs
                            test_posterior.mvn.loc = torch.full_like(
                                test_posterior.mvn.loc, float("inf"))
                            with self.assertRaises(NanError):
                                sample_cached_cholesky(
                                    posterior=test_posterior,
                                    baseline_L=baseline_L,
                                    q=q,
                                    base_samples=sampler.base_samples.detach().
                                    clone(),
                                    sample_shape=sampler.sample_shape,
                                )
                            # test triangular solve raising RuntimeError
                            test_posterior.mvn.loc = torch.full_like(
                                test_posterior.mvn.loc, 0.0)
                            base_samples = sampler.base_samples.detach().clone(
                            )
                            with mock.patch(
                                    "botorch.utils.low_rank.torch.triangular_solve",
                                    side_effect=RuntimeError("singular"),
                            ):
                                with self.assertRaises(NotPSDError):
                                    sample_cached_cholesky(
                                        posterior=test_posterior,
                                        baseline_L=baseline_L,
                                        q=q,
                                        base_samples=base_samples,
                                        sample_shape=sampler.sample_shape,
                                    )
                            with mock.patch(
                                    "botorch.utils.low_rank.torch.triangular_solve",
                                    side_effect=RuntimeError(""),
                            ):
                                with self.assertRaises(RuntimeError):
                                    sample_cached_cholesky(
                                        posterior=test_posterior,
                                        baseline_L=baseline_L,
                                        q=q,
                                        base_samples=base_samples,
                                        sample_shape=sampler.sample_shape,
                                    )
Exemplo n.º 13
0
    def get_and_fit_model(
        self,
        Xs: List[Tensor],
        Ys: List[Tensor],
        Yvars: List[Tensor],
        task_features: List[int],
        fidelity_features: List[int],
        metric_names: List[str],
        state_dict: Optional[Dict[str, Tensor]] = None,
        fidelity_model_id: Optional[int] = None,
        **kwargs: Any,
    ) -> ModelListGP:
        """Get a fitted multi-task contextual GP model for each outcome.
        Args:
            Xs: List of X data, one tensor per outcome.
            Ys: List of Y data, one tensor per outcome.
            Yvars:List of Noise variance of Yvar data, one tensor per outcome.
            task_features: List of columns of X that are tasks.
        Returns: ModeListGP that each model is a fitted LCEM GP model.
        """

        if len(task_features) == 1:
            task_feature = task_features[0]
        elif len(task_features) > 1:
            raise NotImplementedError(
                f"LCEMBO only supports 1 task feature (got {task_features})")
        else:
            raise ValueError("LCEMBO requires context input as task features")

        models = []
        for i, X in enumerate(Xs):
            # validate input Yvars
            Yvar = Yvars[i].clamp_min_(MIN_OBSERVED_NOISE_LEVEL)
            is_nan = torch.isnan(Yvar)
            all_nan_Yvar = torch.all(is_nan)
            if all_nan_Yvar:
                gp_m = LCEMGP(
                    train_X=X,
                    train_Y=Ys[i],
                    task_feature=task_feature,
                    context_cat_feature=self.context_cat_feature,
                    context_emb_feature=self.context_emb_feature,
                    embs_dim_list=self.embs_dim_list,
                )
            else:
                gp_m = FixedNoiseLCEMGP(
                    train_X=X,
                    train_Y=Ys[i],
                    train_Yvar=Yvar,
                    task_feature=task_feature,
                    context_cat_feature=self.context_cat_feature,
                    context_emb_feature=self.context_emb_feature,
                    embs_dim_list=self.embs_dim_list,
                )
            models.append(gp_m)
        # Use a ModelListGP
        model = ModelListGP(*models)
        model.to(Xs[0])
        mll = SumMarginalLogLikelihood(model.likelihood, model)
        fit_gpytorch_model(mll)
        return model
Exemplo n.º 14
0
    def test_multi_objective_max_value_entropy(self):
        for dtype, m in product((torch.float, torch.double), (2, 3)):
            torch.manual_seed(7)
            # test batched model
            train_X = torch.rand(1, 1, 2, dtype=dtype, device=self.device)
            train_Y = torch.rand(1, 1, m, dtype=dtype, device=self.device)
            model = SingleTaskGP(train_X, train_Y)
            with self.assertRaises(NotImplementedError):
                qMultiObjectiveMaxValueEntropy(model,
                                               dummy_sample_pareto_frontiers)
            # test initialization
            train_X = torch.rand(4, 2, dtype=dtype, device=self.device)
            train_Y = torch.rand(4, m, dtype=dtype, device=self.device)
            # test batched MO model
            model = SingleTaskGP(train_X, train_Y)
            mesmo = qMultiObjectiveMaxValueEntropy(
                model, dummy_sample_pareto_frontiers)
            self.assertEqual(mesmo.num_fantasies, 16)
            self.assertIsInstance(mesmo.sampler, SobolQMCNormalSampler)
            self.assertEqual(mesmo.sampler.sample_shape, torch.Size([512]))
            self.assertIsInstance(mesmo.fantasies_sampler,
                                  SobolQMCNormalSampler)
            self.assertEqual(mesmo.posterior_max_values.shape,
                             torch.Size([3, 1, m]))
            # test conversion to single-output model
            self.assertIs(mesmo.mo_model, model)
            self.assertEqual(mesmo.mo_model.num_outputs, m)
            self.assertIsInstance(mesmo.model, SingleTaskGP)
            self.assertEqual(mesmo.model.num_outputs, 1)
            self.assertEqual(mesmo.model._aug_batch_shape,
                             mesmo.model._input_batch_shape)
            # test ModelListGP
            model = ModelListGP(
                *
                [SingleTaskGP(train_X, train_Y[:, i:i + 1]) for i in range(m)])
            mock_sample_pfs = mock.Mock()
            mock_sample_pfs.return_value = dummy_sample_pareto_frontiers(
                model=model)
            mesmo = qMultiObjectiveMaxValueEntropy(model, mock_sample_pfs)
            self.assertEqual(mesmo.num_fantasies, 16)
            self.assertIsInstance(mesmo.sampler, SobolQMCNormalSampler)
            self.assertEqual(mesmo.sampler.sample_shape, torch.Size([512]))
            self.assertIsInstance(mesmo.fantasies_sampler,
                                  SobolQMCNormalSampler)
            self.assertEqual(mesmo.posterior_max_values.shape,
                             torch.Size([3, 1, m]))
            # test conversion to batched MO model
            self.assertIsInstance(mesmo.mo_model, SingleTaskGP)
            self.assertEqual(mesmo.mo_model.num_outputs, m)
            self.assertIs(mesmo.mo_model, mesmo._init_model)
            # test conversion to single-output model
            self.assertIsInstance(mesmo.model, SingleTaskGP)
            self.assertEqual(mesmo.model.num_outputs, 1)
            self.assertEqual(mesmo.model._aug_batch_shape,
                             mesmo.model._input_batch_shape)
            # test that we call sample_pareto_frontiers with the multi-output model
            mock_sample_pfs.assert_called_once_with(mesmo.mo_model)
            # test basic evaluation
            X = torch.rand(1, 2, device=self.device, dtype=dtype)
            with torch.no_grad():
                vals = mesmo(X)
                igs = qMaxValueEntropy.forward(mesmo, X=X.view(1, 1, 1, 2))
            self.assertEqual(vals.shape, torch.Size([1]))
            self.assertTrue(torch.equal(vals, igs.sum(dim=-1)))

            # test batched evaluation
            X = torch.rand(4, 1, 2, device=self.device, dtype=dtype)
            with torch.no_grad():
                vals = mesmo(X)
                igs = qMaxValueEntropy.forward(mesmo, X=X.view(4, 1, 1, 2))
            self.assertEqual(vals.shape, torch.Size([4]))
            self.assertTrue(torch.equal(vals, igs.sum(dim=-1)))

            # test set X pending to None
            mesmo.set_X_pending(None)
            self.assertIs(mesmo.mo_model, mesmo._init_model)
            fant_X = torch.cat(
                [
                    train_X.expand(16, 4, 2),
                    torch.rand(16, 1, 2),
                ],
                dim=1,
            )
            fant_Y = torch.cat(
                [
                    train_Y.expand(16, 4, m),
                    torch.rand(16, 1, m),
                ],
                dim=1,
            )
            fantasy_model = SingleTaskGP(fant_X, fant_Y)

            # test with X_pending is not None
            with mock.patch.object(
                    SingleTaskGP, "fantasize",
                    return_value=fantasy_model) as mock_fantasize:
                qMultiObjectiveMaxValueEntropy(
                    model,
                    dummy_sample_pareto_frontiers,
                    X_pending=torch.rand(1, 2, device=self.device,
                                         dtype=dtype),
                )
                mock_fantasize.assert_called_once()
Exemplo n.º 15
0
def batched_to_model_list(
        batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
    """Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.

    Args:
        batch_model: The `BatchedMultiOutputGPyTorchModel` to be converted to a
            `ModelListGP`.

    Returns:
        The model converted into a `ModelListGP`.

    Example:
        >>> train_X = torch.rand(5, 2)
        >>> train_Y = torch.rand(5, 2)
        >>> batch_gp = SingleTaskGP(train_X, train_Y)
        >>> list_gp = batched_to_model_list(batch_gp)
    """
    # TODO: Add support for HeteroskedasticSingleTaskGP.
    if isinstance(batch_model, HeteroskedasticSingleTaskGP):
        raise NotImplementedError(
            "Conversion of HeteroskedasticSingleTaskGP is currently not supported."
        )
    if isinstance(batch_model, MixedSingleTaskGP):
        raise NotImplementedError(
            "Conversion of MixedSingleTaskGP is currently not supported.")
    input_transform = getattr(batch_model, "input_transform", None)
    outcome_transform = getattr(batch_model, "outcome_transform", None)
    batch_sd = batch_model.state_dict()

    adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
        batch_state_dict=batch_sd,
        input_transform=input_transform,
        outcome_transform=outcome_transform,
    )
    input_bdims = len(batch_model._input_batch_shape)

    models = []

    for i in range(batch_model._num_outputs):
        non_adjusted_batch_sd = {
            s: batch_sd[s].clone()
            for s in non_adjusted_batch_keys
        }
        adjusted_batch_sd = {
            t: (batch_sd[t].select(input_bdims, i).clone()
                if "active_dims" not in t else batch_sd[t].clone())
            for t in adjusted_batch_keys
        }
        sd = {**non_adjusted_batch_sd, **adjusted_batch_sd}
        kwargs = {
            "train_X":
            batch_model.train_inputs[0].select(input_bdims, i).clone(),
            "train_Y":
            batch_model.train_targets.select(input_bdims,
                                             i).clone().unsqueeze(-1),
        }
        if isinstance(batch_model, FixedNoiseGP):
            noise_covar = batch_model.likelihood.noise_covar
            kwargs["train_Yvar"] = (noise_covar.noise.select(
                input_bdims, i).clone().unsqueeze(-1))
        if isinstance(batch_model, SingleTaskMultiFidelityGP):
            kwargs.update(batch_model._init_args)
        # NOTE: Adding outcome transform to kwargs to avoid the multiple
        # values for same kwarg issue with SingleTaskMultiFidelityGP.
        if outcome_transform is not None:
            octf = outcome_transform.subset_output(idcs=[i])
            kwargs["outcome_transform"] = octf
            # Update the outcome transform state dict entries.
            sd = {
                **sd,
                **{
                    "outcome_transform." + k: v
                    for k, v in octf.state_dict().items()
                },
            }
        else:
            kwargs["outcome_transform"] = None
        model = batch_model.__class__(input_transform=input_transform,
                                      **kwargs)
        model.load_state_dict(sd)
        models.append(model)

    return ModelListGP(*models)
Exemplo n.º 16
0
def get_and_fit_model_mcmc(
    Xs: List[Tensor],
    Ys: List[Tensor],
    Yvars: List[Tensor],
    task_features: List[int],
    fidelity_features: List[int],
    metric_names: List[str],
    state_dict: Optional[Dict[str, Tensor]] = None,
    refit_model: bool = True,
    use_input_warping: bool = False,
    use_loocv_pseudo_likelihood: bool = False,
    num_samples: int = 512,
    warmup_steps: int = 1024,
    thinning: int = 16,
    max_tree_depth: int = 6,
    use_saas: bool = False,
    disable_progbar: bool = False,
    **kwargs: Any,
) -> GPyTorchModel:
    if len(task_features) > 0:
        raise NotImplementedError(
            "Currently do not support MT-GP models with MCMC!")
    if len(fidelity_features) > 0:
        raise NotImplementedError(
            "Fidelity MF-GP models are not currently supported with MCMC!")
    model = None
    # TODO: Better logic for deciding when to use a ModelListGP. Currently the
    # logic is unclear. The two cases in which ModelListGP is used are
    # (i) the training inputs (Xs) are not the same for the different outcomes, and
    # (ii) a multi-task model is used

    num_mcmc_samples = num_samples // thinning
    if len(Xs) == 1:
        # Use single output, single task GP
        model = _get_model(
            X=Xs[0].unsqueeze(0).expand(num_mcmc_samples, Xs[0].shape[0], -1),
            Y=Ys[0].unsqueeze(0).expand(num_mcmc_samples, Xs[0].shape[0], -1),
            Yvar=Yvars[0].unsqueeze(0).expand(num_mcmc_samples, Xs[0].shape[0],
                                              -1),
            fidelity_features=fidelity_features,
            use_input_warping=use_input_warping,
            **kwargs,
        )
    else:
        models = [
            _get_model(
                X=X.unsqueeze(0).expand(num_mcmc_samples, X.shape[0],
                                        -1).clone(),
                Y=Y.unsqueeze(0).expand(num_mcmc_samples, Y.shape[0],
                                        -1).clone(),
                Yvar=Yvar.unsqueeze(0).expand(num_mcmc_samples, Yvar.shape[0],
                                              -1).clone(),
                use_input_warping=use_input_warping,
                **kwargs,
            ) for X, Y, Yvar in zip(Xs, Ys, Yvars)
        ]
        model = ModelListGP(*models)
    model.to(Xs[0])
    if isinstance(model, ModelListGP):
        models = model.models
    else:
        models = [model]
    if state_dict is not None:
        # pyre-fixme[6]: Expected `OrderedDict[typing.Any, typing.Any]` for 1st
        #  param but got `Dict[str, Tensor]`.
        model.load_state_dict(state_dict)
    if state_dict is None or refit_model:
        for X, Y, Yvar, m in zip(Xs, Ys, Yvars, models):
            samples = run_inference(
                pyro_model=pyro_model,  # pyre-ignore [6]
                X=X,
                Y=Y,
                Yvar=Yvar,
                num_samples=num_samples,
                warmup_steps=warmup_steps,
                thinning=thinning,
                use_input_warping=use_input_warping,
                use_saas=use_saas,
                max_tree_depth=max_tree_depth,
                disable_progbar=disable_progbar,
            )
            if "noise" in samples:
                m.likelihood.noise_covar.noise = (
                    samples["noise"].detach().clone().view(
                        m.likelihood.noise_covar.noise.shape).clamp_min(
                            MIN_INFERRED_NOISE_LEVEL))
            m.covar_module.base_kernel.lengthscale = (
                samples["lengthscale"].detach().clone().view(
                    m.covar_module.base_kernel.lengthscale.shape))
            m.covar_module.outputscale = (
                samples["outputscale"].detach().clone().view(
                    m.covar_module.outputscale.shape))
            m.mean_module.constant.data = (
                samples["mean"].detach().clone().view(
                    m.mean_module.constant.shape))
            if "c0" in samples:
                m.input_transform._set_concentration(
                    i=0,
                    value=samples["c0"].detach().clone().view(
                        m.input_transform.concentration0.shape),
                )
                m.input_transform._set_concentration(
                    i=1,
                    value=samples["c1"].detach().clone().view(
                        m.input_transform.concentration1.shape),
                )
    return model
Exemplo n.º 17
0
def get_and_fit_model(
    Xs: List[Tensor],
    Ys: List[Tensor],
    Yvars: List[Tensor],
    task_features: List[int],
    state_dict: Optional[Dict[str, Tensor]] = None,
    **kwargs: Any,
) -> GPyTorchModel:
    r"""Instantiates and fits a botorch ModelListGP using the given data.

    Args:
        Xs: List of X data, one tensor per outcome
        Ys: List of Y data, one tensor per outcome
        Yvars: List of observed variance of Ys.
        task_features: List of columns of X that are tasks.
        state_dict: If provided, will set model parameters to this state
            dictionary. Otherwise, will fit the model.

    Returns:
        A fitted ModelListGP.
    """
    model = None
    if len(task_features) > 1:
        raise ValueError(
            f"This model only supports 1 task feature (got {task_features})")
    elif len(task_features) == 1:
        task_feature = task_features[0]
    else:
        task_feature = None
    if task_feature is None:
        if len(Xs) == 1:
            # Use single output, single task GP
            model = _get_model(X=Xs[0],
                               Y=Ys[0],
                               Yvar=Yvars[0],
                               task_feature=task_feature)
        elif all(torch.equal(Xs[0], X) for X in Xs[1:]):
            # Use batched multioutput, single task GP
            Y = torch.cat(Ys, dim=-1)
            Yvar = torch.cat(Yvars, dim=-1)
            model = _get_model(X=Xs[0],
                               Y=Y,
                               Yvar=Yvar,
                               task_feature=task_feature)
    if model is None:
        # Use model list
        models = [
            _get_model(X=X, Y=Y, Yvar=Yvar, task_feature=task_feature)
            for X, Y, Yvar in zip(Xs, Ys, Yvars)
        ]
        model = ModelListGP(gp_models=models)
    model.to(dtype=Xs[0].dtype, device=Xs[0].device)  # pyre-ignore
    if state_dict is None:
        # TODO: Add bounds for optimization stability - requires revamp upstream
        bounds = {}
        if isinstance(model, ModelListGP):
            mll = SumMarginalLogLikelihood(model.likelihood, model)
        else:
            # pyre-ignore: [16]
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
        mll = fit_gpytorch_model(mll, bounds=bounds)
    else:
        model.load_state_dict(state_dict)
    return model
Exemplo n.º 18
0
def get_and_fit_model(
    Xs: List[Tensor],
    Ys: List[Tensor],
    Yvars: List[Tensor],
    task_features: List[int],
    fidelity_features: List[int],
    state_dict: Optional[Dict[str, Tensor]] = None,
    fidelity_model_id: Optional[int] = None,
    **kwargs: Any,
) -> GPyTorchModel:
    r"""Instantiates and fits a botorch ModelListGP using the given data.

    Args:
        Xs: List of X data, one tensor per outcome
        Ys: List of Y data, one tensor per outcome
        Yvars: List of observed variance of Ys.
        task_features: List of columns of X that are tasks.
        fidelity_features: List of columns of X that are fidelity parameters.
        state_dict: If provided, will set model parameters to this state
            dictionary. Otherwise, will fit the model.
        fidelity_model_id: set this if you want to use GP models from `model_list`
            defined above. The `SingleTaskGPLTKernel` model uses linear truncated
            kernel; the `SingleTaskMultiFidelityGP` model uses exponential decay
            kernel.

    Returns:
        A fitted ModelListGP.
    """
    if fidelity_model_id is not None and len(task_features) > 0:
        raise NotImplementedError(
            "Currently do not support MF-GP models with task_features!")
    if fidelity_model_id is not None and len(fidelity_features) > 1:
        raise UnsupportedError(
            "Fidelity MF-GP models currently support only one fidelity parameter!"
        )
    model = None
    if len(task_features) > 1:
        raise ValueError(
            f"This model only supports 1 task feature (got {task_features})")
    elif len(task_features) == 1:
        task_feature = task_features[0]
    else:
        task_feature = None
    if task_feature is None:
        if len(Xs) == 1:
            # Use single output, single task GP
            model = _get_model(
                X=Xs[0],
                Y=Ys[0],
                Yvar=Yvars[0],
                task_feature=task_feature,
                fidelity_features=fidelity_features,
                fidelity_model_id=fidelity_model_id,
            )
        elif all(torch.equal(Xs[0], X) for X in Xs[1:]):
            # Use batched multioutput, single task GP
            Y = torch.cat(Ys, dim=-1)
            Yvar = torch.cat(Yvars, dim=-1)
            model = _get_model(
                X=Xs[0],
                Y=Y,
                Yvar=Yvar,
                task_feature=task_feature,
                fidelity_features=fidelity_features,
                fidelity_model_id=fidelity_model_id,
            )
    if model is None:
        # Use model list
        models = [
            _get_model(X=X, Y=Y, Yvar=Yvar, task_feature=task_feature)
            for X, Y, Yvar in zip(Xs, Ys, Yvars)
        ]
        model = ModelListGP(*models)
    model.to(dtype=Xs[0].dtype, device=Xs[0].device)  # pyre-ignore
    if state_dict is None:
        # TODO: Add bounds for optimization stability - requires revamp upstream
        bounds = {}
        if isinstance(model, ModelListGP):
            mll = SumMarginalLogLikelihood(model.likelihood, model)
        else:
            # pyre-ignore: [16]
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
        mll = fit_gpytorch_model(mll, bounds=bounds)
    else:
        model.load_state_dict(state_dict)
    return model