Beispiel #1
0
def diffusion_forward(*, x, logsnr):
    """q(z_t | x)."""
    return {
        'mean': x * jnp.sqrt(nn.sigmoid(logsnr)),
        'std': jnp.sqrt(nn.sigmoid(-logsnr)),
        'var': nn.sigmoid(-logsnr),
        'logvar': nn.log_sigmoid(-logsnr)
    }
Beispiel #2
0
def diffusion_reverse(*, x, z_t, logsnr_s, logsnr_t, x_logvar):
    """q(z_s | z_t, x) (requires logsnr_s > logsnr_t (i.e. s < t))."""
    alpha_st = jnp.sqrt((1. + jnp.exp(-logsnr_t)) / (1. + jnp.exp(-logsnr_s)))
    alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s))
    r = jnp.exp(logsnr_t - logsnr_s)  # SNR(t)/SNR(s)
    one_minus_r = -jnp.expm1(logsnr_t - logsnr_s)  # 1-SNR(t)/SNR(s)
    log_one_minus_r = utils.log1mexp(logsnr_s -
                                     logsnr_t)  # log(1-SNR(t)/SNR(s))

    mean = r * alpha_st * z_t + one_minus_r * alpha_s * x

    if isinstance(x_logvar, str):
        if x_logvar == 'small':
            # same as setting x_logvar to -infinity
            var = one_minus_r * nn.sigmoid(-logsnr_s)
            logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
        elif x_logvar == 'large':
            # same as setting x_logvar to nn.log_sigmoid(-logsnr_t)
            var = one_minus_r * nn.sigmoid(-logsnr_t)
            logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t)
        elif x_logvar.startswith('medium:'):
            _, frac = x_logvar.split(':')
            frac = float(frac)
            logging.info('logvar frac=%f', frac)
            assert 0 <= frac <= 1
            min_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
            max_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t)
            logvar = frac * max_logvar + (1 - frac) * min_logvar
            var = jnp.exp(logvar)
        else:
            raise NotImplementedError(x_logvar)
    else:
        assert isinstance(x_logvar, jnp.ndarray) or isinstance(
            x_logvar, onp.ndarray)
        assert x_logvar.shape == x.shape
        # start with "small" variance
        var = one_minus_r * nn.sigmoid(-logsnr_s)
        logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
        # extra variance weight is (one_minus_r*alpha_s)**2
        var += jnp.square(one_minus_r) * nn.sigmoid(logsnr_s) * jnp.exp(
            x_logvar)
        logvar = jnp.logaddexp(
            logvar, 2. * log_one_minus_r + nn.log_sigmoid(logsnr_s) + x_logvar)
    return {'mean': mean, 'std': jnp.sqrt(var), 'var': var, 'logvar': logvar}
Beispiel #3
0
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits +
                    (1. - labels) * jnp.log(-jnp.expm1(logits)))
Beispiel #4
0
def loglikelihood_fn(params, Phi, y, predict_fn):
    an = predict_fn(params, Phi)
    log_an = nn.log_sigmoid(an)
    log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - nn.sigmoid(an))
    return log_likelihood_term.sum()