Esempio n. 1
0
 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Bernoulli, self).__init__(batch_shape)
Esempio n. 2
0
 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
         if not self.probs.gt(0).all():
             raise ValueError('All elements of probs must be greater than 0')
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 3
0
 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 4
0
 def __init__(self, low, high):
     self.low, self.high = broadcast_all(low, high)
     if isinstance(low, Number) and isinstance(high, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.low.size()
     super(Uniform, self).__init__(batch_shape)
Esempio n. 5
0
 def __init__(self, concentration, rate):
     self.concentration, self.rate = broadcast_all(concentration, rate)
     if isinstance(concentration, Number) and isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.concentration.size()
     super(Gamma, self).__init__(batch_shape)
Esempio n. 6
0
 def __init__(self, loc, scale):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Laplace, self).__init__(batch_shape)
Esempio n. 7
0
 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Normal, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 8
0
 def __init__(self, rate, validate_args=None):
     self.rate, = broadcast_all(rate)
     if isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.rate.size()
     super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 9
0
 def __init__(self, alpha, beta):
     self.alpha, self.beta = broadcast_all(alpha, beta)
     if isinstance(alpha, Number) and isinstance(beta, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.alpha.size()
     super(Gamma, self).__init__(batch_shape)
Esempio n. 10
0
 def __init__(self, scale, alpha):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     if isinstance(scale, Number) and isinstance(alpha, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Pareto, self).__init__(batch_shape)
Esempio n. 11
0
 def __init__(self, concentration1, concentration0, validate_args=None):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
Esempio n. 12
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     gate, rate, value = broadcast_all(self.gate, self.rate, value)
     log_prob = (-gate).log1p() + (rate.log() * value) - rate - (value + 1).lgamma()
     zeros = value == 0
     log_prob[zeros] = (gate[zeros] + log_prob[zeros].exp()).log()
     return log_prob
Esempio n. 13
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        gate, value = broadcast_all(self.gate, value)
        log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
        log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob)
        return log_prob
