コード例 #1
0
ファイル: bernoulli.py プロジェクト: lxlhh/pytorch
 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Bernoulli, self).__init__(batch_shape)
コード例 #2
0
ファイル: geometric.py プロジェクト: RichieMay/pytorch
 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
         if not self.probs.gt(0).all():
             raise ValueError('All elements of probs must be greater than 0')
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
コード例 #3
0
ファイル: bernoulli.py プロジェクト: gtgalone/pytorch
 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
コード例 #4
0
ファイル: uniform.py プロジェクト: bhuWenDongchao/pytorch
 def __init__(self, low, high):
     self.low, self.high = broadcast_all(low, high)
     if isinstance(low, Number) and isinstance(high, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.low.size()
     super(Uniform, self).__init__(batch_shape)
コード例 #5
0
ファイル: gamma.py プロジェクト: MaheshBhosale/pytorch
 def __init__(self, concentration, rate):
     self.concentration, self.rate = broadcast_all(concentration, rate)
     if isinstance(concentration, Number) and isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.concentration.size()
     super(Gamma, self).__init__(batch_shape)
コード例 #6
0
ファイル: laplace.py プロジェクト: MaheshBhosale/pytorch
 def __init__(self, loc, scale):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Laplace, self).__init__(batch_shape)
コード例 #7
0
ファイル: normal.py プロジェクト: RichieMay/pytorch
 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Normal, self).__init__(batch_shape, validate_args=validate_args)
コード例 #8
0
ファイル: poisson.py プロジェクト: RichieMay/pytorch
 def __init__(self, rate, validate_args=None):
     self.rate, = broadcast_all(rate)
     if isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.rate.size()
     super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
コード例 #9
0
ファイル: gamma.py プロジェクト: lxlhh/pytorch
 def __init__(self, alpha, beta):
     self.alpha, self.beta = broadcast_all(alpha, beta)
     if isinstance(alpha, Number) and isinstance(beta, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.alpha.size()
     super(Gamma, self).__init__(batch_shape)
コード例 #10
0
ファイル: pareto.py プロジェクト: bhuWenDongchao/pytorch
 def __init__(self, scale, alpha):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     if isinstance(scale, Number) and isinstance(alpha, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Pareto, self).__init__(batch_shape)
コード例 #11
0
 def __init__(self, concentration1, concentration0, validate_args=None):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
コード例 #12
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     gate, rate, value = broadcast_all(self.gate, self.rate, value)
     log_prob = (-gate).log1p() + (rate.log() * value) - rate - (value + 1).lgamma()
     zeros = value == 0
     log_prob[zeros] = (gate[zeros] + log_prob[zeros].exp()).log()
     return log_prob
コード例 #13
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        gate, value = broadcast_all(self.gate, value)
        log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
        log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(), log_prob)
        return log_prob
