def test3(self):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Beispiel #2
0
    def test1(self):
        with torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)]
            Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)]

            f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5),
                           rescaled_act.sigmoid)
            c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel))

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 5, 3)

            rx = torch.einsum("ij,zaj->zai", (D_in, x))
            rgeo = geo @ o3.rot(*abc).t()

            y = f(c(x, geo), dim=2)
            ry = torch.einsum("ij,zaj->zai", (D_out, y))

            self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
Beispiel #3
0
 def D(*angles):
     d = o3.direct_sum(
         o3.irr_repr(1, *angles),
         o3.irr_repr(1, *angles),
         o3.irr_repr(2, *angles),
     )
     return A @ d @ torch.inverse(A)
Beispiel #4
0
 def test_irrep_closure_2(self):
     r1, r2 = (0, 0.2, 0), (0.1, 0.4, 1.5)  # two random rotations
     a = sum((o3.irr_repr(l, *r1) * o3.irr_repr(l, *r2)).sum()
             for l in range(12 + 1))
     b = sum((o3.irr_repr(l, *r1) * o3.irr_repr(l, *r1)).sum()
             for l in range(12 + 1))
     self.assertLess(a, b / 100)
Beispiel #5
0
    def test_reduce_tensor_product(self):
        for Rs_i, Rs_j in [([(1, 0)], [(2, 0)]),
                           ([(3, 1), (2, 2)], [(2, 0), (1, 1), (1, 3)])]:
            with o3.torch_default_dtype(torch.float64):
                Rs, Q = rs.tensor_product(Rs_i, Rs_j)

                abc = torch.rand(3, dtype=torch.float64)

                D_i = o3.direct_sum(*[
                    o3.irr_repr(l, *abc) for mul, l in Rs_i for _ in range(mul)
                ])
                D_j = o3.direct_sum(*[
                    o3.irr_repr(l, *abc) for mul, l in Rs_j for _ in range(mul)
                ])
                D = o3.direct_sum(*[
                    o3.irr_repr(l, *abc) for mul, l, _ in Rs
                    for _ in range(mul)
                ])

                Q1 = torch.einsum("ijk,il->ljk", (Q, D))
                Q2 = torch.einsum("li,mj,kij->klm", (D_i, D_j, Q))

                d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                n = Q.size(0)
                M = Q.view(n, n)
                I = torch.eye(n, dtype=M.dtype)

                d = ((M @ M.t()) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                d = ((M.t() @ M) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)
Beispiel #6
0
def rep(Rs, alpha, beta, gamma, parity=None):
    """
    Representation of O(3). Parity applied (-1)**parity times.
    """
    abc = [alpha, beta, gamma]
    if parity is None:
        return o3.direct_sum(*[o3.irr_repr(l, *abc) for mul, l, _ in simplify(Rs) for _ in range(mul)])
    else:
        assert all(parity != 0 for _, _, parity in simplify(Rs))
        return o3.direct_sum(*[(p ** parity) * o3.irr_repr(l, *abc) for mul, l, p in simplify(Rs) for _ in range(mul)])
Beispiel #7
0
def test_reduce_tensor_antisymmetric_L2():
    Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 2)])
    assert Rs[0] == (1, 1, 0)
    q = Q[:3].reshape(3, 5, 5, 5)

    r = o3.rand_angles()
    D1 = o3.irr_repr(1, *r)
    D2 = o3.irr_repr(2, *r)
    Q1 = torch.einsum('il,jm,kn,zijk->zlmn', D2, D2, D2, q)
    Q2 = torch.einsum('yz,zijk->yijk', D1, q)

    assert (Q1 - Q2).abs().max() < 1e-10
    assert (q + q.transpose(1, 2)).abs().max() < 1e-10
    assert (q + q.transpose(1, 3)).abs().max() < 1e-10
    assert (q + q.transpose(3, 2)).abs().max() < 1e-10
