Example #1
0
 def tex_apply_fn(params, xs, **kwargs):
   if param_names:
     assert len(param_names) == len(params)
     params = tuple(tex_var(p, name, True) for p, name in
                    zip(params, param_names))
   return tex_var(apply_fn(params, xs, **kwargs),
                  name,
                  depends_on=xs if explicit_depends else ())
Example #2
0
def main(unused_argv):
  # EX 1
  print_ex(jax2tex(lambda a, b: a + b, 1, 2))

  # EX 2
  print_ex(jax2tex(lambda a, b: a + b / a, 1, 2))

  # EX 3
  f = lambda a, b: a + b / a
  print_ex(jax2tex(grad(f), 1., 2.))

  # EX 4
  def fn(a, b, c):
    return a + a * (b + c) / a
  print_ex(jax2tex(fn, np.array([[1, 2], [2, 4], [3, 7]]), 2, 3))

  # EX 5
  # pylint: disable=function-redefined
  # pylint: disable=invalid-name
  def fn(a, b, c):
    return a + a * (b + c)
  print_ex(jax2tex(grad(fn), 4., 2., 3.))

  # EX 6
  def fn(a, b):
    return a * (a - b) / (a + b) + b
  print_ex(jax2tex(grad(fn), 1., 1.))

  # EX 7
  print_ex(jax2tex(lambda W, x: W @ x, np.ones((3, 3)), np.ones((3,))))

  # EX 8
  print_ex(jax2tex(lambda W, x: W @ x, np.ones((3, 2)), np.ones((2, 3))))

  # EX 9
  def fn(W, x):
    return (W + W) @ (x * x)
  print_ex(jax2tex(fn, np.ones((3, 2)), np.ones((2, 3))))

  # EX 10
  def fn(W, x):
    return (W + W) @ (x * x)
  print_ex(jax2tex(grad(fn), np.ones((2,)), np.ones((2,))))

  # EX 11
  def fn(W, x):
    z = tex_var(W @ x, 'z')
    return z * z
  print_ex(jax2tex(fn, np.ones((4, 2,)), np.ones((2,))))

  # EX 12
  def fn(W, x):
    z1 = tex_var(W @ x, 'z^1')
    z2 = tex_var(W @ z1, 'z^2')
    return z2 @ z2
  print_ex(jax2tex(grad(fn), np.ones((2, 2,)), np.ones((2,))))

  # EX 13
  def fn(W, x):
    z1 = tex_var(W @ x, 'z^1')
    z2 = tex_var(W @ z1, 'z^2')
    return np.sqrt(z2 @ z2)
  print_ex(jax2tex(fn, np.ones((2, 2,)), np.ones((2,))))

  # EX 14
  def fn(x):
    return lax.broadcast_in_dim(x, (2, 3), (1,))
  print_ex(jax2tex(fn, np.ones((3,))))

  # EX 15
  def fn(c, x, y):
    return np.where(c, x, y)
  print_ex(jax2tex(fn, np.ones((3,), bool), np.ones((3,)), np.ones((3,))))

  # EX 16
  def fn(c, x, y):
    return np.where(c, x, y)
  print_ex(jax2tex(fn, True, np.ones((3,)), np.ones((3,))))

  # EX 17
  def fn(x):
    return np.transpose(x)
  print_ex(jax2tex(fn, np.ones((3, 2))))

  # EX 18
  def E(dr):
    idr = (tex_var(1, '\\sigma') / dr)
    idr6 = idr ** 6
    idr12 = idr ** 12
    return 4 * tex_var(1, '\\epsilon') * (idr12 - idr6)
  print_ex(jax2tex(E, np.ones((3, 3))))

  # Stax Examples
  def TexVar(layer, name, param_names=(), explicit_depends=False):
    init_fn, apply_fn = layer
    def tex_apply_fn(params, xs, **kwargs):
      if param_names:
        assert len(param_names) == len(params)
        params = tuple(tex_var(p, name, True) for p, name in
                       zip(params, param_names))
      return tex_var(apply_fn(params, xs, **kwargs),
                     name,
                     depends_on=xs if explicit_depends else ())
    return init_fn, tex_apply_fn
  init_fn, apply_fn = stax.serial(
      TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1')),
      TexVar(stax.Relu, 'y^1'),
      TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2')))

  # EX 19
  def f(params, x):
    return apply_fn(params, tex_var(x, 'x', True))
  _, params = init_fn(random.PRNGKey(0), (-1, 5))
  print_ex(jax2tex(f, params, np.ones((3, 5))))

  # pylint: disable=too-many-function-args
  def L(params, x, y_hat):
    y_hat = tex_var(y_hat, '\\hat y', True)
    return tex_var(-np.sum(y_hat * jax.nn.log_softmax(f(params, x))), 'L')
  # EX 20
  print_ex(jax2tex(L, params, np.ones((3, 5)), np.ones((3, 3))))
  # EX 21
  print_ex(jax2tex(grad(L), params, np.ones((3, 5)), np.ones((3, 3))))

  # EX 22
  init_fn, apply_fn = stax.serial(
      TexVar(stax.Dense(256), 'z^1', ('W^1', 'b^1'), True),
      TexVar(stax.Relu, 'y^1'),
      TexVar(stax.Dense(3), 'z^2', ('W^2', 'b^2')))
  def f(params, x):
    return apply_fn(params, tex_var(x, 'x', True))
  _, params = init_fn(random.PRNGKey(0), (-1, 5))
  print_ex(jax2tex(f, params, np.ones((3, 5))))

  # EX 23
  def nngp(params, x1, x2):
    x1 = tex_var(x1, 'x^1', True)
    x2 = tex_var(x2, 'x^2', True)
    return tex_var(apply_fn(params, x1) @ apply_fn(params, x2).T, '\\mathcal K')
  _, params = init_fn(random.PRNGKey(0), (-1, 5))
  print_ex(jax2tex(nngp, params, np.ones((3, 5)), np.ones((3, 5))))

  # Forward Mode vs Reverse Mode
  f = lambda a, b: a + tex_var(b / a, 'z')
  # EX 24
  print_ex(jax2tex(f, 1., 1.))
  # EX 25
  print_ex(jax2tex(grad(f), 1., 1.))
  # EX 26
  # pylint: disable=g-long-lambda
  print_ex(jax2tex(lambda a, b:
                   jvp(lambda a: f(a, b), (a,), (1.,))[1], 1., 1.))

  # EX 27
  def f(x, y):
    def g(r):
      return tex_var(r ** 2, 'z', depends_on=r)
    return g(x) + g(y)
  print_ex(jax2tex(f, 1., 1.))

  # EX 28
  def f(x_and_y):
    x, y = x_and_y
    return x * y
  print_ex(jax2tex(f, (1., 1.)))

  # EX 29
  def f(x_and_y):
    x, y = x_and_y
    return tex_var(x, 'x') * tex_var(y, 'y')
  print_ex(jax2tex(f, (1., 1.)))

  # EX 30
  def f(x_and_y):
    x, y = x_and_y
    return tex_var(x, 'x', True) * tex_var(y, 'y', True)
  print_ex(jax2tex(f, (1., 1.)))

  def f(x):
    return np.sin(x)
  # EX 31
  print_ex(jax2tex(grad(bind_names(f)), 1.))
  # EX 32
  print_ex(jax2tex(grad(f), 1.))
Example #3
0
 def f(x_and_y):
   x, y = x_and_y
   return tex_var(x, 'x', True) * tex_var(y, 'y', True)
Example #4
0
 def g(r):
   return tex_var(r ** 2, 'z', depends_on=r)
Example #5
0
 def nngp(params, x1, x2):
   x1 = tex_var(x1, 'x^1', True)
   x2 = tex_var(x2, 'x^2', True)
   return tex_var(apply_fn(params, x1) @ apply_fn(params, x2).T, '\\mathcal K')
Example #6
0
 def f(params, x):
   return apply_fn(params, tex_var(x, 'x', True))
Example #7
0
 def L(params, x, y_hat):
   y_hat = tex_var(y_hat, '\\hat y', True)
   return tex_var(-np.sum(y_hat * jax.nn.log_softmax(f(params, x))), 'L')
Example #8
0
 def E(dr):
   idr = (tex_var(1, '\\sigma') / dr)
   idr6 = idr ** 6
   idr12 = idr ** 12
   return 4 * tex_var(1, '\\epsilon') * (idr12 - idr6)
Example #9
0
 def fn(W, x):
   z1 = tex_var(W @ x, 'z^1')
   z2 = tex_var(W @ z1, 'z^2')
   return z2 @ z2
Example #10
0
 def fn(W, x):
   z1 = tex_var(W @ x, 'z^1')
   z2 = tex_var(W @ z1, 'z^2')
   return np.sqrt(z2 @ z2)
Example #11
0
 def fn(W, x):
   z = tex_var(W @ x, 'z')
   return z * z
Example #12
0
 def f_(params, x):
     return apply(params, tex_var(x, 'x', True))
Example #13
0
def get_fwd_vs_rev_fns():
    f_ = lambda a, b: a + tex_var(b / a, 'z')
    jvp_fn_ = lambda a, b: jvp(lambda a: f_(a, b), (a, ), (1., ))[1]
    return f_, jvp_fn_
Example #14
0
f = get_dep_fns()
EXAMPLES += [
    # EX 18
    Jax2TexExample(f, (Scalar, Scalar),
                   ('z(x) &= {x}^{2}\\\\\nz(y) &= {y}^{2}\\\\\n'
                    'q(x,y) &= z(x) + z(y)'))
]

EXAMPLES += [
    # EX 19
    Jax2TexExample(lambda x_and_y: x_and_y[0] * x_and_y[1],
                   ((Scalar, Scalar), ), 'f &= \\theta^0\\theta^1'),
    # EX 20
    Jax2TexExample(
        lambda x_and_y: tex_var(x_and_y[0], 'x') * tex_var(x_and_y[1], 'y'),
        ((Scalar, Scalar), ),
        'x &= \\theta^0\\\\\ny &= \\theta^1\\\\\nf &= xy'),
    # EX 21
    Jax2TexExample(
        lambda x_and_y: tex_var(x_and_y[0], 'x', True) * tex_var(
            x_and_y[1], 'y', True), ((Scalar, Scalar), ), 'f &= xy'),
]


def get_ex22_fn():
    def q(W, x):
        z = tex_var(W @ x, 'z')
        return z * z

    return q
EXAMPLES += [
    # EX 18
    Jax2TexExample(f,
                   (Scalar, Scalar),
                   ('z(x) &= {x}^{2}\\\\\nz(y) &= {y}^{2}\\\\\n'
                    'q(x,y) &= z(x) + z(y)'))
]

EXAMPLES += [
    # EX 19
    Jax2TexExample(lambda x_and_y: x_and_y[0] * x_and_y[1],
                   ((Scalar, Scalar),),
                   'f &= \\theta^0\\theta^1'),
    # EX 20
    Jax2TexExample(lambda x_and_y:
                   tex_var(x_and_y[0], 'x') * tex_var(x_and_y[1], 'y'),
                   ((Scalar, Scalar),),
                   'x &= \\theta^0\\\\\ny &= \\theta^1\\\\\nf &= xy'),
    # EX 21
    Jax2TexExample(
        lambda x_and_y:
        tex_var(x_and_y[0], 'x', True) * tex_var(x_and_y[1], 'y', True),
        ((Scalar, Scalar),),
        'f &= xy'),
]


def get_ex22_fn():
  def q(W, x):
    z = tex_var(W @ x, 'z')
    return z * z