Exemplo n.º 1
0
def kl_normal1_normal1(mean1, var1, mean2, var2, eps=0.0):
    """
    Compute closed-form solution to the KL-divergence between two Gaussians parameterized
    with diagonal variance.
    Parameters
    ----------
    mean1 : torch tensor
        Mean of the q Gaussian.
    var1 : torch tensor
        Variance of the q Gaussian.
    mean2 : torch tensor
        Mean of the p Gaussian.
    var2 : torch tensor
        Variance of the p Gaussian.
    eps : float
        Small number added to variances to avoid NaNs.
    Returns
    -------
    torch tensor
        Element-wise KL-divergence, this has to be summed when the Gaussian distributions are multi-variate.
    See also
    --------
    kl_normal2_normal2 : using log variance parameterization
    """
    var1 += eps
    var2 += eps
    return 0.5 * T.log(var2 / var1) + (var1 +
                                       (mean1 - mean2)**2) / (2 * var2) - 0.5
Exemplo n.º 2
0
def log_normal(x, mean, std, eps=0.0):
    """
    Compute log pdf of a Gaussian distribution with diagonal covariance, at values x.
    Variance is parameterized as standard deviation.
        .. math:: \log p(x) = \log \mathcal{N}(x; \mu, \sigma^2I)

    Parameters
    ----------
    x : torch tensor
        Values at which to evaluate pdf.
    mean : torch tensor
        Mean of the Gaussian distribution.
    std : torch tensor
        Standard deviation of the diagonal covariance Gaussian.
    eps : float
        Small number added to standard deviation to avoid NaNs.
    Returns
    -------
    torch tensor
        Element-wise log probability, this has to be summed for multi-variate distributions.
    See also
    --------
    log_normal1 : using variance parameterization
    log_normal2 : using log variance parameterization
    """
    std += eps
    return c - T.log(T.abs_(std)) - (x - mean)**2 / (2 * std**2)
Exemplo n.º 3
0
def kl_normal1_stdnormal(mean, var, eps=0.0):
    """
    Closed-form solution of the KL-divergence between a Gaussian parameterized
    with diagonal variance and a standard Gaussian.
    .. math::
        D_{KL}[\mathcal{N}(\mu, \sigma^2 I) || \mathcal{N}(0, I)]
    Parameters
    ----------
    mean : torch tensor
        Mean of the diagonal covariance Gaussian.
    var : torch tensor
        Variance of the diagonal covariance Gaussian.
    eps : float
        Small number added to variance to avoid NaNs.
    Returns
    -------
    torch tensor
        Element-wise KL-divergence, this has to be summed when the Gaussian distributions are multi-variate.

    See also
    --------
    kl_normal2_stdnormal : using log variance parameterization
    """
    var += eps
    return -0.5 * (1 + T.log(var) - mean**2 - var)