def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     W_shape = shape[:-1] + self.cov_factor.shape[-1:]
     eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
     eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
     return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
             + self._unbroadcasted_cov_diag.sqrt() * eps_D)
 def rsample(self, sample_shape=torch.Size()):
     if not isinstance(sample_shape, torch.Size):
         sample_shape = torch.Size(sample_shape)
     shape = sample_shape + self._batch_shape + self._event_shape
     W_shape = shape[:-1] + self.cov_factor.shape[-1:]
     eps_W = _standard_normal(W_shape,
                              dtype=self.loc.dtype,
                              device=self.loc.device)
     eps_D = _standard_normal(shape,
                              dtype=self.loc.dtype,
                              device=self.loc.device)
     return self.loc + _batch_mv(self.cov_factor,
                                 eps_W) + self.cov_diag.sqrt() * eps_D
Ejemplo n.º 3
0
 def rsample(self, sample_shape):
     L = self.cholesky
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.mu.dtype,
                            device=self.mu.device)
     return self.mu + _batch_mv(L, eps)
Ejemplo n.º 4
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     sigma = torch.log1p(torch.exp(self.rho))
     eps = _standard_normal(shape,
                            dtype=self.mu.dtype,
                            device=self.mu.device)
     return self.mu + sigma * eps
Ejemplo n.º 5
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     v = self.scale * _standard_normal(
         shape, dtype=self.loc.dtype, device=self.loc.device)
     r = v.norm(dim=-1, keepdim=True)
     res = exp_map_x_polar(self.loc.expand(shape), r, v, self.c)
     return res
Ejemplo n.º 6
0
    def rsample(self, sample_shape=torch.Size()):
        shape = torch.Size([*sample_shape, self._dim + 1])
        output = _standard_normal(shape,
                                  dtype=torch.float,
                                  device=self._device)

        return output / output.norm(dim=-1, keepdim=True)
Ejemplo n.º 7
0
 def get_base_samples(self, sample_shape=torch.Size()):
     """Get i.i.d. standard Normal samples (to be used with rsample(base_samples=base_samples))"""
     with torch.no_grad():
         shape = self._extended_shape(sample_shape)
         base_samples = _standard_normal(shape,
                                         dtype=self.loc.dtype,
                                         device=self.loc.device)
     return base_samples
Ejemplo n.º 8
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     dir_var = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
     dir_var /= dir_var.norm(p=2, dim=-1, keepdim=True)
     norms = torch._standard_gamma(
         self.k * torch.ones(size=shape[:-1], dtype=self.loc.dtype, device=self.loc.device)
     ).reshape((-1, 1)) * self.scale
     return self.loc + norms * dir_var
Ejemplo n.º 9
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     v = self.scale * _standard_normal(
         shape, dtype=self.loc.dtype, device=self.loc.device)
     self.manifold.assert_check_vector_on_tangent(self.manifold.zero, v)
     v = v / self.manifold.lambda_x(self.manifold.zero, keepdim=True)
     u = self.manifold.transp(self.manifold.zero, self.loc, v)
     z = self.manifold.expmap(self.loc, u)
     return z
Ejemplo n.º 10
0
 def rsample(self, sample_shape=torch.Size()):
     #   X ~ Normal(0, I)
     #   Z ~ Chi2(df)
     #   Y = X / sqrt(Z / df) ~ MultivariateStudentT(df)
     shape = self._extended_shape(sample_shape)
     X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
     Z = self._chi2.rsample(sample_shape)
     Y = X * torch.rsqrt(Z / self.df).unsqueeze(-1)
     return self.loc + matvec(self.scale_tril, Y)
Ejemplo n.º 11
0
 def rsample(self, sample_size=torch.Size()):
     v = self.Sigma.mul(
         _standard_normal(sample_size,
                          dtype=self.mu.dtype,
                          device=self.mu.device))
     v = v.div(self.manifold.lambda_x(self.zeros, keepdim=True))
     u = self.manifold.transp(self.zeros, self.mu, v)
     z = self.manifold.expmap(self.mu, u)
     return z
Ejemplo n.º 12
0
 def rsample(self, sample_shape):
     """ Eigenvalue decomposition can also be used for sampling
     https://stats.stackexchange.com/a/179275/79569
     """
     L = self.cholesky
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.mu.dtype,
                            device=self.mu.device)
     return self.mu + _batch_mv(L, eps)
