Пример #1
0
  def testFuseIsFaster(self):
    np.random.seed(1)
    rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3)
    dtype = jnp.float32
    tt_a = random_.tensor(rng1, (10, 10, 10, 10), tt_rank=30, dtype=dtype)
    tt_b = random_.tensor(rng2, (10, 10, 10, 10), tt_rank=30, dtype=dtype)
    tt_c = random_.tensor(rng3, (10, 10, 10, 10), tt_rank=1, dtype=dtype)
    func = lambda a, b, c: ops.flat_inner(ops.multiply(a, b), c)
    fused_func = compile.fuse(func)

    func_speed = benchmark(func, tt_a, tt_b, tt_c)
    fused_func_speed = benchmark(fused_func, tt_a, tt_b, tt_c)
    # Check that fused version is at least 10x faster than non-fused.
    self.assertLess(fused_func_speed, 0.1 * func_speed)
Пример #2
0
    def testToAndFromDeltas(self):
        rng1, rng2 = jax.random.split(jax.random.PRNGKey(0))
        dtype = jnp.float32
        what = random_.tensor(rng1, (4, 5, 6), tt_rank=4, dtype=dtype)
        where = random_.tensor(rng2, (4, 5, 6), tt_rank=3, dtype=dtype)
        projected = autodiff.project(what, where)

        deltas = riemannian.tangent_to_deltas(projected)
        reconstructed_projected = riemannian.deltas_to_tangent(deltas, where)
        # Tangent space element norm can be computed from deltas norm.
        projected_normsq_desired = ops.flat_inner(projected, projected)
        projected_normsq_actual = sum([jnp.sum(c * c) for c in deltas])
        self.assertAllClose(ops.full(projected),
                            ops.full(reconstructed_projected))
        self.assertAllClose(projected_normsq_desired, projected_normsq_actual)
Пример #3
0
    def testHessianVectorProduct(self):
        rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3)
        dtype = jnp.float32
        shape = (5, 5, 5)
        A = random_.matrix(rng1, (shape, shape), dtype=dtype)
        AT = ops.transpose(A)
        A_plus_AT = A + AT
        x = random_.matrix(rng2, (shape, None), dtype=dtype)
        vec = random_.matrix(rng3, (shape, None), dtype=dtype)
        proj_vec = autodiff.project(vec, x)

        func = lambda x: ops.flat_inner(x, A @ x)
        desired = autodiff.project(A_plus_AT @ proj_vec, x)
        desired = ops.full(desired)
        actual = ops.full(autodiff.hessian_vector_product(func)(x, vec))
        self.assertAllClose(desired, actual, rtol=1e-4)
Пример #4
0
  def testFuse(self, op_type):
    np.random.seed(1)
    rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3)
    dtype = jnp.float32
    tt_a = random_.tensor(rng1, (1, 2, 3, 4), tt_rank=2, dtype=dtype)
    tt_b = random_.tensor(rng2, (1, 2, 3, 4), tt_rank=[1, 1, 4, 3, 1], dtype=dtype)
    tt_c = random_.tensor(rng3, (1, 2, 3, 4), tt_rank=3, dtype=dtype)
    if op_type == 'a*b*c':
      func = lambda a, b, c: ops.multiply(ops.multiply(a, b), c)
      fused_func = compile.fuse(func)
      res_actual = ops.full(func(tt_a, tt_b, tt_c))
      res_desired = ops.full(fused_func(tt_a, tt_b, tt_c))
    elif op_type == '<a*b, c>':
      func = lambda a, b, c: ops.flat_inner(ops.multiply(a, b), c)
      fused_func = compile.fuse(func)
      res_actual = func(tt_a, tt_b, tt_c)
      res_desired = fused_func(tt_a, tt_b, tt_c)

      self.assertAllClose(res_actual, res_desired)
Пример #5
0
 def _f(x):
     return ops.flat_inner(what, x)
Пример #6
0
 def f(x):
     return 0.5 * ops.flat_inner(x, x)