예제 #1
0
    def __init__(self, n, lmax):
        super().__init__()
        self.n = n
        self.lmax = lmax

        R = o3.rot(math.pi / 2, math.pi / 2, math.pi / 2)
        self.xyz1, self.proj1 = self.precompute(R)

        R = o3.rot(0, 0, 0)
        self.xyz2, self.proj2 = self.precompute(R)
예제 #2
0
    def find_peaks(self, res=100):
        x1, f1 = self.signal_on_grid(res)

        abc = pi / 2, pi / 2, pi / 2
        R = o3.rot(*abc)
        D = rs.rep(self.Rs, *abc)

        rtensor = SphericalTensor(D @ self.signal)
        rx2, f2 = rtensor.signal_on_grid(res)
        x2 = torch.einsum('ij,baj->bai', R.T, rx2)

        ij = _find_peaks_2d(f1)
        x1p = torch.stack([x1[i, j] for i, j in ij])
        f1p = torch.stack([f1[i, j] for i, j in ij])

        ij = _find_peaks_2d(f2)
        x2p = torch.stack([x2[i, j] for i, j in ij])
        f2p = torch.stack([f2[i, j] for i, j in ij])

        # Union of the results
        mask = torch.cdist(x1p, x2p) < 2 * pi / res
        x = torch.cat([x1p[mask.sum(1) == 0], x2p])
        f = torch.cat([f1p[mask.sum(1) == 0], f2p])

        return x, f
예제 #3
0
    def test1(self):
        with torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)]
            Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)]

            f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5),
                           rescaled_act.sigmoid)
            c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel))

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 5, 3)

            rx = torch.einsum("ij,zaj->zai", (D_in, x))
            rgeo = geo @ o3.rot(*abc).t()

            y = f(c(x, geo), dim=2)
            ry = torch.einsum("ij,zaj->zai", (D_out, y))

            self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
예제 #4
0
def test_equivariance_wtp(Rs_in, Rs_out, n_source, n_target, n_edge):
    torch.set_default_dtype(torch.float64)

    mp = WTPConv(Rs_in, Rs_out, 3, ConstantRadialModel)

    features = rs.randn(n_target, Rs_in)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge, )),
        torch.randint(n_target, size=(n_edge, )),
    ])
    size = (n_target, n_source)

    edge_r = torch.randn(n_edge, 3)
    if n_edge > 1:
        edge_r[0] = 0

    out1 = mp(features, edge_index, edge_r, size=size)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out

    assert (out1 - out2).abs().max() < 1e-10
예제 #5
0
def check_rotation(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(1, 0), (1, 1)]
    Rs_out = [(1, 0), (1, 1), (1, 2)]
    act = GatedBlock(
        Rs_out,
        scalar_activation=sigmoid,
        gate_activation=absolute,
    )
    conv = Convolution(K, Rs_in, act.Rs_in)

    # Setup the data. The geometry, input features, and output features must all rotate.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = o3.rot(*abc)
    D_in = rs.rep(Rs_in, *abc)
    D_out = rs.rep(Rs_out, *abc)

    feat = torch.randn(batch, n_atoms, rs.dim(Rs_in))  # Transforms with wigner D matrix
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix.

    # Test equivariance.
    F = act(conv(feat, geo))
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t()))
    return (RF - FR).norm() < 10e-5 * RF.norm()
예제 #6
0
def check_rotation_parity(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(1, 0, +1)]
    act = GatedBlockParity(
        Rs_scalars=[(4, 0, +1)],
        act_scalars=[(-1, relu)],
        Rs_gates=[(8, 0, +1)],
        act_gates=[(-1, tanh)],
        Rs_nonscalars=[(4, 1, -1), (4, 2, +1)]
    )
    conv = Convolution(K, Rs_in, act.Rs_in)
    Rs_out = act.Rs_out

    # Setup the data. The geometry, input features, and output features must all rotate and observe parity.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = -o3.rot(*abc)  # Negative because geometry has odd parity. i.e. improper rotation.
    D_in = rs.rep(Rs_in, *abc, parity=1)
    D_out = rs.rep(Rs_out, *abc, parity=1)

    feat = torch.randn(batch, n_atoms, rs.dim(Rs_in))  # Transforms with wigner D matrix and parity.
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix and parity.

    # Test equivariance.
    F = act(conv(feat, geo))
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t()))
    return (RF - FR).norm() < 10e-5 * RF.norm()
예제 #7
0
    def test3(self):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
