def forward(self, input):
     batch_shape = _mul_broadcast_shape(self.batch_shape, input.shape[:-2])
     mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2),
                                               input.size(-1) +
                                               1).contiguous()
     mean[..., 1:] = 0
     return mean
 def batch_shape(self):
     kernels = list(self.sub_kernels())
     if len(kernels):
         return _mul_broadcast_shape(self._batch_shape,
                                     *[k.batch_shape for k in kernels])
     else:
         return self._batch_shape
コード例 #3
0
    def __add__(self, other):
        from .diag_lazy_tensor import DiagLazyTensor
        from .added_diag_lazy_tensor import AddedDiagLazyTensor

        if isinstance(other, ZeroLazyTensor):
            return self
        elif isinstance(other, DiagLazyTensor):
            return AddedDiagLazyTensor(self, other)
        elif isinstance(other, SumLazyTensor):
            return SumLazyTensor(*(list(self.lazy_tensors) +
                                   list(other.lazy_tensors)))
        elif isinstance(other, LazyTensor):
            return SumLazyTensor(*(list(self.lazy_tensors) + [other]))
        elif isinstance(other, Tensor):
            # get broadcast shape, assuming mul broadcasting the same as add broadcasting
            broadcasted_shape = _mul_broadcast_shape(self.shape, other.shape)

            # lazify + broadcast other
            broadcasted_other = lazify(other.expand(broadcasted_shape))

            # update the lazy tensors' shape as well
            new_self = self if broadcasted_shape == self.shape else self._expand_batch(
                broadcasted_shape[:-2])

            return SumLazyTensor(*(list(new_self.lazy_tensors) +
                                   [broadcasted_other]))
        else:
            raise AttributeError("other must be a LazyTensor")
    def add_diag(self, added_diag: Tensor) -> "TriangularLazyTensor":
        from .added_diag_lazy_tensor import AddedDiagLazyTensor

        shape = _mul_broadcast_shape(self._diag.shape, added_diag.shape)
        added_diag_lt = AddedDiagLazyTensor(self._tensor.expand(shape),
                                            added_diag.expand(shape))
        return TriangularLazyTensor(added_diag_lt, upper=self.upper)
コード例 #5
0
    def forward(self, i1, i2, **params):
        covar_matrix = self._eval_covar_matrix()
        batch_shape = _mul_broadcast_shape(i1.shape[:-2], self.batch_shape)
        index_shape = batch_shape + i1.shape[-2:]

        res = InterpolatedLazyTensor(
            base_lazy_tensor=covar_matrix,
            left_interp_indices=i1.expand(index_shape),
            right_interp_indices=i2.expand(index_shape),
        )
        return res
    def __init__(self, *lazy_tensors, preconditioner_override=None):
        lazy_tensors = list(lazy_tensors)
        super(AddedDiagLazyTensor,
              self).__init__(*lazy_tensors,
                             preconditioner_override=preconditioner_override)
        if len(lazy_tensors) > 2:
            raise RuntimeError(
                "An AddedDiagLazyTensor can only have two components")

        broadcasting._mul_broadcast_shape(lazy_tensors[0].shape,
                                          lazy_tensors[1].shape)

        if isinstance(lazy_tensors[0], DiagLazyTensor) and isinstance(
                lazy_tensors[1], DiagLazyTensor):
            raise RuntimeError(
                "Trying to lazily add two DiagLazyTensors. Create a single DiagLazyTensor instead."
            )
        elif isinstance(lazy_tensors[0], DiagLazyTensor):
            self._diag_tensor = lazy_tensors[0]
            self._lazy_tensor = lazy_tensors[1]
        elif isinstance(lazy_tensors[1], DiagLazyTensor):
            self._diag_tensor = lazy_tensors[1]
            self._lazy_tensor = lazy_tensors[0]
        else:
            raise RuntimeError(
                "One of the LazyTensors input to AddedDiagLazyTensor must be a DiagLazyTensor!"
            )

        self.preconditioner_override = preconditioner_override

        # Placeholders
        self._constant_diag = None
        self._noise = None
        self._piv_chol_self = None  # <- Doesn't need to be an attribute, but used for testing purposes
        self._precond_lt = None
        self._precond_logdet_cache = None
        self._q_cache = None
        self._r_cache = None
    def __init__(self, left_lazy_tensor, right_lazy_tensor):
        left_lazy_tensor = lazify(left_lazy_tensor)
        right_lazy_tensor = lazify(right_lazy_tensor)

        # Match batch dimensions
        batch_shape = _mul_broadcast_shape(left_lazy_tensor.batch_shape,
                                           right_lazy_tensor.batch_shape)
        if left_lazy_tensor.batch_shape != batch_shape:
            left_lazy_tensor = left_lazy_tensor._expand_batch(batch_shape)
        if right_lazy_tensor.batch_shape != batch_shape:
            right_lazy_tensor = right_lazy_tensor._expand_batch(batch_shape)

        super().__init__(left_lazy_tensor, right_lazy_tensor)
        batch_shape = _mul_broadcast_shape(left_lazy_tensor.batch_shape,
                                           right_lazy_tensor.batch_shape)
        if left_lazy_tensor.batch_shape != batch_shape:
            self.left_lazy_tensor = left_lazy_tensor._expand_batch(batch_shape)
        else:
            self.left_lazy_tensor = left_lazy_tensor
        if right_lazy_tensor.batch_shape != batch_shape:
            self.right_lazy_tensor = right_lazy_tensor._expand_batch(
                batch_shape)
        else:
            self.right_lazy_tensor = right_lazy_tensor
