示例#1
0
    def test3(self):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(Kernel, RadialModel=ConstantRadialModel)
            C = partial(Convolution, K)

            f = GatedBlock(partial(C, Rs_in),
                           Rs_out,
                           scalar_activation=sigmoid,
                           gate_activation=sigmoid)

            abc = torch.randn(3)
            rot_geo = rot(*abc)
            D_in = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)])

            fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo)))
            x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)),
                   torch.einsum("ij,zaj->zai", rot_geo, geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#2
0
def __clebsch_gordan(l1, l2, l3, _version=4):
    """
    Computes the Clebsch–Gordan coefficients

    D(l1)_il D(l2)_jm D(l3)_kn Q_lmn == Q_ijk
    """
    # these three propositions are equivalent
    assert abs(l2 - l3) <= l1 <= l2 + l3
    assert abs(l3 - l1) <= l2 <= l3 + l1
    assert abs(l1 - l2) <= l3 <= l1 + l2

    with torch_default_dtype(torch.float64):
        null_space = _get_d_null_space(l1, l2, l3)

        assert null_space.size(
            0) == 1, null_space.size()  # unique subspace solution
        Q = null_space[0]
        Q = Q.view(2 * l1 + 1, 2 * l2 + 1, 2 * l3 + 1)

        if next(x for x in Q.flatten() if x.abs() > 1e-10 * Q.abs().max()) < 0:
            Q.neg_()

        abc = torch.rand(3)
        _Q = torch.einsum(
            "il,jm,kn,lmn",
            (irr_repr(l1, *abc), irr_repr(l2, *abc), irr_repr(l3, *abc), Q))
        assert torch.allclose(Q, _Q)

    assert Q.dtype == torch.float64
    return Q  # [m1, m2, m3]
示例#3
0
    def test5(self):
        """Test parity equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)
            C = partial(Convolution, K)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            f = GatedBlockParity(C, Rs_in, *scalars, *gates, rs_nonscalars)

            D_in = direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in f.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo)))
            x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo)
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#4
0
    def parity_rotation_gated_block_parity(self, K):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            K = partial(K, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc, 1)
            D_out = rs.rep(act.Rs_out, *abc, 1)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#5
0
def xyz3x3_to_irreducible_basis():
    """
    to convert a 3x3 tensor transforming with xyz3x3_repr(a, b, c)
    into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c)
    see assert for usage
    """
    with torch_default_dtype(torch.float64):
        to1 = torch.tensor([
            [1, 0, 0, 0, 1, 0, 0, 0, 1],
        ], dtype=torch.get_default_dtype())
        assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ xyz3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))

        to3 = torch.tensor([
            [0, 0, -1, 0, 0, 0, 1, 0, 0],
            [0, 1, 0, -1, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, -1, 0],
        ], dtype=torch.get_default_dtype())
        assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ xyz3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))

        to5 = torch.tensor([
            [0, 1, 0, 1, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, 1, 0],
            [-3**.5 / 3, 0, 0, 0, -3**.5 / 3, 0, 0, 0, 12**.5 / 3],
            [0, 0, 1, 0, 0, 0, 1, 0, 0],
            [1, 0, 0, 0, -1, 0, 0, 0, 0]
        ], dtype=torch.get_default_dtype())
        assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ xyz3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))

    return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype())
示例#6
0
    def test1(self):
        with torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)]
            Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)]

            f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5),
                           rescaled_act.sigmoid)
            c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel))

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 5, 3)

            rx = torch.einsum("ij,zaj->zai", (D_in, x))
            rgeo = geo @ o3.rot(*abc).t()

            y = f(c(x, geo), dim=2)
            ry = torch.einsum("ij,zaj->zai", (D_out, y))

            self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
示例#7
0
    def rotation_gated_block(self, K):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(K, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#8
0
def xyz_to_irreducible_basis(check=True):
    """
    to convert a vector [x, y, z] transforming with rot(a, b, c)
    into a vector transforming with irr_repr(1, a, b, c)
    see assert for usage
    """
    with torch_default_dtype(torch.float64):
        A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64)
        if check:
            assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3))
    return A.type(torch.get_default_dtype())
示例#9
0
def spherical_basis_vector_to_xyz_basis(check=True):
    """
    to convert a vector transforming with irr_repr(1, a, b, c)
    into a vector [x, y, z] transforming with rot(a, b, c)
    see assert for usage

    Inverse of xyz_vector_basis_to_spherical_basis
    """
    with torch_default_dtype(torch.float64):
        A = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float64)
        if check:
            assert all(torch.allclose(A @ irr_repr(1, a, b, c), rot(a, b, c) @ A) for a, b, c in torch.rand(10, 3))
    return A.type(torch.get_default_dtype())
示例#10
0
    def parity_kernel(self, K):
        """Test parity equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)]
            Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)]

            k = K(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            D_in = rs.rep(Rs_in, 0, 0, 0, 1)
            D_out = rs.rep(Rs_out, 0, 0, 0, 1)

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(-r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
示例#11
0
def reduce_tensor_product(Rs_i, Rs_j):
    """
    Compute the orthonormal change of basis Q
    from Rs_reduced to Rs_i tensor product with Rs_j
    where Rs_reduced is a direct sum of irreducible representations

    :return: Rs_reduced, Q
    """
    with torch_default_dtype(torch.float64):
        Rs_i = normalizeRs(Rs_i)
        Rs_j = normalizeRs(Rs_j)

        n_i = sum(mul * (2 * l + 1) for mul, l, p in Rs_i)
        n_j = sum(mul * (2 * l + 1) for mul, l, p in Rs_j)
        out = torch.zeros(n_i, n_j, n_i * n_j, dtype=torch.float64)

        Rs_reduced = []
        beg = 0

        beg_i = 0
        for mul_i, l_i, p_i in Rs_i:
            n_i = mul_i * (2 * l_i + 1)

            beg_j = 0
            for mul_j, l_j, p_j in Rs_j:
                n_j = mul_j * (2 * l_j + 1)

                for l in range(abs(l_i - l_j), l_i + l_j + 1):
                    Rs_reduced.append((mul_i * mul_j, l, p_i * p_j))
                    n = mul_i * mul_j * (2 * l + 1)

                    # put sqrt(2l+1) to get an orthonormal output
                    Q = math.sqrt(2 * l + 1) * clebsch_gordan(
                        l_i, l_j, l)  # [m_i, m_j, m]
                    I = torch.eye(mul_i * mul_j).view(
                        mul_i, mul_j,
                        mul_i * mul_j)  # [mul_i, mul_j, mul_i * mul_j]

                    Q = torch.einsum("ijk,mno->imjnko", (I, Q))

                    view = out[beg_i:beg_i + n_i, beg_j:beg_j + n_j,
                               beg:beg + n]
                    view.add_(Q.view_as(view))

                    beg += n
                beg_j += n_j
            beg_i += n_i
        return Rs_reduced, out
示例#12
0
    def rotation_kernel(self, K):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = K(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
示例#13
0
    def test_equivariance_s2network(self):
        with torch_default_dtype(torch.float64):
            mul = 3
            Rs_in = [(mul, l) for l in range(3 + 1)]
            Rs_out = [(mul, l) for l in range(3 + 1)]

            net = S2Network(Rs_in, mul, lmax=4, Rs_out=Rs_out)

            abc = o3.rand_angles()
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(10, rs.dim(Rs_in))

            x1 = torch.einsum("ij,zj->zi", D_out, net(fea))
            x2 = net(torch.einsum("ij,zj->zi", D_in, fea))
            self.assertLess((x1 - x2).norm(), 1e-3 * x1.norm())
示例#14
0
    def test2(self):
        with torch_default_dtype(torch.float64):
            mul = 100000
            for l_in in range(4):
                Rs_in = [(mul, l_in)]
                for l_out in range(4):
                    Rs_out = [(1, l_out)]

                    k = Kernel(Rs_in,
                               Rs_out,
                               ConstantRadialModel,
                               normalization='norm')
                    k = k(torch.randn(1, 3))

                    self.assertLess(k.mean().item(), 1e-3)
                    self.assertAlmostEqual(k.var().item() * mul,
                                           1 / (2 * l_out + 1),
                                           places=1)
示例#15
0
    def test2(self):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
示例#16
0
    def parity_rotation_linear(self, L):
        """Test parity and rotation equivariance on Linear."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]
            Rs_out = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            lin = L(Rs_in, Rs_out)

            abc = torch.randn(3)
            D_in = rs.rep(lin.Rs_in, *abc, 1)
            D_out = rs.rep(lin.Rs_out, *abc, 1)

            fea = torch.randn(rs.dim(Rs_in))

            x1 = torch.einsum("ij,j->i", D_out, lin(fea))
            x2 = lin(torch.einsum("ij,j->i", D_in, fea))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#17
0
    def __init__(self, lmax, res=None, normalization='component'):
        """
        :param lmax: lmax of the input signal
        :param res: resolution of the output as a tuple (beta resolution, alpha resolution)
        :param normalization: either 'norm' or 'component'
        """
        super().__init__()

        assert normalization in [
            'norm', 'component'
        ], "normalization needs to be 'norm' or 'component'"

        if isinstance(res, int):
            res_beta, res_alpha = res, res
        elif res is None:
            res_beta = 2 * (lmax + 1)
            res_alpha = 2 * res_beta
        else:
            res_beta, res_alpha = res
        del res
        assert res_beta % 2 == 0
        assert res_beta >= 2 * (lmax + 1)

        alphas, betas, sha, shb = spherical_harmonics_s2_grid(
            lmax, res_alpha, res_beta)

        with torch_default_dtype(torch.float64):
            # normalize such that all l has the same variance on the sphere
            if normalization == 'component':
                n = math.sqrt(4 * math.pi) * torch.tensor(
                    [1 / math.sqrt(2 * l + 1)
                     for l in range(lmax + 1)]) / math.sqrt(lmax + 1)
            if normalization == 'norm':
                n = math.sqrt(
                    4 * math.pi) * torch.ones(lmax + 1) / math.sqrt(lmax + 1)
            m = rsh.spherical_harmonics_expand_matrix(lmax)  # [l, m, i]
        shb = torch.einsum('lmj,bj,lmi,l->mbi', m, shb, m, n)  # [m, b, i]

        self.register_buffer('alphas', alphas)
        self.register_buffer('betas', betas)
        self.register_buffer('sha', sha)
        self.register_buffer('shb', shb)
        self.to(torch.get_default_dtype())
示例#18
0
    def test_equivariance_gatedconvnetwork(self):
        with torch_default_dtype(torch.float64):
            mul = 3
            Rs_in = [(mul, l) for l in range(3 + 1)]
            Rs_out = [(mul, l) for l in range(3 + 1)]

            net = GatedConvNetwork(Rs_in, [(10, 0), (1, 1), (1, 2), (1, 3)],
                                   Rs_out)

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 10, rs.dim(Rs_in))
            geo = torch.randn(1, 10, 3)

            x1 = torch.einsum("ij,zaj->zai", D_out, net(fea, geo))
            x2 = net(torch.einsum("ij,zaj->zai", D_in, fea),
                     torch.einsum("ij,zaj->zai", rot_geo, geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#19
0
    def test4(self):
        """Test parity equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)]
            Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            D_in = direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_out
                for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(-r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
示例#20
0
    def test6(self):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
示例#21
0
    def test_irr_repr_wigner_3j(self):
        """Test irr_repr and wigner_3j equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.wigner_3j(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = rsh.spherical_harmonics_xyz([l_f], r)
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
示例#22
0
    def test1(self):
        """Test irr_repr and clebsch_gordan equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = clebsch_gordan(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = irr_repr(l_in, *abc)
                D_out = irr_repr(l_out, *abc)

                Y = spherical_harmonics_xyz(l_f, r @ rot(*abc).t())
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = spherical_harmonics_xyz(l_f, r)
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
示例#23
0
 def test_basis_equivariance(self):
     with torch_default_dtype(torch.float64):
         basis = cube_basis_kernels(4 * 5, 2, 2, partial(gaussian_window, radii=[5], J_max_list=[999], sigma=2))
         overlaps = check_basis_equivariance(basis, 2, 2, *torch.rand(3))
         self.assertTrue(overlaps.gt(0.98).all(), overlaps)
示例#24
0
def spherical_harmonics_xyz(order,
                            xyz,
                            sph_last=False,
                            dtype=None,
                            device=None):
    """
    spherical harmonics

    :param order: int or list
    :param xyz: tensor of shape [..., 3]
    :param sph_last: return the spherical harmonics in the last channel
    :param dtype:
    :param device:
    :return: tensor of shape [m, ...] (or [..., m] if sph_last)
    """
    try:
        order = list(order)
    except TypeError:
        order = [order]

    if dtype is None and torch.is_tensor(xyz):
        dtype = xyz.dtype
    if dtype is None:
        dtype = torch.get_default_dtype()

    if device is None and torch.is_tensor(xyz):
        device = xyz.device

    if not torch.is_tensor(xyz):
        xyz = torch.tensor(xyz, dtype=torch.float64)

    with torch_default_dtype(torch.float64):
        if device.type == 'cuda' and max(order) <= 10:
            max_l = max(order)
            out = xyz.new_empty(((max_l + 1) * (max_l + 1),
                                 xyz.size(0)))  # [ filters, batch_size]
            xyz_unit = torch.nn.functional.normalize(xyz, p=2, dim=-1)
            real_spherical_harmonics.rsh(out, xyz_unit)
            # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined  # h - halved
            norm_coef = [
                elem for lh in range((max_l + 1) // 2)
                for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)
            ]
            if max_l % 2 == 0:
                norm_coef.extend([1.] * (2 * max_l + 1))
            norm_coef = torch.tensor(norm_coef, device=device).unsqueeze(1)
            out.mul_(norm_coef)
            if order != list(range(max_l + 1)):
                keep_rows = torch.zeros(out.size(0), dtype=torch.bool)
                for l in order:
                    keep_rows[(l * l):((l + 1) * (l + 1))].fill_(True)
                out = out[keep_rows.to(device)]
        else:
            alpha, beta = xyz_to_angles(xyz)  # two tensors of shape [...]
            out = spherical_harmonics(order, alpha, beta)  # [m, ...]

            # fix values when xyz = 0
            val = xyz.new_tensor([1 / math.sqrt(4 * math.pi)])
            val = torch.cat([
                val if l == 0 else xyz.new_zeros(2 * l + 1) for l in order
            ])  # [m]
            out[:, xyz.norm(2, -1) == 0] = val.view(-1, 1)

        if sph_last:
            rank = len(out.shape)
            return out.to(dtype=dtype,
                          device=device).permute(*range(1, rank),
                                                 0).contiguous()
        else:
            return out.to(dtype=dtype, device=device)
示例#25
0
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs):
    """
    Usage
    Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)])
    Rs = 0,2,4
    Q = tensor of shape [15, 81]
    """
    with torch_default_dtype(torch.float64):
        formulas = [(-1 if f.startswith('-') else 1, f.replace('-', ''))
                    for f in formula.split('=')]
        s0, f0 = formulas[0]
        assert s0 == 1

        for _s, f in formulas:
            if len(set(f)) != len(f) or set(f) != set(f0):
                raise RuntimeError(f'{f} is not a permutation of {f0}')
            if len(f0) != len(f):
                raise RuntimeError(
                    f'{f0} and {f} don\'t have the same number of indices')

        formulas = {(s, tuple(f.index(i) for i in f0))
                    for s, f in formulas}  # set of generators (permutations)

        # create the entire group
        while True:
            n = len(formulas)
            formulas = formulas.union([(s, perm.inverse(p))
                                       for s, p in formulas])
            formulas = formulas.union([(s1 * s2, perm.compose(p1, p2))
                                       for s1, p1 in formulas
                                       for s2, p2 in formulas])
            if len(formulas) == n:
                break

        for i in kw_Rs:
            if not callable(kw_Rs[i]):
                Rs = convention(kw_Rs[i])
                if has_parity is None:
                    has_parity = any(p != 0 for _, _, p in Rs)
                if not has_parity and not all(p == 0 for _, _, p in Rs):
                    raise RuntimeError(
                        f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere'
                    )
                if has_parity and any(p == 0 for _, _, p in Rs):
                    raise RuntimeError(
                        f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere'
                    )
                kw_Rs[i] = Rs

        if has_parity is None:
            raise RuntimeError(f'please specify the argument `has_parity`')

        for _s, p in formulas:
            f = "".join(f0[i] for i in p)
            for i, j in zip(f0, f):
                if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]:
                    raise RuntimeError(
                        f'Rs of {i} (Rs={format_Rs(kw_Rs[i])}) and {j} (Rs={format_Rs(kw_Rs[j])}) should be the same'
                    )
                if i in kw_Rs:
                    kw_Rs[j] = kw_Rs[i]
                if j in kw_Rs:
                    kw_Rs[i] = kw_Rs[j]

        for i in f0:
            if i not in kw_Rs:
                raise RuntimeError(f'index {i} has not Rs associated to it')

        e = (0, 0, 0, 0) if has_parity else (0, 0, 0)
        full_base = list(
            itertools.product(*(range(
                len(kw_Rs[i](*e)) if callable(kw_Rs[i]) else dim(kw_Rs[i]))
                                for i in f0)))

        base = set()
        for x in full_base:
            xs = {(s, tuple(x[i] for i in p)) for s, p in formulas}
            # s * T[x] all equal for (s, x) in xs
            if not (-1, x) in xs:
                # the sign is arbitrary, put both possibilities
                base.add(
                    frozenset(
                        {frozenset(xs),
                         frozenset({(-s, x)
                                    for s, x in xs})}))

        base = sorted([
            sorted([sorted(xs) for xs in x]) for x in base
        ])  # requested for python 3.7 but not for 3.8 (probably a bug in 3.7)

        d_sym = len(base)
        d = len(full_base)
        Q = torch.zeros(d_sym, d)

        for i, x in enumerate(base):
            x = max(x, key=lambda xs: sum(s for s, x in xs))
            for s, e in x:
                j = full_base.index(e)
                Q[i, j] = s / len(x)**0.5

        assert torch.allclose(Q @ Q.T, torch.eye(d_sym))

        if d_sym == 0:
            return [], torch.zeros(d_sym, d)

        def representation(alpha, beta, gamma, parity=None):
            def re(r):
                if callable(r):
                    if has_parity:
                        return r(alpha, beta, gamma, parity)
                    return r(alpha, beta, gamma)
                return rep(r, alpha, beta, gamma, parity)

            m = o3.kron(*(re(kw_Rs[i]) for i in f0))
            return Q @ m @ Q.T

        assert _is_representation(representation, eps, has_parity)

        Rs_out = []
        A = Q.clone()
        for l in range(int((d_sym - 1) // 2) + 1):
            for p in [-1, 1] if has_parity else [0]:
                if 2 * l + 1 > d_sym - dim(Rs_out):
                    break

                mul, B, representation = o3.reduce(representation,
                                                   partial(rep, [(1, l, p)]),
                                                   eps, has_parity)
                A = o3.direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A
                A = _round_sqrt(A, eps)
                Rs_out += [(mul, l, p)]

                if dim(Rs_out) == d_sym:
                    break

        if dim(Rs_out) != d_sym:
            raise RuntimeError(
                f'unable to decompose into irreducible representations')
        return simplify(Rs_out), A