def forward(self, x): a = self.a(x) b = F.softplus(self.b(x)) pdf = torch._standard_gamma(a + b) * (x**(a - 1)) * ((1 - x)**( b - 1)) / (torch._standard_gamma(a) + torch._standard_gamma(b)) return pdf
def rsample(self, sample_shape=torch.Size()): """ References ---------- - Sawyer, S. (2007). Wishart Distributions and Inverse-Wishart Sampling. https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf - Anderson, T. W. (2003). An Introduction to Multivariate Statistical Analysis (3rd ed.). John Wiley & Sons, Inc. - Odell, P. L. & Feiveson, A. H. (1966). A Numerical Procedure to Generate a Sample Covariance Matrix. Journal of the American Statistical Association, 61(313):199-203. - Ku, Y.-C. & Blomfield, P. (2010). Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX. """ shape = torch.Size(sample_shape) + self.batch_shape dtype, device = self.concentration.dtype, self.concentration.device D = self.event_shape[-1] df = 2. * self.concentration # type: torch.Tensor i = torch.arange(D, dtype=dtype, device=device) concentration = .5 * (df.unsqueeze(-1) - i).expand(shape + (D, )) V = 2. * torch._standard_gamma(concentration) N = torch.randn(*shape, D * (D - 1) // 2, dtype=dtype, device=device) T = torch.diag_embed(V.sqrt()) # T is lower-triangular i, j = torch.tril_indices(D, D, offset=-1) T[..., i, j] = N M = self.scale_tril @ T W = M @ M.transpose(-2, -1) return W
def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) dir_var = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) dir_var /= dir_var.norm(p=2, dim=-1, keepdim=True) norms = torch._standard_gamma( self.k * torch.ones(size=shape[:-1], dtype=self.loc.dtype, device=self.loc.device) ).reshape((-1, 1)) * self.scale return self.loc + norms * dir_var
def rsample(self, sample_shape=torch.Size()): odx = (Categorical( logits=self.offset_logits).expand(self.batch_shape + self.event_shape).sample()) offset = self.offset_samples[odx] shape = self._extended_shape(sample_shape) value = torch._standard_gamma( self.concentration.expand(shape)) / self.rate.expand(shape) value.detach().clamp_(min=torch.finfo( value.dtype).tiny) # do not record in autograd graph return value + offset
def forward(self, input): concentration = self._get_concentration(input.state) # Backwards pass of dirichlet distribution not implemented in PyTorch # so sample using Gamma distribution outlined here: # https://en.wikipedia.org/wiki/Dirichlet_distribution#Random_number_generation gamma_samples = torch._standard_gamma(concentration) action = gamma_samples / torch.sum(gamma_samples, dim=1, keepdim=True) if not self.training: # ONNX doesn't like reshape either.. return rlt.ActorOutput(action=action) log_prob = self.get_log_prob(input.state, action) return rlt.ActorOutput(action=action, log_prob=log_prob.unsqueeze(dim=1))
def _standard_wishart_tril(df: torch.Tensor, dim: int, shape: torch.Size): """ References ---------- - Sawyer, S. (2007). Wishart Distributions and Inverse-Wishart Sampling. https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf - Anderson, T. W. (2003). An Introduction to Multivariate Statistical Analysis (3rd ed.). John Wiley & Sons, Inc. - Odell, P. L. & Feiveson, A. H. (1966). A Numerical Procedure to Generate a Sample Covariance Matrix. Journal of the American Statistical Association, 61(313):199-203. - Ku, Y.-C. & Blomfield, P. (2010). Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX. """ dtype, device = df.dtype, df.device i = torch.arange(dim, dtype=dtype, device=device) concentration = .5 * (df.unsqueeze(-1) - i).expand(shape + (dim,)) V = 2. * torch._standard_gamma(concentration) N = torch.randn(*shape, dim * (dim - 1) // 2, dtype=dtype, device=device) T = torch.diag_embed(V.sqrt()) # T is lower-triangular i, j = torch.tril_indices(dim, dim, offset=-1) T[..., i, j] = N return T
def _dirichlet_sample_nograd(concentration): probs = torch._standard_gamma(concentration) probs /= probs.sum(-1, True) return clamp_probs(probs)
def _standard_gamma(concentration): return torch._standard_gamma(concentration)
def _dirichlet_sample_nograd(concentration): probs = torch._standard_gamma(concentration) probs /= probs.sum(-1, True) eps = _finfo(probs).eps return probs.clamp_(min=eps, max=1 - eps)