Ejemplo n.º 1
0
def truncated_normal(location, scale, lower=-onp.inf, upper=onp.inf):
    """Construct truncated normal distribution with support on real interval.

    Args:
        location (float): Location parameter (mean of untruncated normal distribution).
        scale (float): Scale parameter (std. of untrunctated normal distribution).
        lower (float): Lower-bound of support.
        upper (float): Upper-bound of support.

    Returns:
        Distribution: Truncated normal distribution object.
    """
    def neg_log_dens(x):
        return ((x - location) / scale)**2 / 2

    log_normalizing_constant = (
        np.log(2 * np.pi) / 2 + np.log(scale) +
        log_diff_exp(log_ndtr(
            (upper - location) / scale), log_ndtr((lower - location) / scale)))

    def sample(rng, shape=()):
        a = ndtr((lower - location) / scale)
        b = ndtr((upper - location) / scale)
        return ndtri(a + rng.uniform(size=shape) * (b - a)) * scale + location

    support = RealInterval(lower, upper)

    from_standard_normal_transform = transforms.standard_normal_to_truncated_normal(
        location, scale, lower, upper)

    return Distribution(
        neg_log_dens=neg_log_dens,
        log_normalizing_constant=log_normalizing_constant,
        sample=sample,
        support=support,
        from_standard_normal_transform=from_standard_normal_transform,
    )
Ejemplo n.º 2
0
 def log_prob(self, value):
     # log(cdf(high) - cdf(low)) = log(1 - cdf(low)) = log(cdf(-low))
     return self._normal.log_prob(value) - log_ndtr(self.base_loc)
Ejemplo n.º 3
0
 def _logsf(self, x):
     return log_ndtr(-x)
Ejemplo n.º 4
0
 def _logcdf(self, x):
     return log_ndtr(x)
Ejemplo n.º 5
0
def logcdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logcdf", x, loc, scale)
    return special.log_ndtr(lax.div(lax.sub(x, loc), scale))
Ejemplo n.º 6
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     # log(cdf(high) - cdf(low)) = log(1 - cdf(low)) = log(cdf(-low))
     low = (self.low - self.loc) / self.scale
     return self._normal.log_prob(value) - log_ndtr(-low)
Ejemplo n.º 7
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     # log(cdf(high) - cdf(low)) = log(1 - cdf(low)) = log(cdf(-low))
     return self._normal.log_prob(value) - log_ndtr(self.base_loc)