def diffusion_forward(*, x, logsnr): """q(z_t | x).""" return { 'mean': x * jnp.sqrt(nn.sigmoid(logsnr)), 'std': jnp.sqrt(nn.sigmoid(-logsnr)), 'var': nn.sigmoid(-logsnr), 'logvar': nn.log_sigmoid(-logsnr) }
def diffusion_reverse(*, x, z_t, logsnr_s, logsnr_t, x_logvar): """q(z_s | z_t, x) (requires logsnr_s > logsnr_t (i.e. s < t)).""" alpha_st = jnp.sqrt((1. + jnp.exp(-logsnr_t)) / (1. + jnp.exp(-logsnr_s))) alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s)) r = jnp.exp(logsnr_t - logsnr_s) # SNR(t)/SNR(s) one_minus_r = -jnp.expm1(logsnr_t - logsnr_s) # 1-SNR(t)/SNR(s) log_one_minus_r = utils.log1mexp(logsnr_s - logsnr_t) # log(1-SNR(t)/SNR(s)) mean = r * alpha_st * z_t + one_minus_r * alpha_s * x if isinstance(x_logvar, str): if x_logvar == 'small': # same as setting x_logvar to -infinity var = one_minus_r * nn.sigmoid(-logsnr_s) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) elif x_logvar == 'large': # same as setting x_logvar to nn.log_sigmoid(-logsnr_t) var = one_minus_r * nn.sigmoid(-logsnr_t) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t) elif x_logvar.startswith('medium:'): _, frac = x_logvar.split(':') frac = float(frac) logging.info('logvar frac=%f', frac) assert 0 <= frac <= 1 min_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) max_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t) logvar = frac * max_logvar + (1 - frac) * min_logvar var = jnp.exp(logvar) else: raise NotImplementedError(x_logvar) else: assert isinstance(x_logvar, jnp.ndarray) or isinstance( x_logvar, onp.ndarray) assert x_logvar.shape == x.shape # start with "small" variance var = one_minus_r * nn.sigmoid(-logsnr_s) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) # extra variance weight is (one_minus_r*alpha_s)**2 var += jnp.square(one_minus_r) * nn.sigmoid(logsnr_s) * jnp.exp( x_logvar) logvar = jnp.logaddexp( logvar, 2. * log_one_minus_r + nn.log_sigmoid(logsnr_s) + x_logvar) return {'mean': mean, 'std': jnp.sqrt(var), 'var': var, 'logvar': logvar}
def binary_cross_entropy_with_logits(logits, labels): logits = nn.log_sigmoid(logits) return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
def loglikelihood_fn(params, Phi, y, predict_fn): an = predict_fn(params, Phi) log_an = nn.log_sigmoid(an) log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - nn.sigmoid(an)) return log_likelihood_term.sum()