Ejemplo n.º 13
0
 def rsample(self, sample_shape=None):
     if not sample_shape:
         sample_shape = torch.Size()
     eps = _standard_normal((sample_shape[0], 2),
                            dtype=torch.float,
                            device=torch.device("cpu"))
     z = torch.zeros(eps.shape)
     z[..., 1] = torch.tensor(3.0) * eps[..., 1]
     z[..., 0] = torch.exp(z[..., 1] / 2.0) * eps[..., 0]
     return z
Ejemplo n.º 14
0
    def rsample(self, sample_shape=torch.Size()):
        # NOTE: This does not agree with scipy implementation as much as other distributions.
        # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
        # parameters seems to help.

        #   X ~ Normal(0, 1)
        #   Z ~ Chi2(df)
        #   Y = X / sqrt(Z / df) ~ StudentT(df)
        shape = self._extended_shape(sample_shape)
        X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
        Z = self._chi2.rsample(sample_shape)
        Y = X * torch.rsqrt(Z / self.df)
        return self.loc + self.scale * Y
Ejemplo n.º 15
0
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape,
                               dtype=self.loc.dtype,
                               device=self.loc.device)

        r_inv = self.gamma.rsample(sample_shape=sample_shape)
        scale = ((self.df - 2) / r_inv).sqrt()
        # We want 1 gamma for every `event` only. The size of self.df and this
        # `.view` provide that
        scale = scale.view(scale.size() +
                           torch.Size([1] * len(self._event_shape)))

        return self.loc + scale * _batch_mv(self._unbroadcasted_scale_tril,
                                            eps)
    def rsample(self, sample_shape):
        '''
		Copied from torch.distributions.normal
		stores eps for later access
		'''
        if True:
            shape = self._extended_shape(sample_shape)
            self.eps = _standard_normal(shape,
                                        dtype=self.loc.dtype,
                                        device=self.loc.device)
            samples = self.loc + self.eps * F.softplus(self.logscale)
            # print(f"{self.loc.shape=} {self.eps.shape=} {samples.shape=}")
        else:
            samples = self.dist().rsample(sample_shape)
            # print(f"{self.loc.shape=} {samples.shape=}")
        return samples
Ejemplo n.º 17
0
def deterministic_sample_mvnorm(distribution: MultivariateNormal,
                                eps: Optional[Tensor] = None) -> Tensor:
    if isinstance(eps, Tensor):
        if eps.shape[-len(distribution.event_shape
                          ):] != distribution.event_shape:
            raise RuntimeError(
                f"Expected shape ending in {distribution.event_shape}, got {eps.shape}."
            )
    else:
        shape = distribution.batch_shape + distribution.event_shape
        if eps is None:
            eps = 1.0
        eps *= _standard_normal(shape,
                                dtype=distribution.loc.dtype,
                                device=distribution.loc.device)
    return distribution.loc + _batch_mv(distribution._unbroadcasted_scale_tril,
                                        eps)
Ejemplo n.º 18
0
 def deterministic_sample(self, eps=None) -> Tensor:
     expected_shape = self._batch_shape + self._event_shape
     if self.univariate:
         if eps is None:
             eps = self.loc.new(*expected_shape).normal_()
         else:
             assert eps.size(
             ) == expected_shape, f"expected-shape:{expected_shape}, actual:{eps.size()}"
         std = torch.sqrt(torch.squeeze(self.covariance_matrix, -1))
         return std * eps + self.loc
     else:
         if eps is None:
             eps = _standard_normal(expected_shape,
                                    dtype=self.loc.dtype,
                                    device=self.loc.device)
         else:
             assert eps.size(
             ) == expected_shape, f"expected-shape:{expected_shape}, actual:{eps.size()}"
         return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
Ejemplo n.º 19
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.loc.dtype,
                            device=self.loc.device)
     return self.loc + eps * self.scale
Ejemplo n.º 20
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.loc.dtype,
                            device=self.loc.device)
     return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
Ejemplo n.º 21
0
 def perturb(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     eps = _standard_normal(shape,
                            dtype=self.loc.dtype,
                            device=self.loc.device)
     return eps * self.stddev