Esempio n. 14
0
 def __init__(self, l, k):
     self.l, self.k = broadcast_all(l, k)
     self.const = 1e-8
     if isinstance(k, Number) and isinstance(l, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.k.size()
     super(SimpleWeibull, self).__init__(batch_shape=batch_shape)
 def __init__(self, rate, censoring, validate_args=None):
     self.rate, self.censoring = broadcast_all(rate, censoring)
     if isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.rate.size()
     super(PyroCensoredPoison, self).__init__(batch_shape,
                                              validate_args=validate_args)
Esempio n. 16
0
 def __init__(self, loc, scale, beta, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     self.beta = beta
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(GeneralisedNormal, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 17
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
Esempio n. 18
0
 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError(
             "Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
         if not self.probs.gt(0).all():
             raise ValueError(
                 'All elements of probs must be greater than 0')
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Geometric, self).__init__(batch_shape)
Esempio n. 19
0
 def __init__(self, alpha, beta):
     if isinstance(alpha, Number) and isinstance(beta, Number):
         alpha_beta = torch.Tensor([alpha, beta])
     else:
         alpha, beta = broadcast_all(alpha, beta)
         alpha_beta = torch.stack([alpha, beta], -1)
     self._dirichlet = Dirichlet(alpha_beta)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
Esempio n. 20
0
 def __init__(self, concentration1, concentration0, validate_args=None):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
 def __init__(self, mu, alpha, validate_args=None):
     self.mu, self.alpha = broadcast_all(mu, alpha)
     if isinstance(mu, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.mu.size()
     super(PyroNegBinomial, self).__init__(batch_shape,
                                           validate_args=validate_args)
Esempio n. 22
0
 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError(
             "Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     super(Bernoulli, self).__init__(batch_shape,
                                     validate_args=validate_args)
Esempio n. 23
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
Esempio n. 24
0
 def __init__(self, concentration1, concentration0):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = variable([concentration1, concentration0])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
Esempio n. 25
0
 def __init__(self, concentration1, concentration0):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.Tensor([concentration1, concentration0])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
Esempio n. 26
0
 def __init__(self, concentration, validate_args=None):
     self.concentration, = broadcast_all(concentration)
     batch_shape, event_shape = concentration.shape[:
                                                    -1], concentration.shape[
                                                        -1:]
     super(Dirichlet, self).__init__(batch_shape,
                                     event_shape,
                                     validate_args=validate_args)
Esempio n. 27
0
 def log_prob(self, value):
     self.threshold = self.censoring==1
     if self._validate_args:
         self._validate_sample(value)
     rate, value = broadcast_all(self.rate, value)
     log_prob = (rate.log() * value) - rate - (value + 1).lgamma()
     log_prob[self.threshold] = math.log(1 - self.cdf(value)[self.threshold] + 0.01) if isinstance(1 - self.cdf(value)[self.threshold] + 1e-6, Number) else (1 - self.cdf(value)[self.threshold] + 1e-6).log()
     return log_prob
Esempio n. 28
0
 def __init__(self, scale, concentration, validate_args=None):
     self.scale, self.concentration = broadcast_all(scale, concentration)
     self.concentration_reciprocal = self.concentration.reciprocal()
     base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args)
     transforms = [PowerTransform(exponent=self.concentration_reciprocal),
                   AffineTransform(loc=0, scale=self.scale)]
     super(Weibull, self).__init__(base_dist,
                                   transforms,
                                   validate_args=validate_args)
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        mu, alpha, value = broadcast_all(self.mu, self.alpha, value)
        value = value.float()

        d = torch.distributions.negative_binomial.NegativeBinomial(
            total_count=1 / alpha, logits=alpha * mu)
        return d.log_prob(value)
Esempio n. 30
0
 def __init__(self, pi, loc, scale, alpha=0.05, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     self.pi = pi
     self.alpha = torch.tensor(alpha).to(self.loc.device)
     if isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Sparse, self).__init__(batch_shape)
Esempio n. 31
0
        def __init__(self, loc, validate_args=None):
            self.loc = broadcast_all(loc)[0]

            if isinstance(loc, Number):
                batch_shape = torch.Size()
            else:
                batch_shape = self.loc.size()
            super(TorchDeterministic,
                  self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 32
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     gate, rate, value = broadcast_all(self.gate, self.rate, value)
     log_prob = (-gate).log1p() + (rate.log() *
                                   value) - rate - (value + 1).lgamma()
     log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(),
                            log_prob)
     return log_prob
Esempio n. 33
0
def test_broadcast_all(shapes):
    inputs, dim_to_symbol, symbol_to_dim = make_inputs(shapes)
    packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs]
    packed_outputs = packed.broadcast_all(*packed_inputs)
    actual = tuple(packed.unpack(x, symbol_to_dim) for x in packed_outputs)
    expected = broadcast_all(*inputs) if inputs else []
    assert len(actual) == len(expected)
    for a, e in zip(actual, expected):
        assert_equal(a, e)
Esempio n. 34
0
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.probs, Number)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.logits, Number)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 35
0
 def __init__(self, loc, scale, censoring, validate_args=None):
     
     self.loc, self.scale, self.censoring = broadcast_all(loc, scale, censoring)
     
     if isinstance(loc, Number) and isinstance(scale, Number) and isinstance(censoring, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(PyroCensoredNormal, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 36
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     logits, value = broadcast_all(self.logits.clone(memory_format=torch.contiguous_format), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -inf)] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
 def __init__(self, gamma, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     self.gamma = gamma
     self.alpha = torch.tensor(0.05).to(self.loc.device)
     if isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Sparse_torch, self).__init__(batch_shape)
Esempio n. 38
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        if "gate" in self.__dict__:
            gate, value = broadcast_all(self.gate, value)
            log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
            log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(),
                                   log_prob)
        else:
            gate_logits, value = broadcast_all(self.gate_logits, value)
            log_prob_minus_log_gate = -gate_logits + self.base_dist.log_prob(
                value)
            log_gate = -softplus(-gate_logits)
            log_prob = log_prob_minus_log_gate + log_gate
            zero_log_prob = softplus(log_prob_minus_log_gate) + log_gate
            log_prob = torch.where(value == 0, zero_log_prob, log_prob)
        return log_prob
Esempio n. 39
0
def gather(value, index, dim):
    """
    Broadcasted gather of indexed values along a named dim.
    """
    value, index = broadcast_all(value, index)
    with ignore_jit_warnings():
        zero = torch.zeros(1, dtype=torch.long, device=index.device)
    index = index.index_select(dim, zero)
    return value.gather(dim, index)
Esempio n. 40
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        logits, value = broadcast_all(self.logits, value)
        diff = logits - value.mul(self.temperature)

        out = self.temperature.log() + diff - 2 * diff.exp().log1p()

        return out
