Ejemplo n.º 1
0
 def log_prior(params):
     """Computes the Gaussian prior log-density."""
     # ToDo izmailovpavel: make temperature treatment the same as in gaussian
     # likelihood function.
     n_params = sum([p.size for p in jax.tree_leaves(params)])
     log_prob = -(0.5 * tree_utils.tree_dot(params, params) * weight_decay +
                  0.5 * n_params * jnp.log(weight_decay / (2 * math.pi)))
     return log_prob / temperature
def get_u_v_o(params1, params2, params3):

  u_params = tree_utils.tree_diff(params2, params1)
  u_norm = tree_utils.tree_norm(u_params)
  u_params = tree_utils.tree_scalarmul(u_params, 1 / u_norm)
  v_params = tree_utils.tree_diff(params3, params1)
  uv_dot = tree_utils.tree_dot(u_params, v_params)
  v_params = jax.tree_multimap(lambda v, u: v - uv_dot * u, v_params, u_params)
  v_norm = tree_utils.tree_norm(v_params)
  v_params = tree_utils.tree_scalarmul(v_params, 1 / v_norm)

  return u_params, u_norm, v_params, v_norm, params1