Exemple #1
0
def diffusion_forward(*, x, logsnr):
    """q(z_t | x)."""
    return {
        'mean': x * jnp.sqrt(nn.sigmoid(logsnr)),
        'std': jnp.sqrt(nn.sigmoid(-logsnr)),
        'var': nn.sigmoid(-logsnr),
        'logvar': nn.log_sigmoid(-logsnr)
    }
Exemple #2
0
 def ddim_step(self, i, z_t, num_steps, logsnr_schedule_fn, clip_x):
     shape, dtype = z_t.shape, z_t.dtype
     logsnr_t = logsnr_schedule_fn((i + 1.).astype(dtype) / num_steps)
     logsnr_s = logsnr_schedule_fn(i.astype(dtype) / num_steps)
     model_out = self._run_model(z=z_t,
                                 logsnr=jnp.full((shape[0], ), logsnr_t),
                                 model_fn=self.model_fn,
                                 clip_x=clip_x)
     x_pred_t = model_out['model_x']
     eps_pred_t = model_out['model_eps']
     stdv_s = jnp.sqrt(nn.sigmoid(-logsnr_s))
     alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s))
     z_s_pred = alpha_s * x_pred_t + stdv_s * eps_pred_t
     return jnp.where(i == 0, x_pred_t, z_s_pred)
Exemple #3
0
def to_valid_rgb(t, decorrelated=False, sigmoid=True):
    """Transform inner dimension of t to valid rgb colors.

  In practice this consists of two parts:
  (1) If requested, transform the colors from a decorrelated color space to RGB.
  (2) Constrain the color channels to be in [0,1], either using a sigmoid
      function or clipping.

  Args:
    t: Input tensor, trailing dimension will be interpreted as colors and
      transformed/constrained.
    decorrelated: If True, the input tensor's colors are interpreted as coming
      from a whitened space.
    sigmoid: If True, the colors are constrained elementwise using sigmoid. If
      False, colors are constrained by clipping infinity norm.

  Returns:
    t with the innermost dimension transformed.
  """
    if decorrelated:
        t = _linear_correlate_color(t)
    if decorrelated and not sigmoid:
        t += color_mean

    if sigmoid:
        return nn.sigmoid(t)

    return constrain_l_inf(2 * t - 1) / 2 + 0.5
Exemple #4
0
    def __call__(self, image):
        x = self.backbone(image)

        x = self.feature_conv(x)
        B, H, W, EMB = x.shape[0], x.shape[1], x.shape[2], x.shape[
            3]  # 0 is batch, 3 is feature

        col_embeds = jnp.repeat(self.col_embed[:W][jnp.newaxis, :, :], H,
                                0)  #  H, W, embedding_size//2
        row_embeds = jnp.repeat(self.col_embed[:H][:, jnp.newaxis, :], W,
                                1)  # H, W, embedding_size//2

        positional_embeds = jnp.concatenate([col_embeds, row_embeds],
                                            -1)  # H, W, embedding_size

        positional_embeds_as_seq = jnp.reshape(
            positional_embeds, (1, H * W, EMB))  # H*W, embedding_size

        image_tiles_as_seq = jnp.reshape(x, (B, H * W, -1))

        queries = jnp.repeat(self.query_pos[jnp.newaxis, :, :], B, 0)

        x = self.transformer(
            positional_embeds_as_seq + 0.1 * image_tiles_as_seq, queries)

        pred_logits = self.linear_class(x)
        pred_bbox = nn.sigmoid(
            self.linear_bbox(x))  # TODO maybe chuck an MLP on here

        return {'logits': pred_logits, 'pred_boxes': pred_bbox}
