예제 #1
0
 def __init__(self, concentration, rate, validate_args=None):
     base_dist = Gamma(concentration, rate)
     super().__init__(
         base_dist,
         PowerTransform(-base_dist.rate.new_ones(())),
         validate_args=validate_args,
     )
예제 #2
0
 def __init__(self, concentration):
     if concentration.data.min() < 1:
         raise NotImplementedError('concentration < 1 is not supported')
     self.concentration = concentration
     self._standard_gamma = Gamma(
         concentration,
         concentration.new([1.]).squeeze().expand_as(concentration))
     # The following are Marsaglia & Tsang's variable names.
     self._d = self.concentration - 1.0 / 3.0
     self._c = 1.0 / torch.sqrt(9.0 * self._d)
     # Compute log scale using Gamma.log_prob().
     x = self._d.detach()  # just an arbitrary x.
     log_scale = self.propose_log_prob(x) + self.log_prob_accept(
         x) - self.log_prob(x)
     super(RejectionStandardGamma,
           self).__init__(self.propose, self.log_prob_accept, log_scale)
예제 #3
0
파일: hmm.py 프로젝트: youisbaby/pyro
    def filter(self, value):
        """
        Compute posteriors over the multiplier and the final state
        given a sequence of observations. The posterior is a pair of
        Gamma and MultivariateNormal distributions (i.e. a GammaGaussian
        instance).

        :param ~torch.Tensor value: A sequence of observations.
        :return: A pair of posterior distributions over the mixing and the latent
            state at the final time step.
        :rtype: a tuple of ~pyro.distributions.Gamma and ~pyro.distributions.MultivariateNormal
        """
        # Combine observation and transition factors.
        logp = self._trans + self._obs.condition(value).event_pad(
            left=self.hidden_dim)

        # Eliminate time dimension.
        logp = _sequential_gamma_gaussian_tensordot(
            logp.expand(logp.batch_shape))

        # Combine initial factor.
        logp = gamma_gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

        # Posterior of the scale
        gamma_dist = logp.event_logsumexp()
        scale_post = Gamma(gamma_dist.concentration,
                           gamma_dist.rate,
                           validate_args=self._validate_args)
        # Conditional of last state on unit scale
        scale_tril = logp.precision.cholesky()
        loc = logp.info_vec.unsqueeze(-1).cholesky_solve(scale_tril).squeeze(
            -1)
        mvn = MultivariateNormal(loc,
                                 scale_tril=scale_tril,
                                 validate_args=self._validate_args)
        return scale_post, mvn
예제 #4
0
파일: conjugate.py 프로젝트: www3cam/pyro
 def __init__(self, concentration, rate, validate_args=None):
     concentration, rate = broadcast_all(concentration, rate)
     self._gamma = Gamma(concentration, rate)
     super(GammaPoisson, self).__init__(self._gamma._batch_shape,
                                        validate_args=validate_args)
예제 #5
0
 def __init__(self, concentration1, concentration0, validate_args=None):
     super(NaiveBeta, self).__init__(concentration1,
                                     concentration0,
                                     validate_args=validate_args)
     alpha_beta = torch.stack([concentration1, concentration0], -1)
     self._gamma = Gamma(alpha_beta, torch.ones_like(alpha_beta))
예제 #6
0
 def __init__(self, concentration, validate_args=None):
     super(NaiveDirichlet, self).__init__(concentration)
     self._gamma = Gamma(concentration,
                         torch.ones_like(concentration),
                         validate_args=validate_args)
예제 #7
0
 def __init__(self, concentration, rate, validate_args=None):
     base_dist = Gamma(concentration, rate)
     super(InverseGamma, self).__init__(base_dist, PowerTransform(-1.0), validate_args=validate_args)
예제 #8
0
 def __init__(self, concentration):
     super(NaiveDirichlet, self).__init__(concentration)
     self._gamma = Gamma(concentration, torch.ones_like(concentration))