예제 #1
0
    def log_prob(self, x: Tensor) -> Tensor:
        F = self.F
        alpha, beta = self.alpha, self.beta

        def gamma_log_prob(x, alpha, beta):
            return (
                alpha * F.log(beta)
                - F.gammaln(alpha)
                + (alpha - 1) * F.log(x)
                - beta * x
            )

        """
        The gamma_log_prob(x) above returns NaNs for x<=0. Wherever there are NaN in either of the F.where() conditional
        vectors, then F.where() returns NaN at that entry as well, due to its indicator function multiplication: 
        1*f(x) + np.nan*0 = nan, since np.nan*0 return nan. 
        Therefore replacing gamma_log_prob(x) with gamma_log_prob(abs(x) mitigates nan returns in cases of x<=0 without 
        altering the value in cases of x>0. 
        This is a known issue in pytorch as well https://github.com/pytorch/pytorch/issues/12986.
        """
        # mask zeros to prevent NaN gradients for x==0
        x_masked = F.where(x == 0, x.ones_like() * 0.5, x)

        return F.where(
            x > 0,
            gamma_log_prob(F.abs(x_masked), alpha, beta),
            -(10.0 ** 15) * F.ones_like(x),
        )
예제 #2
0
 def __init__(
     self,
     alpha: Tensor,
     beta: Tensor,
     zero_probability: Tensor,
     one_probability: Tensor,
 ) -> None:
     F = getF(alpha)
     self.alpha = alpha
     self.beta = beta
     self.zero_probability = zero_probability
     self.one_probability = one_probability
     self.beta_probability = 1 - zero_probability - one_probability
     self.beta_distribution = Beta(alpha=alpha, beta=beta)
     mixture_probs = F.stack(zero_probability,
                             one_probability,
                             self.beta_probability,
                             axis=-1)
     super().__init__(
         components=[
             Deterministic(alpha.zeros_like()),
             Deterministic(alpha.ones_like()),
             self.beta_distribution,
         ],
         mixture_probs=mixture_probs,
     )
예제 #3
0
    def log_prob(self, x: Tensor) -> Tensor:
        F = self.F

        # masking data NaN's with ones to prevent NaN gradients
        x_non_nan = F.where(x != x, F.ones_like(x), x)

        # calculate likelihood for values which are not NaN
        non_nan_dist_log_likelihood = F.where(
            x != x,
            -x.ones_like() / 0.0,
            self.components[0].log_prob(x_non_nan),
        )

        log_mix_weights = F.log(self.mixture_probs)

        # stack log probabilities of components
        component_log_likelihood = F.stack(
            *[non_nan_dist_log_likelihood, self.components[1].log_prob(x)],
            axis=-1,
        )
        # compute mixture log probability by log-sum-exp
        summands = log_mix_weights + component_log_likelihood
        max_val = F.max_axis(summands, axis=-1, keepdims=True)

        sum_exp = F.sum(F.exp(F.broadcast_minus(summands, max_val)),
                        axis=-1,
                        keepdims=True)

        log_sum_exp = F.log(sum_exp) + max_val
        return log_sum_exp.squeeze(axis=-1)
예제 #4
0
 def s(mu: Tensor, b: Tensor) -> Tensor:
     ones = mu.ones_like()
     x = F.random.uniform(-0.5 * ones, 0.5 * ones, dtype=dtype)
     laplace_samples = mu - b * F.sign(x) * F.log(
         (1.0 - 2.0 * F.abs(x)).clip(1.0e-30, 1.0e30)
         # 1.0 - 2.0 * F.abs(x)
     )
     return laplace_samples
예제 #5
0
    def quantile(self, level: Tensor) -> Tensor:
        F = self.F
        # we consider level to be an independent axis and so expand it
        # to shape (num_levels, 1, 1, ...)

        for _ in range(self.all_dim):
            level = level.expand_dims(axis=-1)

        quantiles = F.broadcast_mul(self.value, level.ones_like())
        level = F.broadcast_mul(quantiles.ones_like(), level)

        minus_inf = -quantiles.ones_like() / 0.0
        quantiles = F.where(
            F.broadcast_logical_or(level != 0, F.contrib.isnan(quantiles)),
            quantiles,
            minus_inf,
        )

        nans = level.zeros_like() / 0.0
        quantiles = F.where(level != level, nans, quantiles)

        return quantiles
예제 #6
0
 def s(mu: Tensor, sigma: Tensor) -> Tensor:
     raw_samples = self.F.sample_normal(mu=mu.zeros_like(),
                                        sigma=sigma.ones_like(),
                                        dtype=dtype)
     return sigma * raw_samples + mu
예제 #7
0
 def s(low: Tensor, high: Tensor) -> Tensor:
     raw_samples = self.F.sample_uniform(low=low.zeros_like(),
                                         high=high.ones_like(),
                                         dtype=dtype)
     return low + raw_samples * (high - low)