コード例 #14
0
ファイル: distributions.py プロジェクト: chipper1/torchbearer
 def __init__(self, l, k):
     self.l, self.k = broadcast_all(l, k)
     self.const = 1e-8
     if isinstance(k, Number) and isinstance(l, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.k.size()
     super(SimpleWeibull, self).__init__(batch_shape=batch_shape)
コード例 #15
0
 def __init__(self, rate, censoring, validate_args=None):
     self.rate, self.censoring = broadcast_all(rate, censoring)
     if isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.rate.size()
     super(PyroCensoredPoison, self).__init__(batch_shape,
                                              validate_args=validate_args)
コード例 #16
0
 def __init__(self, loc, scale, beta, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     self.beta = beta
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(GeneralisedNormal, self).__init__(batch_shape, validate_args=validate_args)
コード例 #17
0
ファイル: multinomial.py プロジェクト: bhuWenDongchao/pytorch
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
コード例 #18
0
ファイル: geometric.py プロジェクト: pelluru/pytorch-1
 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError(
             "Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
         if not self.probs.gt(0).all():
             raise ValueError(
                 'All elements of probs must be greater than 0')
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Geometric, self).__init__(batch_shape)
コード例 #19
0
 def __init__(self, alpha, beta):
     if isinstance(alpha, Number) and isinstance(beta, Number):
         alpha_beta = torch.Tensor([alpha, beta])
     else:
         alpha, beta = broadcast_all(alpha, beta)
         alpha_beta = torch.stack([alpha, beta], -1)
     self._dirichlet = Dirichlet(alpha_beta)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
コード例 #20
0
ファイル: beta.py プロジェクト: gtgalone/pytorch
 def __init__(self, concentration1, concentration0, validate_args=None):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
コード例 #21
0
 def __init__(self, mu, alpha, validate_args=None):
     self.mu, self.alpha = broadcast_all(mu, alpha)
     if isinstance(mu, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.mu.size()
     super(PyroNegBinomial, self).__init__(batch_shape,
                                           validate_args=validate_args)
コード例 #22
0
 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError(
             "Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     super(Bernoulli, self).__init__(batch_shape,
                                     validate_args=validate_args)
コード例 #23
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
コード例 #24
0
ファイル: beta.py プロジェクト: Jsmilemsj/pytorch
 def __init__(self, concentration1, concentration0):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = variable([concentration1, concentration0])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
コード例 #25
0
ファイル: beta.py プロジェクト: xianweilv/pytorch
 def __init__(self, concentration1, concentration0):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.Tensor([concentration1, concentration0])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
コード例 #26
0
 def __init__(self, concentration, validate_args=None):
     self.concentration, = broadcast_all(concentration)
     batch_shape, event_shape = concentration.shape[:
                                                    -1], concentration.shape[
                                                        -1:]
     super(Dirichlet, self).__init__(batch_shape,
                                     event_shape,
                                     validate_args=validate_args)
コード例 #27
0
ファイル: GP_likelihoods.py プロジェクト: Krollo12/GHMCGP
 def log_prob(self, value):
     self.threshold = self.censoring==1
     if self._validate_args:
         self._validate_sample(value)
     rate, value = broadcast_all(self.rate, value)
     log_prob = (rate.log() * value) - rate - (value + 1).lgamma()
     log_prob[self.threshold] = math.log(1 - self.cdf(value)[self.threshold] + 0.01) if isinstance(1 - self.cdf(value)[self.threshold] + 1e-6, Number) else (1 - self.cdf(value)[self.threshold] + 1e-6).log()
     return log_prob
コード例 #28
0
ファイル: weibull.py プロジェクト: thomascong121/NCRF
 def __init__(self, scale, concentration, validate_args=None):
     self.scale, self.concentration = broadcast_all(scale, concentration)
     self.concentration_reciprocal = self.concentration.reciprocal()
     base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args)
     transforms = [PowerTransform(exponent=self.concentration_reciprocal),
                   AffineTransform(loc=0, scale=self.scale)]
     super(Weibull, self).__init__(base_dist,
                                   transforms,
                                   validate_args=validate_args)
コード例 #29
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        mu, alpha, value = broadcast_all(self.mu, self.alpha, value)
        value = value.float()

        d = torch.distributions.negative_binomial.NegativeBinomial(
            total_count=1 / alpha, logits=alpha * mu)
        return d.log_prob(value)
コード例 #30
0
 def __init__(self, pi, loc, scale, alpha=0.05, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     self.pi = pi
     self.alpha = torch.tensor(alpha).to(self.loc.device)
     if isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Sparse, self).__init__(batch_shape)
コード例 #31
0
        def __init__(self, loc, validate_args=None):
            self.loc = broadcast_all(loc)[0]

            if isinstance(loc, Number):
                batch_shape = torch.Size()
            else:
                batch_shape = self.loc.size()
            super(TorchDeterministic,
                  self).__init__(batch_shape, validate_args=validate_args)
コード例 #32
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     gate, rate, value = broadcast_all(self.gate, self.rate, value)
     log_prob = (-gate).log1p() + (rate.log() *
                                   value) - rate - (value + 1).lgamma()
     log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(),
                            log_prob)
     return log_prob
コード例 #33
0
ファイル: test_packed.py プロジェクト: pyro-ppl/pyro
def test_broadcast_all(shapes):
    inputs, dim_to_symbol, symbol_to_dim = make_inputs(shapes)
    packed_inputs = [packed.pack(x, dim_to_symbol) for x in inputs]
    packed_outputs = packed.broadcast_all(*packed_inputs)
    actual = tuple(packed.unpack(x, symbol_to_dim) for x in packed_outputs)
    expected = broadcast_all(*inputs) if inputs else []
    assert len(actual) == len(expected)
    for a, e in zip(actual, expected):
        assert_equal(a, e)
コード例 #34
0
ファイル: binomial.py プロジェクト: zsk423200/pytorch
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.probs, Number)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.logits, Number)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
コード例 #35
0
ファイル: GP_likelihoods.py プロジェクト: Krollo12/GHMCGP
 def __init__(self, loc, scale, censoring, validate_args=None):
     
     self.loc, self.scale, self.censoring = broadcast_all(loc, scale, censoring)
     
     if isinstance(loc, Number) and isinstance(scale, Number) and isinstance(censoring, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(PyroCensoredNormal, self).__init__(batch_shape, validate_args=validate_args)
コード例 #36
0
ファイル: multinomial.py プロジェクト: khabya/DeepStack
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     logits, value = broadcast_all(self.logits.clone(memory_format=torch.contiguous_format), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -inf)] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
コード例 #37
0
 def __init__(self, gamma, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     self.gamma = gamma
     self.alpha = torch.tensor(0.05).to(self.loc.device)
     if isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Sparse_torch, self).__init__(batch_shape)
