示例#1
0
def statistics(net_params: List[jnp.ndarray], deq_params: List[jnp.ndarray],
               rng: random.PRNGKey):
    # Split pseudo-random number key.
    rng, rng_sample, rng_xobs, rng_kl = random.split(rng, 4)
    # Compute comparison statistics.
    _, xsph, _ = ode_forward(rng_sample, net_params, 10000, 4)
    xobs = rejection_sampling(rng_xobs, len(xsph), 4, embedded_sphere_density)
    mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0)))
    cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T)))
    approx = importance_density(rng_kl, net_params, deq_params, 10000, xsph)
    log_approx = jnp.log(approx)
    target = embedded_sphere_density(xsph)
    w = target / approx
    Z = jnp.nanmean(w)
    log_approx = jnp.log(approx)
    log_target = jnp.log(target)
    klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
    ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
    ress = 100 * ess / len(w)
    del w, Z, log_approx, approx, log_target, target, xsph
    approx = importance_density(rng_kl, net_params, deq_params, 10000, xobs)
    log_approx = jnp.log(approx)
    target = embedded_sphere_density(xobs)
    w = approx / target
    Z = jnp.nanmean(w)
    log_target = jnp.log(target)
    klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z)
    del w, Z, log_approx, approx, log_target, target
    method = 'deqode ({})'.format('ELBO' if args.elbo_loss else 'KL')
    print(
        '{} - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%'
        .format(method, mean_mse, cov_mse, klqp, klpq, ress))
示例#2
0
    def update(i, opt_state, batch, acc_grad, nsamples, base_rng):
        params = get_params(opt_state)
        bsz = int(math.ceil(batch[0].shape[0] / acc_grad))
        first_batch = (batch[0][:bsz], batch[1][:bsz])

        rngs = jax.random.split(base_rng, nsamples)
        grads = loss_grad_fn(params, first_batch, rngs,
                             args.kl_coef / train_size)

        grad_std = tree_map(lambda bg: jnp.std(bg, 0), grads)
        avg_std = jnp.nanmean(ravel_pytree(grad_std)[0])

        grads = tree_map(lambda bg: jnp.mean(bg, 0), grads)

        grad_snr = tree_multimap(lambda m, sd: jnp.abs(m / sd), grads,
                                 grad_std)
        avg_snr = jnp.nanmean(ravel_pytree(grad_snr)[0])

        for i in range(1, acc_grad):
            batch_i = (batch[0][(i - 1) * bsz:i * bsz],
                       batch[1][(i - 1) * bsz:i * bsz])
            grads_i = loss_grad_fn(params, batch_i, rngs,
                                   args.kl_coef / train_size)
            grads_i = tree_map(lambda bg: jnp.mean(bg, 0), grads_i)
            grads = tree_multimap(lambda g, g_new: (g * i + g_new) / (i + 1),
                                  grads, grads_i)

        pre_update = get_params(opt_state)
        post_update = jit(opt_update)(i, grads, opt_state)
        assert jnp.not_equal(
            ravel_pytree(pre_update)[0],
            ravel_pytree(get_params(post_update))[0]).any()
        return post_update, avg_std, avg_snr
def get_centre(model):
    centre_mux = jnp.nanmean(
        jnp.array([
            model.get(k, {}).get('mux', jnp.nan) for k in ('bulge', 'bar')
            if model.get(k, None) is not None
        ]))
    centre_muy = jnp.nanmean(
        jnp.array([
            model.get(k, {}).get('muy', jnp.nan) for k in ('bulge', 'bar')
            if model.get(k, None) is not None
        ]))
    return dict(mux=float(centre_mux), muy=float(centre_muy))