Exemple #5
0
  def render_samples(self,
                     level,
                     points,
                     z_vals,
                     directions,
                     viewdirs,
                     metadata,
                     warp_extra,
                     use_warp=True,
                     use_warp_jacobian=False,
                     metadata_encoded=False,
                     return_points=False,
                     return_weights=False):
    trunk_condition, alpha_condition, rgb_condition = (
        self.get_condition_inputs(viewdirs, metadata, metadata_encoded))

    out = {}
    if return_points:
      out['points'] = points
    # Apply the deformation field to the samples.
    if use_warp:
      metadata_channels = self.num_warp_features if metadata_encoded else 1
      warp_metadata = (
          metadata['time']
          if self.warp_metadata_encoder_type == 'time' else metadata['warp'])
      warp_metadata = jnp.broadcast_to(
          warp_metadata[:, jnp.newaxis, :],
          shape=(*points.shape[:2], metadata_channels))
      warp_out = self.warp_field(
          points,
          warp_metadata,
          warp_extra,
          use_warp_jacobian,
          metadata_encoded)
      points = warp_out['warped_points']
      if 'jacobian' in warp_out:
        out['warp_jacobian'] = warp_out['jacobian']
      if return_points:
        out['warped_points'] = warp_out['warped_points']

    points_embed = self.point_encoder(points)

    raw = self.nerf_mlps[level](
        points_embed, trunk_condition, alpha_condition, rgb_condition)
    raw = model_utils.noise_regularize(
        self.make_rng(level), raw, self.noise_std, self.use_stratified_sampling)
    rgb = nn.sigmoid(raw['rgb'])
    sigma = self.sigma_activation(jnp.squeeze(raw['alpha'], axis=-1))
    out.update(model_utils.volumetric_rendering(
        rgb,
        sigma,
        z_vals,
        directions,
        return_weights=return_weights,
        use_white_background=self.use_white_background,
        sample_at_infinity=self.use_sample_at_infinity))

    return out
Exemple #6
0
def volumetric_rendering(raw,
                         z_vals,
                         dirs,
                         use_white_background,
                         sigma_activation=nn.relu,
                         sample_at_infinity=True,
                         eps=1e-10):
    """Volumetric Rendering Function.

  Args:
    raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4].
    z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
    dirs: jnp.ndarray(float32), [batch_size, 3].
    use_white_background: bool.
    sigma_activation: the activation functions to apply to the sigma values.
    sample_at_infinity: if True adds a sample at infinity.
    eps: a small number to prevent numerical issues.

  Returns:
    rgb: jnp.ndarray(float32), [batch_size, 3].
    depth: jnp.ndarray(float32), [batch_size].
    acc: jnp.ndarray(float32), [batch_size].
    weights: jnp.ndarray(float32), [batch_size, num_coarse_samples]
  """
    rgb = nn.sigmoid(raw['rgb'])
    sigma = sigma_activation(jnp.squeeze(raw['alpha'], axis=-1))
    # TODO(keunhong): remove this hack.
    last_sample_z = 1e10 if sample_at_infinity else 1e-19
    dists = jnp.concatenate([
        z_vals[..., 1:] - z_vals[..., :-1],
        jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
    ], -1)
    dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
    alpha = 1.0 - jnp.exp(-sigma * dists)
    accum_prod = jnp.concatenate([
        jnp.full_like(alpha[..., :1], 1., alpha.dtype),
        jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1)
    ],
                                 axis=-1)
    weights = alpha * accum_prod

    rgb = (weights[..., None] * rgb).sum(axis=-2)
    exp_depth = (weights * z_vals).sum(axis=-1)
    med_depth = compute_depth_map(weights, z_vals)
    acc = weights.sum(axis=-1)
    if use_white_background:
        rgb = rgb + (1. - acc[..., None])

    inv_eps = 1.0 / eps
    disp = 1.0 / exp_depth
    disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp,
                     inv_eps)

    if sample_at_infinity:
        acc = weights[..., :-1].sum(axis=-1)

    return rgb, exp_depth, med_depth, disp, acc, weights
