def forward(self, input): batch_shape = _mul_broadcast_shape(self.batch_shape, input.shape[:-2]) mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), input.size(-1) + 1).contiguous() mean[..., 1:] = 0 return mean
def batch_shape(self): kernels = list(self.sub_kernels()) if len(kernels): return _mul_broadcast_shape(self._batch_shape, *[k.batch_shape for k in kernels]) else: return self._batch_shape
def __add__(self, other): from .diag_lazy_tensor import DiagLazyTensor from .added_diag_lazy_tensor import AddedDiagLazyTensor if isinstance(other, ZeroLazyTensor): return self elif isinstance(other, DiagLazyTensor): return AddedDiagLazyTensor(self, other) elif isinstance(other, SumLazyTensor): return SumLazyTensor(*(list(self.lazy_tensors) + list(other.lazy_tensors))) elif isinstance(other, LazyTensor): return SumLazyTensor(*(list(self.lazy_tensors) + [other])) elif isinstance(other, Tensor): # get broadcast shape, assuming mul broadcasting the same as add broadcasting broadcasted_shape = _mul_broadcast_shape(self.shape, other.shape) # lazify + broadcast other broadcasted_other = lazify(other.expand(broadcasted_shape)) # update the lazy tensors' shape as well new_self = self if broadcasted_shape == self.shape else self._expand_batch( broadcasted_shape[:-2]) return SumLazyTensor(*(list(new_self.lazy_tensors) + [broadcasted_other])) else: raise AttributeError("other must be a LazyTensor")
def add_diag(self, added_diag: Tensor) -> "TriangularLazyTensor": from .added_diag_lazy_tensor import AddedDiagLazyTensor shape = _mul_broadcast_shape(self._diag.shape, added_diag.shape) added_diag_lt = AddedDiagLazyTensor(self._tensor.expand(shape), added_diag.expand(shape)) return TriangularLazyTensor(added_diag_lt, upper=self.upper)
def forward(self, i1, i2, **params): covar_matrix = self._eval_covar_matrix() batch_shape = _mul_broadcast_shape(i1.shape[:-2], self.batch_shape) index_shape = batch_shape + i1.shape[-2:] res = InterpolatedLazyTensor( base_lazy_tensor=covar_matrix, left_interp_indices=i1.expand(index_shape), right_interp_indices=i2.expand(index_shape), ) return res
def __init__(self, *lazy_tensors, preconditioner_override=None): lazy_tensors = list(lazy_tensors) super(AddedDiagLazyTensor, self).__init__(*lazy_tensors, preconditioner_override=preconditioner_override) if len(lazy_tensors) > 2: raise RuntimeError( "An AddedDiagLazyTensor can only have two components") broadcasting._mul_broadcast_shape(lazy_tensors[0].shape, lazy_tensors[1].shape) if isinstance(lazy_tensors[0], DiagLazyTensor) and isinstance( lazy_tensors[1], DiagLazyTensor): raise RuntimeError( "Trying to lazily add two DiagLazyTensors. Create a single DiagLazyTensor instead." ) elif isinstance(lazy_tensors[0], DiagLazyTensor): self._diag_tensor = lazy_tensors[0] self._lazy_tensor = lazy_tensors[1] elif isinstance(lazy_tensors[1], DiagLazyTensor): self._diag_tensor = lazy_tensors[1] self._lazy_tensor = lazy_tensors[0] else: raise RuntimeError( "One of the LazyTensors input to AddedDiagLazyTensor must be a DiagLazyTensor!" ) self.preconditioner_override = preconditioner_override # Placeholders self._constant_diag = None self._noise = None self._piv_chol_self = None # <- Doesn't need to be an attribute, but used for testing purposes self._precond_lt = None self._precond_logdet_cache = None self._q_cache = None self._r_cache = None
def __init__(self, left_lazy_tensor, right_lazy_tensor): left_lazy_tensor = lazify(left_lazy_tensor) right_lazy_tensor = lazify(right_lazy_tensor) # Match batch dimensions batch_shape = _mul_broadcast_shape(left_lazy_tensor.batch_shape, right_lazy_tensor.batch_shape) if left_lazy_tensor.batch_shape != batch_shape: left_lazy_tensor = left_lazy_tensor._expand_batch(batch_shape) if right_lazy_tensor.batch_shape != batch_shape: right_lazy_tensor = right_lazy_tensor._expand_batch(batch_shape) super().__init__(left_lazy_tensor, right_lazy_tensor) batch_shape = _mul_broadcast_shape(left_lazy_tensor.batch_shape, right_lazy_tensor.batch_shape) if left_lazy_tensor.batch_shape != batch_shape: self.left_lazy_tensor = left_lazy_tensor._expand_batch(batch_shape) else: self.left_lazy_tensor = left_lazy_tensor if right_lazy_tensor.batch_shape != batch_shape: self.right_lazy_tensor = right_lazy_tensor._expand_batch( batch_shape) else: self.right_lazy_tensor = right_lazy_tensor
def __init__(self, *lazy_tensors, **kwargs): try: lazy_tensors = tuple(lazify(lt) for lt in lazy_tensors) except TypeError: raise TypeError( "All arguments of a SumLazyTensor should be LazyTensors or Tensors" ) batch_shape = _mul_broadcast_shape( *[lt.batch_shape for lt in lazy_tensors]) lazy_tensors = tuple( lt._expand_batch(batch_shape ) if lt.batch_shape != batch_shape else lt for lt in lazy_tensors) super(SumLazyTensor, self).__init__(*lazy_tensors, **kwargs) self.lazy_tensors = lazy_tensors
def _size(self): if settings.debug.on(): if hasattr(self.kernel, "size"): raise RuntimeError( "Kernels must define `num_outputs_per_input` and should not define `size`" ) x1 = self.x1 x2 = self.x2 num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2) num_rows = x1.size(-2) * num_outputs_per_input num_cols = x2.size(-2) * num_outputs_per_input # Default case - when we're not using broadcasting # We write this case special for efficiency if x1.shape[: -2] == x2.shape[: -2] and x1.shape[: -2] == self.kernel.batch_shape: expected_size = self.kernel.batch_shape + torch.Size( (num_rows, num_cols)) # When we're using broadcasting else: expected_size = broadcasting._matmul_broadcast_shape( torch.Size([*x1.shape[:-2], num_rows, x1.size(-1)]), torch.Size([*x2.shape[:-2], x2.size(-1), num_cols]), error_msg= "x1 and x2 were not broadcastable to a proper kernel shape. " "Got x1.shape = {} and x2.shape = {}".format( str(x1.shape), str(x2.shape)), ) expected_size = (broadcasting._mul_broadcast_shape( expected_size[:-2], self.kernel.batch_shape, error_msg= (f"x1 and x2 were not broadcastable with kernel of batch_shape {self.kernel.batch_shape}. " f"Got x1.shape = {x1.shape} and x2.shape = {x2.shape}"), ) + expected_size[-2:]) # Handle when the last dim is batch if self.last_dim_is_batch: expected_size = expected_size[:-2] + x1.shape[-1:] + expected_size[ -2:] return expected_size
def __call__(self, x, prior=False): # If we're in prior mode, then we're done! if prior: return self.model.forward(x) # Delete previously cached items from the training distribution if self.training: if hasattr(self, "_memoize_cache"): delattr(self, "_memoize_cache") self._memoize_cache = dict() # (Maybe) initialize variational distribution if not self.variational_params_initialized.item(): prior_dist = self.prior_distribution self._variational_distribution.initialize_variational_distribution( prior_dist) self.variational_params_initialized.fill_(1) # Ensure inducing_points and x are the same size inducing_points = self.inducing_points if inducing_points.shape[:-2] != x.shape[:-2]: batch_shape = _mul_broadcast_shape(inducing_points.shape[:-2], x.shape[:-2]) inducing_points = inducing_points.expand( *batch_shape, *inducing_points.shape[-2:]) x = x.expand(*batch_shape, *x.shape[-2:]) # Get p(u)/q(u) variational_dist_u = self.variational_distribution # Get q(f) if isinstance(variational_dist_u, MultivariateNormal): return super().__call__( x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=variational_dist_u. lazy_covariance_matrix, ) elif isinstance(variational_dist_u, Delta): return super().__call__(x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None) else: raise RuntimeError( f"Invalid variational distribuition ({type(variational_dist_u)}). " "Expected a multivariate normal or a delta distribution.")
def _get_indices(self, row_index, col_index, *batch_indices): indices = [*batch_indices, row_index, col_index] target_shape = _mul_broadcast_shape( *[index.shape for index in indices]) indices = [index.expand(target_shape).reshape(-1) for index in indices] cat_dim_indices = indices[self.cat_dim] # Find out for which indices we switch to different tensors target_tensors = self.idx_to_tensor_idx[cat_dim_indices] does_switch_tensor = torch.ones(target_tensors.numel() + 1, dtype=bool_compat, device=self.device) torch.ne(target_tensors[:-1], target_tensors[1:], out=does_switch_tensor[1:-1]) # Get the LazyTensors that will comprise the new LazyTensor lazy_tensor_indices = target_tensors[does_switch_tensor[:-1]].tolist() lazy_tensors = [self.lazy_tensors[idx] for idx in lazy_tensor_indices] # Get the new set of indices for each of the LazyTensors switch_tensor = does_switch_tensor.nonzero(as_tuple=False).squeeze(-1) split_sizes = (switch_tensor[1:] - switch_tensor[:-1]).tolist() sub_indices = zip(*[ list(index.split(split_sizes)) if torch. is_tensor(index) else [index] * len(split_sizes) for index in indices ]) # Make everything a list sub_indices = [list(sub_index) for sub_index in sub_indices] # Make sure that we have adjusted the start and ends of the indices that correspond to the cat dim for lazy_tensor_idx, sub_index in zip(lazy_tensor_indices, sub_indices): sub_index[self.cat_dim] = sub_index[ self.cat_dim] - self.cat_dim_cum_sizes[lazy_tensor_idx] res_list = [ lazy_tensor._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2]) for lazy_tensor, sub_index in zip(lazy_tensors, sub_indices) ] if len(res_list) == 1: return res_list[0].view(target_shape).to(self.device) else: return torch.cat(res_list).view(target_shape).to(self.device)
def _inv_matmul(self, right_tensor, left_tensor=None): # Computes inv_matmul by exploiting the identity (A \kron B)^-1 = A^-1 \kron B^-1 tsr_shapes = [q.size(-1) for q in self.lazy_tensors] n_rows = right_tensor.size(-2) batch_shape = _mul_broadcast_shape(self.shape[:-2], right_tensor.shape[:-2]) perm_batch = tuple(range(len(batch_shape))) y = right_tensor.clone().expand(*batch_shape, *right_tensor.shape[-2:]) for n, q in zip(tsr_shapes, self.lazy_tensors): # for KroneckerProductTriangularLazyTensor this inv_matmul is very cheap y = q.inv_matmul(y.reshape(*batch_shape, n, -1)) y = y.reshape(*batch_shape, n, n_rows // n, -1).permute(*perm_batch, -2, -3, -1) res = y.reshape(*batch_shape, n_rows, -1) if left_tensor is not None: res = left_tensor @ res return res
def _compute_grid(self, inputs): n_data, n_dimensions = inputs.size(-2), inputs.size(-1) batch_shape = inputs.shape[:-2] inputs = inputs.reshape(-1, n_dimensions) interp_indices, interp_values = Interpolation().interpolate( self.grid, inputs) interp_indices = interp_indices.view(*batch_shape, n_data, -1) interp_values = interp_values.view(*batch_shape, n_data, -1) if (interp_indices.dim() - 2) != len( self._variational_distribution.batch_shape): batch_shape = _mul_broadcast_shape( interp_indices.shape[:-2], self._variational_distribution.batch_shape) interp_indices = interp_indices.expand(*batch_shape, *interp_indices.shape[-2:]) interp_values = interp_values.expand(*batch_shape, *interp_values.shape[-2:]) return interp_indices, interp_values
def forward(self, *params: Any, shape: Optional[torch.Size] = None, **kwargs: Any) -> DiagLazyTensor: """In the homoskedastic case, the parameters are only used to infer the required shape. Here are the possible scenarios: - non-batched noise, non-batched input, non-MT -> noise_diag shape is `n` - non-batched noise, non-batched input, MT -> noise_diag shape is `nt` - non-batched noise, batched input, non-MT -> noise_diag shape is `b x n` with b' the broadcasted batch shape - non-batched noise, batched input, MT -> noise_diag shape is `b x nt` - batched noise, non-batched input, non-MT -> noise_diag shape is `b x n` - batched noise, non-batched input, MT -> noise_diag shape is `b x nt` - batched noise, batched input, non-MT -> noise_diag shape is `b' x n` - batched noise, batched input, MT -> noise_diag shape is `b' x nt` where `n` is the number of evaluation points and `t` is the number of tasks (i.e. `num_tasks` of self.noise). So bascially the shape is always `b' x nt`, with `b'` appropriately broadcast from the noise parameter and input batch shapes. `n` and the input batch shape are determined either from the shape arg or from the params input. For this it is sufficient to take in a single `shape` arg, with the convention that shape[:-1] is the batch shape of the input, and shape[-1] is `n`. If a "noise" kwarg (a Tensor) is provided, this noise is used directly. """ if "noise" in kwargs: return DiagLazyTensor(kwargs.get("noise")) if shape is None: p = params[0] if torch.is_tensor(params[0]) else params[0][0] shape = p.shape if len(p.shape) == 1 else p.shape[:-1] noise = self.noise *batch_shape, n = shape noise_batch_shape = noise.shape[:-1] if noise.dim( ) > 1 else torch.Size() num_tasks = noise.shape[-1] batch_shape = _mul_broadcast_shape(noise_batch_shape, batch_shape) noise = noise.unsqueeze(-2) noise_diag = noise.expand(*batch_shape, n, num_tasks).contiguous() if num_tasks == 1: noise_diag = noise_diag.view(*batch_shape, n) return DiagLazyTensor(noise_diag)
def forward(self, input): if input.shape[:-2] == self.batch_shape: return self.constant.expand(input.shape[:-1]) else: return self.constant.expand( _mul_broadcast_shape(input.shape[:-1], self.constant.shape))
def _size(self): return _mul_broadcast_shape(*[lt.shape for lt in self.lazy_tensors])
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None): # If our points equal the inducing points, we're done if torch.equal(x, inducing_points): if variational_inducing_covar is None: raise RuntimeError else: return MultivariateNormal(inducing_values, variational_inducing_covar) # Otherwise, we have to marginalize num_induc = inducing_points.size(-2) full_inputs = torch.cat([inducing_points, x], dim=-2) full_output = self.model.forward(full_inputs) full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix # Mean terms test_mean = full_mean[..., num_induc:] induc_mean = full_mean[..., :num_induc] mean_diff = (inducing_values - induc_mean).unsqueeze(-1) # Covariance terms induc_induc_covar = full_covar[ ..., :num_induc, :num_induc].add_jitter() induc_data_covar = full_covar[..., :num_induc, num_induc:].evaluate() data_data_covar = full_covar[..., num_induc:, num_induc:] # If we're less than a certain size, we'll compute the Cholesky decomposition of induc_induc_covar cholesky = False if settings.fast_computations.log_prob.off() or ( num_induc <= settings.max_cholesky_size.value()): induc_induc_covar = CholLazyTensor( self._cholesky_factor(induc_induc_covar)) cholesky = True # If we are making predictions and don't need variances, we can do things very quickly. if not self.training and settings.skip_posterior_variances.on(): if not hasattr(self, "_mean_cache"): # For now: run variational inference without a preconditioner # The preconditioner screws things up for some reason with settings.max_preconditioner_size(0): self._mean_cache = induc_induc_covar.inv_matmul( mean_diff).detach() predictive_mean = torch.add( test_mean, induc_data_covar.transpose(-2, -1).matmul( self._mean_cache).squeeze(-1)) predictive_covar = ZeroLazyTensor(test_mean.size(-1), test_mean.size(-1)) return MultivariateNormal(predictive_mean, predictive_covar) # Expand everything to the right size shapes = [ mean_diff.shape[:-1], induc_data_covar.shape[:-1], induc_induc_covar.shape[:-1] ] if variational_inducing_covar is not None: root_variational_covar = variational_inducing_covar.root_decomposition( ).root.evaluate() shapes.append(root_variational_covar.shape[:-1]) shape = _mul_broadcast_shape(*shapes) mean_diff = mean_diff.expand(*shape, mean_diff.size(-1)) induc_data_covar = induc_data_covar.expand(*shape, induc_data_covar.size(-1)) induc_induc_covar = induc_induc_covar.expand( *shape, induc_induc_covar.size(-1)) if variational_inducing_covar is not None: root_variational_covar = root_variational_covar.expand( *shape, root_variational_covar.size(-1)) # Cache the CG results # For now: run variational inference without a preconditioner # The preconditioner screws things up for some reason with settings.max_preconditioner_size(0): # Cache the CG results if variational_inducing_covar is None: left_tensors = mean_diff else: left_tensors = torch.cat([mean_diff, root_variational_covar], -1) with torch.no_grad(): eager_rhs = torch.cat([left_tensors, induc_data_covar], -1) solve, probe_vecs, probe_vec_norms, probe_vec_solves, tmats = CachedCGLazyTensor.precompute_terms( induc_induc_covar, eager_rhs.detach(), logdet_terms=(not cholesky), include_tmats=(not settings.skip_logdet_forward.on() and not cholesky), ) eager_rhss = [ eager_rhs.detach(), eager_rhs[..., left_tensors.size(-1):].detach(), eager_rhs[..., :left_tensors.size(-1)].detach(), ] solves = [ solve.detach(), solve[..., left_tensors.size(-1):].detach(), solve[..., :left_tensors.size(-1)].detach(), ] if settings.skip_logdet_forward.on(): eager_rhss.append(torch.cat([probe_vecs, left_tensors], -1)) solves.append( torch.cat([ probe_vec_solves, solve[..., :left_tensors.size(-1)] ], -1)) induc_induc_covar = CachedCGLazyTensor( induc_induc_covar, eager_rhss=eager_rhss, solves=solves, probe_vectors=probe_vecs, probe_vector_norms=probe_vec_norms, probe_vector_solves=probe_vec_solves, probe_vector_tmats=tmats, ) # Cache the kernel matrix with the cached CG calls if self.training: self._memoize_cache[ "prior_distribution_memo"] = MultivariateNormal( induc_mean, induc_induc_covar) # Compute predictive mean inv_products = induc_induc_covar.inv_matmul( induc_data_covar, left_tensors.transpose(-1, -2)) predictive_mean = torch.add(test_mean, inv_products[..., 0, :]) # Compute covariance if self.training: interp_data_data_var, _ = induc_induc_covar.inv_quad_logdet( induc_data_covar, logdet=False, reduce_inv_quad=False) data_covariance = DiagLazyTensor( (data_data_covar.diag() - interp_data_data_var).clamp( 0, math.inf)) else: neg_induc_data_data_covar = torch.matmul( induc_data_covar.transpose(-1, -2).mul(-1), induc_induc_covar.inv_matmul(induc_data_covar)) data_covariance = data_data_covar + neg_induc_data_data_covar predictive_covar = PsdSumLazyTensor( RootLazyTensor(inv_products[..., 1:, :].transpose(-1, -2)), data_covariance) # Done! return MultivariateNormal(predictive_mean, predictive_covar)
def get_fantasy_model(self, inputs, targets, **kwargs): """ Returns a new GP model that incorporates the specified inputs and targets as new training data. Using this method is more efficient than updating with `set_train_data` when the number of inputs is relatively small, because any computed test-time caches will be updated in linear time rather than computed from scratch. .. note:: If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP. If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points are the same for each target batch. :param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy observations. :param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations. :return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added and all test-time caches have been updated. :rtype: ~gpytorch.models.ExactGP """ if self.prediction_strategy is None: raise RuntimeError( "Fantasy observations can only be added after making predictions with a model so that " "all test independent caches exist. Call the model on some data first!" ) model_batch_shape = self.train_inputs[0].shape[:-2] if self.train_targets.dim() > len(model_batch_shape) + 1: raise RuntimeError( "Cannot yet add fantasy observations to multitask GPs, but this is coming soon!" ) if not isinstance(inputs, list): inputs = [inputs] inputs = [ i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs ] target_batch_shape = targets.shape[:-1] input_batch_shape = inputs[0].shape[:-2] tbdim, ibdim = len(target_batch_shape), len(input_batch_shape) if not (tbdim == ibdim + 1 or tbdim == ibdim): raise RuntimeError( f"Unsupported batch shapes: The target batch shape ({target_batch_shape}) must have either the " f"same dimension as or one more dimension than the input batch shape ({input_batch_shape})" ) # Check whether we can properly broadcast batch dimensions err_msg = ( f"Model batch shape ({model_batch_shape}) and target batch shape " f"({target_batch_shape}) are not broadcastable.") _mul_broadcast_shape(model_batch_shape, target_batch_shape, error_msg=err_msg) if len(model_batch_shape) > len(input_batch_shape): input_batch_shape = model_batch_shape if len(model_batch_shape) > len(target_batch_shape): target_batch_shape = model_batch_shape # If input has no fantasy batch dimension but target does, we can save memory and computation by not # computing the covariance for each element of the batch. Therefore we don't expand the inputs to the # size of the fantasy model here - this is done below, after the evaluation and fast fantasy update train_inputs = [ tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs ] train_targets = self.train_targets.expand( target_batch_shape + self.train_targets.shape[-1:]) full_inputs = [ torch.cat([ train_input, input.expand(input_batch_shape + input.shape[-2:]) ], dim=-2) for train_input, input in zip(train_inputs, inputs) ] full_targets = torch.cat([ train_targets, targets.expand(target_batch_shape + targets.shape[-1:]) ], dim=-1) try: fantasy_kwargs = {"noise": kwargs.pop("noise")} except KeyError: fantasy_kwargs = {} full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs) # Copy model without copying training data or prediction strategy (since we'll overwrite those) old_pred_strat = self.prediction_strategy old_train_inputs = self.train_inputs old_train_targets = self.train_targets old_likelihood = self.likelihood self.prediction_strategy = None self.train_inputs = None self.train_targets = None self.likelihood = None new_model = deepcopy(self) self.prediction_strategy = old_pred_strat self.train_inputs = old_train_inputs self.train_targets = old_train_targets self.likelihood = old_likelihood new_model.likelihood = old_likelihood.get_fantasy_likelihood( **fantasy_kwargs) new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy( inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs) # if the fantasies are at the same points, we need to expand the inputs for the new model if tbdim == ibdim + 1: new_model.train_inputs = [ fi.expand(target_batch_shape + fi.shape[-2:]) for fi in full_inputs ] else: new_model.train_inputs = full_inputs new_model.train_targets = full_targets return new_model
def mul(self, other): shape = _mul_broadcast_shape(self.shape, other.shape) return self.__class__(*shape, dtype=self._dtype, device=self._device)
def __call__(self, *args, **kwargs): train_inputs = list( self.train_inputs) if self.train_inputs is not None else [] inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args] # Training mode: optimizing if self.training: if self.train_inputs is None: raise RuntimeError( "train_inputs, train_targets cannot be None in training mode. " "Call .eval() for prior predictions, or call .set_train_data() to add training data." ) if settings.debug.on(): if not all( torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)): raise RuntimeError( "You must train on the training inputs!") res = super().__call__(*inputs, **kwargs) return res # Prior mode elif settings.prior_mode.on( ) or self.train_inputs is None or self.train_targets is None: full_inputs = args full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs) if settings.debug().on(): if not isinstance(full_output, MultivariateNormal): raise RuntimeError( "ExactGP.forward must return a MultivariateNormal") return full_output # Posterior mode else: if settings.debug.on(): if all( torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)): warnings.warn( "The input matches the stored training data. Did you forget to call model.train()?", GPInputWarning, ) # Get the terms that only depend on training data if self.prediction_strategy is None: train_output = super().__call__(*train_inputs, **kwargs) # Create the prediction strategy for self.prediction_strategy = prediction_strategy( train_inputs=train_inputs, train_prior_dist=train_output, train_labels=self.train_targets, likelihood=self.likelihood, ) # Concatenate the input to the training input full_inputs = [] batch_shape = train_inputs[0].shape[:-2] for train_input, input in zip(train_inputs, inputs): # Make sure the batch shapes agree for training/test data if batch_shape != train_input.shape[:-2]: batch_shape = _mul_broadcast_shape(batch_shape, train_input.shape[:-2]) train_input = train_input.expand(*batch_shape, *train_input.shape[-2:]) if batch_shape != input.shape[:-2]: batch_shape = _mul_broadcast_shape(batch_shape, input.shape[:-2]) train_input = train_input.expand(*batch_shape, *train_input.shape[-2:]) input = input.expand(*batch_shape, *input.shape[-2:]) full_inputs.append(torch.cat([train_input, input], dim=-2)) # Get the joint distribution for training/test data full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs) if settings.debug().on(): if not isinstance(full_output, MultivariateNormal): raise RuntimeError( "ExactGP.forward must return a MultivariateNormal") full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix # Determine the shape of the joint distribution batch_shape = full_output.batch_shape joint_shape = full_output.event_shape tasks_shape = joint_shape[1:] # For multitask learning test_shape = torch.Size([ joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape ]) # Make the prediction with settings._use_eval_tolerance(): predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction( full_mean, full_covar) # Reshape predictive mean to match the appropriate event shape predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous() return full_output.__class__(predictive_mean, predictive_covar)
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): # See if we need to update the grid or not if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in if torch.equal(x1, x2): x = x1.reshape(-1, self.num_dims) else: x = torch.cat([ x1.reshape(-1, self.num_dims), x2.reshape(-1, self.num_dims) ]) x_maxs = x.max(0)[0].tolist() x_mins = x.min(0)[0].tolist() # We need to update the grid if # 1) it hasn't ever been initialized, or # 2) if any of the grid points are "out of bounds" update_grid = (not self.has_initialized_grid.item()) or any( x_min < bound[0] or x_max > bound[1] for x_min, x_max, bound in zip(x_mins, x_maxs, self._tight_grid_bounds)) # Update the grid if needed if update_grid: grid_spacings = tuple((x_max - x_min) / (gs - 4.02) for gs, x_min, x_max in zip( self.grid_sizes, x_mins, x_maxs)) self.grid_bounds = tuple( (x_min - 2.01 * spacing, x_max + 2.01 * spacing) for x_min, x_max, spacing in zip(x_mins, x_maxs, grid_spacings)) grid = create_grid( self.grid_sizes, self.grid_bounds, dtype=self.grid[0].dtype, device=self.grid[0].device, ) self.update_grid(grid) base_lazy_tsr = lazify( self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params)) if last_dim_is_batch and base_lazy_tsr.size(-3) == 1: base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1) left_interp_indices, left_interp_values = self._compute_grid( x1, last_dim_is_batch) if torch.equal(x1, x2): right_interp_indices = left_interp_indices right_interp_values = left_interp_values else: right_interp_indices, right_interp_values = self._compute_grid( x2, last_dim_is_batch) batch_shape = _mul_broadcast_shape( base_lazy_tsr.batch_shape, left_interp_indices.shape[:-2], right_interp_indices.shape[:-2], ) res = InterpolatedLazyTensor( base_lazy_tsr.expand(*batch_shape, *base_lazy_tsr.matrix_shape), left_interp_indices.detach().expand( *batch_shape, *left_interp_indices.shape[-2:]), left_interp_values.expand(*batch_shape, *left_interp_values.shape[-2:]), right_interp_indices.detach().expand( *batch_shape, *right_interp_indices.shape[-2:]), right_interp_values.expand(*batch_shape, *right_interp_values.shape[-2:]), ) if diag: return res.diag() else: return res
def add_diag(self, added_diag): shape = _mul_broadcast_shape(self._diag.shape, added_diag.shape) return DiagLazyTensor( self._diag.expand(shape) + added_diag.expand(shape))