コード例 #8
0
    def __init__(self, *lazy_tensors, **kwargs):
        try:
            lazy_tensors = tuple(lazify(lt) for lt in lazy_tensors)
        except TypeError:
            raise TypeError(
                "All arguments of a SumLazyTensor should be LazyTensors or Tensors"
            )
        batch_shape = _mul_broadcast_shape(
            *[lt.batch_shape for lt in lazy_tensors])
        lazy_tensors = tuple(
            lt._expand_batch(batch_shape
                             ) if lt.batch_shape != batch_shape else lt
            for lt in lazy_tensors)
        super(SumLazyTensor, self).__init__(*lazy_tensors, **kwargs)

        self.lazy_tensors = lazy_tensors
コード例 #9
0
    def _size(self):
        if settings.debug.on():
            if hasattr(self.kernel, "size"):
                raise RuntimeError(
                    "Kernels must define `num_outputs_per_input` and should not define `size`"
                )

        x1 = self.x1
        x2 = self.x2
        num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2)
        num_rows = x1.size(-2) * num_outputs_per_input
        num_cols = x2.size(-2) * num_outputs_per_input

        # Default case - when we're not using broadcasting
        # We write this case special for efficiency
        if x1.shape[:
                    -2] == x2.shape[:
                                    -2] and x1.shape[:
                                                     -2] == self.kernel.batch_shape:
            expected_size = self.kernel.batch_shape + torch.Size(
                (num_rows, num_cols))

        # When we're using broadcasting
        else:
            expected_size = broadcasting._matmul_broadcast_shape(
                torch.Size([*x1.shape[:-2], num_rows,
                            x1.size(-1)]),
                torch.Size([*x2.shape[:-2],
                            x2.size(-1), num_cols]),
                error_msg=
                "x1 and x2 were not broadcastable to a proper kernel shape. "
                "Got x1.shape = {} and x2.shape = {}".format(
                    str(x1.shape), str(x2.shape)),
            )
            expected_size = (broadcasting._mul_broadcast_shape(
                expected_size[:-2],
                self.kernel.batch_shape,
                error_msg=
                (f"x1 and x2 were not broadcastable with kernel of batch_shape {self.kernel.batch_shape}. "
                 f"Got x1.shape = {x1.shape} and x2.shape = {x2.shape}"),
            ) + expected_size[-2:])

        # Handle when the last dim is batch
        if self.last_dim_is_batch:
            expected_size = expected_size[:-2] + x1.shape[-1:] + expected_size[
                -2:]
        return expected_size
    def __call__(self, x, prior=False):
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x)

        # Delete previously cached items from the training distribution
        if self.training:
            if hasattr(self, "_memoize_cache"):
                delattr(self, "_memoize_cache")
                self._memoize_cache = dict()
        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(
                prior_dist)
            self.variational_params_initialized.fill_(1)

        # Ensure inducing_points and x are the same size
        inducing_points = self.inducing_points
        if inducing_points.shape[:-2] != x.shape[:-2]:
            batch_shape = _mul_broadcast_shape(inducing_points.shape[:-2],
                                               x.shape[:-2])
            inducing_points = inducing_points.expand(
                *batch_shape, *inducing_points.shape[-2:])
            x = x.expand(*batch_shape, *x.shape[-2:])

        # Get p(u)/q(u)
        variational_dist_u = self.variational_distribution

        # Get q(f)
        if isinstance(variational_dist_u, MultivariateNormal):
            return super().__call__(
                x,
                inducing_points,
                inducing_values=variational_dist_u.mean,
                variational_inducing_covar=variational_dist_u.
                lazy_covariance_matrix,
            )
        elif isinstance(variational_dist_u, Delta):
            return super().__call__(x,
                                    inducing_points,
                                    inducing_values=variational_dist_u.mean,
                                    variational_inducing_covar=None)
        else:
            raise RuntimeError(
                f"Invalid variational distribuition ({type(variational_dist_u)}). "
                "Expected a multivariate normal or a delta distribution.")
