Ejemplo n.º 1
0
    def testNormGrad(self):
        dtype = jnp.float32
        rng = jax.random.PRNGKey(42)
        tensor = random_.tensor(rng, (2, 1, 3, 4),
                                tt_rank=[1, 2, 4, 3, 1],
                                dtype=dtype)

        def f(x):
            return 0.5 * ops.flat_inner(x, x)

        grad = autodiff.grad(f)
        actual = grad(tensor)
        desired = tensor
        self.assertAllClose(ops.full(actual), ops.full(desired), rtol=1e-4)
Ejemplo n.º 2
0
    def testToAndFromDeltas(self):
        rng1, rng2 = jax.random.split(jax.random.PRNGKey(0))
        dtype = jnp.float32
        what = random_.tensor(rng1, (4, 5, 6), tt_rank=4, dtype=dtype)
        where = random_.tensor(rng2, (4, 5, 6), tt_rank=3, dtype=dtype)
        projected = autodiff.project(what, where)

        deltas = riemannian.tangent_to_deltas(projected)
        reconstructed_projected = riemannian.deltas_to_tangent(deltas, where)
        # Tangent space element norm can be computed from deltas norm.
        projected_normsq_desired = ops.flat_inner(projected, projected)
        projected_normsq_actual = sum([jnp.sum(c * c) for c in deltas])
        self.assertAllClose(ops.full(projected),
                            ops.full(reconstructed_projected))
        self.assertAllClose(projected_normsq_desired, projected_normsq_actual)
Ejemplo n.º 3
0
    def testHessianVectorProduct(self):
        rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3)
        dtype = jnp.float32
        shape = (5, 5, 5)
        A = random_.matrix(rng1, (shape, shape), dtype=dtype)
        AT = ops.transpose(A)
        A_plus_AT = A + AT
        x = random_.matrix(rng2, (shape, None), dtype=dtype)
        vec = random_.matrix(rng3, (shape, None), dtype=dtype)
        proj_vec = autodiff.project(vec, x)

        func = lambda x: ops.flat_inner(x, A @ x)
        desired = autodiff.project(A_plus_AT @ proj_vec, x)
        desired = ops.full(desired)
        actual = ops.full(autodiff.hessian_vector_product(func)(x, vec))
        self.assertAllClose(desired, actual, rtol=1e-4)
Ejemplo n.º 4
0
    def testDoubleProjection(self):
        """Compare P grad f(x) against P grad (<x, stop_grad(P grad f(x))>)."""
        dtype = jnp.float32
        rng = jax.random.PRNGKey(42)
        vector = random_.matrix(rng, ((2, 1, 3, 4), (1, 1, 1, 1)),
                                tt_rank=[1, 2, 4, 3, 1],
                                dtype=dtype)
        matrix = random_.matrix(rng, ((2, 1, 3, 4), (2, 1, 3, 4)),
                                tt_rank=[1, 2, 4, 3, 1],
                                dtype=dtype)

        project = autodiff.project(matrix @ vector, vector)
        double_project = autodiff.project(project, vector)
        self.assertAllClose(ops.full(project),
                            ops.full(double_project),
                            rtol=1e-4)
Ejemplo n.º 5
0
 def testFuse(self, op_type):
   np.random.seed(1)
   rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3)
   dtype = jnp.float32
   left_shape = (2, 3, 4)
   sum_shape = (4, 3, 5)
   right_shape = (4, 4, 4)
   tt_a = random_.matrix(rng1, (left_shape, sum_shape), tt_rank=3, dtype=dtype)
   tt_b = random_.matrix(rng2, (left_shape, sum_shape), tt_rank=3, dtype=dtype)
   tt_c = random_.matrix(rng3, (sum_shape, right_shape), tt_rank=[1, 4, 3, 1],
                             dtype=dtype)
   if op_type == '(a*b) @ c':
     def func(a, b, c):
       return (a * b) @ c
     fused_func = compile.fuse(func)
     res_actual = ops.full(func(tt_a, tt_b, tt_c))
     res_desired = ops.full(fused_func(tt_a, tt_b, tt_c))
     self.assertAllClose(res_actual, res_desired, rtol=1e-4)
Ejemplo n.º 6
0
  def testFuse(self, op_type):
    np.random.seed(1)
    rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3)
    dtype = jnp.float32
    tt_a = random_.tensor(rng1, (1, 2, 3, 4), tt_rank=2, dtype=dtype)
    tt_b = random_.tensor(rng2, (1, 2, 3, 4), tt_rank=[1, 1, 4, 3, 1], dtype=dtype)
    tt_c = random_.tensor(rng3, (1, 2, 3, 4), tt_rank=3, dtype=dtype)
    if op_type == 'a*b*c':
      func = lambda a, b, c: ops.multiply(ops.multiply(a, b), c)
      fused_func = compile.fuse(func)
      res_actual = ops.full(func(tt_a, tt_b, tt_c))
      res_desired = ops.full(fused_func(tt_a, tt_b, tt_c))
    elif op_type == '<a*b, c>':
      func = lambda a, b, c: ops.flat_inner(ops.multiply(a, b), c)
      fused_func = compile.fuse(func)
      res_actual = func(tt_a, tt_b, tt_c)
      res_desired = fused_func(tt_a, tt_b, tt_c)

      self.assertAllClose(res_actual, res_desired)
Ejemplo n.º 7
0
  def testOrthogonalizeRightToLeft(self):
    dtype = jnp.float32
    rng = jax.random.PRNGKey(0)
    shape = (2, 4, 3, 3)
    tt_ranks = (1, 5, 2, 17, 1)
    updated_tt_ranks = (1, 5, 2, 3, 1)
    tens = random_.tensor(rng, shape, tt_rank=tt_ranks, dtype=dtype)
    orthogonal = decompositions.orthogonalize(tens, left_to_right=False)

    self.assertAllClose(ops.full(tens), ops.full(orthogonal), atol=1e-5,
                        rtol=1e-5)
    self.assertArraysEqual(updated_tt_ranks, orthogonal.tt_ranks)
    # Check that the TT-cores are orthogonal.
    for core_idx in range(1, 4):
      core = orthogonal.tt_cores[core_idx]
      core = jnp.reshape(core, (updated_tt_ranks[core_idx], shape[core_idx] *
                                updated_tt_ranks[core_idx + 1]))
      should_be_eye = core @ core.T
      self.assertAllClose(np.eye(updated_tt_ranks[core_idx]), should_be_eye)
Ejemplo n.º 8
0
 def testRound2d(self):
   dtype = jnp.float32
   rank = 5
   np.random.seed(0)
   x = np.random.randn(10, 20).astype(dtype)
   u, s, v = np.linalg.svd(x, full_matrices=False)
   core_1 = u @ np.diag(s)
   core_1 = core_1.reshape(1, 10, 10)
   core_2 = v
   core_2 = core_2.reshape(10, 20, 1)
   tt = TT((core_1, core_2))
   truncated_x = u[:, :rank] @ np.diag(s[:rank]) @ v[:rank, :]
   rounded = decompositions.round(tt, 5)
   self.assertAllClose(truncated_x, ops.full(rounded), rtol=1e-5, atol=1e-5)