コード例 #38
0
ファイル: zero_inflated.py プロジェクト: pyro-ppl/pyro
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        if "gate" in self.__dict__:
            gate, value = broadcast_all(self.gate, value)
            log_prob = (-gate).log1p() + self.base_dist.log_prob(value)
            log_prob = torch.where(value == 0, (gate + log_prob.exp()).log(),
                                   log_prob)
        else:
            gate_logits, value = broadcast_all(self.gate_logits, value)
            log_prob_minus_log_gate = -gate_logits + self.base_dist.log_prob(
                value)
            log_gate = -softplus(-gate_logits)
            log_prob = log_prob_minus_log_gate + log_gate
            zero_log_prob = softplus(log_prob_minus_log_gate) + log_gate
            log_prob = torch.where(value == 0, zero_log_prob, log_prob)
        return log_prob
コード例 #39
0
ファイル: util.py プロジェクト: yufengwa/pyro
def gather(value, index, dim):
    """
    Broadcasted gather of indexed values along a named dim.
    """
    value, index = broadcast_all(value, index)
    with ignore_jit_warnings():
        zero = torch.zeros(1, dtype=torch.long, device=index.device)
    index = index.index_select(dim, zero)
    return value.gather(dim, index)
コード例 #40
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        logits, value = broadcast_all(self.logits, value)
        diff = logits - value.mul(self.temperature)

        out = self.temperature.log() + diff - 2 * diff.exp().log1p()

        return out
コード例 #41
0
ファイル: core.py プロジェクト: eladsar/spinningup
    def __init__(self, loc, logscale, validate_args=None):

        self.loc, self.logscale = broadcast_all(loc, logscale)

        if isinstance(loc, Number) and isinstance(logscale, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.loc.size()
        super(MyNormal, self).__init__(batch_shape, validate_args=validate_args)
コード例 #42
0
ファイル: binomial.py プロジェクト: RichieMay/pytorch
    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.probs, Number)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.logits, Number)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
コード例 #43
0
 def __init__(self, mu, alpha, lam, validate_args=None):
     self.mu, self.alpha, self.lam = broadcast_all(mu, alpha, lam)
     if isinstance(mu, Number) and isinstance(alpha, Number) \
             and isinstance(lam, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.mu.size()
     super(TorchNegativeBinomialPoissonConvApprox,
           self).__init__(batch_shape, validate_args=validate_args)
コード例 #44
0
ファイル: binomial.py プロジェクト: Jsmilemsj/pytorch
    def __init__(self, total_count=1, probs=None, logits=None):
        if not isinstance(total_count, Number):
            raise NotImplementedError('inhomogeneous total_count is not supported')
        self.total_count = total_count
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            is_scalar = isinstance(probs, Number)
            self.probs, = broadcast_all(probs)
        else:
            is_scalar = isinstance(logits, Number)
            self.logits, = broadcast_all(logits)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape)
コード例 #45
0
ファイル: fishersnedecor.py プロジェクト: RichieMay/pytorch
    def __init__(self, df1, df2, validate_args=None):
        self.df1, self.df2 = broadcast_all(df1, df2)
        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

        if isinstance(df1, Number) and isinstance(df2, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.df1.size()
        super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)
