def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     W_shape = shape[:-1] + self.cov_factor.shape[-1:]
     eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
     eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
     return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
             + self._unbroadcasted_cov_diag.sqrt() * eps_D)
Exemplo n.º 2
0
 def rsample(self, sample_shape):
     L = self.cholesky
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.mu.dtype,
                            device=self.mu.device)
     return self.mu + _batch_mv(L, eps)
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     eps_W = self.loc.new_empty(shape[:-1] +
                                (self.cov_factor.size(-1), )).normal_()
     eps_D = self.loc.new_empty(shape).normal_()
     return self.loc + _batch_mv(self.cov_factor,
                                 eps_W) + self.cov_diag.sqrt() * eps_D
Exemplo n.º 4
0
 def rsample(self, sample_shape):
     """ Eigenvalue decomposition can also be used for sampling
     https://stats.stackexchange.com/a/179275/79569
     """
     L = self.cholesky
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.mu.dtype,
                            device=self.mu.device)
     return self.mu + _batch_mv(L, eps)
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    """
    Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2
 def rsample(self, sample_shape=torch.Size()):
     if not isinstance(sample_shape, torch.Size):
         sample_shape = torch.Size(sample_shape)
     shape = sample_shape + self._batch_shape + self._event_shape
     W_shape = shape[:-1] + self.cov_factor.shape[-1:]
     eps_W = _standard_normal(W_shape,
                              dtype=self.loc.dtype,
                              device=self.loc.device)
     eps_D = _standard_normal(shape,
                              dtype=self.loc.dtype,
                              device=self.loc.device)
     return self.loc + _batch_mv(self.cov_factor,
                                 eps_W) + self.cov_diag.sqrt() * eps_D
Exemplo n.º 7
0
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape,
                               dtype=self.loc.dtype,
                               device=self.loc.device)

        r_inv = self.gamma.rsample(sample_shape=sample_shape)
        scale = ((self.df - 2) / r_inv).sqrt()
        # We want 1 gamma for every `event` only. The size of self.df and this
        # `.view` provide that
        scale = scale.view(scale.size() +
                           torch.Size([1] * len(self._event_shape)))

        return self.loc + scale * _batch_mv(self._unbroadcasted_scale_tril,
                                            eps)
Exemplo n.º 8
0
def deterministic_sample_mvnorm(distribution: MultivariateNormal,
                                eps: Optional[Tensor] = None) -> Tensor:
    if isinstance(eps, Tensor):
        if eps.shape[-len(distribution.event_shape
                          ):] != distribution.event_shape:
            raise RuntimeError(
                f"Expected shape ending in {distribution.event_shape}, got {eps.shape}."
            )
    else:
        shape = distribution.batch_shape + distribution.event_shape
        if eps is None:
            eps = 1.0
        eps *= _standard_normal(shape,
                                dtype=distribution.loc.dtype,
                                device=distribution.loc.device)
    return distribution.loc + _batch_mv(distribution._unbroadcasted_scale_tril,
                                        eps)
Exemplo n.º 9
0
 def deterministic_sample(self, eps=None) -> Tensor:
     expected_shape = self._batch_shape + self._event_shape
     if self.univariate:
         if eps is None:
             eps = self.loc.new(*expected_shape).normal_()
         else:
             assert eps.size(
             ) == expected_shape, f"expected-shape:{expected_shape}, actual:{eps.size()}"
         std = torch.sqrt(torch.squeeze(self.covariance_matrix, -1))
         return std * eps + self.loc
     else:
         if eps is None:
             eps = _standard_normal(expected_shape,
                                    dtype=self.loc.dtype,
                                    device=self.loc.device)
         else:
             assert eps.size(
             ) == expected_shape, f"expected-shape:{expected_shape}, actual:{eps.size()}"
         return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)