Beispiel #1
0
def projection_halfspace(x,  a, b):
  r"""Projection onto a halfspace defined by a pytree and scalar.

  The output is:
    ``argmin_{y, dot(a, y) <= b} ||y - x||``.
  Args:
    x: pytree to project.
    a: pytree
    b: pytree

  Returns:
    y: output array (same shape as ``x``)
  """
  # a, b = hyperparams
  scale = jax.nn.relu(tree_util.tree_vdot(a, x) - b) / tree_util.tree_vdot(a, a)
  return tree_util.tree_add_scalar_mul(x, -scale, a)
Beispiel #2
0
def projection_hyperplane(a, b, x = None):
  r"""Projection onto a hyperplane defined by a pytree and scalar.

  The output is:
    ``argmin_{y, dot(a, y) = b} ||y - x||``.
  Which is equivalent to
     y = x - (<a,x>-b)/<a,a> a
  Args:
    x: pytree to project.
    hyperparams: tuple ``hyperparams = (a, b)``, where ``a`` is a pytree and
      ``b`` is a scalar.
  Returns:
    y: output array (same shape as ``x``)
  """
  if x is None:
    scale = b/tree_util.tree_vdot(a,a)
    return tree_util.tree_scalar_mul(scale, a)
  else:
    scale = (tree_util.tree_vdot(a,x) -b)/tree_util.tree_vdot(a,a)
    return tree_util.tree_add_scalar_mul(x, -scale, a)
 def least_square_regularizor_1d(a, b, delta):
   # Computes the solution to min || a^Tx -b||^2 + delta ||x||^2
   scale = -b/(tree_vdot(a, a) + delta)
   return tree_scalar_mul(scale, a)