def projection_halfspace(x, a, b): r"""Projection onto a halfspace defined by a pytree and scalar. The output is: ``argmin_{y, dot(a, y) <= b} ||y - x||``. Args: x: pytree to project. a: pytree b: pytree Returns: y: output array (same shape as ``x``) """ # a, b = hyperparams scale = jax.nn.relu(tree_util.tree_vdot(a, x) - b) / tree_util.tree_vdot(a, a) return tree_util.tree_add_scalar_mul(x, -scale, a)
def projection_hyperplane(a, b, x = None): r"""Projection onto a hyperplane defined by a pytree and scalar. The output is: ``argmin_{y, dot(a, y) = b} ||y - x||``. Which is equivalent to y = x - (<a,x>-b)/<a,a> a Args: x: pytree to project. hyperparams: tuple ``hyperparams = (a, b)``, where ``a`` is a pytree and ``b`` is a scalar. Returns: y: output array (same shape as ``x``) """ if x is None: scale = b/tree_util.tree_vdot(a,a) return tree_util.tree_scalar_mul(scale, a) else: scale = (tree_util.tree_vdot(a,x) -b)/tree_util.tree_vdot(a,a) return tree_util.tree_add_scalar_mul(x, -scale, a)
def least_square_regularizor_1d(a, b, delta): # Computes the solution to min || a^Tx -b||^2 + delta ||x||^2 scale = -b/(tree_vdot(a, a) + delta) return tree_scalar_mul(scale, a)