Beispiel #8
0
def test_sh_is_in_irrep():
    with o3.torch_default_dtype(torch.float64):
        for l in range(4 + 1):
            a, b = 3.14 * torch.rand(2)  # works only for beta in [0, pi]
            Y = rsh.spherical_harmonics_alpha_beta([l], a, b) * math.sqrt(4 * math.pi) / math.sqrt(2 * l + 1) * (-1) ** l
            D = o3.irr_repr(l, a, b, 0)
            assert (Y - D[:, l]).norm() < 1e-10
Beispiel #9
0
    def __init__(self, Rs, act, n):
        '''
        map to a signal on SO3, apply the non linearity point wise and project back
        the signal on SO3 is the regular representation of SO3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation
        :param act: activation function
        :param n: number of point on the sphere (the higher the more accurate)
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        mul0, _, _ = Rs[0]
        assert all(mul0 * (2 * l + 1) == mul for mul, l, _ in Rs)
        assert [l for _, l, _ in Rs] == list(range(len(Rs)))
        assert all(p == 0 for _, l, p in Rs)

        self.Rs_out = Rs

        x = [o3.rand_rot() for _ in range(n)]
        Z = torch.stack([
            torch.cat([
                o3.irr_repr(l, *o3.rot_to_abc(R)).flatten() * (2 * l + 1)**0.5
                for l in range(len(Rs))
            ]) for R in x
        ])  # [z, lmn]
        Z.div_(Z.shape[1]**0.5)
        self.register_buffer('Z', Z)
        self.act = act
Beispiel #10
0
def rsh_surface(l, m, scale, tr, rot):
    n = 50
    a = torch.linspace(0, 2 * math.pi, 2 * n)
    b = torch.linspace(0, math.pi, n)
    a, b = torch.meshgrid(a, b)

    f = rsh.spherical_harmonics_alpha_beta([l], a, b)
    f = torch.einsum('ij,...j->...i', o3.irr_repr(l, *rot), f)
    f = f[..., l + m]

    x, y, z = o3.angles_to_xyz(a, b)

    r = f.abs()
    x = scale * r * x + tr[0]
    y = scale * r * y + tr[1]
    z = scale * r * z + tr[2]

    max_value = 0.5

    return go.Surface(
        x=x.numpy(),
        y=y.numpy(),
        z=z.numpy(),
        surfacecolor=f.numpy(),
        showscale=False,
        cmin=-max_value,
        cmax=max_value,
        colorscale=[[0, 'rgb(0,50,255)'], [0.5, 'rgb(200,200,200)'],
                    [1, 'rgb(255,50,0)']],
    )
Beispiel #11
0
def test_reduce_tensor_Levi_Civita_symbol():
    Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 1)])
    assert Rs == [(1, 0, 0)]
    r = o3.rand_angles()
    D = o3.irr_repr(1, *r)
    Q = Q.reshape(3, 3, 3)
    Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q)
    assert (Q1 - Q).abs().max() < 1e-10
Beispiel #12
0
    def test_xyz_to_irreducible_basis(self, ):
        with o3.torch_default_dtype(torch.float64):
            A = o3.xyz_to_irreducible_basis()

            a, b, c = torch.rand(3)

            r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A
            r2 = o3.rot(a, b, c)

            assert torch.allclose(r1, r2)
Beispiel #13
0
    def test_xyz_vector_basis_to_spherical_basis(self, ):
        with o3.torch_default_dtype(torch.float64):
            A = o3.xyz_vector_basis_to_spherical_basis()

            a, b, c = torch.rand(3)

            r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A
            r2 = o3.rot(a, b, c)

            self.assertLess((r1 - r2).abs().max(), 1e-10)
Beispiel #14
0
 def test_irrep_closure_1(self):
     rots = [o3.rand_angles() for _ in range(10000)]
     Us = [torch.stack([o3.irr_repr(l, *abc) for abc in rots]) for l in range(3 + 1)]
     for l1, U1 in enumerate(Us):
         for l2, U2 in enumerate(Us):
             m = torch.einsum('zij,zkl->zijkl', U1, U2).mean(0).reshape((2 * l1 + 1)**2, (2 * l2 + 1)**2)
             if l1 == l2:
                 i = torch.eye((2 * l1 + 1)**2)
                 self.assertLess((m.mul(2 * l1 + 1) - i).abs().max(), 0.1)
             else:
                 self.assertLess(m.abs().max(), 0.1)
    def test2(self):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
    def test6(self):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Beispiel #17
0
    def test_irr_repr_wigner_3j(self):
        """Test irr_repr and wigner_3j equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.wigner_3j(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = rsh.spherical_harmonics_xyz([l_f], r)
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
Beispiel #18
0
 def test_derivative_irr_repr(self):
     with o3.torch_default_dtype(torch.float64):
         l = 4
         angles = o3.rand_angles()
         da, db, dc = o3.derivative_irr_repr(l, *angles)
         h = 1e-7
         da1 = (o3.irr_repr(l, angles[0] + h, angles[1], angles[2]) - o3.irr_repr(l, *angles)) / h
         db1 = (o3.irr_repr(l, angles[0], angles[1] + h, angles[2]) - o3.irr_repr(l, *angles)) / h
         dc1 = (o3.irr_repr(l, angles[0], angles[1], angles[2] + h) - o3.irr_repr(l, *angles)) / h
         self.assertLess((da1 - da).abs().max(), 1e-5)
         self.assertLess((db1 - db).abs().max(), 1e-5)
         self.assertLess((dc1 - dc).abs().max(), 1e-5)
Beispiel #19
0
def test_sh_equivariance():
    """
    This test tests that
    - irr_repr
    - compose
    - spherical_harmonics
    are compatible

    Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
    with x = Z(a) Y(b) eta
    """
    for l in range(7):
        with o3.torch_default_dtype(torch.float64):
            a, b = torch.rand(2)
            alpha, beta, gamma = torch.rand(3)

            ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0)
            Yrx = rsh.spherical_harmonics_alpha_beta([l], ra, rb)

            Y = rsh.spherical_harmonics_alpha_beta([l], a, b)
            DrY = o3.irr_repr(l, alpha, beta, gamma) @ Y

            assert (Yrx - DrY).abs().max() < 1e-10 * Y.abs().max()
Beispiel #20
0
    def test_spherical_harmonics(self):
        """
        This test tests that
        - irr_repr
        - compose
        - spherical_harmonics
        are compatible

        Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
        with x = Z(a) Y(b) eta
        """
        for order in range(7):
            with o3.torch_default_dtype(torch.float64):
                a, b = torch.rand(2)
                alpha, beta, gamma = torch.rand(3)

                ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0)
                Yrx = o3.spherical_harmonics(order, ra, rb)

                Y = o3.spherical_harmonics(order, a, b)
                DrY = o3.irr_repr(order, alpha, beta, gamma) @ Y

                self.assertLess((Yrx - DrY).abs().max(), 1e-10 * Y.abs().max())
data = np.random.randn(36).reshape((6,6))
et1 = ElasticTensor(data, 'voigt')

# back and forth transformation to cartesian ensures that voigt is symmetric
_ = et1.voigt_to_cartesian()
_ = et1.cartesian_to_voigt()

_ = et1.voigt_to_cartesian()
_ = et1.cartesian_to_covariant()
_ = et1.covariant_to_spherical()

et2 = ElasticTensor(et1.covariant.copy(), 'covariant')

a, b, c = rand_angles()

D0 = irr_repr(0, a, b, c).numpy()
D1 = irr_repr(1, a, b, c).numpy()
D2 = irr_repr(2, a, b, c).numpy()
D4 = irr_repr(4, a, b, c).numpy()

# forward rotation
et2.covariant = np.einsum('ijkn,ai,bj,ck,dn->abcd', et2.covariant, D1, D1, D1, D1)
_ = et2.covariant_to_spherical()

# inverse rotation 
et2.spherical[0:1] = D0.T @ et2.spherical[0:1]
et2.spherical[1:6] = D2.T @ et2.spherical[1:6]
et2.spherical[6:7] = D0.T @ et2.spherical[6:7]
et2.spherical[7:12] = D2.T @ et2.spherical[7:12]
et2.spherical[12:21] = D4.T @ et2.spherical[12:21]