def testNormGrad(self): dtype = jnp.float32 rng = jax.random.PRNGKey(42) tensor = random_.tensor(rng, (2, 1, 3, 4), tt_rank=[1, 2, 4, 3, 1], dtype=dtype) def f(x): return 0.5 * ops.flat_inner(x, x) grad = autodiff.grad(f) actual = grad(tensor) desired = tensor self.assertAllClose(ops.full(actual), ops.full(desired), rtol=1e-4)
def testToAndFromDeltas(self): rng1, rng2 = jax.random.split(jax.random.PRNGKey(0)) dtype = jnp.float32 what = random_.tensor(rng1, (4, 5, 6), tt_rank=4, dtype=dtype) where = random_.tensor(rng2, (4, 5, 6), tt_rank=3, dtype=dtype) projected = autodiff.project(what, where) deltas = riemannian.tangent_to_deltas(projected) reconstructed_projected = riemannian.deltas_to_tangent(deltas, where) # Tangent space element norm can be computed from deltas norm. projected_normsq_desired = ops.flat_inner(projected, projected) projected_normsq_actual = sum([jnp.sum(c * c) for c in deltas]) self.assertAllClose(ops.full(projected), ops.full(reconstructed_projected)) self.assertAllClose(projected_normsq_desired, projected_normsq_actual)
def testHessianVectorProduct(self): rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3) dtype = jnp.float32 shape = (5, 5, 5) A = random_.matrix(rng1, (shape, shape), dtype=dtype) AT = ops.transpose(A) A_plus_AT = A + AT x = random_.matrix(rng2, (shape, None), dtype=dtype) vec = random_.matrix(rng3, (shape, None), dtype=dtype) proj_vec = autodiff.project(vec, x) func = lambda x: ops.flat_inner(x, A @ x) desired = autodiff.project(A_plus_AT @ proj_vec, x) desired = ops.full(desired) actual = ops.full(autodiff.hessian_vector_product(func)(x, vec)) self.assertAllClose(desired, actual, rtol=1e-4)
def testDoubleProjection(self): """Compare P grad f(x) against P grad (<x, stop_grad(P grad f(x))>).""" dtype = jnp.float32 rng = jax.random.PRNGKey(42) vector = random_.matrix(rng, ((2, 1, 3, 4), (1, 1, 1, 1)), tt_rank=[1, 2, 4, 3, 1], dtype=dtype) matrix = random_.matrix(rng, ((2, 1, 3, 4), (2, 1, 3, 4)), tt_rank=[1, 2, 4, 3, 1], dtype=dtype) project = autodiff.project(matrix @ vector, vector) double_project = autodiff.project(project, vector) self.assertAllClose(ops.full(project), ops.full(double_project), rtol=1e-4)
def testFuse(self, op_type): np.random.seed(1) rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3) dtype = jnp.float32 left_shape = (2, 3, 4) sum_shape = (4, 3, 5) right_shape = (4, 4, 4) tt_a = random_.matrix(rng1, (left_shape, sum_shape), tt_rank=3, dtype=dtype) tt_b = random_.matrix(rng2, (left_shape, sum_shape), tt_rank=3, dtype=dtype) tt_c = random_.matrix(rng3, (sum_shape, right_shape), tt_rank=[1, 4, 3, 1], dtype=dtype) if op_type == '(a*b) @ c': def func(a, b, c): return (a * b) @ c fused_func = compile.fuse(func) res_actual = ops.full(func(tt_a, tt_b, tt_c)) res_desired = ops.full(fused_func(tt_a, tt_b, tt_c)) self.assertAllClose(res_actual, res_desired, rtol=1e-4)
def testFuse(self, op_type): np.random.seed(1) rng1, rng2, rng3 = jax.random.split(jax.random.PRNGKey(0), 3) dtype = jnp.float32 tt_a = random_.tensor(rng1, (1, 2, 3, 4), tt_rank=2, dtype=dtype) tt_b = random_.tensor(rng2, (1, 2, 3, 4), tt_rank=[1, 1, 4, 3, 1], dtype=dtype) tt_c = random_.tensor(rng3, (1, 2, 3, 4), tt_rank=3, dtype=dtype) if op_type == 'a*b*c': func = lambda a, b, c: ops.multiply(ops.multiply(a, b), c) fused_func = compile.fuse(func) res_actual = ops.full(func(tt_a, tt_b, tt_c)) res_desired = ops.full(fused_func(tt_a, tt_b, tt_c)) elif op_type == '<a*b, c>': func = lambda a, b, c: ops.flat_inner(ops.multiply(a, b), c) fused_func = compile.fuse(func) res_actual = func(tt_a, tt_b, tt_c) res_desired = fused_func(tt_a, tt_b, tt_c) self.assertAllClose(res_actual, res_desired)
def testOrthogonalizeRightToLeft(self): dtype = jnp.float32 rng = jax.random.PRNGKey(0) shape = (2, 4, 3, 3) tt_ranks = (1, 5, 2, 17, 1) updated_tt_ranks = (1, 5, 2, 3, 1) tens = random_.tensor(rng, shape, tt_rank=tt_ranks, dtype=dtype) orthogonal = decompositions.orthogonalize(tens, left_to_right=False) self.assertAllClose(ops.full(tens), ops.full(orthogonal), atol=1e-5, rtol=1e-5) self.assertArraysEqual(updated_tt_ranks, orthogonal.tt_ranks) # Check that the TT-cores are orthogonal. for core_idx in range(1, 4): core = orthogonal.tt_cores[core_idx] core = jnp.reshape(core, (updated_tt_ranks[core_idx], shape[core_idx] * updated_tt_ranks[core_idx + 1])) should_be_eye = core @ core.T self.assertAllClose(np.eye(updated_tt_ranks[core_idx]), should_be_eye)
def testRound2d(self): dtype = jnp.float32 rank = 5 np.random.seed(0) x = np.random.randn(10, 20).astype(dtype) u, s, v = np.linalg.svd(x, full_matrices=False) core_1 = u @ np.diag(s) core_1 = core_1.reshape(1, 10, 10) core_2 = v core_2 = core_2.reshape(10, 20, 1) tt = TT((core_1, core_2)) truncated_x = u[:, :rank] @ np.diag(s[:rank]) @ v[:rank, :] rounded = decompositions.round(tt, 5) self.assertAllClose(truncated_x, ops.full(rounded), rtol=1e-5, atol=1e-5)