Пример #1
0
def norm_f(x, norm_type):
  """Differentiable implementation of norm handling any Lp norm."""
  if norm_type == 'l2':
    return jnp.float_power(jnp.maximum(1e-7, jnp.sum(x**2)), 0.5)
  if norm_type == 'l1':
    return jnp.sum(jnp.abs(x))
  if norm_type == 'linf':
    return jnp.max(jnp.abs(x))
  if norm_type == 'dft1':
    # dft = scipy.linalg.dft(x.shape[0]) / jnp.sqrt(x.shape[0])
    dft = scipy.linalg.dft(x.shape[0], scale='sqrtn')  #  / jnp.sqrt(x.shape[0])
    return jnp.sum(jnp.abs(dft @ x))
    # return jnp.sum(jnp.abs(scipy.fft.fft(x, norm='ortho')))
  if norm_type == 'dftinf':
    dft = scipy.linalg.dft(x.shape[0], scale='sqrtn')  #  / jnp.sqrt(x.shape[0])
    return jnp.max(jnp.abs(dft @ x))
    # return jnp.max(jnp.abs(scipy.fft.fft(x, norm='ortho')))
  p = float(norm_type[1:].split('_')[0])
  q = float(norm_type[1:].split('_')[1]) if '_' in norm_type else 1
  # root(p) becomes gives nan because it's not differentiable at 0
  # The subgradient of the norm is 0 but the directional derivative is inf
  # because 1/x^(p-1) is inf at 0
  # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#Enforcing-a-differentiation-convention
  if q != 1:
    return jnp.float_power(
        jnp.maximum(1e-7, jnp.sum(jnp.float_power(jnp.abs(x), p))), 1 / q)
  return jnp.sum(jnp.float_power(jnp.abs(x), p))
Пример #2
0
def float_power_jvp(primals, tangents):
  x, p = primals
  x_dot, p_dot = tangents
  ans = float_power(x, p)
  fdx_xdot = p * jnp.float_power(x, p - 1) * x_dot
  fdp_pdot = jnp.log(jnp.maximum(1e-7, x)) * jnp.float_power(x, p) * p_dot
  ans_dot = fdx_xdot + fdp_pdot
  return ans, ans_dot
Пример #3
0
  def test_value_and_grad_aux(self):
    o = object()

    def f(x):
      m = SquareModule()
      return m(x), o

    x = jnp.array(3.)
    (y, aux), g = stateful.value_and_grad(f, has_aux=True)(x)
    self.assertEqual(y, jnp.float_power(x, 2))
    np.testing.assert_allclose(g, 2 * x, rtol=1e-4)
    self.assertIs(aux, o)
Пример #4
0
def float_power(x, p):
  return jnp.float_power(x, p)
Пример #5
0
def float_power(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.float_power(x1, x2))
Пример #6
0
 def lr_schedule(t):
     cur_epoch = jnp.minimum(num_anneals, jnp.sum(t > anneal_steps))
     return lr_init * jnp.float_power(anneal_factor, cur_epoch)