def __call__(self, inputs, is_training): h = nn.softplus(nn.Dense(self.hidden)(inputs)) h = nn.softplus(nn.Dense(self.hidden)(h)) h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h) h = nn.Dense(self.num_topics)(h) log_concentration = nn.BatchNorm( use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training, )(h) return jnp.exp(log_concentration)
def __call__(self, dtype=jnp.float32): value = self.param( "value", lambda key, shape: jnp.full(shape, self.start_value, dtype), (1, )) if self.absolute: value = nn.softplus(value) return jnp.asarray(value, dtype)
def __call__(self, x): #x = nn.tanh(nn.Dense(features=128)(x)) x = nn.tanh(nn.Dense(features=64)(x)) x = nn.Dense(features=1)(x) sp = -nn.softplus(x) return jnp.concatenate([sp, sp + x], -1) #p(z|x), 1-p(z|x)
def __call__(self, z, train: bool = True): # Common arguments conv_kwargs = { 'kernel_size': (4, 4), 'strides': (2, 2), 'padding': 'SAME', 'use_bias': False, 'kernel_init': he_normal() } norm_kwargs = { 'use_running_average': not train, 'momentum': 0.99, 'epsilon': 0.001, 'use_scale': True, 'use_bias': True } z = np.reshape(z, (1, 1, self.zdim)) # Layer 1 z = nn.ConvTranspose(features=512, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, kernel_init=he_normal())(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 2 z = nn.ConvTranspose(features=256, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 3 z = nn.ConvTranspose(features=128, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 4 z = nn.ConvTranspose(features=64, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 5 z = nn.ConvTranspose(features=1, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, kernel_init=nn.initializers.xavier_normal())(z) # x = nn.sigmoid(z) x = nn.softplus(z) return jnp.rot90(np.squeeze(x), k=2) # Rotate to match TF output
def logistic_preprocess(nn_out): *batch, h, w, _ = nn_out.shape assert nn_out.shape[-1] % 10 == 0 k = nn_out.shape[-1] // 10 logit_weights, nn_out = jnp.split(nn_out, [k], -1) m, s, t = jnp.moveaxis(jnp.reshape(nn_out, tuple(batch) + (h, w, 3, k, 3)), (-2, -1), (-4, 0)) assert m.shape == tuple(batch) + (k, h, w, 3) inv_scales = jnp.maximum(nn.softplus(s), 1e-7) return m, jnp.tanh(t), inv_scales, jnp.moveaxis(logit_weights, -1, -3)
def conditional_params_from_outputs(theta, img): """Maps an image `img` and the PixelCNN++ convnet output `theta` to conditional parameters for a mixture of k logistics over each pixel. Returns a tuple `(means, inverse_scales, logit_weights)` where `means` and `inverse_scales` are the conditional means and inverse scales of each mixture component (for each pixel-channel) and `logit_weights` are the logits of the mixture weights (for each pixel). These have the following shapes: means.shape == inv_scales.shape == (batch..., k, h, w, c) logit_weights.shape == (batch..., k, h, w) Args: theta: outputs of PixelCNN++ neural net with shape (batch..., h, w, (1 + 3 * c) * k) img: an image with shape (batch..., h, w, c) Returns: The tuple `(means, inverse_scales, logit_weights)`. """ *batch, h, w, c = img.shape assert theta.shape[-1] % (3 * c + 1) == 0 k = theta.shape[-1] // (3 * c + 1) logit_weights, theta = theta[..., :k], theta[..., k:] assert theta.shape[-3:] == (h, w, 3 * c * k) # Each of m, s and t must have shape (batch..., k, h, w, c), we effectively # spread the last dimension of theta out into c, k, 3, move the k dimension to # after batch and split along the 3 dimension. m, s, t = jnp.moveaxis(jnp.reshape(theta, tuple(batch) + (h, w, c, k, 3)), (-2, -1), (-4, 0)) assert m.shape[-4:] == (k, h, w, c) t = jnp.tanh(t) # Add a mixture dimension to images img = jnp.expand_dims(img, -4) # Ensure inv_scales cannot be zero (zeros cause nans in sampling) inv_scales = jnp.maximum(nn.softplus(s), 1e-7) # now condition the means for the last 2 channels (assuming c == 3) mean_red = m[..., 0] mean_green = m[..., 1] + t[..., 0] * img[..., 0] mean_blue = m[..., 2] + t[..., 1] * img[..., 0] + t[..., 2] * img[..., 1] means = jnp.stack((mean_red, mean_green, mean_blue), axis=-1) return means, inv_scales, jnp.moveaxis(logit_weights, -1, -3)
def smooth_softplus(x, smooth=50): """becomes relu when smooth tends to infinity""" return linen.softplus(x * smooth) / smooth
def __call__(self, x): l1 = self.groupnorm1(nn.softplus(self.straight1(x))) l2 = self.groupnorm2(nn.softplus(self.straight2(l1))) return nn.softplus(self.straight3(l2))
def __call__(self, x): out_l1 = nn.softplus(self.group_l1(self.layer1(x))) out_l1 = nn.softplus(self.group_l1(self.layer12(x))) out_1 = nn.softplus(self.group1(self.down1(out_l1))) out_1 = nn.softplus(self.group12(self.down12(out_1))) out_2 = nn.softplus(self.group2(self.down2(out_1))) out_2 = nn.softplus(self.group22(self.down22(out_2))) out_3 = nn.softplus(self.group3(self.down3(out_2))) out_3 = nn.softplus(self.group32(self.down32(out_3))) out_4 = nn.softplus(self.group4(self.down4(out_3))) out_4 = nn.softplus(self.group42(self.down42(out_4))) out_latent = nn.softplus(self.group_latent(self.latent(out_4))) in_up4 = jnp.concatenate((out_4, out_latent), axis=-1) # out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(out_4)))) out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(in_up4)))) out_up4 = nn.softplus(self.group_up42(self.up42(out_up4))) in_up3 = jnp.concatenate((out_3, out_up4), axis=-1) out_up3 = nn.softplus(self.group_up3(self.up3(self.deconv(in_up3)))) out_up3 = nn.softplus(self.group_up32(self.up32(out_up3))) in_up2 = jnp.concatenate((out_2, out_up3), axis=-1) out_up2 = nn.softplus(self.group_up2(self.up2(self.deconv(in_up2)))) out_up2 = nn.softplus(self.group_up22(self.up22(out_up2))) in_up1 = jnp.concatenate((out_1, out_up2), axis=-1) out_up1 = nn.softplus(self.group_up1(self.up1(self.deconv(in_up1)))) out_up1 = nn.softplus(self.group_up12(self.up12(out_up1))) in_straight1 = jnp.concatenate((out_l1, out_up1), axis=-1) out_straight1 = nn.softplus( self.group_straight1(self.straight1(in_straight1))) out_straight1 = nn.softplus( self.group_straight12(self.straight12(out_straight1))) return nn.tanh(self.group_straight2(self.straight2(out_straight1)))
def training_losses(self, *, x, rng, logsnr_schedule_fn, num_steps, mean_loss_weight_type): assert x.dtype in [jnp.float32, jnp.float64] assert isinstance(num_steps, int) rng = utils.RngGen(rng) eps = jax.random.normal(next(rng), shape=x.shape, dtype=x.dtype) bc = lambda z: utils.broadcast_from_left(z, x.shape) # sample logsnr if num_steps > 0: logging.info('Discrete time training: num_steps=%d', num_steps) assert num_steps >= 1 i = jax.random.randint(next(rng), shape=(x.shape[0], ), minval=0, maxval=num_steps) u = (i + 1).astype(x.dtype) / num_steps else: logging.info('Continuous time training') # continuous time u = jax.random.uniform(next(rng), shape=(x.shape[0], ), dtype=x.dtype) logsnr = logsnr_schedule_fn(u) assert logsnr.shape == (x.shape[0], ) # sample z ~ q(z_logsnr | x) z_dist = diffusion_forward(x=x, logsnr=bc(logsnr)) z = z_dist['mean'] + z_dist['std'] * eps # get denoising target if self.target_model_fn is not None: # distillation assert num_steps >= 1 # two forward steps of DDIM from z_t using teacher teach_out_start = self._run_model(z=z, logsnr=logsnr, model_fn=self.target_model_fn, clip_x=False) x_pred = teach_out_start['model_x'] eps_pred = teach_out_start['model_eps'] u_mid = u - 0.5 / num_steps logsnr_mid = logsnr_schedule_fn(u_mid) stdv_mid = bc(jnp.sqrt(nn.sigmoid(-logsnr_mid))) a_mid = bc(jnp.sqrt(nn.sigmoid(logsnr_mid))) z_mid = a_mid * x_pred + stdv_mid * eps_pred teach_out_mid = self._run_model(z=z_mid, logsnr=logsnr_mid, model_fn=self.target_model_fn, clip_x=False) x_pred = teach_out_mid['model_x'] eps_pred = teach_out_mid['model_eps'] u_s = u - 1. / num_steps logsnr_s = logsnr_schedule_fn(u_s) stdv_s = bc(jnp.sqrt(nn.sigmoid(-logsnr_s))) a_s = bc(jnp.sqrt(nn.sigmoid(logsnr_s))) z_teacher = a_s * x_pred + stdv_s * eps_pred # get x-target implied by z_teacher (!= x_pred) a_t = bc(jnp.sqrt(nn.sigmoid(logsnr))) stdv_frac = bc( jnp.exp(0.5 * (nn.softplus(logsnr) - nn.softplus(logsnr_s)))) x_target = (z_teacher - stdv_frac * z) / (a_s - stdv_frac * a_t) x_target = jnp.where(bc(i == 0), x_pred, x_target) eps_target = predict_eps_from_x(z=z, x=x_target, logsnr=logsnr) else: # denoise to original data x_target = x eps_target = eps # also get v-target v_target = predict_v_from_x_and_eps(x=x_target, eps=eps_target, logsnr=logsnr) # denoising loss model_output = self._run_model(z=z, logsnr=logsnr, model_fn=self.model_fn, clip_x=False) x_mse = utils.meanflat(jnp.square(model_output['model_x'] - x_target)) eps_mse = utils.meanflat( jnp.square(model_output['model_eps'] - eps_target)) v_mse = utils.meanflat(jnp.square(model_output['model_v'] - v_target)) if mean_loss_weight_type == 'constant': # constant weight on x_mse loss = x_mse elif mean_loss_weight_type == 'snr': # SNR * x_mse = eps_mse loss = eps_mse elif mean_loss_weight_type == 'snr_trunc': # x_mse * max(SNR, 1) loss = jnp.maximum(x_mse, eps_mse) elif mean_loss_weight_type == 'v_mse': loss = v_mse else: raise NotImplementedError(mean_loss_weight_type) return {'loss': loss}