コード例 #46
0
 def log_prob(self, value):
     K = self._categorical._num_events
     if self._validate_args:
         self._validate_sample(value)
     logits, value = broadcast_all(self.logits, value)
     log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
                  self.temperature.log().mul(-(K - 1)))
     score = logits - value.mul(self.temperature)
     score = (score - _log_sum_exp(score)).sum(-1)
     return score + log_scale
コード例 #47
0
    def __init__(self, df1, df2):
        self.df1, self.df2 = broadcast_all(df1, df2)
        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

        if isinstance(df1, Number) and isinstance(df2, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.df1.size()
        super(FisherSnedecor, self).__init__(batch_shape)
コード例 #48
0
ファイル: uniform.py プロジェクト: gtgalone/pytorch
    def __init__(self, low, high, validate_args=None):
        self.low, self.high = broadcast_all(low, high)

        if isinstance(low, Number) and isinstance(high, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.low.size()
        super(Uniform, self).__init__(batch_shape, validate_args=validate_args)

        if self._validate_args and not torch.lt(self.low, self.high).all():
            raise ValueError("Uniform is not defined when low>= high")
コード例 #49
0
ファイル: gumbel.py プロジェクト: gtgalone/pytorch
 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     finfo = _finfo(self.loc)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
         base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
     else:
         batch_shape = self.scale.size()
         base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps)
     transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
                   ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
     super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
コード例 #50
0
ファイル: exponential.py プロジェクト: RichieMay/pytorch
 def __init__(self, rate, validate_args=None):
     self.rate, = broadcast_all(rate)
     batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
     super(Exponential, self).__init__(batch_shape, validate_args=validate_args)
コード例 #51
0
ファイル: dirichlet.py プロジェクト: bhuWenDongchao/pytorch
 def __init__(self, concentration):
     self.concentration, = broadcast_all(concentration)
     batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
     super(Dirichlet, self).__init__(batch_shape, event_shape)
コード例 #52
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits, value)
     diff = logits - value.mul(self.temperature)
     return self.temperature.log() + diff - 2 * diff.exp().log1p()
コード例 #53
0
ファイル: studentT.py プロジェクト: lxlhh/pytorch
 def __init__(self, df, loc=0., scale=1.):
     self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
     self._chi2 = Chi2(df)
     batch_shape = torch.Size() if isinstance(df, Number) else self.df.size()
     super(StudentT, self).__init__(batch_shape)
コード例 #54
0
ファイル: bernoulli.py プロジェクト: lxlhh/pytorch
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits, value)
     return -binary_cross_entropy_with_logits(logits, value, reduce=False)
コード例 #55
0
ファイル: dirichlet.py プロジェクト: lxlhh/pytorch
 def __init__(self, alpha):
     self.alpha, = broadcast_all(alpha)
     batch_shape, event_shape = alpha.shape[:-1], alpha.shape[-1:]
     super(Dirichlet, self).__init__(batch_shape, event_shape)
コード例 #56
0
ファイル: poisson.py プロジェクト: RichieMay/pytorch
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     rate, value = broadcast_all(self.rate, value)
     return (rate.log() * value) - rate - (value + 1).lgamma()
コード例 #57
0
ファイル: half_cauchy.py プロジェクト: lewisKit/pyro
 def __init__(self, loc, scale):
     loc, scale = broadcast_all(loc, scale)
     base_dist = Cauchy(0, scale)
     transforms = [AbsTransform(), AffineTransform(loc, 1)]
     super(HalfCauchy, self).__init__(base_dist, transforms)
コード例 #58
0
ファイル: transforms.py プロジェクト: MaheshBhosale/pytorch
 def __init__(self, loc, scale, event_dim=0, cache_size=0):
     super(AffineTransform, self).__init__(cache_size=cache_size)
     self.loc, self.scale = broadcast_all(loc, scale)
     self.event_dim = event_dim
コード例 #59
0
ファイル: pareto.py プロジェクト: Jsmilemsj/pytorch
 def __init__(self, scale, alpha):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     base_dist = Exponential(self.alpha)
     transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
     super(Pareto, self).__init__(base_dist, transforms)