コード例 #1
0
def test_quaternion():
    torch.set_default_dtype(torch.float64)
    abc1 = o3.rand_angles()
    abc2 = o3.rand_angles()
    q1 = o3.angles_to_quaternion(*abc1)
    q2 = o3.angles_to_quaternion(*abc2)

    abc = o3.compose_angles(*abc1, *abc2)
    q = o3.compose_quaternion(q1, q2)

    qq = o3.angles_to_quaternion(*abc)
    assert min((q - qq).abs().max(), (q + qq).abs().max()) < 1e-10
コード例 #2
0
def test_sh_is_in_irrep():
    torch.set_default_dtype(torch.float64)
    for l in range(4 + 1):
        a, b, _ = o3.rand_angles()
        Y = o3.spherical_harmonics_alpha_beta([l], a, b) * math.sqrt(
            4 * math.pi) / math.sqrt(2 * l + 1) * (-1)**l
        D = o3.irrep(l, a, b, 0)
        assert (Y - D[:, l]).abs().max() < 1e-10
コード例 #3
0
def test_sh_equivariance1():
    """test
    - compose
    - spherical_harmonics_alpha_beta
    - irrep
    """
    torch.set_default_dtype(torch.float64)
    for l in range(7 + 1):
        a, b, _ = o3.rand_angles()
        alpha, beta, gamma = o3.rand_angles()

        ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0)
        Yrx = o3.spherical_harmonics_alpha_beta([l], ra, rb)

        Y = o3.spherical_harmonics_alpha_beta([l], a, b)
        DrY = o3.irrep(l, alpha, beta, gamma) @ Y

        assert (Yrx - DrY).abs().max() < 1e-10 * Y.abs().max()
コード例 #4
0
def test_reduce_tensor_Levi_Civita_symbol():
    torch.set_default_dtype(torch.float64)

    Rs, Q = reduce_tensor('ijk=-ikj=-jik', i=[(1, 1)])
    assert Rs == [(1, 0, 0)]
    r = o3.rand_angles()
    D = o3.irrep(1, *r)
    Q = Q.reshape(3, 3, 3)
    Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q)
    assert (Q1 - Q).abs().max() < 1e-10
コード例 #5
0
def test_reduce_tensor_Levi_Civita_symbol():
    torch.set_default_dtype(torch.float64)

    irreps, Q = o3.reduce_tensor('ijk=-ikj=-jik', i='1e')
    assert irreps == ((1, (0, 1)),)
    r = o3.rand_angles()
    D = o3.wigner_D(1, *r)
    Q = Q.reshape(3, 3, 3)
    Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q)
    assert (Q1 - Q).abs().max() < 1e-10
コード例 #6
0
def test_reduce_tensor_equivariance():
    torch.set_default_dtype(torch.float64)

    ir = o3.Irreps('1e')
    irreps, Q = o3.reduce_tensor('ijkl=jikl=klij', i=ir)

    abc = o3.rand_angles()
    R = ir.D(*abc)
    D = irreps.D(*abc)

    q1 = torch.einsum('qmnop,mi,nj,ok,pl->qijkl', Q, R, R, R, R)
    q2 = torch.einsum('qa,aijkl->qijkl', D, Q)

    assert (q1 - q2).abs().max() < 1e-10
コード例 #7
0
def test_reduce_tensor_antisymmetric_L2():
    torch.set_default_dtype(torch.float64)

    Rs, Q = reduce_tensor('ijk=-ikj=-jik', i=[(1, 2)])
    assert Rs[0] == (1, 1, 0)
    q = Q[:3].reshape(3, 5, 5, 5)

    r = o3.rand_angles()
    D1 = o3.irrep(1, *r)
    D2 = o3.irrep(2, *r)
    Q1 = torch.einsum('il,jm,kn,zijk->zlmn', D2, D2, D2, q)
    Q2 = torch.einsum('yz,zijk->yijk', D1, q)

    assert (Q1 - Q2).abs().max() < 1e-10
    assert (q + q.transpose(1, 2)).abs().max() < 1e-10
    assert (q + q.transpose(1, 3)).abs().max() < 1e-10
    assert (q + q.transpose(3, 2)).abs().max() < 1e-10
コード例 #8
0
def test_sh_equivariance2():
    """test
    - rot
    - rep
    - spherical_harmonics
    """
    torch.set_default_dtype(torch.float64)
    Rs = [0, 1, 2, 3, 4, 5, 6]

    abc = o3.rand_angles()
    R = o3.rot(*abc)
    D = o3.rep(Rs, *abc)

    x = torch.randn(10, 3)

    y1 = o3.spherical_harmonics(Rs, x @ R.T)
    y2 = o3.spherical_harmonics(Rs, x) @ D.T

    assert (y1 - y2).abs().max() < 1e-10
コード例 #9
0
def test_sh_equivariance2():
    """test
    - rot
    - rep
    - spherical_harmonics
    """
    torch.set_default_dtype(torch.float64)
    irreps = o3.Irreps("0e + 1o + 2e + 3o + 4e")

    abc = o3.rand_angles()
    R = o3.rot(*abc)
    D = irreps.D(*abc)

    x = torch.randn(10, 3)

    y1 = o3.spherical_harmonics(irreps, x @ R.T)
    y2 = o3.spherical_harmonics(irreps, x) @ D.T

    assert (y1 - y2).abs().max() < 1e-10
コード例 #10
0
ファイル: groups.py プロジェクト: blondegeek/e3nn_little
 def random(self):
     return o3.rand_angles() + (random.choice([0, 1]), )
コード例 #11
0
ファイル: groups.py プロジェクト: blondegeek/e3nn_little
 def random(self):
     return o3.rand_angles()