Exemple #7
0
def diffusion_reverse(*, x, z_t, logsnr_s, logsnr_t, x_logvar):
    """q(z_s | z_t, x) (requires logsnr_s > logsnr_t (i.e. s < t))."""
    alpha_st = jnp.sqrt((1. + jnp.exp(-logsnr_t)) / (1. + jnp.exp(-logsnr_s)))
    alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s))
    r = jnp.exp(logsnr_t - logsnr_s)  # SNR(t)/SNR(s)
    one_minus_r = -jnp.expm1(logsnr_t - logsnr_s)  # 1-SNR(t)/SNR(s)
    log_one_minus_r = utils.log1mexp(logsnr_s -
                                     logsnr_t)  # log(1-SNR(t)/SNR(s))

    mean = r * alpha_st * z_t + one_minus_r * alpha_s * x

    if isinstance(x_logvar, str):
        if x_logvar == 'small':
            # same as setting x_logvar to -infinity
            var = one_minus_r * nn.sigmoid(-logsnr_s)
            logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
        elif x_logvar == 'large':
            # same as setting x_logvar to nn.log_sigmoid(-logsnr_t)
            var = one_minus_r * nn.sigmoid(-logsnr_t)
            logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t)
        elif x_logvar.startswith('medium:'):
            _, frac = x_logvar.split(':')
            frac = float(frac)
            logging.info('logvar frac=%f', frac)
            assert 0 <= frac <= 1
            min_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
            max_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t)
            logvar = frac * max_logvar + (1 - frac) * min_logvar
            var = jnp.exp(logvar)
        else:
            raise NotImplementedError(x_logvar)
    else:
        assert isinstance(x_logvar, jnp.ndarray) or isinstance(
            x_logvar, onp.ndarray)
        assert x_logvar.shape == x.shape
        # start with "small" variance
        var = one_minus_r * nn.sigmoid(-logsnr_s)
        logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s)
        # extra variance weight is (one_minus_r*alpha_s)**2
        var += jnp.square(one_minus_r) * nn.sigmoid(logsnr_s) * jnp.exp(
            x_logvar)
        logvar = jnp.logaddexp(
            logvar, 2. * log_one_minus_r + nn.log_sigmoid(logsnr_s) + x_logvar)
    return {'mean': mean, 'std': jnp.sqrt(var), 'var': var, 'logvar': logvar}
Exemple #8
0
    def __call__(self, inputs):
        weights = self.param(
            'weights',
            self.kernel_init,  # Initialization function
            (inputs.shape[-1], 1))  # shape info.
        bias = self.param('bias', self.bias_init, (1, ))

        logit = jnp.dot(inputs, weights) + bias

        return nn.sigmoid(logit)
Exemple #9
0
 def __call__(self, x: spec.Tensor, train: bool):
     del train
     input_size = 28 * 28
     num_hidden = 128
     num_classes = 10
     x = x.reshape((x.shape[0], input_size))  # Flatten.
     x = nn.Dense(features=num_hidden, use_bias=True)(x)
     x = nn.sigmoid(x)
     x = nn.Dense(features=num_classes, use_bias=True)(x)
     x = nn.log_softmax(x)
     return x
Exemple #10
0
    def get_joint(self, s_logits):
        """Returns a joint that agrees with s_logits when axis 0 is marginalized out."""
        value = jax.nn.softmax(s_logits)
        for layer in self.hidden_layers:
            value = nn.sigmoid(layer(value))

        log_joint = self.output_layer(value)

        # Fix the marginals for the output.
        log_joint = (log_joint - jax.scipy.special.logsumexp(
            log_joint, axis=0, keepdims=True) + s_logits[None, :])

        return log_joint
Exemple #11
0
  def __call__(self, inputs, aux=None):
    c = inputs.shape[-1]
    y = self.conv_module(c)(self.nonlinearity(inputs))
    if aux is not None:
      y = self.nonlinearity(y + ConvOneByOne(c)(self.nonlinearity(aux)))

    if self.dropout_p > 0:
      y = nn.Dropout(rate=self.dropout_p)(y)

    # Set init_scale=0.1 so that the res block is close to the identity at
    # initialization.
    a, b = jnp.split(self.conv_module(2 * c, init_scale=0.1)(y), 2, axis=-1)
    return inputs + a * nn.sigmoid(b)
    def get_joint(self, s_logits):
        """Returns a joint that agrees with s_logits when axis 0 is marginalized out."""
        z_logits = self.get_prior()
        value = jax.nn.softmax(s_logits)
        for layer in self.hidden_layers:
            value = nn.sigmoid(layer(value))

        log_joint = self.output_layer(value)
        # Sinkhorn has nicer gradients.
        log_joint = fix_coupling_sinkhorn(log_joint, z_logits, s_logits,
                                          self.sinkhorn_iterations)
        # ... but rejection is harder to exploit inaccuracy for
        log_joint = fix_coupling_rejection(log_joint, z_logits, s_logits)
        return log_joint