示例#4
0
def main():
    # Set pseudo-random number generator keys.
    rng = random.PRNGKey(args.seed)
    rng, rng_net = random.split(rng, 2)
    rng, rng_sample, rng_xobs, rng_basis = random.split(rng, 4)
    rng, rng_fwd, rng_rev = random.split(rng, 3)
    rng, rng_kl = random.split(rng, 2)

    # Initialize the parameters of the ambient vector field network.
    _, params = net_init(rng_net, (-1, 4))
    opt_state = opt_init(params)

    for it in range(args.num_steps):
        opt_state, kl = step(opt_state, it, args.num_samples)
        print('iter.: {} - kl: {:.4f}'.format(it, kl))

    params = get_params(opt_state)
    count = lambda x: jnp.prod(jnp.array(x.shape))
    num_params = jnp.array(
        tree_util.tree_map(count,
                           tree_util.tree_flatten(params)[0])).sum()
    print('number of parameters: {}'.format(num_params))

    # Compute comparison statistics.
    xsph, log_approx = manifold_ode_log_prob(params, rng_sample, 10000)
    xobs = rejection_sampling(rng_xobs, len(xsph), 3, embedded_sphere_density)
    mean_mse = jnp.square(jnp.linalg.norm(xsph.mean(0) - xobs.mean(0)))
    cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xsph.T) - jnp.cov(xobs.T)))
    approx = jnp.exp(log_approx)
    target = embedded_sphere_density(xsph)
    w = target / approx
    Z = jnp.nanmean(w)
    log_approx = jnp.log(approx)
    log_target = jnp.log(target)
    klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
    ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
    ress = 100 * ess / len(w)
    del w, Z, log_approx, approx, log_target, target
    log_approx = manifold_reverse_ode_log_prob(params, rng_kl, xobs)
    approx = jnp.exp(log_approx)
    target = embedded_sphere_density(xobs)
    w = approx / target
    Z = jnp.nanmean(w)
    log_target = jnp.log(target)
    klpq = jnp.nanmean(log_target - log_approx) + jnp.log(Z)
    del w, Z, log_approx, approx, log_target, target
    print(
        'manode - Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - KL$(p\Vert q)$ = {:.5f} - Rel. ESS: {:.2f}%'
        .format(mean_mse, cov_mse, klqp, klpq, ress))
示例#5
0
 def negative_log_predictive_density(self, X, Y, R=None):
     predict_mean, predict_var = self.predict(X, R)
     if predict_mean.ndim > 1:  # multi-latent case
         pred_mean, pred_var, Y = predict_mean[..., None], diag(predict_var), Y.reshape(-1, 1)
     else:
         pred_mean, pred_var, Y = predict_mean.reshape(-1, 1, 1), predict_var.reshape(-1, 1, 1), Y.reshape(-1, 1)
     log_density = vmap(self.likelihood.log_density)(Y, pred_mean, pred_var)
     return -np.nanmean(log_density)
示例#6
0
def bayesian_regression(x: np.ndarray, y: Optional[np.ndarray] = None) -> None:

    batch, x_dim = jnp.shape(x)
    theta = numpyro.sample(
        "theta", dist.Normal(jnp.zeros(x_dim),
                             jnp.ones(x_dim) * 100))
    sigma = numpyro.sample("sigma", dist.Gamma(1.0, 1.0))

    x_mu = numpyro.sample(
        "x_mu", dist.Normal(jnp.nanmean(x, axis=0), jnp.nanmean(x, axis=0)))
    x_std = numpyro.sample("x_std", dist.Gamma(1.0, 1.0))
    with numpyro.plate("batch", batch, dim=-2):
        mask = ~np.isnan(x)
        numpyro.sample("x", dist.Normal(x_mu, x_std).mask(mask), obs=x)

        index = (~mask).astype(int).nonzero()
        x_sample = numpyro.sample("x_sample", dist.Normal(x_mu, x_std))
        x_filled = ops.index_update(x, index, x_sample[index])

    numpyro.sample("y", dist.Normal(jnp.matmul(x_filled, theta), sigma), obs=y)
示例#7
0
def _fill_nans(numerical_points, numerical_clusters, categorical_points,
               categorical_clusters):
    filled_num_nan = jnp.isnan(numerical_points) * jnp.nanmean(
        numerical_clusters, axis=0)
    numerical_points = jnp.nan_to_num(numerical_points) + filled_num_nan

    filled_cat_nan = (categorical_points == FILL_VAL) * mode(
        categorical_clusters, axis=0)
    categorical_points = jnp.where(
        categorical_points == FILL_VAL,
        filled_cat_nan,
        categorical_points,
    )

    return numerical_points, categorical_points
示例#8
0
    def cost(self, p, extra=None, precomputed=None):
        """

        Mean Squared Error.

        """

        y = self.y if extra is None else extra['y']
        yhat = self.forward_pass(p,
                                 extra) if precomputed is None else precomputed

        mse = np.nanmean((y - yhat)**2)

        if self.beta and extra is None:

            l1 = np.linalg.norm(p['b'], 1)
            l2 = np.linalg.norm(p['b'], 2)
            mse += self.beta * ((1 - self.alpha) * l2 + self.alpha * l1)

        if hasattr(self, 'Cinv'):
            mse += 0.5 * p['b'] @ self.Cinv @ p['b']

        return mse