예제 #8
0
    def rotation_gated_block(self, K):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(K, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
예제 #9
0
 def test_rot_to_abc(self):
     with o3.torch_default_dtype(torch.float64):
         R = o3.rand_rot()
         abc = o3.rot_to_abc(R)
         R2 = o3.rot(*abc)
         d = (R - R2).norm() / R.norm()
         self.assertTrue(d < 1e-10, d)
예제 #10
0
    def parity_rotation_gated_block_parity(self, K):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            K = partial(K, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc, 1)
            D_out = rs.rep(act.Rs_out, *abc, 1)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
예제 #11
0
def random_rotate_translate(positions, rotation=True, translation=1):
    while True:
        trans = torch.rand(3) * 2 - 1
        if trans.norm() <= 1:
            break
    rot = o3.rot(*torch.rand(3) * 6.2832).type(torch.float32)
    return [rot @ pos + translation * trans for pos in positions]
예제 #12
0
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge):
    torch.set_default_dtype(torch.float64)

    mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel))
    groups = 4
    mp_group = Convolution(
        GroupKernel(Rs_in, Rs_out,
                    partial(Kernel, RadialModel=ConstantRadialModel), groups))

    features = rs.randn(n_target, Rs_in)
    features2 = rs.randn(n_target, Rs_in * groups)

    r_source = torch.randn(n_source, 3)
    r_target = torch.randn(n_target, 3)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge, )),
        torch.randint(n_target, size=(n_edge, )),
    ])
    size = (n_target, n_source)

    if n_edge == 0:
        edge_r = torch.zeros(0, 3)
    else:
        edge_r = torch.stack(
            [r_target[j] - r_source[i] for i, j in edge_index.T])
    print(features.shape, edge_index.shape, edge_r.shape, size)
    out1 = mp(features, edge_index, edge_r, size=size)
    out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups)
    out1_kernel_groups = mp_group(features2,
                                  edge_index,
                                  edge_r,
                                  size=size,
                                  groups=groups)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    D_in_groups = rs.rep(Rs_in * groups, *angles)
    D_out_groups = rs.rep(Rs_out * groups, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out
    out2_groups = mp(features2 @ D_in_groups.T,
                     edge_index,
                     edge_r @ R.T,
                     size=size,
                     groups=groups) @ D_out_groups
    out2_kernel_groups = mp_group(features2 @ D_in_groups.T,
                                  edge_index,
                                  edge_r @ R.T,
                                  size=size,
                                  groups=groups) @ D_out_groups

    assert (out1 - out2).abs().max() < 1e-10
    assert (out1_groups - out2_groups).abs().max() < 1e-10
    assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
예제 #13
0
    def test_xyz_vector_basis_to_spherical_basis(self, ):
        with o3.torch_default_dtype(torch.float64):
            A = o3.xyz_vector_basis_to_spherical_basis()

            a, b, c = torch.rand(3)

            r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A
            r2 = o3.rot(a, b, c)

            self.assertLess((r1 - r2).abs().max(), 1e-10)
예제 #14
0
    def test_xyz_to_irreducible_basis(self, ):
        with o3.torch_default_dtype(torch.float64):
            A = o3.xyz_to_irreducible_basis()

            a, b, c = torch.rand(3)

            r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A
            r2 = o3.rot(a, b, c)

            assert torch.allclose(r1, r2)
예제 #15
0
def test_equivariance():
    torch.set_default_dtype(torch.float64)

    n_edge = 3
    n_source = 4
    n_target = 2

    Rs_in = [(3, 0), (0, 1)]
    Rs_mid1 = [(5, 0), (1, 1)]
    Rs_mid2 = [(5, 0), (1, 1), (1, 2)]
    Rs_out = [(5, 1), (3, 2)]

    convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel))
    convolution_groups = lambda Rs_in, Rs_out: Convolution(
        GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups))
    groups = 4
    mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution)
    mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups)

    features = rs.randn(n_target, Rs_in)

    r_source = torch.randn(n_source, 3)
    r_target = torch.randn(n_target, 3)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge,)),
        torch.randint(n_target, size=(n_edge,)),
    ])
    size = (n_target, n_source)

    if n_edge == 0:
        edge_r = torch.zeros(0, 3)
    else:
        edge_r = torch.stack([
            r_target[j] - r_source[i]
            for i, j in edge_index.T
        ])
    print(features.shape, edge_index.shape, edge_r.shape, size)
    out1 = mp(features, edge_index, edge_r, size=size)
    out1_groups = mp_groups(features, edge_index, edge_r, size=size)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out
    out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out

    assert (out1 - out2).abs().max() < 1e-10
    assert (out1_groups - out2_groups).abs().max() < 1e-10
예제 #16
0
    def rotation_kernel(self, K):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = K(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
예제 #17
0
def test_s2conv_network():
    torch.set_default_dtype(torch.float64)

    lmax = 3
    Rs = [(1, l, 1) for l in range(lmax + 1)]
    model = S2ConvNetwork(Rs, 4, Rs, lmax)

    features = rs.randn(1, 4, Rs)
    geometry = torch.randn(1, 4, 3)

    output = model(features, geometry)

    angles = o3.rand_angles()
    D = rs.rep(Rs, *angles, 1)
    R = -o3.rot(*angles)
    ein = torch.einsum
    output2 = ein('ij,zaj->zai', D.T, model(ein('ij,zaj->zai', D, features), ein('ij,zaj->zai', R, geometry)))

    assert (output - output2).abs().max() < 1e-10 * output.abs().max()
예제 #18
0
    def test2(self):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
예제 #19
0
    def test_equivariance_gatedconvnetwork(self):
        with torch_default_dtype(torch.float64):
            mul = 3
            Rs_in = [(mul, l) for l in range(3 + 1)]
            Rs_out = [(mul, l) for l in range(3 + 1)]

            net = GatedConvNetwork(Rs_in, [(10, 0), (1, 1), (1, 2), (1, 3)],
                                   Rs_out)

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 10, rs.dim(Rs_in))
            geo = torch.randn(1, 10, 3)

            x1 = torch.einsum("ij,zaj->zai", D_out, net(fea, geo))
            x2 = net(torch.einsum("ij,zaj->zai", D_in, fea),
                     torch.einsum("ij,zaj->zai", rot_geo, geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
예제 #20
0
    def test6(self):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
예제 #21
0
    def test_irr_repr_wigner_3j(self):
        """Test irr_repr and wigner_3j equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.wigner_3j(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = rsh.spherical_harmonics_xyz([l_f], r)
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)