コード例 #11
0
    def _get_indices(self, row_index, col_index, *batch_indices):
        indices = [*batch_indices, row_index, col_index]
        target_shape = _mul_broadcast_shape(
            *[index.shape for index in indices])
        indices = [index.expand(target_shape).reshape(-1) for index in indices]
        cat_dim_indices = indices[self.cat_dim]

        # Find out for which indices we switch to different tensors
        target_tensors = self.idx_to_tensor_idx[cat_dim_indices]
        does_switch_tensor = torch.ones(target_tensors.numel() + 1,
                                        dtype=bool_compat,
                                        device=self.device)
        torch.ne(target_tensors[:-1],
                 target_tensors[1:],
                 out=does_switch_tensor[1:-1])

        # Get the LazyTensors that will comprise the new LazyTensor
        lazy_tensor_indices = target_tensors[does_switch_tensor[:-1]].tolist()
        lazy_tensors = [self.lazy_tensors[idx] for idx in lazy_tensor_indices]

        # Get the new set of indices for each of the LazyTensors
        switch_tensor = does_switch_tensor.nonzero(as_tuple=False).squeeze(-1)
        split_sizes = (switch_tensor[1:] - switch_tensor[:-1]).tolist()
        sub_indices = zip(*[
            list(index.split(split_sizes)) if torch.
            is_tensor(index) else [index] * len(split_sizes)
            for index in indices
        ])
        # Make everything a list
        sub_indices = [list(sub_index) for sub_index in sub_indices]

        # Make sure that we have adjusted the start and ends of the indices that correspond to the cat dim
        for lazy_tensor_idx, sub_index in zip(lazy_tensor_indices,
                                              sub_indices):
            sub_index[self.cat_dim] = sub_index[
                self.cat_dim] - self.cat_dim_cum_sizes[lazy_tensor_idx]

        res_list = [
            lazy_tensor._get_indices(sub_index[-2], sub_index[-1],
                                     *sub_index[:-2])
            for lazy_tensor, sub_index in zip(lazy_tensors, sub_indices)
        ]
        if len(res_list) == 1:
            return res_list[0].view(target_shape).to(self.device)
        else:
            return torch.cat(res_list).view(target_shape).to(self.device)
コード例 #12
0
 def _inv_matmul(self, right_tensor, left_tensor=None):
     # Computes inv_matmul by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1
     tsr_shapes = [q.size(-1) for q in self.lazy_tensors]
     n_rows = right_tensor.size(-2)
     batch_shape = _mul_broadcast_shape(self.shape[:-2],
                                        right_tensor.shape[:-2])
     perm_batch = tuple(range(len(batch_shape)))
     y = right_tensor.clone().expand(*batch_shape, *right_tensor.shape[-2:])
     for n, q in zip(tsr_shapes, self.lazy_tensors):
         # for KroneckerProductTriangularLazyTensor this inv_matmul is very cheap
         y = q.inv_matmul(y.reshape(*batch_shape, n, -1))
         y = y.reshape(*batch_shape, n, n_rows // n,
                       -1).permute(*perm_batch, -2, -3, -1)
     res = y.reshape(*batch_shape, n_rows, -1)
     if left_tensor is not None:
         res = left_tensor @ res
     return res