# Sample from the learned distribution.
num_samples = 100000
num_dims = 2
xamb = random.normal(rng_xamb, [num_samples, num_dims])
xamb = forward(bij_params, bij_fns, xamb)
xtor = jnp.mod(xamb, 2.0 * jnp.pi)
lp = induced_torus_log_density(bij_params, bij_fns, xtor)
xobs = rejection_sampling(rng_xobs, len(xtor), torus_density, args.beta)

# Compute comparison statistics.
mean_mse = jnp.square(jnp.linalg.norm(xtor.mean(0) - xobs.mean(0)))
cov_mse = jnp.square(jnp.linalg.norm(jnp.cov(xtor.T) - jnp.cov(xobs.T)))
approx = jnp.exp(lp)
target = torus_density(xtor)
w = target / approx
Z = jnp.nanmean(w)
log_approx = jnp.log(approx)
log_target = jnp.log(target)
klqp = jnp.nanmean(log_approx - log_target) + jnp.log(Z)
ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
ress = 100 * ess / len(w)
del w, Z, log_approx, log_target
log_approx = induced_torus_log_density(bij_params, bij_fns, xobs)
approx = jnp.exp(log_approx)
target = torus_density(xobs)
log_target = jnp.log(target)
w = approx / target
Z = jnp.mean(w)
klpq = jnp.mean(log_target - log_approx) + jnp.log(Z)
del w, Z, log_approx, log_target
print(
示例#10
0
def _parallel_train_step(
    optimizer,
    batched_examples,
    static_batch_metadata,
    loss_fn,
    max_global_norm=None,
    **optimizer_hyper_params,
):
    """Train the model for one step in parallel across devices.

  Args:
    optimizer: Optimizer that tracks the model and parameter state. Should be
      replicated to each device, i.e. should contain ShardedDeviceArrays with a
      leading axis (num_devices, ...) but with the same content on each device.
    batched_examples: A structure of NDArrays representing a batch of examples.
      Should have two leading batch dimensions: (num_devices,
        batch_size_per_device, ...)
    static_batch_metadata: Metadata about this batch, which will be shared
      across all batched examples. Each value of this results in a separate
      XLA-compiled module.
    loss_fn: Task-specific non-batched loss function to apply. Should take the
      current model (optimizer.target) and an example from batched_examples, and
      return a tuple of the current loss (as a scalar) and a dictionary from
      string names to metric values (also scalars, or RatioMetrics).
    max_global_norm: Maximum global norm to clip gradients to. Should be a
      scalar, which will be broadcast automatically.
    **optimizer_hyper_params: Hyperparameters to pass to the optimizer's
      `apply_gradient` function, which will be broadcast across devices
      automatically.

  Returns:
    Tuple (updated_optimizer, grads_ok, metrics). Metrics will be as returned by
    loss_fn, with an extra elements "loss". All metrics will be averaged
    across all elements of the batch. Both optimizer and metrics will contain
    ShardedDeviceArrays that are identical across devices. grads_ok will be
    a replicated bool ndarray that is True if the gradients were finite.
  """
    def batched_loss_fn(model):
        """Apply loss function across a batch of examples."""
        loss, metrics = jax.vmap(loss_fn,
                                 (None, 0, None))(model, batched_examples,
                                                  static_batch_metadata)
        return jnp.mean(loss), metrics

    # Compute gradients of loss, along with metrics.
    (loss, metrics), grads = jax.value_and_grad(batched_loss_fn,
                                                has_aux=True)(optimizer.target)
    metrics["loss"] = loss

    # Exchange average gradients and metrics across devices.
    agg_grads = jax.lax.pmean(grads, "devices")
    agg_metrics = {}
    for k, v in metrics.items():
        if isinstance(v, RatioMetric):
            num = jax.lax.psum(jnp.sum(v.numerator), "devices")
            denom = jax.lax.psum(jnp.sum(v.denominator), "devices")
            new_value = num / denom
        else:
            # Use nanmean to aggregate bare floats.
            new_value = jnp.nanmean(jax.lax.all_gather(v, "devices"))
        agg_metrics[k] = new_value

    # Compute global norm and possibly clip.
    global_norm = optax.global_norm(agg_grads)
    agg_metrics["gradient_global_norm"] = global_norm
    if max_global_norm is not None:
        should_clip = global_norm > max_global_norm
        agg_grads = jax.tree_map(
            lambda g: jnp.where(should_clip, g * max_global_norm / global_norm,
                                g), agg_grads)
        agg_metrics["gradient_was_clipped"] = should_clip.astype("float32")

    # Check for non-finite gradients.
    grads_ok = jnp.all(
        jnp.stack(
            [jnp.all(jnp.isfinite(x)) for x in jax.tree_leaves(agg_grads)]))

    # Apply updates.
    updated_optimizer = optimizer.apply_gradient(agg_grads,
                                                 **optimizer_hyper_params)

    return updated_optimizer, grads_ok, agg_metrics, agg_grads
示例#11
0
文件: eval.py 项目: google/nerfies
def process_batch(*,
                  batch: Dict[str, jnp.ndarray],
                  rng: types.PRNGKey,
                  state: model_utils.TrainState,
                  tag: str,
                  item_id: str,
                  step: int,
                  summary_writer: tensorboard.SummaryWriter,
                  render_fn: Any,
                  save_dir: Optional[gpath.GPath],
                  datasource: datasets.DataSource):
  """Process and plot a single batch."""
  item_id = item_id.replace('/', '_')
  render = render_fn(state, batch, rng=rng)
  out = {}
  if jax.process_index() != 0:
    return out

  rgb = render['rgb']
  acc = render['acc']
  depth_exp = render['depth']
  depth_med = render['med_depth']
  colorize_depth = functools.partial(viz.colorize,
                                     cmin=datasource.near,
                                     cmax=datasource.far,
                                     invert=True)

  depth_exp_viz = colorize_depth(depth_exp)
  depth_med_viz = colorize_depth(depth_med)
  disp_exp_viz = viz.colorize(1.0 / depth_exp)
  disp_med_viz = viz.colorize(1.0 / depth_med)
  acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0)
  if save_dir:
    save_dir.mkdir(parents=True, exist_ok=True)
    image_utils.save_image(save_dir / f'rgb_{item_id}.png',
                           image_utils.image_to_uint8(rgb))
    image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png',
                           image_utils.image_to_uint8(depth_exp_viz))
    image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png',
                           depth_exp)
    image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png',
                           image_utils.image_to_uint8(depth_med_viz))
    image_utils.save_depth(save_dir / f'depth_median_{item_id}.png',
                           depth_med)

  summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step)
  summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step)
  summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step)
  summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz,
                       step)
  summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step)
  summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step)

  if 'rgb' in batch:
    rgb_target = batch['rgb']
    mse = ((rgb - batch['rgb'])**2).mean()
    psnr = utils.compute_psnr(mse)
    ssim = compute_multiscale_ssim(rgb_target, rgb)
    out['mse'] = mse
    out['psnr'] = psnr
    out['ssim'] = ssim
    logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f',
                 mse, psnr, ssim)

    rgb_abs_error = viz.colorize(
        abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1)
    rgb_sq_error = viz.colorize(
        ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1)
    summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step)
    summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step)
    summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step)

  if 'depth' in batch:
    depth_target = batch['depth']
    depth_target_viz = colorize_depth(depth_target[..., 0])
    out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med))
    summary_writer.image(
        f'depth-target/{tag}/{item_id}', depth_target_viz, step)
    depth_med_error = viz.colorize(
        abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1)
    summary_writer.image(
        f'depth-median-error/{tag}/{item_id}', depth_med_error, step)
    depth_exp_error = viz.colorize(
        abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1)
    summary_writer.image(
        f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step)

  return out
