def test_rand(self):
     # Just make sure that things generated are on the manifold and that
     # if you generate two they are not equal.
     X = self.man.rand()
     np_testing.assert_allclose(multiprod(multihconj(X), X),
                                multieye(self.k, self.n), atol=1e-10)
     Y = self.man.rand()
     assert la.norm(X - Y) > 1e-6
     assert np.iscomplex(X).all()
 def test_randvec(self):
     # Make sure things generated are in tangent space and if you generate
     # two then they are not equal.
     X = self.man.rand()
     U = self.man.randvec(X)
     np_testing.assert_allclose(multisym(multiprod(multihconj(X), U)),
                                np.zeros((self.k, self.n, self.n)),
                                atol=1e-10)
     V = self.man.randvec(X)
     assert la.norm(U - V) > 1e-6
     assert np.iscomplex(U).all()
 def test_random_point(self):
     # Just make sure that things generated are on the manifold
     # and that if you generate two they are not equal.
     # Test also that matrices are complex.
     X = self.manifold.random_point()
     np_testing.assert_allclose(multihconj(X) @ X,
                                np.eye(self.n),
                                atol=1e-10)
     Y = self.manifold.random_point()
     assert np.linalg.norm(X - Y) > 1e-6
     assert np.iscomplex(X).all()
 def test_randvec(self):
     # Just make sure that things generated are on the horizontal space of
     # complex Stiefel manifold
     # and that if you generate two they are not equal.
     # Test also that matrices are complex.
     X = self.man.rand()
     G = self.man.randvec(X)
     np_testing.assert_allclose(multiprod(multihconj(X), G),
                                np.zeros((self.n, self.n)), atol=1e-10)
     H = self.man.randvec(X)
     assert la.norm(G - H) > 1e-6
     assert np.iscomplex(G).all()
Exemple #5
0
def test_unimplemented_falseyness():
    @contextmanager
    def remove_grad_definitions(fun):
        vjpmaker = primitive_vjps.pop(fun, None)
        yield
        if vjpmaker:
            primitive_vjps[fun] = vjpmaker

    with remove_grad_definitions(np.iscomplex):
        fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
        check_grads(fun)(5.)
        check_grads(fun)(2. + 1j)
Exemple #6
0
def test_unimplemented_falseyness():
    @contextmanager
    def remove_grad_definitions(fun):
        vjpmaker = primitive_vjps.pop(fun, None)
        yield
        if vjpmaker:
            primitive_vjps[fun] = vjpmaker

    with remove_grad_definitions(np.iscomplex):
        fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
        check_grads(fun)(5.)
        check_grads(fun)(2. + 1j)
 def test_random_tangent_vector(self):
     # Make sure things generated are in tangent space and if you generate
     # two then they are not equal.
     X = self.manifold.random_point()
     U = self.manifold.random_tangent_vector(X)
     np_testing.assert_allclose(
         multisym(multihconj(X) @ U),
         np.zeros((self.k, self.n, self.n)),
         atol=1e-10,
     )
     V = self.manifold.random_tangent_vector(X)
     assert np.linalg.norm(U - V) > 1e-6
     assert np.iscomplex(U).all()
def test_unimplemented_falseyness():
    def remove_grad_definitions(fun):
        return primitive_vjps.pop(fun, None)

    def restore_grad_definitions(fun, grads):
        if grads:
            primitive_vjps[fun] = grads

    grad_defs = remove_grad_definitions(np.iscomplex)

    fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
    check_grads(fun)(5.)
    check_grads(fun)(2. + 1j)

    restore_grad_definitions(np.iscomplex, grad_defs)
def test_unimplemented_falseyness():
    def remove_grad_definitions(fun):
        grads, zero_vjps = fun.vjps, fun.zero_vjps
        fun.vjps, fun.zero_vjps = {}, set()
        return grads, zero_vjps

    def restore_grad_definitions(fun, grad_defs):
        fun.vjps, fun.zero_vjps = grad_defs

    grad_defs = remove_grad_definitions(np.iscomplex)

    fun = lambda x: x**2 if np.iscomplex(x) else np.sum(x)
    check_grads(fun, 5.)
    check_grads(fun, 2. + 1j)

    restore_grad_definitions(np.iscomplex, grad_defs)
Exemple #10
0
def test_unimplemented_falseyness():
    def remove_grad_definitions(fun):
        grads, zero_vjps = fun.vjps, fun.zero_vjps
        fun.vjps, fun.zero_vjps = {}, set()
        return grads, zero_vjps

    def restore_grad_definitions(fun, grad_defs):
        fun.vjps, fun.zero_vjps = grad_defs

    grad_defs = remove_grad_definitions(np.iscomplex)

    fun = lambda x: x**2 if np.iscomplex(x) else np.sum(x)
    check_grads(fun, 5.)
    check_grads(fun, 2. + 1j)

    restore_grad_definitions(np.iscomplex, grad_defs)
Exemple #11
0
def unary_nd(f, x, eps=1e-4):
    if isinstance(x, np.ndarray):
        if np.iscomplexobj(x):
            nd_grad = np.zeros(x.shape) + 0j
        else:
            nd_grad = np.zeros(x.shape)
        for dims in it.product(*map(range, x.shape)):
            nd_grad[dims] = unary_nd(indexed_function(f, x, dims), x[dims])
        return nd_grad
    elif isinstance(x, tuple):
        return tuple([unary_nd(indexed_function(f, tuple(x), i), x[i])
                      for i in range(len(x))])
    elif isinstance(x, dict):
        return {k : unary_nd(indexed_function(f, x, k), v) for k, v in x.iteritems()}
    elif isinstance(x, list):
        return [unary_nd(indexed_function(f, x, i), v) for i, v in enumerate(x)]
    elif np.iscomplex(x):
        result = (f(x +    eps/2) - f(x -    eps/2)) / eps \
            - 1j*(f(x + 1j*eps/2) - f(x - 1j*eps/2)) / eps
        return type(safe_type(x))(result)
    else:
        return type(safe_type(x))((f(x + eps/2) - f(x - eps/2)) / eps)
Exemple #12
0
def test_falseyness():
    fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
    check_grads(fun)(2.)
    check_grads(fun)(2. + 1j)
Exemple #13
0
def test_falseyness():
    fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
    check_grads(fun)(2.)
    check_grads(fun)(2. + 1j)
def test_falseyness():
    fun = lambda x: x**2 if np.iscomplex(x) else np.sum(x)
    check_grads(fun, 2.)
    check_grads(fun, 2. + 1j)
Exemple #15
0
def test_falseyness():
    fun = lambda x: x**2 if np.iscomplex(x) else np.sum(x)
    check_grads(fun, 2.)
    check_grads(fun, 2. + 1j)