コード例 #13
0
    def _compute_grid(self, inputs):
        n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
        batch_shape = inputs.shape[:-2]

        inputs = inputs.reshape(-1, n_dimensions)
        interp_indices, interp_values = Interpolation().interpolate(
            self.grid, inputs)
        interp_indices = interp_indices.view(*batch_shape, n_data, -1)
        interp_values = interp_values.view(*batch_shape, n_data, -1)

        if (interp_indices.dim() - 2) != len(
                self._variational_distribution.batch_shape):
            batch_shape = _mul_broadcast_shape(
                interp_indices.shape[:-2],
                self._variational_distribution.batch_shape)
            interp_indices = interp_indices.expand(*batch_shape,
                                                   *interp_indices.shape[-2:])
            interp_values = interp_values.expand(*batch_shape,
                                                 *interp_values.shape[-2:])
        return interp_indices, interp_values
    def forward(self,
                *params: Any,
                shape: Optional[torch.Size] = None,
                **kwargs: Any) -> DiagLazyTensor:
        """In the homoskedastic case, the parameters are only used to infer the required shape.
        Here are the possible scenarios:
        - non-batched noise, non-batched input, non-MT -> noise_diag shape is `n`
        - non-batched noise, non-batched input, MT -> noise_diag shape is `nt`
        - non-batched noise, batched input, non-MT -> noise_diag shape is `b x n` with b' the broadcasted batch shape
        - non-batched noise, batched input, MT -> noise_diag shape is `b x nt`
        - batched noise, non-batched input, non-MT -> noise_diag shape is `b x n`
        - batched noise, non-batched input, MT -> noise_diag shape is `b x nt`
        - batched noise, batched input, non-MT -> noise_diag shape is `b' x n`
        - batched noise, batched input, MT -> noise_diag shape is `b' x nt`
        where `n` is the number of evaluation points and `t` is the number of tasks (i.e. `num_tasks` of self.noise).
        So bascially the shape is always `b' x nt`, with `b'` appropriately broadcast from the noise parameter and
        input batch shapes. `n` and the input batch shape are determined either from the shape arg or from the params
        input. For this it is sufficient to take in a single `shape` arg, with the convention that shape[:-1] is the
        batch shape of the input, and shape[-1] is `n`.

        If a "noise" kwarg (a Tensor) is provided, this noise is used directly.
        """
        if "noise" in kwargs:
            return DiagLazyTensor(kwargs.get("noise"))
        if shape is None:
            p = params[0] if torch.is_tensor(params[0]) else params[0][0]
            shape = p.shape if len(p.shape) == 1 else p.shape[:-1]
        noise = self.noise
        *batch_shape, n = shape
        noise_batch_shape = noise.shape[:-1] if noise.dim(
        ) > 1 else torch.Size()
        num_tasks = noise.shape[-1]
        batch_shape = _mul_broadcast_shape(noise_batch_shape, batch_shape)
        noise = noise.unsqueeze(-2)
        noise_diag = noise.expand(*batch_shape, n, num_tasks).contiguous()
        if num_tasks == 1:
            noise_diag = noise_diag.view(*batch_shape, n)
        return DiagLazyTensor(noise_diag)
 def forward(self, input):
     if input.shape[:-2] == self.batch_shape:
         return self.constant.expand(input.shape[:-1])
     else:
         return self.constant.expand(
             _mul_broadcast_shape(input.shape[:-1], self.constant.shape))
コード例 #16
0
 def _size(self):
     return _mul_broadcast_shape(*[lt.shape for lt in self.lazy_tensors])