示例#12
0
def nanmean(a, axis=None, dtype=None, keepdims=None):
  if isinstance(a, JaxArray): a = a.value
  r = jnp.nanmean(a, axis=axis, dtype=dtype, keepdims=keepdims)
  return r if axis is None else JaxArray(r)
示例#13
0
文件: eval.py 项目: dukebw/nerfies
def process_batch(
    *,
    batch: Dict[str, jnp.ndarray],
    rng: types.PRNGKey,
    state: model_utils.TrainState,
    tag: str,
    item_id: str,
    step: int,
    summary_writer: tensorboard.SummaryWriter,
    render_fn: Any,
    save_dir: Optional[gpath.GPath],
    datasource: datasets.DataSource,
):
    """Process and plot a single batch."""
    rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng)
    out = {}
    if jax.host_id() != 0:
        return out

    colorize_depth = functools.partial(viz.colorize,
                                       cmin=datasource.near,
                                       cmax=datasource.far,
                                       invert=True)

    depth_exp_viz = colorize_depth(depth_exp[..., 0])
    depth_med_viz = colorize_depth(depth_med[..., 0])
    if save_dir:
        save_dir.mkdir(parents=True, exist_ok=True)
        image_utils.save_image(save_dir / f"rgb_{item_id}.png",
                               image_utils.image_to_uint8(rgb))
        image_utils.save_image(
            save_dir / f"depth_expected_viz_{item_id}.png",
            image_utils.image_to_uint8(depth_exp_viz),
        )
        image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png",
                               depth_med[..., 0])
        image_utils.save_image(
            save_dir / f"depth_median_viz_{item_id}.png",
            image_utils.image_to_uint8(depth_med_viz),
        )
        image_utils.save_depth(save_dir / f"depth_median_{item_id}.png",
                               depth_med[..., 0])

    summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step)
    summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz,
                         step)
    summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step)
    summary_writer.image(f"acc/{tag}/{item_id}", acc, step)

    if "rgb" in batch:
        rgb_target = batch["rgb"]
        mse = ((rgb - batch["rgb"])**2).mean()
        psnr = utils.compute_psnr(mse)
        ssim = compute_multiscale_ssim(rgb_target, rgb)
        out["mse"] = mse
        out["psnr"] = psnr
        out["ssim"] = ssim
        logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr,
                     ssim)

        rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1),
                                     cmin=0,
                                     cmax=1)
        rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1),
                                    cmin=0,
                                    cmax=1)
        summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step)
        summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error,
                             step)
        summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error,
                             step)

    if "depth" in batch:
        depth_target = batch["depth"]
        depth_target_viz = colorize_depth(depth_target[..., 0])
        out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med))
        summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz,
                             step)
        depth_med_error = viz.colorize(abs(depth_target -
                                           depth_med).squeeze(axis=-1),
                                       cmin=0,
                                       cmax=1)
        summary_writer.image(f"depth-median-error/{tag}/{item_id}",
                             depth_med_error, step)
        depth_exp_error = viz.colorize(abs(depth_target -
                                           depth_exp).squeeze(axis=-1),
                                       cmin=0,
                                       cmax=1)
        summary_writer.image(f"depth-expected-error/{tag}/{item_id}",
                             depth_exp_error, step)
        rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0])
        summary_writer.image(f"relative-disparity/{tag}/{item_id}",
                             rel_disp_pred, step)

    return out