Exemple #13
0
  def __call__(self, z):
    shape_before_flattening, flatten_out_size = self.flatten_enc_shape()

    x = nn.Dense(flatten_out_size, name='fc1')(z)
    x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
    
    hidden_dims = self.hidden_dims[::-1]
    # Build Decoder
    for h_dim in range(len(hidden_dims)-1):
      x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x)
      x = nn.GroupNorm()(x)
      x = nn.gelu(x)
    
    x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x)
    x = nn.sigmoid(x)
    return x
Exemple #14
0
    def setUp(self):
        super(TrainProxyTest, self).setUp()
        np.random.seed(0)
        self.data_x = np.array(np.random.binomial(1, 0.5, size=(1000, 2)),
                               dtype=float)
        self.data_y = np.random.binomial(1,
                                         p=nn.sigmoid(self.data_x[:, 0] -
                                                      self.data_x[:, 1]))
        self.data_a = np.random.choice(3, size=self.data_y.shape)
        self.data_t = np.random.choice(3, size=self.data_a.shape)

        self.data = {
            'a': self.data_a,
            'm': self.data_x,
            'y': self.data_y,
            't': self.data_t,
        }
Exemple #15
0
    def __call__(self, inputs):
        in_ch = inputs.shape[-1]
        if self.se_ratio is None:
            if self.hidden_ch is None:
                raise ValueError('Must provide one of se_ratio or hidden_ch')
            hidden_ch = self.hidden_ch
        else:
            hidden_ch = max(1, int(in_ch * self.se_ratio))

        dense = partial(nn.Dense, use_bias=True, dtype=self.dtype)

        x = jnp.mean(inputs, axis=(1, 2), dtype=self.dtype,
                     keepdims=True)(inputs)
        x = dense(features=hidden_ch)(x)
        x = self.activation_fn(x)
        x = dense(features=in_ch)(x)
        output = nn.sigmoid(x) * inputs
        return output
Exemple #16
0
 def __call__(self, z, a, key):
     kernel_initializer = jax.nn.initializers.glorot_uniform()
     x = nn.Dense(features=self.layer_width,
                  kernel_init=kernel_initializer)(jnp.concatenate([z, a]))
     x = nn.LayerNorm()(x)
     x = nn.relu(x)
     mu = nn.Dense(features=self.embedding_dim,
                   kernel_init=kernel_initializer)(x)
     if self.probabilistic:
         sigma = nn.Dense(features=self.embedding_dim,
                          kernel_init=kernel_initializer)(x)
         sigma = nn.sigmoid(sigma)
         sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma
         eps = jax.random.normal(key, shape=sigma.shape)
         sample = mu + sigma * eps
     else:
         sigma = jnp.zeros(self.embedding_dim)
         sample = mu
     return DynamicsModelType(mu, sigma, sample)
Exemple #17
0
def unnormalized_sigmoid_mean_squared_error(logits, targets, weights=None):
  """Computes the sigmoid mean squared error per example.

  Args:
    logits: float array of shape (batch, output_shape).
    targets: float array of shape (batch, output_shape).
    weights: None or float array of shape (batch,).

  Returns:
    Sigmoid mean squared error computed per example, shape (batch,).
  """
  losses = jnp.square(nn.sigmoid(logits) - targets)

  if weights is not None:
    weights = conform_weights_to_targets(weights, targets)
    weighted_losses = losses * weights
  else:
    weighted_losses = losses

  return jnp.sum((weighted_losses).reshape(losses.shape[0], -1), axis=-1)
