Ejemplo n.º 1
0
    def __call__(self, inputs, is_training):
        h = nn.softplus(nn.Dense(self.hidden)(inputs))
        h = nn.softplus(nn.Dense(self.hidden)(h))
        h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h)
        h = nn.Dense(self.num_topics)(h)

        log_concentration = nn.BatchNorm(
            use_bias=False,
            use_scale=False,
            momentum=0.9,
            use_running_average=not is_training,
        )(h)
        return jnp.exp(log_concentration)
Ejemplo n.º 2
0
 def __call__(self, dtype=jnp.float32):
     value = self.param(
         "value",
         lambda key, shape: jnp.full(shape, self.start_value, dtype), (1, ))
     if self.absolute:
         value = nn.softplus(value)
     return jnp.asarray(value, dtype)
Ejemplo n.º 3
0
    def __call__(self, x):

        #x = nn.tanh(nn.Dense(features=128)(x))
        x = nn.tanh(nn.Dense(features=64)(x))
        x = nn.Dense(features=1)(x)
        sp = -nn.softplus(x)

        return jnp.concatenate([sp, sp + x], -1)  #p(z|x), 1-p(z|x)
Ejemplo n.º 4
0
    def __call__(self, z, train: bool = True):
        # Common arguments
        conv_kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }
        norm_kwargs = {
            'use_running_average': not train,
            'momentum': 0.99,
            'epsilon': 0.001,
            'use_scale': True,
            'use_bias': True
        }

        z = np.reshape(z, (1, 1, self.zdim))

        # Layer 1
        z = nn.ConvTranspose(features=512,
                             kernel_size=(4, 4),
                             strides=(1, 1),
                             padding='VALID',
                             use_bias=False,
                             kernel_init=he_normal())(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 2
        z = nn.ConvTranspose(features=256, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 3
        z = nn.ConvTranspose(features=128, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 4
        z = nn.ConvTranspose(features=64, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 5
        z = nn.ConvTranspose(features=1,
                             kernel_size=(4, 4),
                             strides=(2, 2),
                             padding='SAME',
                             use_bias=False,
                             kernel_init=nn.initializers.xavier_normal())(z)
        # x = nn.sigmoid(z)
        x = nn.softplus(z)

        return jnp.rot90(np.squeeze(x), k=2)  # Rotate to match TF output
Ejemplo n.º 5
0
def logistic_preprocess(nn_out):
    *batch, h, w, _ = nn_out.shape
    assert nn_out.shape[-1] % 10 == 0
    k = nn_out.shape[-1] // 10
    logit_weights, nn_out = jnp.split(nn_out, [k], -1)
    m, s, t = jnp.moveaxis(jnp.reshape(nn_out,
                                       tuple(batch) + (h, w, 3, k, 3)),
                           (-2, -1), (-4, 0))
    assert m.shape == tuple(batch) + (k, h, w, 3)
    inv_scales = jnp.maximum(nn.softplus(s), 1e-7)
    return m, jnp.tanh(t), inv_scales, jnp.moveaxis(logit_weights, -1, -3)
Ejemplo n.º 6
0
def conditional_params_from_outputs(theta, img):
    """Maps an image `img` and the PixelCNN++ convnet output `theta` to
  conditional parameters for a mixture of k logistics over each pixel.

  Returns a tuple `(means, inverse_scales, logit_weights)` where `means` and
  `inverse_scales` are the conditional means and inverse scales of each mixture
  component (for each pixel-channel) and `logit_weights` are the logits of the
  mixture weights (for each pixel). These have the following shapes:

    means.shape == inv_scales.shape == (batch..., k, h, w, c)
    logit_weights.shape == (batch..., k, h, w)

  Args:
    theta: outputs of PixelCNN++ neural net with shape
      (batch..., h, w, (1 + 3 * c) * k)
    img: an image with shape (batch..., h, w, c)

  Returns:
    The tuple `(means, inverse_scales, logit_weights)`.
  """
    *batch, h, w, c = img.shape
    assert theta.shape[-1] % (3 * c + 1) == 0
    k = theta.shape[-1] // (3 * c + 1)

    logit_weights, theta = theta[..., :k], theta[..., k:]
    assert theta.shape[-3:] == (h, w, 3 * c * k)

    # Each of m, s and t must have shape (batch..., k, h, w, c), we effectively
    # spread the last dimension of theta out into c, k, 3, move the k dimension to
    # after batch and split along the 3 dimension.
    m, s, t = jnp.moveaxis(jnp.reshape(theta,
                                       tuple(batch) + (h, w, c, k, 3)),
                           (-2, -1), (-4, 0))
    assert m.shape[-4:] == (k, h, w, c)
    t = jnp.tanh(t)

    # Add a mixture dimension to images
    img = jnp.expand_dims(img, -4)

    # Ensure inv_scales cannot be zero (zeros cause nans in sampling)
    inv_scales = jnp.maximum(nn.softplus(s), 1e-7)

    # now condition the means for the last 2 channels (assuming c == 3)
    mean_red = m[..., 0]
    mean_green = m[..., 1] + t[..., 0] * img[..., 0]
    mean_blue = m[..., 2] + t[..., 1] * img[..., 0] + t[..., 2] * img[..., 1]
    means = jnp.stack((mean_red, mean_green, mean_blue), axis=-1)
    return means, inv_scales, jnp.moveaxis(logit_weights, -1, -3)
Ejemplo n.º 7
0
def smooth_softplus(x, smooth=50):
    """becomes relu when smooth tends to infinity"""
    return linen.softplus(x * smooth) / smooth
Ejemplo n.º 8
0
 def __call__(self, x):
     l1 = self.groupnorm1(nn.softplus(self.straight1(x)))
     l2 = self.groupnorm2(nn.softplus(self.straight2(l1)))
     return nn.softplus(self.straight3(l2))
Ejemplo n.º 9
0
 def __call__(self, x):
     out_l1 = nn.softplus(self.group_l1(self.layer1(x)))
     out_l1 = nn.softplus(self.group_l1(self.layer12(x)))
     out_1 = nn.softplus(self.group1(self.down1(out_l1)))
     out_1 = nn.softplus(self.group12(self.down12(out_1)))
     out_2 = nn.softplus(self.group2(self.down2(out_1)))
     out_2 = nn.softplus(self.group22(self.down22(out_2)))
     out_3 = nn.softplus(self.group3(self.down3(out_2)))
     out_3 = nn.softplus(self.group32(self.down32(out_3)))
     out_4 = nn.softplus(self.group4(self.down4(out_3)))
     out_4 = nn.softplus(self.group42(self.down42(out_4)))
     out_latent = nn.softplus(self.group_latent(self.latent(out_4)))
     in_up4 = jnp.concatenate((out_4, out_latent), axis=-1)
     # out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(out_4))))
     out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(in_up4))))
     out_up4 = nn.softplus(self.group_up42(self.up42(out_up4)))
     in_up3 = jnp.concatenate((out_3, out_up4), axis=-1)
     out_up3 = nn.softplus(self.group_up3(self.up3(self.deconv(in_up3))))
     out_up3 = nn.softplus(self.group_up32(self.up32(out_up3)))
     in_up2 = jnp.concatenate((out_2, out_up3), axis=-1)
     out_up2 = nn.softplus(self.group_up2(self.up2(self.deconv(in_up2))))
     out_up2 = nn.softplus(self.group_up22(self.up22(out_up2)))
     in_up1 = jnp.concatenate((out_1, out_up2), axis=-1)
     out_up1 = nn.softplus(self.group_up1(self.up1(self.deconv(in_up1))))
     out_up1 = nn.softplus(self.group_up12(self.up12(out_up1)))
     in_straight1 = jnp.concatenate((out_l1, out_up1), axis=-1)
     out_straight1 = nn.softplus(
         self.group_straight1(self.straight1(in_straight1)))
     out_straight1 = nn.softplus(
         self.group_straight12(self.straight12(out_straight1)))
     return nn.tanh(self.group_straight2(self.straight2(out_straight1)))
Ejemplo n.º 10
0
    def training_losses(self, *, x, rng, logsnr_schedule_fn, num_steps,
                        mean_loss_weight_type):
        assert x.dtype in [jnp.float32, jnp.float64]
        assert isinstance(num_steps, int)
        rng = utils.RngGen(rng)
        eps = jax.random.normal(next(rng), shape=x.shape, dtype=x.dtype)
        bc = lambda z: utils.broadcast_from_left(z, x.shape)

        # sample logsnr
        if num_steps > 0:
            logging.info('Discrete time training: num_steps=%d', num_steps)
            assert num_steps >= 1
            i = jax.random.randint(next(rng),
                                   shape=(x.shape[0], ),
                                   minval=0,
                                   maxval=num_steps)
            u = (i + 1).astype(x.dtype) / num_steps
        else:
            logging.info('Continuous time training')
            # continuous time
            u = jax.random.uniform(next(rng),
                                   shape=(x.shape[0], ),
                                   dtype=x.dtype)
        logsnr = logsnr_schedule_fn(u)
        assert logsnr.shape == (x.shape[0], )

        # sample z ~ q(z_logsnr | x)
        z_dist = diffusion_forward(x=x, logsnr=bc(logsnr))
        z = z_dist['mean'] + z_dist['std'] * eps

        # get denoising target
        if self.target_model_fn is not None:  # distillation
            assert num_steps >= 1

            # two forward steps of DDIM from z_t using teacher
            teach_out_start = self._run_model(z=z,
                                              logsnr=logsnr,
                                              model_fn=self.target_model_fn,
                                              clip_x=False)
            x_pred = teach_out_start['model_x']
            eps_pred = teach_out_start['model_eps']

            u_mid = u - 0.5 / num_steps
            logsnr_mid = logsnr_schedule_fn(u_mid)
            stdv_mid = bc(jnp.sqrt(nn.sigmoid(-logsnr_mid)))
            a_mid = bc(jnp.sqrt(nn.sigmoid(logsnr_mid)))
            z_mid = a_mid * x_pred + stdv_mid * eps_pred

            teach_out_mid = self._run_model(z=z_mid,
                                            logsnr=logsnr_mid,
                                            model_fn=self.target_model_fn,
                                            clip_x=False)
            x_pred = teach_out_mid['model_x']
            eps_pred = teach_out_mid['model_eps']

            u_s = u - 1. / num_steps
            logsnr_s = logsnr_schedule_fn(u_s)
            stdv_s = bc(jnp.sqrt(nn.sigmoid(-logsnr_s)))
            a_s = bc(jnp.sqrt(nn.sigmoid(logsnr_s)))
            z_teacher = a_s * x_pred + stdv_s * eps_pred

            # get x-target implied by z_teacher (!= x_pred)
            a_t = bc(jnp.sqrt(nn.sigmoid(logsnr)))
            stdv_frac = bc(
                jnp.exp(0.5 * (nn.softplus(logsnr) - nn.softplus(logsnr_s))))
            x_target = (z_teacher - stdv_frac * z) / (a_s - stdv_frac * a_t)
            x_target = jnp.where(bc(i == 0), x_pred, x_target)
            eps_target = predict_eps_from_x(z=z, x=x_target, logsnr=logsnr)

        else:  # denoise to original data
            x_target = x
            eps_target = eps

        # also get v-target
        v_target = predict_v_from_x_and_eps(x=x_target,
                                            eps=eps_target,
                                            logsnr=logsnr)

        # denoising loss
        model_output = self._run_model(z=z,
                                       logsnr=logsnr,
                                       model_fn=self.model_fn,
                                       clip_x=False)
        x_mse = utils.meanflat(jnp.square(model_output['model_x'] - x_target))
        eps_mse = utils.meanflat(
            jnp.square(model_output['model_eps'] - eps_target))
        v_mse = utils.meanflat(jnp.square(model_output['model_v'] - v_target))
        if mean_loss_weight_type == 'constant':  # constant weight on x_mse
            loss = x_mse
        elif mean_loss_weight_type == 'snr':  # SNR * x_mse = eps_mse
            loss = eps_mse
        elif mean_loss_weight_type == 'snr_trunc':  # x_mse * max(SNR, 1)
            loss = jnp.maximum(x_mse, eps_mse)
        elif mean_loss_weight_type == 'v_mse':
            loss = v_mse
        else:
            raise NotImplementedError(mean_loss_weight_type)
        return {'loss': loss}