示例#14
0
文件: utils.py 项目: google/nerfies
def robust_whiten(x):
    median = jnp.nanmedian(x)
    mad = jnp.nanmean(jnp.abs(x - median))
    return (x - median) / mad
示例#15
0
# Train dequantization networks.
(bij_params, deq_params), trace = train(rng_train, bij_params, bij_fns,
                                        deq_params, deq_fn, args.num_steps,
                                        args.lr)

# Compute an estimate of the KL divergence.
num_is = 150
_, xon = ambient.sample(rng_mse, 10000, bij_params, bij_fns, num_dims)
log_approx = importance_log_density(rng_kl, bij_params, bij_fns, deq_params,
                                    deq_fn, num_is, xon)
log_approx = jnp.clip(log_approx, -10., 10.)
log_target = log_density(xon)
approx, target = jnp.exp(log_approx), jnp.exp(log_target)
w = jnp.exp(log_target - log_approx)
Z = jnp.nanmean(w)
logZ = jspsp.logsumexp(log_target - log_approx) - jnp.log(len(xon))
klqp = jnp.nanmean(log_approx - log_target) + logZ
ess = jnp.square(jnp.nansum(w)) / jnp.nansum(jnp.square(w))
ress = 100 * ess / len(w)
del w, Z, logZ, log_approx, approx, log_target, target
xobs = xobs[:1000]
log_approx = importance_log_density(rng_kl, bij_params, bij_fns, deq_params,
                                    deq_fn, num_is, xobs)
log_approx = jnp.clip(log_approx, -10., 10.)
log_target = log_density(xobs)
approx, target = jnp.exp(log_approx), jnp.exp(log_target)
logZ = jspsp.logsumexp(log_approx - log_target) - jnp.log(len(xobs))
klpq = jnp.nanmean(log_target - log_approx) + logZ
del logZ, log_approx, approx, log_target, target
mean_mse = jnp.square(jnp.linalg.norm(xon.mean(0) - xobs.mean(0)))