def diffusion_forward(*, x, logsnr): """q(z_t | x).""" return { 'mean': x * jnp.sqrt(nn.sigmoid(logsnr)), 'std': jnp.sqrt(nn.sigmoid(-logsnr)), 'var': nn.sigmoid(-logsnr), 'logvar': nn.log_sigmoid(-logsnr) }
def ddim_step(self, i, z_t, num_steps, logsnr_schedule_fn, clip_x): shape, dtype = z_t.shape, z_t.dtype logsnr_t = logsnr_schedule_fn((i + 1.).astype(dtype) / num_steps) logsnr_s = logsnr_schedule_fn(i.astype(dtype) / num_steps) model_out = self._run_model(z=z_t, logsnr=jnp.full((shape[0], ), logsnr_t), model_fn=self.model_fn, clip_x=clip_x) x_pred_t = model_out['model_x'] eps_pred_t = model_out['model_eps'] stdv_s = jnp.sqrt(nn.sigmoid(-logsnr_s)) alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s)) z_s_pred = alpha_s * x_pred_t + stdv_s * eps_pred_t return jnp.where(i == 0, x_pred_t, z_s_pred)
def to_valid_rgb(t, decorrelated=False, sigmoid=True): """Transform inner dimension of t to valid rgb colors. In practice this consists of two parts: (1) If requested, transform the colors from a decorrelated color space to RGB. (2) Constrain the color channels to be in [0,1], either using a sigmoid function or clipping. Args: t: Input tensor, trailing dimension will be interpreted as colors and transformed/constrained. decorrelated: If True, the input tensor's colors are interpreted as coming from a whitened space. sigmoid: If True, the colors are constrained elementwise using sigmoid. If False, colors are constrained by clipping infinity norm. Returns: t with the innermost dimension transformed. """ if decorrelated: t = _linear_correlate_color(t) if decorrelated and not sigmoid: t += color_mean if sigmoid: return nn.sigmoid(t) return constrain_l_inf(2 * t - 1) / 2 + 0.5
def __call__(self, image): x = self.backbone(image) x = self.feature_conv(x) B, H, W, EMB = x.shape[0], x.shape[1], x.shape[2], x.shape[ 3] # 0 is batch, 3 is feature col_embeds = jnp.repeat(self.col_embed[:W][jnp.newaxis, :, :], H, 0) # H, W, embedding_size//2 row_embeds = jnp.repeat(self.col_embed[:H][:, jnp.newaxis, :], W, 1) # H, W, embedding_size//2 positional_embeds = jnp.concatenate([col_embeds, row_embeds], -1) # H, W, embedding_size positional_embeds_as_seq = jnp.reshape( positional_embeds, (1, H * W, EMB)) # H*W, embedding_size image_tiles_as_seq = jnp.reshape(x, (B, H * W, -1)) queries = jnp.repeat(self.query_pos[jnp.newaxis, :, :], B, 0) x = self.transformer( positional_embeds_as_seq + 0.1 * image_tiles_as_seq, queries) pred_logits = self.linear_class(x) pred_bbox = nn.sigmoid( self.linear_bbox(x)) # TODO maybe chuck an MLP on here return {'logits': pred_logits, 'pred_boxes': pred_bbox}
def render_samples(self, level, points, z_vals, directions, viewdirs, metadata, warp_extra, use_warp=True, use_warp_jacobian=False, metadata_encoded=False, return_points=False, return_weights=False): trunk_condition, alpha_condition, rgb_condition = ( self.get_condition_inputs(viewdirs, metadata, metadata_encoded)) out = {} if return_points: out['points'] = points # Apply the deformation field to the samples. if use_warp: metadata_channels = self.num_warp_features if metadata_encoded else 1 warp_metadata = ( metadata['time'] if self.warp_metadata_encoder_type == 'time' else metadata['warp']) warp_metadata = jnp.broadcast_to( warp_metadata[:, jnp.newaxis, :], shape=(*points.shape[:2], metadata_channels)) warp_out = self.warp_field( points, warp_metadata, warp_extra, use_warp_jacobian, metadata_encoded) points = warp_out['warped_points'] if 'jacobian' in warp_out: out['warp_jacobian'] = warp_out['jacobian'] if return_points: out['warped_points'] = warp_out['warped_points'] points_embed = self.point_encoder(points) raw = self.nerf_mlps[level]( points_embed, trunk_condition, alpha_condition, rgb_condition) raw = model_utils.noise_regularize( self.make_rng(level), raw, self.noise_std, self.use_stratified_sampling) rgb = nn.sigmoid(raw['rgb']) sigma = self.sigma_activation(jnp.squeeze(raw['alpha'], axis=-1)) out.update(model_utils.volumetric_rendering( rgb, sigma, z_vals, directions, return_weights=return_weights, use_white_background=self.use_white_background, sample_at_infinity=self.use_sample_at_infinity)) return out
def volumetric_rendering(raw, z_vals, dirs, use_white_background, sigma_activation=nn.relu, sample_at_infinity=True, eps=1e-10): """Volumetric Rendering Function. Args: raw: jnp.ndarray(float32), [batch_size, num_coarse_samples, 4]. z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples]. dirs: jnp.ndarray(float32), [batch_size, 3]. use_white_background: bool. sigma_activation: the activation functions to apply to the sigma values. sample_at_infinity: if True adds a sample at infinity. eps: a small number to prevent numerical issues. Returns: rgb: jnp.ndarray(float32), [batch_size, 3]. depth: jnp.ndarray(float32), [batch_size]. acc: jnp.ndarray(float32), [batch_size]. weights: jnp.ndarray(float32), [batch_size, num_coarse_samples] """ rgb = nn.sigmoid(raw['rgb']) sigma = sigma_activation(jnp.squeeze(raw['alpha'], axis=-1)) # TODO(keunhong): remove this hack. last_sample_z = 1e10 if sample_at_infinity else 1e-19 dists = jnp.concatenate([ z_vals[..., 1:] - z_vals[..., :-1], jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape) ], -1) dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1) alpha = 1.0 - jnp.exp(-sigma * dists) accum_prod = jnp.concatenate([ jnp.full_like(alpha[..., :1], 1., alpha.dtype), jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1) ], axis=-1) weights = alpha * accum_prod rgb = (weights[..., None] * rgb).sum(axis=-2) exp_depth = (weights * z_vals).sum(axis=-1) med_depth = compute_depth_map(weights, z_vals) acc = weights.sum(axis=-1) if use_white_background: rgb = rgb + (1. - acc[..., None]) inv_eps = 1.0 / eps disp = 1.0 / exp_depth disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps) if sample_at_infinity: acc = weights[..., :-1].sum(axis=-1) return rgb, exp_depth, med_depth, disp, acc, weights
def diffusion_reverse(*, x, z_t, logsnr_s, logsnr_t, x_logvar): """q(z_s | z_t, x) (requires logsnr_s > logsnr_t (i.e. s < t)).""" alpha_st = jnp.sqrt((1. + jnp.exp(-logsnr_t)) / (1. + jnp.exp(-logsnr_s))) alpha_s = jnp.sqrt(nn.sigmoid(logsnr_s)) r = jnp.exp(logsnr_t - logsnr_s) # SNR(t)/SNR(s) one_minus_r = -jnp.expm1(logsnr_t - logsnr_s) # 1-SNR(t)/SNR(s) log_one_minus_r = utils.log1mexp(logsnr_s - logsnr_t) # log(1-SNR(t)/SNR(s)) mean = r * alpha_st * z_t + one_minus_r * alpha_s * x if isinstance(x_logvar, str): if x_logvar == 'small': # same as setting x_logvar to -infinity var = one_minus_r * nn.sigmoid(-logsnr_s) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) elif x_logvar == 'large': # same as setting x_logvar to nn.log_sigmoid(-logsnr_t) var = one_minus_r * nn.sigmoid(-logsnr_t) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t) elif x_logvar.startswith('medium:'): _, frac = x_logvar.split(':') frac = float(frac) logging.info('logvar frac=%f', frac) assert 0 <= frac <= 1 min_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) max_logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_t) logvar = frac * max_logvar + (1 - frac) * min_logvar var = jnp.exp(logvar) else: raise NotImplementedError(x_logvar) else: assert isinstance(x_logvar, jnp.ndarray) or isinstance( x_logvar, onp.ndarray) assert x_logvar.shape == x.shape # start with "small" variance var = one_minus_r * nn.sigmoid(-logsnr_s) logvar = log_one_minus_r + nn.log_sigmoid(-logsnr_s) # extra variance weight is (one_minus_r*alpha_s)**2 var += jnp.square(one_minus_r) * nn.sigmoid(logsnr_s) * jnp.exp( x_logvar) logvar = jnp.logaddexp( logvar, 2. * log_one_minus_r + nn.log_sigmoid(logsnr_s) + x_logvar) return {'mean': mean, 'std': jnp.sqrt(var), 'var': var, 'logvar': logvar}
def __call__(self, inputs): weights = self.param( 'weights', self.kernel_init, # Initialization function (inputs.shape[-1], 1)) # shape info. bias = self.param('bias', self.bias_init, (1, )) logit = jnp.dot(inputs, weights) + bias return nn.sigmoid(logit)
def __call__(self, x: spec.Tensor, train: bool): del train input_size = 28 * 28 num_hidden = 128 num_classes = 10 x = x.reshape((x.shape[0], input_size)) # Flatten. x = nn.Dense(features=num_hidden, use_bias=True)(x) x = nn.sigmoid(x) x = nn.Dense(features=num_classes, use_bias=True)(x) x = nn.log_softmax(x) return x
def get_joint(self, s_logits): """Returns a joint that agrees with s_logits when axis 0 is marginalized out.""" value = jax.nn.softmax(s_logits) for layer in self.hidden_layers: value = nn.sigmoid(layer(value)) log_joint = self.output_layer(value) # Fix the marginals for the output. log_joint = (log_joint - jax.scipy.special.logsumexp( log_joint, axis=0, keepdims=True) + s_logits[None, :]) return log_joint
def __call__(self, inputs, aux=None): c = inputs.shape[-1] y = self.conv_module(c)(self.nonlinearity(inputs)) if aux is not None: y = self.nonlinearity(y + ConvOneByOne(c)(self.nonlinearity(aux))) if self.dropout_p > 0: y = nn.Dropout(rate=self.dropout_p)(y) # Set init_scale=0.1 so that the res block is close to the identity at # initialization. a, b = jnp.split(self.conv_module(2 * c, init_scale=0.1)(y), 2, axis=-1) return inputs + a * nn.sigmoid(b)
def get_joint(self, s_logits): """Returns a joint that agrees with s_logits when axis 0 is marginalized out.""" z_logits = self.get_prior() value = jax.nn.softmax(s_logits) for layer in self.hidden_layers: value = nn.sigmoid(layer(value)) log_joint = self.output_layer(value) # Sinkhorn has nicer gradients. log_joint = fix_coupling_sinkhorn(log_joint, z_logits, s_logits, self.sinkhorn_iterations) # ... but rejection is harder to exploit inaccuracy for log_joint = fix_coupling_rejection(log_joint, z_logits, s_logits) return log_joint
def __call__(self, z): shape_before_flattening, flatten_out_size = self.flatten_enc_shape() x = nn.Dense(flatten_out_size, name='fc1')(z) x = x.reshape((x.shape[0], *shape_before_flattening[1:])) hidden_dims = self.hidden_dims[::-1] # Build Decoder for h_dim in range(len(hidden_dims)-1): x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x) x = nn.GroupNorm()(x) x = nn.gelu(x) x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x) x = nn.sigmoid(x) return x
def setUp(self): super(TrainProxyTest, self).setUp() np.random.seed(0) self.data_x = np.array(np.random.binomial(1, 0.5, size=(1000, 2)), dtype=float) self.data_y = np.random.binomial(1, p=nn.sigmoid(self.data_x[:, 0] - self.data_x[:, 1])) self.data_a = np.random.choice(3, size=self.data_y.shape) self.data_t = np.random.choice(3, size=self.data_a.shape) self.data = { 'a': self.data_a, 'm': self.data_x, 'y': self.data_y, 't': self.data_t, }
def __call__(self, inputs): in_ch = inputs.shape[-1] if self.se_ratio is None: if self.hidden_ch is None: raise ValueError('Must provide one of se_ratio or hidden_ch') hidden_ch = self.hidden_ch else: hidden_ch = max(1, int(in_ch * self.se_ratio)) dense = partial(nn.Dense, use_bias=True, dtype=self.dtype) x = jnp.mean(inputs, axis=(1, 2), dtype=self.dtype, keepdims=True)(inputs) x = dense(features=hidden_ch)(x) x = self.activation_fn(x) x = dense(features=in_ch)(x) output = nn.sigmoid(x) * inputs return output
def __call__(self, z, a, key): kernel_initializer = jax.nn.initializers.glorot_uniform() x = nn.Dense(features=self.layer_width, kernel_init=kernel_initializer)(jnp.concatenate([z, a])) x = nn.LayerNorm()(x) x = nn.relu(x) mu = nn.Dense(features=self.embedding_dim, kernel_init=kernel_initializer)(x) if self.probabilistic: sigma = nn.Dense(features=self.embedding_dim, kernel_init=kernel_initializer)(x) sigma = nn.sigmoid(sigma) sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma eps = jax.random.normal(key, shape=sigma.shape) sample = mu + sigma * eps else: sigma = jnp.zeros(self.embedding_dim) sample = mu return DynamicsModelType(mu, sigma, sample)
def unnormalized_sigmoid_mean_squared_error(logits, targets, weights=None): """Computes the sigmoid mean squared error per example. Args: logits: float array of shape (batch, output_shape). targets: float array of shape (batch, output_shape). weights: None or float array of shape (batch,). Returns: Sigmoid mean squared error computed per example, shape (batch,). """ losses = jnp.square(nn.sigmoid(logits) - targets) if weights is not None: weights = conform_weights_to_targets(weights, targets) weighted_losses = losses * weights else: weighted_losses = losses return jnp.sum((weighted_losses).reshape(losses.shape[0], -1), axis=-1)
def _run_model(self, *, z, logsnr, model_fn, clip_x): model_output = model_fn(z, logsnr) if self.mean_type == 'eps': model_eps = model_output elif self.mean_type == 'x': model_x = model_output elif self.mean_type == 'v': model_v = model_output elif self.mean_type == 'both': _model_x, _model_eps = jnp.split(model_output, 2, axis=-1) # pylint: disable=invalid-name else: raise NotImplementedError(self.mean_type) # get prediction of x at t=0 if self.mean_type == 'both': # reconcile the two predictions model_x_eps = predict_x_from_eps(z=z, eps=_model_eps, logsnr=logsnr) wx = utils.broadcast_from_left(nn.sigmoid(-logsnr), z.shape) model_x = wx * _model_x + (1. - wx) * model_x_eps elif self.mean_type == 'eps': model_x = predict_x_from_eps(z=z, eps=model_eps, logsnr=logsnr) elif self.mean_type == 'v': model_x = predict_x_from_v(z=z, v=model_v, logsnr=logsnr) # clipping if clip_x: model_x = jnp.clip(model_x, -1., 1.) # get eps prediction if clipping or if mean_type != eps if self.mean_type != 'eps' or clip_x: model_eps = predict_eps_from_x(z=z, x=model_x, logsnr=logsnr) # get v prediction if clipping or if mean_type != v if self.mean_type != 'v' or clip_x: model_v = predict_v_from_x_and_eps(x=model_x, eps=model_eps, logsnr=logsnr) return {'model_x': model_x, 'model_eps': model_eps, 'model_v': model_v}
def __call__(self, inputs): """Passes the input through a squeeze and excite block. Arguments: inputs: [batch_size, height, width, dim] Returns: output: [batch_size, height, width, dim] """ cfg = self.config out_dim = inputs.shape[-1] se_features = max(1, int(out_dim * cfg.se_ratio)) dense = partial(nn.Dense, dtype=cfg.dtype, precision=cfg.precision, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init) y = jnp.mean(inputs, axis=(1, 2), dtype=cfg.dtype, keepdims=True) y = dense(features=se_features)(y) y = cfg.activation_fn(y) y = dense(features=out_dim)(y) y = nn.sigmoid(y) * inputs return y
def __call__(self, word): word_features = self.vocab_layer(word) word_features_act = nn.sigmoid(word_features) embed_features = self.embed_layer(word_features_act) embed_act = nn.softmax(embed_features) return embed_act
def raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, rng=None): """Transforms model's predictions to semantically meaningful values. Args: raw: (num_rays, num_samples || num_importance, 4) prediction from model z_vals: (num_rays, num_samples || num_importance) integration time rays_d: (num_rays, 3) direction of each ray raw_noise_std: std of noise added for regularization white_bkgd: whether to use the alpha channel for white background rng: random key Returns: acc_map: (num_rays) sum of weights along each ray depth_map: (num_rays) estimated distance to object disp_map: (num_rays) disparity map (inverse of depth map) rgb_map: (num_rays, 3) estimated RGB color of a ray weights: (num_rays, num_samples || num_importance) weights assigned to each sampled color """ # compute 'distance' (in time) between each integration time along a ray dists = z_vals[..., 1:] - z_vals[..., :-1] # the 'distance' from the last integration time is infinity dists = jnp.concatenate( [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1) dists = dists.astype(z_vals.dtype) # [num_rays, num_samples] # multiply each distance by the norm of its corresponding direction ray # to convert to real world distance (accounts for non-unit directions) dists = dists * jnp.linalg.norm(rays_d[..., None, :], axis=-1) # extract RGB of each sample position along each ray rgb = nn.sigmoid(raw[..., :3]) # [num_rays, num_samples, 3] # add noise to predictions for density, can be used to (this value is strictly between [0, 1]) # regularize network during training (prevents floater artifacts) noise = 0.0 if raw_noise_std > 0.0 and rng is not None: noise = random.normal(rng, raw[..., 3].shape) * raw_noise_std # predict density of each sample along each ray (alpha channel) # higher values imply higher likelihood of being absorbed at this point alpha = 1.0 - jnp.exp(-nn.relu(raw[..., 3] + noise) * dists) # compute weight for RGB of each sample along each ray # cumprod() is used to express the idea of the ray not having reflected up to this sample yet # weights = alpha * tf.math.cumprod(1.0 - alpha + 1e-10, axis=-1, exclusive=True) alpha_ = jnp.clip(1.0 - alpha, 1e-5, 1.0) weights = jnp.concatenate( [jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1) weights = alpha * jnp.cumprod(weights, -1) # [num_rays, num_samples] # computed weighted color of each sample along each ray rgb_map = jnp.einsum("ij,ijk->ik", weights, rgb) # [num_rays, 3] # estimated depth map is expected distance depth_map = jnp.einsum("ij,ij->i", weights, z_vals) # [num_rays] # sum of weights along each ray (this value is in [0, 1] up to numerical error) acc_map = jnp.einsum("ij->i", weights) # [num_rays] # disparity map is inverse depth i_depth = depth_map / jnp.clip(acc_map, 1e-5) disp_map = 1.0 / jnp.clip(i_depth, 1e-5) # to composite onto a white background, use the accumulated alpha map if white_bkgd: rgb_map += 1.0 - acc_map[..., None] return { "rgb": rgb_map.astype(jnp.float32), "disp": disp_map.astype(jnp.float32), "acc": acc_map.astype(jnp.float32), "depth": depth_map.astype(jnp.float32), }, weights
def generate(self, z): return nn.sigmoid(self.decoder(z))
def training_losses(self, *, x, rng, logsnr_schedule_fn, num_steps, mean_loss_weight_type): assert x.dtype in [jnp.float32, jnp.float64] assert isinstance(num_steps, int) rng = utils.RngGen(rng) eps = jax.random.normal(next(rng), shape=x.shape, dtype=x.dtype) bc = lambda z: utils.broadcast_from_left(z, x.shape) # sample logsnr if num_steps > 0: logging.info('Discrete time training: num_steps=%d', num_steps) assert num_steps >= 1 i = jax.random.randint(next(rng), shape=(x.shape[0], ), minval=0, maxval=num_steps) u = (i + 1).astype(x.dtype) / num_steps else: logging.info('Continuous time training') # continuous time u = jax.random.uniform(next(rng), shape=(x.shape[0], ), dtype=x.dtype) logsnr = logsnr_schedule_fn(u) assert logsnr.shape == (x.shape[0], ) # sample z ~ q(z_logsnr | x) z_dist = diffusion_forward(x=x, logsnr=bc(logsnr)) z = z_dist['mean'] + z_dist['std'] * eps # get denoising target if self.target_model_fn is not None: # distillation assert num_steps >= 1 # two forward steps of DDIM from z_t using teacher teach_out_start = self._run_model(z=z, logsnr=logsnr, model_fn=self.target_model_fn, clip_x=False) x_pred = teach_out_start['model_x'] eps_pred = teach_out_start['model_eps'] u_mid = u - 0.5 / num_steps logsnr_mid = logsnr_schedule_fn(u_mid) stdv_mid = bc(jnp.sqrt(nn.sigmoid(-logsnr_mid))) a_mid = bc(jnp.sqrt(nn.sigmoid(logsnr_mid))) z_mid = a_mid * x_pred + stdv_mid * eps_pred teach_out_mid = self._run_model(z=z_mid, logsnr=logsnr_mid, model_fn=self.target_model_fn, clip_x=False) x_pred = teach_out_mid['model_x'] eps_pred = teach_out_mid['model_eps'] u_s = u - 1. / num_steps logsnr_s = logsnr_schedule_fn(u_s) stdv_s = bc(jnp.sqrt(nn.sigmoid(-logsnr_s))) a_s = bc(jnp.sqrt(nn.sigmoid(logsnr_s))) z_teacher = a_s * x_pred + stdv_s * eps_pred # get x-target implied by z_teacher (!= x_pred) a_t = bc(jnp.sqrt(nn.sigmoid(logsnr))) stdv_frac = bc( jnp.exp(0.5 * (nn.softplus(logsnr) - nn.softplus(logsnr_s)))) x_target = (z_teacher - stdv_frac * z) / (a_s - stdv_frac * a_t) x_target = jnp.where(bc(i == 0), x_pred, x_target) eps_target = predict_eps_from_x(z=z, x=x_target, logsnr=logsnr) else: # denoise to original data x_target = x eps_target = eps # also get v-target v_target = predict_v_from_x_and_eps(x=x_target, eps=eps_target, logsnr=logsnr) # denoising loss model_output = self._run_model(z=z, logsnr=logsnr, model_fn=self.model_fn, clip_x=False) x_mse = utils.meanflat(jnp.square(model_output['model_x'] - x_target)) eps_mse = utils.meanflat( jnp.square(model_output['model_eps'] - eps_target)) v_mse = utils.meanflat(jnp.square(model_output['model_v'] - v_target)) if mean_loss_weight_type == 'constant': # constant weight on x_mse loss = x_mse elif mean_loss_weight_type == 'snr': # SNR * x_mse = eps_mse loss = eps_mse elif mean_loss_weight_type == 'snr_trunc': # x_mse * max(SNR, 1) loss = jnp.maximum(x_mse, eps_mse) elif mean_loss_weight_type == 'v_mse': loss = v_mse else: raise NotImplementedError(mean_loss_weight_type) return {'loss': loss}
initial_std=0.1, initial_scale=0.3, initial_mean=-0.8 + 0.1 * random.normal(key, (nfeatures, 1)), optimizer=optax.adafactor(1e-4)) w_lowrank = w_lowrank.squeeze() cov_lowrank = b @ b.T + jnp.diag(c**2) # *** Ploting surface predictive distribution *** colors = ["black" if el else "white" for el in y] key = random.PRNGKey(31415) nsamples = 5000 # FFVB surface predictive distribution ffvb_samples = random.multivariate_normal(key, w_ffvb, cov_ffvb, (nsamples, )) Z_ffvb = nn.sigmoid(jnp.einsum("mij,sm->sij", Phispace, ffvb_samples)) Z_ffvb = Z_ffvb.mean(axis=0) # Variational Bayes Low Rank surface predictive distribution lowrank_samples = random.multivariate_normal(key, w_lowrank, cov_lowrank, (nsamples, )) Z_lowrank = nn.sigmoid(jnp.einsum("mij,sm->sij", Phispace, lowrank_samples)) Z_lowrank = Z_lowrank.mean(axis=0) fig_ffvb, ax = plt.subplots() title = "FFVB Predictive Distribution" plot_posterior_predictive(ax, X, Xspace, Z_ffvb, title, colors) pml.savefig('ffvb_predictive_distribution.pdf') pml.savefig('ffvb_predictive_distribution.png')
def __call__(self, x, logsnr, y, *, train): B, H, W, _ = x.shape # pylint: disable=invalid-name assert H == W assert x.dtype in (jnp.float32, jnp.float64) assert logsnr.shape == (B,) and logsnr.dtype in (jnp.float32, jnp.float64) num_resolutions = len(self.ch_mult) ch = self.ch emb_ch = self.emb_ch # Timestep embedding if self.logsnr_input_type == 'linear': logging.info('LogSNR representation: linear') logsnr_input = (logsnr - self.logsnr_scale_range[0]) / ( self.logsnr_scale_range[1] - self.logsnr_scale_range[0]) elif self.logsnr_input_type == 'sigmoid': logging.info('LogSNR representation: sigmoid') logsnr_input = nn.sigmoid(logsnr) elif self.logsnr_input_type == 'inv_cos': logging.info('LogSNR representation: inverse cosine') logsnr_input = (jnp.arctan(jnp.exp(-0.5 * jnp.clip(logsnr, -20., 20.))) / (0.5 * jnp.pi)) else: raise NotImplementedError(self.logsnr_input_type) emb = get_timestep_embedding(logsnr_input, embedding_dim=ch, max_time=1.) emb = nn.Dense(features=emb_ch, name='dense0')(emb) emb = nn.Dense(features=emb_ch, name='dense1')(nonlinearity(emb)) assert emb.shape == (B, emb_ch) # Class embedding assert self.num_classes >= 1 if self.num_classes > 1: logging.info('conditional: num_classes=%d', self.num_classes) assert y.shape == (B,) and y.dtype == jnp.int32 y_emb = jax.nn.one_hot(y, num_classes=self.num_classes, dtype=x.dtype) y_emb = nn.Dense(features=emb_ch, name='class_emb')(y_emb) assert y_emb.shape == emb.shape == (B, emb_ch) emb += y_emb else: logging.info('unconditional: num_classes=%d', self.num_classes) del y # Downsampling hs = [nn.Conv( features=ch, kernel_size=(3, 3), strides=(1, 1), name='conv_in')(x)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(self.num_res_blocks): h = ResnetBlock( out_ch=ch * self.ch_mult[i_level], dropout=self.dropout, name=f'down_{i_level}.block_{i_block}')( hs[-1], emb=emb, deterministic=not train) if h.shape[1] in self.attn_resolutions: h = AttnBlock( num_heads=self.num_heads, head_dim=self.head_dim, name=f'down_{i_level}.attn_{i_block}')(h) hs.append(h) # Downsample if i_level != num_resolutions - 1: hs.append(self._downsample( hs[-1], name=f'down_{i_level}.downsample', emb=emb, train=train)) # Middle h = hs[-1] h = ResnetBlock(dropout=self.dropout, name='mid.block_1')( h, emb=emb, deterministic=not train) h = AttnBlock( num_heads=self.num_heads, head_dim=self.head_dim, name='mid.attn_1')(h) h = ResnetBlock(dropout=self.dropout, name='mid.block_2')( h, emb=emb, deterministic=not train) # Upsampling for i_level in reversed(range(num_resolutions)): # Residual blocks for this resolution for i_block in range(self.num_res_blocks + 1): h = ResnetBlock( out_ch=ch * self.ch_mult[i_level], dropout=self.dropout, name=f'up_{i_level}.block_{i_block}')( jnp.concatenate([h, hs.pop()], axis=-1), emb=emb, deterministic=not train) if h.shape[1] in self.attn_resolutions: h = AttnBlock( num_heads=self.num_heads, head_dim=self.head_dim, name=f'up_{i_level}.attn_{i_block}')(h) # Upsample if i_level != 0: h = self._upsample( h, name=f'up_{i_level}.upsample', emb=emb, train=train) assert not hs # End h = nonlinearity(Normalize(name='norm_out')(h)) h = nn.Conv( features=self.out_ch, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.zeros, name='conv_out')(h) assert h.shape == (*x.shape[:3], self.out_ch) return h
def __call__(self, x): R = self.R_module(hidden_dim=self.hidden_dim, output_dim=self.ode_dim) h = nn.Dense(self.ode_dim)(x) h = ContinuousBlock(R, self.n_step, self.basis, self.n_basis)(h) return nn.sigmoid(nn.Dense(1)(h))
def decode(self, z): z = self.decoder(z) x = nn.sigmoid(z) x = jnp.reshape(x, (x.shape[0], ) + self.input_shape) return x
def loglikelihood_fn(params, Phi, y, predict_fn): an = predict_fn(params, Phi) log_an = nn.log_sigmoid(an) log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - nn.sigmoid(an)) return log_likelihood_term.sum()