コード例 #17
0
    def forward(self,
                x,
                inducing_points,
                inducing_values,
                variational_inducing_covar=None):
        # If our points equal the inducing points, we're done
        if torch.equal(x, inducing_points):
            if variational_inducing_covar is None:
                raise RuntimeError
            else:
                return MultivariateNormal(inducing_values,
                                          variational_inducing_covar)

        # Otherwise, we have to marginalize
        num_induc = inducing_points.size(-2)
        full_inputs = torch.cat([inducing_points, x], dim=-2)
        full_output = self.model.forward(full_inputs)
        full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix

        # Mean terms
        test_mean = full_mean[..., num_induc:]
        induc_mean = full_mean[..., :num_induc]
        mean_diff = (inducing_values - induc_mean).unsqueeze(-1)

        # Covariance terms
        induc_induc_covar = full_covar[
            ..., :num_induc, :num_induc].add_jitter()
        induc_data_covar = full_covar[..., :num_induc, num_induc:].evaluate()
        data_data_covar = full_covar[..., num_induc:, num_induc:]

        # If we're less than a certain size, we'll compute the Cholesky decomposition of induc_induc_covar
        cholesky = False
        if settings.fast_computations.log_prob.off() or (
                num_induc <= settings.max_cholesky_size.value()):
            induc_induc_covar = CholLazyTensor(
                self._cholesky_factor(induc_induc_covar))
            cholesky = True

        # If we are making predictions and don't need variances, we can do things very quickly.
        if not self.training and settings.skip_posterior_variances.on():
            if not hasattr(self, "_mean_cache"):
                # For now: run variational inference without a preconditioner
                # The preconditioner screws things up for some reason
                with settings.max_preconditioner_size(0):
                    self._mean_cache = induc_induc_covar.inv_matmul(
                        mean_diff).detach()
            predictive_mean = torch.add(
                test_mean,
                induc_data_covar.transpose(-2, -1).matmul(
                    self._mean_cache).squeeze(-1))
            predictive_covar = ZeroLazyTensor(test_mean.size(-1),
                                              test_mean.size(-1))
            return MultivariateNormal(predictive_mean, predictive_covar)

        # Expand everything to the right size
        shapes = [
            mean_diff.shape[:-1], induc_data_covar.shape[:-1],
            induc_induc_covar.shape[:-1]
        ]
        if variational_inducing_covar is not None:
            root_variational_covar = variational_inducing_covar.root_decomposition(
            ).root.evaluate()
            shapes.append(root_variational_covar.shape[:-1])
        shape = _mul_broadcast_shape(*shapes)
        mean_diff = mean_diff.expand(*shape, mean_diff.size(-1))
        induc_data_covar = induc_data_covar.expand(*shape,
                                                   induc_data_covar.size(-1))
        induc_induc_covar = induc_induc_covar.expand(
            *shape, induc_induc_covar.size(-1))
        if variational_inducing_covar is not None:
            root_variational_covar = root_variational_covar.expand(
                *shape, root_variational_covar.size(-1))

        # Cache the CG results
        # For now: run variational inference without a preconditioner
        # The preconditioner screws things up for some reason
        with settings.max_preconditioner_size(0):
            # Cache the CG results
            if variational_inducing_covar is None:
                left_tensors = mean_diff
            else:
                left_tensors = torch.cat([mean_diff, root_variational_covar],
                                         -1)

            with torch.no_grad():
                eager_rhs = torch.cat([left_tensors, induc_data_covar], -1)
                solve, probe_vecs, probe_vec_norms, probe_vec_solves, tmats = CachedCGLazyTensor.precompute_terms(
                    induc_induc_covar,
                    eager_rhs.detach(),
                    logdet_terms=(not cholesky),
                    include_tmats=(not settings.skip_logdet_forward.on()
                                   and not cholesky),
                )
                eager_rhss = [
                    eager_rhs.detach(),
                    eager_rhs[..., left_tensors.size(-1):].detach(),
                    eager_rhs[..., :left_tensors.size(-1)].detach(),
                ]
                solves = [
                    solve.detach(),
                    solve[..., left_tensors.size(-1):].detach(),
                    solve[..., :left_tensors.size(-1)].detach(),
                ]
                if settings.skip_logdet_forward.on():
                    eager_rhss.append(torch.cat([probe_vecs, left_tensors],
                                                -1))
                    solves.append(
                        torch.cat([
                            probe_vec_solves,
                            solve[..., :left_tensors.size(-1)]
                        ], -1))
            induc_induc_covar = CachedCGLazyTensor(
                induc_induc_covar,
                eager_rhss=eager_rhss,
                solves=solves,
                probe_vectors=probe_vecs,
                probe_vector_norms=probe_vec_norms,
                probe_vector_solves=probe_vec_solves,
                probe_vector_tmats=tmats,
            )

        # Cache the kernel matrix with the cached CG calls
        if self.training:
            self._memoize_cache[
                "prior_distribution_memo"] = MultivariateNormal(
                    induc_mean, induc_induc_covar)

        # Compute predictive mean
        inv_products = induc_induc_covar.inv_matmul(
            induc_data_covar, left_tensors.transpose(-1, -2))
        predictive_mean = torch.add(test_mean, inv_products[..., 0, :])

        # Compute covariance
        if self.training:
            interp_data_data_var, _ = induc_induc_covar.inv_quad_logdet(
                induc_data_covar, logdet=False, reduce_inv_quad=False)
            data_covariance = DiagLazyTensor(
                (data_data_covar.diag() - interp_data_data_var).clamp(
                    0, math.inf))
        else:
            neg_induc_data_data_covar = torch.matmul(
                induc_data_covar.transpose(-1, -2).mul(-1),
                induc_induc_covar.inv_matmul(induc_data_covar))
            data_covariance = data_data_covar + neg_induc_data_data_covar
        predictive_covar = PsdSumLazyTensor(
            RootLazyTensor(inv_products[..., 1:, :].transpose(-1, -2)),
            data_covariance)

        # Done!
        return MultivariateNormal(predictive_mean, predictive_covar)
    def get_fantasy_model(self, inputs, targets, **kwargs):
        """
        Returns a new GP model that incorporates the specified inputs and targets as new training data.

        Using this method is more efficient than updating with `set_train_data` when the number of inputs is relatively
        small, because any computed test-time caches will be updated in linear time rather than computed from scratch.

        .. note::
            If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP.
            If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
            are the same for each target batch.

        :param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
            observations.
        :param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
        :return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
            and all test-time caches have been updated.
        :rtype: ~gpytorch.models.ExactGP
        """
        if self.prediction_strategy is None:
            raise RuntimeError(
                "Fantasy observations can only be added after making predictions with a model so that "
                "all test independent caches exist. Call the model on some data first!"
            )

        model_batch_shape = self.train_inputs[0].shape[:-2]

        if self.train_targets.dim() > len(model_batch_shape) + 1:
            raise RuntimeError(
                "Cannot yet add fantasy observations to multitask GPs, but this is coming soon!"
            )

        if not isinstance(inputs, list):
            inputs = [inputs]

        inputs = [
            i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs
        ]

        target_batch_shape = targets.shape[:-1]
        input_batch_shape = inputs[0].shape[:-2]
        tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)

        if not (tbdim == ibdim + 1 or tbdim == ibdim):
            raise RuntimeError(
                f"Unsupported batch shapes: The target batch shape ({target_batch_shape}) must have either the "
                f"same dimension as or one more dimension than the input batch shape ({input_batch_shape})"
            )

        # Check whether we can properly broadcast batch dimensions
        err_msg = (
            f"Model batch shape ({model_batch_shape}) and target batch shape "
            f"({target_batch_shape}) are not broadcastable.")
        _mul_broadcast_shape(model_batch_shape,
                             target_batch_shape,
                             error_msg=err_msg)

        if len(model_batch_shape) > len(input_batch_shape):
            input_batch_shape = model_batch_shape
        if len(model_batch_shape) > len(target_batch_shape):
            target_batch_shape = model_batch_shape

        # If input has no fantasy batch dimension but target does, we can save memory and computation by not
        # computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
        # size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
        train_inputs = [
            tin.expand(input_batch_shape + tin.shape[-2:])
            for tin in self.train_inputs
        ]
        train_targets = self.train_targets.expand(
            target_batch_shape + self.train_targets.shape[-1:])

        full_inputs = [
            torch.cat([
                train_input,
                input.expand(input_batch_shape + input.shape[-2:])
            ],
                      dim=-2)
            for train_input, input in zip(train_inputs, inputs)
        ]
        full_targets = torch.cat([
            train_targets,
            targets.expand(target_batch_shape + targets.shape[-1:])
        ],
                                 dim=-1)

        try:
            fantasy_kwargs = {"noise": kwargs.pop("noise")}
        except KeyError:
            fantasy_kwargs = {}

        full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)

        # Copy model without copying training data or prediction strategy (since we'll overwrite those)
        old_pred_strat = self.prediction_strategy
        old_train_inputs = self.train_inputs
        old_train_targets = self.train_targets
        old_likelihood = self.likelihood
        self.prediction_strategy = None
        self.train_inputs = None
        self.train_targets = None
        self.likelihood = None
        new_model = deepcopy(self)
        self.prediction_strategy = old_pred_strat
        self.train_inputs = old_train_inputs
        self.train_targets = old_train_targets
        self.likelihood = old_likelihood

        new_model.likelihood = old_likelihood.get_fantasy_likelihood(
            **fantasy_kwargs)
        new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
            inputs, targets, full_inputs, full_targets, full_output,
            **fantasy_kwargs)

        # if the fantasies are at the same points, we need to expand the inputs for the new model
        if tbdim == ibdim + 1:
            new_model.train_inputs = [
                fi.expand(target_batch_shape + fi.shape[-2:])
                for fi in full_inputs
            ]
        else:
            new_model.train_inputs = full_inputs
        new_model.train_targets = full_targets

        return new_model
 def mul(self, other):
     shape = _mul_broadcast_shape(self.shape, other.shape)
     return self.__class__(*shape, dtype=self._dtype, device=self._device)
    def __call__(self, *args, **kwargs):
        train_inputs = list(
            self.train_inputs) if self.train_inputs is not None else []
        inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]

        # Training mode: optimizing
        if self.training:
            if self.train_inputs is None:
                raise RuntimeError(
                    "train_inputs, train_targets cannot be None in training mode. "
                    "Call .eval() for prior predictions, or call .set_train_data() to add training data."
                )
            if settings.debug.on():
                if not all(
                        torch.equal(train_input, input)
                        for train_input, input in zip(train_inputs, inputs)):
                    raise RuntimeError(
                        "You must train on the training inputs!")
            res = super().__call__(*inputs, **kwargs)
            return res

        # Prior mode
        elif settings.prior_mode.on(
        ) or self.train_inputs is None or self.train_targets is None:
            full_inputs = args
            full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
            if settings.debug().on():
                if not isinstance(full_output, MultivariateNormal):
                    raise RuntimeError(
                        "ExactGP.forward must return a MultivariateNormal")
            return full_output

        # Posterior mode
        else:
            if settings.debug.on():
                if all(
                        torch.equal(train_input, input)
                        for train_input, input in zip(train_inputs, inputs)):
                    warnings.warn(
                        "The input matches the stored training data. Did you forget to call model.train()?",
                        GPInputWarning,
                    )

            # Get the terms that only depend on training data
            if self.prediction_strategy is None:
                train_output = super().__call__(*train_inputs, **kwargs)

                # Create the prediction strategy for
                self.prediction_strategy = prediction_strategy(
                    train_inputs=train_inputs,
                    train_prior_dist=train_output,
                    train_labels=self.train_targets,
                    likelihood=self.likelihood,
                )

            # Concatenate the input to the training input
            full_inputs = []
            batch_shape = train_inputs[0].shape[:-2]
            for train_input, input in zip(train_inputs, inputs):
                # Make sure the batch shapes agree for training/test data
                if batch_shape != train_input.shape[:-2]:
                    batch_shape = _mul_broadcast_shape(batch_shape,
                                                       train_input.shape[:-2])
                    train_input = train_input.expand(*batch_shape,
                                                     *train_input.shape[-2:])
                if batch_shape != input.shape[:-2]:
                    batch_shape = _mul_broadcast_shape(batch_shape,
                                                       input.shape[:-2])
                    train_input = train_input.expand(*batch_shape,
                                                     *train_input.shape[-2:])
                    input = input.expand(*batch_shape, *input.shape[-2:])
                full_inputs.append(torch.cat([train_input, input], dim=-2))

            # Get the joint distribution for training/test data
            full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
            if settings.debug().on():
                if not isinstance(full_output, MultivariateNormal):
                    raise RuntimeError(
                        "ExactGP.forward must return a MultivariateNormal")
            full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix

            # Determine the shape of the joint distribution
            batch_shape = full_output.batch_shape
            joint_shape = full_output.event_shape
            tasks_shape = joint_shape[1:]  # For multitask learning
            test_shape = torch.Size([
                joint_shape[0] - self.prediction_strategy.train_shape[0],
                *tasks_shape
            ])

            # Make the prediction
            with settings._use_eval_tolerance():
                predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(
                    full_mean, full_covar)

            # Reshape predictive mean to match the appropriate event shape
            predictive_mean = predictive_mean.view(*batch_shape,
                                                   *test_shape).contiguous()
            return full_output.__class__(predictive_mean, predictive_covar)
    def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
        # See if we need to update the grid or not
        if self.grid_is_dynamic:  # This is true if a grid_bounds wasn't passed in
            if torch.equal(x1, x2):
                x = x1.reshape(-1, self.num_dims)
            else:
                x = torch.cat([
                    x1.reshape(-1, self.num_dims),
                    x2.reshape(-1, self.num_dims)
                ])
            x_maxs = x.max(0)[0].tolist()
            x_mins = x.min(0)[0].tolist()

            # We need to update the grid if
            # 1) it hasn't ever been initialized, or
            # 2) if any of the grid points are "out of bounds"
            update_grid = (not self.has_initialized_grid.item()) or any(
                x_min < bound[0] or x_max > bound[1]
                for x_min, x_max, bound in zip(x_mins, x_maxs,
                                               self._tight_grid_bounds))

            # Update the grid if needed
            if update_grid:
                grid_spacings = tuple((x_max - x_min) / (gs - 4.02)
                                      for gs, x_min, x_max in zip(
                                          self.grid_sizes, x_mins, x_maxs))
                self.grid_bounds = tuple(
                    (x_min - 2.01 * spacing, x_max + 2.01 * spacing)
                    for x_min, x_max, spacing in zip(x_mins, x_maxs,
                                                     grid_spacings))
                grid = create_grid(
                    self.grid_sizes,
                    self.grid_bounds,
                    dtype=self.grid[0].dtype,
                    device=self.grid[0].device,
                )
                self.update_grid(grid)

        base_lazy_tsr = lazify(
            self._inducing_forward(last_dim_is_batch=last_dim_is_batch,
                                   **params))
        if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
            base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1),
                                                 1, 1)

        left_interp_indices, left_interp_values = self._compute_grid(
            x1, last_dim_is_batch)
        if torch.equal(x1, x2):
            right_interp_indices = left_interp_indices
            right_interp_values = left_interp_values
        else:
            right_interp_indices, right_interp_values = self._compute_grid(
                x2, last_dim_is_batch)

        batch_shape = _mul_broadcast_shape(
            base_lazy_tsr.batch_shape,
            left_interp_indices.shape[:-2],
            right_interp_indices.shape[:-2],
        )
        res = InterpolatedLazyTensor(
            base_lazy_tsr.expand(*batch_shape, *base_lazy_tsr.matrix_shape),
            left_interp_indices.detach().expand(
                *batch_shape, *left_interp_indices.shape[-2:]),
            left_interp_values.expand(*batch_shape,
                                      *left_interp_values.shape[-2:]),
            right_interp_indices.detach().expand(
                *batch_shape, *right_interp_indices.shape[-2:]),
            right_interp_values.expand(*batch_shape,
                                       *right_interp_values.shape[-2:]),
        )

        if diag:
            return res.diag()
        else:
            return res
コード例 #22
0
 def add_diag(self, added_diag):
     shape = _mul_broadcast_shape(self._diag.shape, added_diag.shape)
     return DiagLazyTensor(
         self._diag.expand(shape) + added_diag.expand(shape))