Esempio n. 41
0
    def __init__(self, loc, logscale, validate_args=None):

        self.loc, self.logscale = broadcast_all(loc, logscale)

        if isinstance(loc, Number) and isinstance(logscale, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.loc.size()
        super(MyNormal, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 42
0
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.probs, Number)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.logits, Number)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
 def __init__(self, mu, alpha, lam, validate_args=None):
     self.mu, self.alpha, self.lam = broadcast_all(mu, alpha, lam)
     if isinstance(mu, Number) and isinstance(alpha, Number) \
             and isinstance(lam, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.mu.size()
     super(TorchNegativeBinomialPoissonConvApprox,
           self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 44
0
    def __init__(self, total_count=1, probs=None, logits=None):
        if not isinstance(total_count, Number):
            raise NotImplementedError('inhomogeneous total_count is not supported')
        self.total_count = total_count
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            is_scalar = isinstance(probs, Number)
            self.probs, = broadcast_all(probs)
        else:
            is_scalar = isinstance(logits, Number)
            self.logits, = broadcast_all(logits)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape)
Esempio n. 45
0
    def __init__(self, df1, df2, validate_args=None):
        self.df1, self.df2 = broadcast_all(df1, df2)
        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

        if isinstance(df1, Number) and isinstance(df2, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.df1.size()
        super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 46
0
 def log_prob(self, value):
     K = self._categorical._num_events
     if self._validate_args:
         self._validate_sample(value)
     logits, value = broadcast_all(self.logits, value)
     log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
                  self.temperature.log().mul(-(K - 1)))
     score = logits - value.mul(self.temperature)
     score = (score - _log_sum_exp(score)).sum(-1)
     return score + log_scale
Esempio n. 47
0
    def __init__(self, df1, df2):
        self.df1, self.df2 = broadcast_all(df1, df2)
        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

        if isinstance(df1, Number) and isinstance(df2, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.df1.size()
        super(FisherSnedecor, self).__init__(batch_shape)
Esempio n. 48
0
    def __init__(self, low, high, validate_args=None):
        self.low, self.high = broadcast_all(low, high)

        if isinstance(low, Number) and isinstance(high, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.low.size()
        super(Uniform, self).__init__(batch_shape, validate_args=validate_args)

        if self._validate_args and not torch.lt(self.low, self.high).all():
            raise ValueError("Uniform is not defined when low>= high")
Esempio n. 49
0
 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     finfo = _finfo(self.loc)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
         base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
     else:
         batch_shape = self.scale.size()
         base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps)
     transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
                   ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
     super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
Esempio n. 50
0
 def __init__(self, rate, validate_args=None):
     self.rate, = broadcast_all(rate)
     batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
     super(Exponential, self).__init__(batch_shape, validate_args=validate_args)
Esempio n. 51
0
 def __init__(self, concentration):
     self.concentration, = broadcast_all(concentration)
     batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
     super(Dirichlet, self).__init__(batch_shape, event_shape)
Esempio n. 52
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits, value)
     diff = logits - value.mul(self.temperature)
     return self.temperature.log() + diff - 2 * diff.exp().log1p()
Esempio n. 53
0
 def __init__(self, df, loc=0., scale=1.):
     self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
     self._chi2 = Chi2(df)
     batch_shape = torch.Size() if isinstance(df, Number) else self.df.size()
     super(StudentT, self).__init__(batch_shape)
Esempio n. 54
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits, value)
     return -binary_cross_entropy_with_logits(logits, value, reduce=False)
Esempio n. 55
0
 def __init__(self, alpha):
     self.alpha, = broadcast_all(alpha)
     batch_shape, event_shape = alpha.shape[:-1], alpha.shape[-1:]
     super(Dirichlet, self).__init__(batch_shape, event_shape)
Esempio n. 56
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     rate, value = broadcast_all(self.rate, value)
     return (rate.log() * value) - rate - (value + 1).lgamma()
Esempio n. 57
0
 def __init__(self, loc, scale):
     loc, scale = broadcast_all(loc, scale)
     base_dist = Cauchy(0, scale)
     transforms = [AbsTransform(), AffineTransform(loc, 1)]
     super(HalfCauchy, self).__init__(base_dist, transforms)
Esempio n. 58
0
 def __init__(self, loc, scale, event_dim=0, cache_size=0):
     super(AffineTransform, self).__init__(cache_size=cache_size)
     self.loc, self.scale = broadcast_all(loc, scale)
     self.event_dim = event_dim
Esempio n. 59
0
 def __init__(self, scale, alpha):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     base_dist = Exponential(self.alpha)
     transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
     super(Pareto, self).__init__(base_dist, transforms)