Пример #1
0
    def forward(self, x):
        a = self.a(x)
        b = F.softplus(self.b(x))
        pdf = torch._standard_gamma(a + b) * (x**(a - 1)) * ((1 - x)**(
            b - 1)) / (torch._standard_gamma(a) + torch._standard_gamma(b))

        return pdf
Пример #2
0
 def rsample(self, sample_shape=torch.Size()):
     """
     References
     ----------
     - Sawyer, S. (2007). Wishart Distributions and Inverse-Wishart Sampling.
       https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf
     - Anderson, T. W. (2003). An Introduction to Multivariate Statistical Analysis (3rd ed.).
       John Wiley & Sons, Inc.
     - Odell, P. L. & Feiveson, A. H. (1966). A Numerical Procedure to Generate a Sample
       Covariance Matrix. Journal of the American Statistical Association, 61(313):199-203.
     - Ku, Y.-C. & Blomfield, P. (2010). Generating Random Wishart Matrices with Fractional
       Degrees of Freedom in OX.
     """
     shape = torch.Size(sample_shape) + self.batch_shape
     dtype, device = self.concentration.dtype, self.concentration.device
     D = self.event_shape[-1]
     df = 2. * self.concentration  # type: torch.Tensor
     i = torch.arange(D, dtype=dtype, device=device)
     concentration = .5 * (df.unsqueeze(-1) - i).expand(shape + (D, ))
     V = 2. * torch._standard_gamma(concentration)
     N = torch.randn(*shape, D * (D - 1) // 2, dtype=dtype, device=device)
     T = torch.diag_embed(V.sqrt())  # T is lower-triangular
     i, j = torch.tril_indices(D, D, offset=-1)
     T[..., i, j] = N
     M = self.scale_tril @ T
     W = M @ M.transpose(-2, -1)
     return W
Пример #3
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     dir_var = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
     dir_var /= dir_var.norm(p=2, dim=-1, keepdim=True)
     norms = torch._standard_gamma(
         self.k * torch.ones(size=shape[:-1], dtype=self.loc.dtype, device=self.loc.device)
     ).reshape((-1, 1)) * self.scale
     return self.loc + norms * dir_var
Пример #4
0
 def rsample(self, sample_shape=torch.Size()):
     odx = (Categorical(
         logits=self.offset_logits).expand(self.batch_shape +
                                           self.event_shape).sample())
     offset = self.offset_samples[odx]
     shape = self._extended_shape(sample_shape)
     value = torch._standard_gamma(
         self.concentration.expand(shape)) / self.rate.expand(shape)
     value.detach().clamp_(min=torch.finfo(
         value.dtype).tiny)  # do not record in autograd graph
     return value + offset
Пример #5
0
    def forward(self, input):
        concentration = self._get_concentration(input.state)
        # Backwards pass of dirichlet distribution not implemented in PyTorch
        # so sample using Gamma distribution outlined here:
        # https://en.wikipedia.org/wiki/Dirichlet_distribution#Random_number_generation
        gamma_samples = torch._standard_gamma(concentration)
        action = gamma_samples / torch.sum(gamma_samples, dim=1, keepdim=True)

        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)

        log_prob = self.get_log_prob(input.state, action)
        return rlt.ActorOutput(action=action,
                               log_prob=log_prob.unsqueeze(dim=1))
Пример #6
0
def _standard_wishart_tril(df: torch.Tensor, dim: int, shape: torch.Size):
    """
    References
    ----------
    - Sawyer, S. (2007). Wishart Distributions and Inverse-Wishart Sampling.
      https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf
    - Anderson, T. W. (2003). An Introduction to Multivariate Statistical Analysis (3rd ed.).
      John Wiley & Sons, Inc.
    - Odell, P. L. & Feiveson, A. H. (1966). A Numerical Procedure to Generate a Sample
      Covariance Matrix. Journal of the American Statistical Association, 61(313):199-203.
    - Ku, Y.-C. & Blomfield, P. (2010). Generating Random Wishart Matrices with Fractional
      Degrees of Freedom in OX.
    """
    dtype, device = df.dtype, df.device
    i = torch.arange(dim, dtype=dtype, device=device)
    concentration = .5 * (df.unsqueeze(-1) - i).expand(shape + (dim,))
    V = 2. * torch._standard_gamma(concentration)
    N = torch.randn(*shape, dim * (dim - 1) // 2, dtype=dtype, device=device)
    T = torch.diag_embed(V.sqrt())  # T is lower-triangular
    i, j = torch.tril_indices(dim, dim, offset=-1)
    T[..., i, j] = N
    return T
Пример #7
0
def _dirichlet_sample_nograd(concentration):
    probs = torch._standard_gamma(concentration)
    probs /= probs.sum(-1, True)
    return clamp_probs(probs)
Пример #8
0
def _standard_gamma(concentration):
    return torch._standard_gamma(concentration)
Пример #9
0
def _dirichlet_sample_nograd(concentration):
    probs = torch._standard_gamma(concentration)
    probs /= probs.sum(-1, True)
    return clamp_probs(probs)
Пример #10
0
def _dirichlet_sample_nograd(concentration):
    probs = torch._standard_gamma(concentration)
    probs /= probs.sum(-1, True)
    eps = _finfo(probs).eps
    return probs.clamp_(min=eps, max=1 - eps)
Пример #11
0
def _dirichlet_sample_nograd(concentration):
    probs = torch._standard_gamma(concentration)
    probs /= probs.sum(-1, True)
    eps = _finfo(probs).eps
    return probs.clamp_(min=eps, max=1 - eps)