def statistics(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray], rng: random.PRNGKey): # Split pseudo-random number key. rng, rng_sample, rng_xobs, rng_kl = random.split(rng, 4) # Compute comparison statistics. _, xsph, _ = ode_forward(rng_sample, net_params, 10000, 4) xobs = rejection_sampling(rng_xobs, len(xsph), 4, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = importance_density(rng_kl, net_params, deq_params, 10000, xsph) log_approx = jnp.log(approx) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target, xsph approx = importance_density(rng_kl, net_params, deq_params, 10000, xobs) log_approx = jnp.log(approx) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, approx, log_target, target method = 'deqode ({})'.format('ELBO' if args.elbo_loss else 'KL') print( '{} - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(method, mean_mse, cov_mse, klqp, klpq, ress))
def update(i, opt_state, batch, acc_grad, nsamples, base_rng): params = get_params(opt_state) bsz = int(math.ceil(batch[0].shape[0] / acc_grad)) first_batch = (batch[0][:bsz], batch[1][:bsz]) rngs = jax.random.split(base_rng, nsamples) grads = loss_grad_fn(params, first_batch, rngs, args.kl_coef / train_size) grad_std = tree_map(lambda bg: jnp.std(bg, 0), grads) avg_std = jnp.nanmean(ravel_pytree(grad_std)[0]) grads = tree_map(lambda bg: jnp.mean(bg, 0), grads) grad_snr = tree_multimap(lambda m, sd: jnp.abs(m / sd), grads, grad_std) avg_snr = jnp.nanmean(ravel_pytree(grad_snr)[0]) for i in range(1, acc_grad): batch_i = (batch[0][(i - 1) * bsz:i * bsz], batch[1][(i - 1) * bsz:i * bsz]) grads_i = loss_grad_fn(params, batch_i, rngs, args.kl_coef / train_size) grads_i = tree_map(lambda bg: jnp.mean(bg, 0), grads_i) grads = tree_multimap(lambda g, g_new: (g * i + g_new) / (i + 1), grads, grads_i) pre_update = get_params(opt_state) post_update = jit(opt_update)(i, grads, opt_state) assert jnp.not_equal( ravel_pytree(pre_update)[0], ravel_pytree(get_params(post_update))[0]).any() return post_update, avg_std, avg_snr
def get_centre(model): centre_mux = jnp.nanmean( jnp.array([ model.get(k, {}).get('mux', jnp.nan) for k in ('bulge', 'bar') if model.get(k, None) is not None ])) centre_muy = jnp.nanmean( jnp.array([ model.get(k, {}).get('muy', jnp.nan) for k in ('bulge', 'bar') if model.get(k, None) is not None ])) return dict(mux=float(centre_mux), muy=float(centre_muy))
def main(): # Set pseudo-random number generator keys. rng = random.PRNGKey(args.seed) rng, rng_net = random.split(rng, 2) rng, rng_sample, rng_xobs, rng_basis = random.split(rng, 4) rng, rng_fwd, rng_rev = random.split(rng, 3) rng, rng_kl = random.split(rng, 2) # Initialize the parameters of the ambient vector field network. _, params = net_init(rng_net, (-1, 4)) opt_state = opt_init(params) for it in range(args.num_steps): opt_state, kl = step(opt_state, it, args.num_samples) print('iter.: {} - kl: {:.4f}'.format(it, kl)) params = get_params(opt_state) count = lambda x: jnp.prod(jnp.array(x.shape)) num_params = jnp.array( tree_util.tree_map(count, tree_util.tree_flatten(params)[0])).sum() print('number of parameters: {}'.format(num_params)) # Compute comparison statistics. xsph, log_approx = manifold_ode_log_prob(params, rng_sample, 10000) xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density) mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T))) approx = jnp.exp(log_approx) target = embedded_sphere_density(xsph) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, approx, log_target, target log_approx = manifold_reverse_ode_log_prob(params, rng_kl, xobs) approx = jnp.exp(log_approx) target = embedded_sphere_density(xobs) w = approx / target Z = jnp.nanmean(w) log_target = jnp.log(target) klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, approx, log_target, target print( 'manode - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%' .format(mean_mse, cov_mse, klqp, klpq, ress))
def negative_log_predictive_density(self, X, Y, R=None): predict_mean, predict_var = self.predict(X, R) if predict_mean.ndim > 1: # multi-latent case pred_mean, pred_var, Y = predict_mean[..., None], diag(predict_var), Y.reshape(-1, 1) else: pred_mean, pred_var, Y = predict_mean.reshape(-1, 1, 1), predict_var.reshape(-1, 1, 1), Y.reshape(-1, 1) log_density = vmap(self.likelihood.log_density)(Y, pred_mean, pred_var) return -np.nanmean(log_density)
def bayesian_regression(x: np.ndarray, y: Optional[np.ndarray] = None) -> None: batch, x_dim = jnp.shape(x) theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(x_dim), jnp.ones(x_dim) * 100)) sigma = numpyro.sample("sigma", dist.Gamma(1.0, 1.0)) x_mu = numpyro.sample( "x_mu", dist.Normal(jnp.nanmean(x, axis=0), jnp.nanmean(x, axis=0))) x_std = numpyro.sample("x_std", dist.Gamma(1.0, 1.0)) with numpyro.plate("batch", batch, dim=-2): mask = ~np.isnan(x) numpyro.sample("x", dist.Normal(x_mu, x_std).mask(mask), obs=x) index = (~mask).astype(int).nonzero() x_sample = numpyro.sample("x_sample", dist.Normal(x_mu, x_std)) x_filled = ops.index_update(x, index, x_sample[index]) numpyro.sample("y", dist.Normal(jnp.matmul(x_filled, theta), sigma), obs=y)
def _fill_nans(numerical_points, numerical_clusters, categorical_points, categorical_clusters): filled_num_nan = jnp.isnan(numerical_points) * jnp.nanmean( numerical_clusters, axis=0) numerical_points = jnp.nan_to_num(numerical_points) + filled_num_nan filled_cat_nan = (categorical_points == FILL_VAL) * mode( categorical_clusters, axis=0) categorical_points = jnp.where( categorical_points == FILL_VAL, filled_cat_nan, categorical_points, ) return numerical_points, categorical_points
def cost(self, p, extra=None, precomputed=None): """ Mean Squared Error. """ y = self.y if extra is None else extra['y'] yhat = self.forward_pass(p, extra) if precomputed is None else precomputed mse = np.nanmean((y - yhat)**2) if self.beta and extra is None: l1 = np.linalg.norm(p['b'], 1) l2 = np.linalg.norm(p['b'], 2) mse += self.beta * ((1 - self.alpha) * l2 + self.alpha * l1) if hasattr(self, 'Cinv'): mse += 0.5 * p['b'] @ self.Cinv @ p['b'] return mse
# Sample from the learned distribution. num_samples = 100000 num_dims = 2 xamb = random.normal(rng_xamb, [num_samples, num_dims]) xamb = forward(bij_params, bij_fns, xamb) xtor = jnp.mod(xamb, 2.0 * jnp.pi) lp = induced_torus_log_density(bij_params, bij_fns, xtor) xobs = rejection_sampling(rng_xobs, len(xtor), torus_density, args.beta) # Compute comparison statistics. mean_mse = jnp.square(jnp.linalg.norm(xtor.mean(0) - xobs.mean(0))) cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xtor.T) - jnp.cov(xobs.T))) approx = jnp.exp(lp) target = torus_density(xtor) w = target / approx Z = jnp.nanmean(w) log_approx = jnp.log(approx) log_target = jnp.log(target) klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z) ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, log_approx, log_target log_approx = induced_torus_log_density(bij_params, bij_fns, xobs) approx = jnp.exp(log_approx) target = torus_density(xobs) log_target = jnp.log(target) w = approx / target Z = jnp.mean(w) klpq = jnp.mean(log_target - log_approx) + jnp.log(Z) del w, Z, log_approx, log_target print(
def _parallel_train_step( optimizer, batched_examples, static_batch_metadata, loss_fn, max_global_norm=None, **optimizer_hyper_params, ): """Train the model for one step in parallel across devices. Args: optimizer: Optimizer that tracks the model and parameter state. Should be replicated to each device, i.e. should contain ShardedDeviceArrays with a leading axis (num_devices, ...) but with the same content on each device. batched_examples: A structure of NDArrays representing a batch of examples. Should have two leading batch dimensions: (num_devices, batch_size_per_device, ...) static_batch_metadata: Metadata about this batch, which will be shared across all batched examples. Each value of this results in a separate XLA-compiled module. loss_fn: Task-specific non-batched loss function to apply. Should take the current model (optimizer.target) and an example from batched_examples, and return a tuple of the current loss (as a scalar) and a dictionary from string names to metric values (also scalars, or RatioMetrics). max_global_norm: Maximum global norm to clip gradients to. Should be a scalar, which will be broadcast automatically. **optimizer_hyper_params: Hyperparameters to pass to the optimizer's `apply_gradient` function, which will be broadcast across devices automatically. Returns: Tuple (updated_optimizer, grads_ok, metrics). Metrics will be as returned by loss_fn, with an extra elements "loss". All metrics will be averaged across all elements of the batch. Both optimizer and metrics will contain ShardedDeviceArrays that are identical across devices. grads_ok will be a replicated bool ndarray that is True if the gradients were finite. """ def batched_loss_fn(model): """Apply loss function across a batch of examples.""" loss, metrics = jax.vmap(loss_fn, (None, 0, None))(model, batched_examples, static_batch_metadata) return jnp.mean(loss), metrics # Compute gradients of loss, along with metrics. (loss, metrics), grads = jax.value_and_grad(batched_loss_fn, has_aux=True)(optimizer.target) metrics["loss"] = loss # Exchange average gradients and metrics across devices. agg_grads = jax.lax.pmean(grads, "devices") agg_metrics = {} for k, v in metrics.items(): if isinstance(v, RatioMetric): num = jax.lax.psum(jnp.sum(v.numerator), "devices") denom = jax.lax.psum(jnp.sum(v.denominator), "devices") new_value = num / denom else: # Use nanmean to aggregate bare floats. new_value = jnp.nanmean(jax.lax.all_gather(v, "devices")) agg_metrics[k] = new_value # Compute global norm and possibly clip. global_norm = optax.global_norm(agg_grads) agg_metrics["gradient_global_norm"] = global_norm if max_global_norm is not None: should_clip = global_norm > max_global_norm agg_grads = jax.tree_map( lambda g: jnp.where(should_clip, g * max_global_norm / global_norm, g), agg_grads) agg_metrics["gradient_was_clipped"] = should_clip.astype("float32") # Check for non-finite gradients. grads_ok = jnp.all( jnp.stack( [jnp.all(jnp.isfinite(x)) for x in jax.tree_leaves(agg_grads)])) # Apply updates. updated_optimizer = optimizer.apply_gradient(agg_grads, **optimizer_hyper_params) return updated_optimizer, grads_ok, agg_metrics, agg_grads
def process_batch(*, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process and plot a single batch.""" item_id = item_id.replace('/', '_') render = render_fn(state, batch, rng=rng) out = {} if jax.process_index() != 0: return out rgb = render['rgb'] acc = render['acc'] depth_exp = render['depth'] depth_med = render['med_depth'] colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp) depth_med_viz = colorize_depth(depth_med) disp_exp_viz = viz.colorize(1.0 / depth_exp) disp_med_viz = viz.colorize(1.0 / depth_med) acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f'rgb_{item_id}.png', image_utils.image_to_uint8(rgb)) image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png', image_utils.image_to_uint8(depth_exp_viz)) image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png', depth_exp) image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png', image_utils.image_to_uint8(depth_med_viz)) image_utils.save_depth(save_dir / f'depth_median_{item_id}.png', depth_med) summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step) summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step) summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step) summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz, step) summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step) summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step) if 'rgb' in batch: rgb_target = batch['rgb'] mse = ((rgb - batch['rgb'])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out['mse'] = mse out['psnr'] = psnr out['ssim'] = ssim logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f', mse, psnr, ssim) rgb_abs_error = viz.colorize( abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize( ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step) summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step) summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step) if 'depth' in batch: depth_target = batch['depth'] depth_target_viz = colorize_depth(depth_target[..., 0]) out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image( f'depth-target/{tag}/{item_id}', depth_target_viz, step) depth_med_error = viz.colorize( abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-median-error/{tag}/{item_id}', depth_med_error, step) depth_exp_error = viz.colorize( abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step) return out
def nanmean(a, axis=None, dtype=None, keepdims=None): if isinstance(a, JaxArray): a = a.value r = jnp.nanmean(a, axis=axis, dtype=dtype, keepdims=keepdims) return r if axis is None else JaxArray(r)
def process_batch( *, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource, ): """Process and plot a single batch.""" rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng) out = {} if jax.host_id() != 0: return out colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp[..., 0]) depth_med_viz = colorize_depth(depth_med[..., 0]) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f"rgb_{item_id}.png", image_utils.image_to_uint8(rgb)) image_utils.save_image( save_dir / f"depth_expected_viz_{item_id}.png", image_utils.image_to_uint8(depth_exp_viz), ) image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png", depth_med[..., 0]) image_utils.save_image( save_dir / f"depth_median_viz_{item_id}.png", image_utils.image_to_uint8(depth_med_viz), ) image_utils.save_depth(save_dir / f"depth_median_{item_id}.png", depth_med[..., 0]) summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step) summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz, step) summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step) summary_writer.image(f"acc/{tag}/{item_id}", acc, step) if "rgb" in batch: rgb_target = batch["rgb"] mse = ((rgb - batch["rgb"])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out["mse"] = mse out["psnr"] = psnr out["ssim"] = ssim logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr, ssim) rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step) summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error, step) summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error, step) if "depth" in batch: depth_target = batch["depth"] depth_target_viz = colorize_depth(depth_target[..., 0]) out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz, step) depth_med_error = viz.colorize(abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-median-error/{tag}/{item_id}", depth_med_error, step) depth_exp_error = viz.colorize(abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-expected-error/{tag}/{item_id}", depth_exp_error, step) rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0]) summary_writer.image(f"relative-disparity/{tag}/{item_id}", rel_disp_pred, step) return out
def robust_whiten(x): median = jnp.nanmedian(x) mad = jnp.nanmean(jnp.abs(x - median)) return (x - median) / mad
# Train dequantization networks. (bij_params, deq_params), trace = train(rng_train, bij_params, bij_fns, deq_params, deq_fn, args.num_steps, args.lr) # Compute an estimate of the KL divergence. num_is = 150 _, xon = ambient.sample(rng_mse, 10000, bij_params, bij_fns, num_dims) log_approx = importance_log_density(rng_kl, bij_params, bij_fns, deq_params, deq_fn, num_is, xon) log_approx = jnp.clip(log_approx, -10., 10.) log_target = log_density(xon) approx, target = jnp.exp(log_approx), jnp.exp(log_target) w = jnp.exp(log_target - log_approx) Z = jnp.nanmean(w) logZ = jspsp.logsumexp(log_target - log_approx) - jnp.log(len(xon)) klqp = jnp.nanmean(log_approx - log_target) + logZ ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w)) ress = 100 * ess / len(w) del w, Z, logZ, log_approx, approx, log_target, target xobs = xobs[:1000] log_approx = importance_log_density(rng_kl, bij_params, bij_fns, deq_params, deq_fn, num_is, xobs) log_approx = jnp.clip(log_approx, -10., 10.) log_target = log_density(xobs) approx, target = jnp.exp(log_approx), jnp.exp(log_target) logZ = jspsp.logsumexp(log_approx - log_target) - jnp.log(len(xobs)) klpq = jnp.nanmean(log_target - log_approx) + logZ del logZ, log_approx, approx, log_target, target mean_mse = jnp.square(jnp.linalg.norm(xon.mean(0) - xobs.mean(0)))