Exemple #18
0
    def _run_model(self, *, z, logsnr, model_fn, clip_x):
        model_output = model_fn(z, logsnr)
        if self.mean_type == 'eps':
            model_eps = model_output
        elif self.mean_type == 'x':
            model_x = model_output
        elif self.mean_type == 'v':
            model_v = model_output
        elif self.mean_type == 'both':
            _model_x, _model_eps = jnp.split(model_output, 2, axis=-1)  # pylint: disable=invalid-name
        else:
            raise NotImplementedError(self.mean_type)

        # get prediction of x at t=0
        if self.mean_type == 'both':
            # reconcile the two predictions
            model_x_eps = predict_x_from_eps(z=z,
                                             eps=_model_eps,
                                             logsnr=logsnr)
            wx = utils.broadcast_from_left(nn.sigmoid(-logsnr), z.shape)
            model_x = wx * _model_x + (1. - wx) * model_x_eps
        elif self.mean_type == 'eps':
            model_x = predict_x_from_eps(z=z, eps=model_eps, logsnr=logsnr)
        elif self.mean_type == 'v':
            model_x = predict_x_from_v(z=z, v=model_v, logsnr=logsnr)

        # clipping
        if clip_x:
            model_x = jnp.clip(model_x, -1., 1.)

        # get eps prediction if clipping or if mean_type != eps
        if self.mean_type != 'eps' or clip_x:
            model_eps = predict_eps_from_x(z=z, x=model_x, logsnr=logsnr)

        # get v prediction if clipping or if mean_type != v
        if self.mean_type != 'v' or clip_x:
            model_v = predict_v_from_x_and_eps(x=model_x,
                                               eps=model_eps,
                                               logsnr=logsnr)

        return {'model_x': model_x, 'model_eps': model_eps, 'model_v': model_v}
Exemple #19
0
    def __call__(self, inputs):
        """Passes the input through a squeeze and excite block.
        Arguments:
            inputs:     [batch_size, height, width, dim]
        Returns:
            output:     [batch_size, height, width, dim]
        """
        cfg = self.config
        out_dim = inputs.shape[-1]
        se_features = max(1, int(out_dim * cfg.se_ratio))

        dense = partial(nn.Dense,
                        dtype=cfg.dtype,
                        precision=cfg.precision,
                        kernel_init=cfg.kernel_init,
                        bias_init=cfg.bias_init)

        y = jnp.mean(inputs, axis=(1, 2), dtype=cfg.dtype, keepdims=True)
        y = dense(features=se_features)(y)
        y = cfg.activation_fn(y)
        y = dense(features=out_dim)(y)
        y = nn.sigmoid(y) * inputs
        return y
Exemple #20
0
 def __call__(self, word):
     word_features = self.vocab_layer(word)
     word_features_act = nn.sigmoid(word_features)
     embed_features = self.embed_layer(word_features_act)
     embed_act = nn.softmax(embed_features)
     return embed_act
Exemple #21
0
def raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, rng=None):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: (num_rays, num_samples || num_importance, 4) prediction from model
        z_vals: (num_rays, num_samples || num_importance) integration time
        rays_d: (num_rays, 3) direction of each ray
        raw_noise_std: std of noise added for regularization
        white_bkgd: whether to use the alpha channel for white background
        rng: random key
    Returns:
        acc_map: (num_rays) sum of weights along each ray
        depth_map: (num_rays) estimated distance to object
        disp_map: (num_rays) disparity map (inverse of depth map)
        rgb_map: (num_rays, 3) estimated RGB color of a ray
        weights: (num_rays, num_samples || num_importance) weights assigned to each sampled color
    """

    # compute 'distance' (in time) between each integration time along a ray
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # the 'distance' from the last integration time is infinity
    dists = jnp.concatenate(
        [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1)
    dists = dists.astype(z_vals.dtype)  # [num_rays, num_samples]

    # multiply each distance by the norm of its corresponding direction ray
    # to convert to real world distance (accounts for non-unit directions)
    dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1)

    # extract RGB of each sample position along each ray
    rgb = nn.sigmoid(raw[..., :3])  # [num_rays, num_samples, 3]

    # add noise to predictions for density, can be used to (this value is strictly between [0, 1])
    # regularize network during training (prevents floater artifacts)
    noise = 0.0
    if raw_noise_std > 0.0 and rng is not None:
        noise = random.normal(rng, raw[..., 3].shape) * raw_noise_std

    # predict density of each sample along each ray (alpha channel)
    # higher values imply higher likelihood of being absorbed at this point
    alpha = 1.0 - jnp.exp(-nn.relu(raw[..., 3] + noise) * dists)

    # compute weight for RGB of each sample along each ray
    # cumprod() is used to express the idea of the ray not having reflected up to this sample yet
    # weights = alpha * tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True)
    alpha_ = jnp.clip(1.0 - alpha, 1e-5, 1.0)
    weights = jnp.concatenate(
        [jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1)
    weights = alpha * jnp.cumprod(weights, -1)  # [num_rays, num_samples]

    # computed weighted color of each sample along each ray
    rgb_map = jnp.einsum("ij,ijk->ik", weights, rgb)  # [num_rays, 3]

    # estimated depth map is expected distance
    depth_map = jnp.einsum("ij,ij->i", weights, z_vals)  # [num_rays]

    # sum of weights along each ray (this value is in [0, 1] up to numerical error)
    acc_map = jnp.einsum("ij->i", weights)  # [num_rays]

    # disparity map is inverse depth
    i_depth = depth_map / jnp.clip(acc_map, 1e-5)
    disp_map = 1.0 / jnp.clip(i_depth, 1e-5)

    # to composite onto a white background, use the accumulated alpha map
    if white_bkgd:
        rgb_map += 1.0 - acc_map[..., None]

    return {
        "rgb": rgb_map.astype(jnp.float32),
        "disp": disp_map.astype(jnp.float32),
        "acc": acc_map.astype(jnp.float32),
        "depth": depth_map.astype(jnp.float32),
    }, weights
Exemple #22
0
 def generate(self, z):
     return nn.sigmoid(self.decoder(z))
Exemple #23
0
    def training_losses(self, *, x, rng, logsnr_schedule_fn, num_steps,
                        mean_loss_weight_type):
        assert x.dtype in [jnp.float32, jnp.float64]
        assert isinstance(num_steps, int)
        rng = utils.RngGen(rng)
        eps = jax.random.normal(next(rng), shape=x.shape, dtype=x.dtype)
        bc = lambda z: utils.broadcast_from_left(z, x.shape)

        # sample logsnr
        if num_steps > 0:
            logging.info('Discrete time training: num_steps=%d', num_steps)
            assert num_steps >= 1
            i = jax.random.randint(next(rng),
                                   shape=(x.shape[0], ),
                                   minval=0,
                                   maxval=num_steps)
            u = (i + 1).astype(x.dtype) / num_steps
        else:
            logging.info('Continuous time training')
            # continuous time
            u = jax.random.uniform(next(rng),
                                   shape=(x.shape[0], ),
                                   dtype=x.dtype)
        logsnr = logsnr_schedule_fn(u)
        assert logsnr.shape == (x.shape[0], )

        # sample z ~ q(z_logsnr | x)
        z_dist = diffusion_forward(x=x, logsnr=bc(logsnr))
        z = z_dist['mean'] + z_dist['std'] * eps

        # get denoising target
        if self.target_model_fn is not None:  # distillation
            assert num_steps >= 1

            # two forward steps of DDIM from z_t using teacher
            teach_out_start = self._run_model(z=z,
                                              logsnr=logsnr,
                                              model_fn=self.target_model_fn,
                                              clip_x=False)
            x_pred = teach_out_start['model_x']
            eps_pred = teach_out_start['model_eps']

            u_mid = u - 0.5 / num_steps
            logsnr_mid = logsnr_schedule_fn(u_mid)
            stdv_mid = bc(jnp.sqrt(nn.sigmoid(-logsnr_mid)))
            a_mid = bc(jnp.sqrt(nn.sigmoid(logsnr_mid)))
            z_mid = a_mid * x_pred + stdv_mid * eps_pred

            teach_out_mid = self._run_model(z=z_mid,
                                            logsnr=logsnr_mid,
                                            model_fn=self.target_model_fn,
                                            clip_x=False)
            x_pred = teach_out_mid['model_x']
            eps_pred = teach_out_mid['model_eps']

            u_s = u - 1. / num_steps
            logsnr_s = logsnr_schedule_fn(u_s)
            stdv_s = bc(jnp.sqrt(nn.sigmoid(-logsnr_s)))
            a_s = bc(jnp.sqrt(nn.sigmoid(logsnr_s)))
            z_teacher = a_s * x_pred + stdv_s * eps_pred

            # get x-target implied by z_teacher (!= x_pred)
            a_t = bc(jnp.sqrt(nn.sigmoid(logsnr)))
            stdv_frac = bc(
                jnp.exp(0.5 * (nn.softplus(logsnr) - nn.softplus(logsnr_s))))
            x_target = (z_teacher - stdv_frac * z) / (a_s - stdv_frac * a_t)
            x_target = jnp.where(bc(i == 0), x_pred, x_target)
            eps_target = predict_eps_from_x(z=z, x=x_target, logsnr=logsnr)

        else:  # denoise to original data
            x_target = x
            eps_target = eps

        # also get v-target
        v_target = predict_v_from_x_and_eps(x=x_target,
                                            eps=eps_target,
                                            logsnr=logsnr)

        # denoising loss
        model_output = self._run_model(z=z,
                                       logsnr=logsnr,
                                       model_fn=self.model_fn,
                                       clip_x=False)
        x_mse = utils.meanflat(jnp.square(model_output['model_x'] - x_target))
        eps_mse = utils.meanflat(
            jnp.square(model_output['model_eps'] - eps_target))
        v_mse = utils.meanflat(jnp.square(model_output['model_v'] - v_target))
        if mean_loss_weight_type == 'constant':  # constant weight on x_mse
            loss = x_mse
        elif mean_loss_weight_type == 'snr':  # SNR * x_mse = eps_mse
            loss = eps_mse
        elif mean_loss_weight_type == 'snr_trunc':  # x_mse * max(SNR, 1)
            loss = jnp.maximum(x_mse, eps_mse)
        elif mean_loss_weight_type == 'v_mse':
            loss = v_mse
        else:
            raise NotImplementedError(mean_loss_weight_type)
        return {'loss': loss}
Exemple #24
0
    initial_std=0.1,
    initial_scale=0.3,
    initial_mean=-0.8 + 0.1 * random.normal(key, (nfeatures, 1)),
    optimizer=optax.adafactor(1e-4))

w_lowrank = w_lowrank.squeeze()
cov_lowrank = b @ b.T + jnp.diag(c**2)

# *** Ploting surface predictive distribution ***
colors = ["black" if el else "white" for el in y]
key = random.PRNGKey(31415)
nsamples = 5000

# FFVB surface predictive distribution
ffvb_samples = random.multivariate_normal(key, w_ffvb, cov_ffvb, (nsamples, ))
Z_ffvb = nn.sigmoid(jnp.einsum("mij,sm->sij", Phispace, ffvb_samples))
Z_ffvb = Z_ffvb.mean(axis=0)

# Variational Bayes Low Rank surface predictive distribution
lowrank_samples = random.multivariate_normal(key, w_lowrank, cov_lowrank,
                                             (nsamples, ))

Z_lowrank = nn.sigmoid(jnp.einsum("mij,sm->sij", Phispace, lowrank_samples))
Z_lowrank = Z_lowrank.mean(axis=0)

fig_ffvb, ax = plt.subplots()
title = "FFVB  Predictive Distribution"
plot_posterior_predictive(ax, X, Xspace, Z_ffvb, title, colors)
pml.savefig('ffvb_predictive_distribution.pdf')
pml.savefig('ffvb_predictive_distribution.png')
Exemple #25
0
  def __call__(self, x, logsnr, y, *, train):
    B, H, W, _ = x.shape  # pylint: disable=invalid-name
    assert H == W
    assert x.dtype in (jnp.float32, jnp.float64)
    assert logsnr.shape == (B,) and logsnr.dtype in (jnp.float32, jnp.float64)
    num_resolutions = len(self.ch_mult)
    ch = self.ch
    emb_ch = self.emb_ch

    # Timestep embedding
    if self.logsnr_input_type == 'linear':
      logging.info('LogSNR representation: linear')
      logsnr_input = (logsnr - self.logsnr_scale_range[0]) / (
          self.logsnr_scale_range[1] - self.logsnr_scale_range[0])
    elif self.logsnr_input_type == 'sigmoid':
      logging.info('LogSNR representation: sigmoid')
      logsnr_input = nn.sigmoid(logsnr)
    elif self.logsnr_input_type == 'inv_cos':
      logging.info('LogSNR representation: inverse cosine')
      logsnr_input = (jnp.arctan(jnp.exp(-0.5 * jnp.clip(logsnr, -20., 20.)))
                      / (0.5 * jnp.pi))
    else:
      raise NotImplementedError(self.logsnr_input_type)

    emb = get_timestep_embedding(logsnr_input, embedding_dim=ch, max_time=1.)
    emb = nn.Dense(features=emb_ch, name='dense0')(emb)
    emb = nn.Dense(features=emb_ch, name='dense1')(nonlinearity(emb))
    assert emb.shape == (B, emb_ch)

    # Class embedding
    assert self.num_classes >= 1
    if self.num_classes > 1:
      logging.info('conditional: num_classes=%d', self.num_classes)
      assert y.shape == (B,) and y.dtype == jnp.int32
      y_emb = jax.nn.one_hot(y, num_classes=self.num_classes, dtype=x.dtype)
      y_emb = nn.Dense(features=emb_ch, name='class_emb')(y_emb)
      assert y_emb.shape == emb.shape == (B, emb_ch)
      emb += y_emb
    else:
      logging.info('unconditional: num_classes=%d', self.num_classes)
    del y

    # Downsampling
    hs = [nn.Conv(
        features=ch, kernel_size=(3, 3), strides=(1, 1), name='conv_in')(x)]
    for i_level in range(num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks):
        h = ResnetBlock(
            out_ch=ch * self.ch_mult[i_level],
            dropout=self.dropout,
            name=f'down_{i_level}.block_{i_block}')(
                hs[-1], emb=emb, deterministic=not train)
        if h.shape[1] in self.attn_resolutions:
          h = AttnBlock(
              num_heads=self.num_heads,
              head_dim=self.head_dim,
              name=f'down_{i_level}.attn_{i_block}')(h)
        hs.append(h)
      # Downsample
      if i_level != num_resolutions - 1:
        hs.append(self._downsample(
            hs[-1], name=f'down_{i_level}.downsample', emb=emb, train=train))

    # Middle
    h = hs[-1]
    h = ResnetBlock(dropout=self.dropout, name='mid.block_1')(
        h, emb=emb, deterministic=not train)
    h = AttnBlock(
        num_heads=self.num_heads, head_dim=self.head_dim, name='mid.attn_1')(h)
    h = ResnetBlock(dropout=self.dropout, name='mid.block_2')(
        h, emb=emb, deterministic=not train)

    # Upsampling
    for i_level in reversed(range(num_resolutions)):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks + 1):
        h = ResnetBlock(
            out_ch=ch * self.ch_mult[i_level],
            dropout=self.dropout,
            name=f'up_{i_level}.block_{i_block}')(
                jnp.concatenate([h, hs.pop()], axis=-1),
                emb=emb, deterministic=not train)
        if h.shape[1] in self.attn_resolutions:
          h = AttnBlock(
              num_heads=self.num_heads,
              head_dim=self.head_dim,
              name=f'up_{i_level}.attn_{i_block}')(h)
      # Upsample
      if i_level != 0:
        h = self._upsample(
            h, name=f'up_{i_level}.upsample', emb=emb, train=train)
    assert not hs

    # End
    h = nonlinearity(Normalize(name='norm_out')(h))
    h = nn.Conv(
        features=self.out_ch,
        kernel_size=(3, 3),
        strides=(1, 1),
        kernel_init=nn.initializers.zeros,
        name='conv_out')(h)
    assert h.shape == (*x.shape[:3], self.out_ch)
    return h
Exemple #26
0
 def __call__(self, x):
     R = self.R_module(hidden_dim=self.hidden_dim, output_dim=self.ode_dim)
     h = nn.Dense(self.ode_dim)(x)
     h = ContinuousBlock(R, self.n_step, self.basis, self.n_basis)(h)
     return nn.sigmoid(nn.Dense(1)(h))
Exemple #27
0
 def decode(self, z):
     z = self.decoder(z)
     x = nn.sigmoid(z)
     x = jnp.reshape(x, (x.shape[0], ) + self.input_shape)
     return x
Exemple #28
0
def loglikelihood_fn(params, Phi, y, predict_fn):
    an = predict_fn(params, Phi)
    log_an = nn.log_sigmoid(an)
    log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - nn.sigmoid(an))
    return log_likelihood_term.sum()