Exemplo n.º 1
0
    def test_tensor_square_norm(self):
        for Rs_in in [[(1, 0), (2, 1), (4, 3)]]:
            with o3.torch_default_dtype(torch.float64):
                Rs_out, Q = rs.tensor_square(Rs_in,
                                             o3.selection_rule,
                                             normalization='component',
                                             sorted=True)

                abc = o3.rand_angles()

                D_in = rs.rep(Rs_in, *abc)
                D_out = rs.rep(Rs_out, *abc)

                Q1 = torch.einsum("ijk,il->ljk", (Q, D_out))
                Q2 = torch.einsum("li,mj,kij->klm", (D_in, D_in, Q))

                d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                n = Q.size(0)
                M = Q.reshape(n, -1)
                I = torch.eye(n)

                d = ((M @ M.t()) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)
Exemplo n.º 2
0
        def test(Rs, act):
            x = torch.randn(55, sum(2 * l + 1 for _, l, _ in Rs))
            ac = S2Activation(Rs, act, 1000)

            y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, 0, 0, 0, -1).T
            y2 = ac(x @ rs.rep(Rs, 0, 0, 0, -1).T, dim=-1)
            self.assertLess((y1 - y2).abs().max(), 1e-10)
Exemplo n.º 3
0
    def test_tensor_product_norm(self):
        for Rs_in1, Rs_in2 in [([(1, 0)], [(2, 0)]),
                               ([(3, 1), (2, 2)], [(2, 0), (1, 1), (1, 3)])]:
            with o3.torch_default_dtype(torch.float64):
                Rs_out, Q = rs.tensor_product(Rs_in1, Rs_in2,
                                              o3.selection_rule)

                abc = torch.rand(3, dtype=torch.float64)

                D_in1 = rs.rep(Rs_in1, *abc)
                D_in2 = rs.rep(Rs_in2, *abc)
                D_out = rs.rep(Rs_out, *abc)

                Q1 = torch.einsum("ijk,il->ljk", (Q, D_out))
                Q2 = torch.einsum("li,mj,kij->klm", (D_in1, D_in2, Q))

                d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                n = Q.size(0)
                M = Q.reshape(n, n)
                I = torch.eye(n, dtype=M.dtype)

                d = ((M @ M.t()) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                d = ((M.t() @ M) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)
Exemplo n.º 4
0
    def rotation_gated_block(self, K):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(K, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemplo n.º 5
0
def test_equivariance_wtp(Rs_in, Rs_out, n_source, n_target, n_edge):
    torch.set_default_dtype(torch.float64)

    mp = WTPConv(Rs_in, Rs_out, 3, ConstantRadialModel)

    features = rs.randn(n_target, Rs_in)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge, )),
        torch.randint(n_target, size=(n_edge, )),
    ])
    size = (n_target, n_source)

    edge_r = torch.randn(n_edge, 3)
    if n_edge > 1:
        edge_r[0] = 0

    out1 = mp(features, edge_index, edge_r, size=size)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out

    assert (out1 - out2).abs().max() < 1e-10
Exemplo n.º 6
0
def test_custom_weighted_tensor_product():
    torch.set_default_dtype(torch.float64)

    Rs_in1 = [(20, 1), (4, 2)]
    Rs_in2 = [(10, 0), (10, 1), (4, 2)]
    Rs_out = [(3, 0), (4, 1)]

    instr = [
        (0, 1, 0, 'uvw'),
        (1, 2, 1, 'uuu'),
        (0, 1, 1, 'uvw'),
    ]

    tp = CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr)

    x1 = rs.randn(20, Rs_in1)
    x2 = rs.randn(20, Rs_in2)

    angles = o3.rand_angles()

    z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T
    z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T)

    z1.sum().backward()

    assert torch.allclose(z1, z2)
Exemplo n.º 7
0
    def parity_rotation_gated_block_parity(self, K):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            K = partial(K, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc, 1)
            D_out = rs.rep(act.Rs_out, *abc, 1)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemplo n.º 8
0
def check_rotation_parity(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(1, 0, +1)]
    act = GatedBlockParity(
        Rs_scalars=[(4, 0, +1)],
        act_scalars=[(-1, relu)],
        Rs_gates=[(8, 0, +1)],
        act_gates=[(-1, tanh)],
        Rs_nonscalars=[(4, 1, -1), (4, 2, +1)]
    )
    conv = Convolution(K, Rs_in, act.Rs_in)
    Rs_out = act.Rs_out

    # Setup the data. The geometry, input features, and output features must all rotate and observe parity.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = -o3.rot(*abc)  # Negative because geometry has odd parity. i.e. improper rotation.
    D_in = rs.rep(Rs_in, *abc, parity=1)
    D_out = rs.rep(Rs_out, *abc, parity=1)

    feat = torch.randn(batch, n_atoms, rs.dim(Rs_in))  # Transforms with wigner D matrix and parity.
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix and parity.

    # Test equivariance.
    F = act(conv(feat, geo))
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t()))
    return (RF - FR).norm() < 10e-5 * RF.norm()
Exemplo n.º 9
0
def check_rotation(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    Rs_in = [(1, 0), (1, 1)]
    Rs_out = [(1, 0), (1, 1), (1, 2)]
    act = GatedBlock(
        Rs_out,
        scalar_activation=sigmoid,
        gate_activation=absolute,
    )
    conv = Convolution(K, Rs_in, act.Rs_in)

    # Setup the data. The geometry, input features, and output features must all rotate.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = o3.rot(*abc)
    D_in = rs.rep(Rs_in, *abc)
    D_out = rs.rep(Rs_out, *abc)

    feat = torch.randn(batch, n_atoms, rs.dim(Rs_in))  # Transforms with wigner D matrix
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix.

    # Test equivariance.
    F = act(conv(feat, geo))
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t()))
    return (RF - FR).norm() < 10e-5 * RF.norm()
Exemplo n.º 10
0
        def test(Rs, act):
            x = rs.randn(2, Rs)
            ac = S2Activation(Rs, act, 200, lmax_out=lmax + 1, random_rot=True)

            a, b, c, p = *torch.rand(3), 1
            y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, p).T
            y2 = ac(x @ rs.rep(Rs, a, b, c, p).T)
            self.assertLess((y1 - y2).abs().max(), 3e-4 * y1.abs().max())
Exemplo n.º 11
0
        def test(Rs, ac):
            x = torch.randn(99, rs.dim(Rs))
            a, b = torch.rand(2)
            c = 1

            y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, a, b, c).T
            y2 = ac(x @ rs.rep(Rs, a, b, c).T, dim=-1)
            y3 = ac(x @ rs.rep(Rs, -c, -b, -a).T, dim=-1)
            self.assertLess((y1 - y2).norm(), (y1 - y3).norm())
Exemplo n.º 12
0
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge):
    torch.set_default_dtype(torch.float64)

    mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel))
    groups = 4
    mp_group = Convolution(
        GroupKernel(Rs_in, Rs_out,
                    partial(Kernel, RadialModel=ConstantRadialModel), groups))

    features = rs.randn(n_target, Rs_in)
    features2 = rs.randn(n_target, Rs_in * groups)

    r_source = torch.randn(n_source, 3)
    r_target = torch.randn(n_target, 3)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge, )),
        torch.randint(n_target, size=(n_edge, )),
    ])
    size = (n_target, n_source)

    if n_edge == 0:
        edge_r = torch.zeros(0, 3)
    else:
        edge_r = torch.stack(
            [r_target[j] - r_source[i] for i, j in edge_index.T])
    print(features.shape, edge_index.shape, edge_r.shape, size)
    out1 = mp(features, edge_index, edge_r, size=size)
    out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups)
    out1_kernel_groups = mp_group(features2,
                                  edge_index,
                                  edge_r,
                                  size=size,
                                  groups=groups)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    D_in_groups = rs.rep(Rs_in * groups, *angles)
    D_out_groups = rs.rep(Rs_out * groups, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out
    out2_groups = mp(features2 @ D_in_groups.T,
                     edge_index,
                     edge_r @ R.T,
                     size=size,
                     groups=groups) @ D_out_groups
    out2_kernel_groups = mp_group(features2 @ D_in_groups.T,
                                  edge_index,
                                  edge_r @ R.T,
                                  size=size,
                                  groups=groups) @ D_out_groups

    assert (out1 - out2).abs().max() < 1e-10
    assert (out1_groups - out2_groups).abs().max() < 1e-10
    assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
Exemplo n.º 13
0
    def test_equiv(self):
        torch.set_default_dtype(torch.float64)
        Rs_in = [(5, 0), (15, 1), (5, 0), (10, 2)]
        Rs_out = [(2, 0), (1, 1), (1, 2), (3, 0)]

        lin = Linear(Rs_in, Rs_out)
        f_in = torch.randn(100, rs.dim(Rs_in))

        angles = torch.randn(3)
        y1 = lin(torch.einsum('ij,zj->zi', rs.rep(Rs_in, *angles), f_in))
        y2 = torch.einsum('ij,zj->zi', rs.rep(Rs_out, *angles), lin(f_in))

        self.assertLess((y1 - y2).abs().max(), 1e-10)
Exemplo n.º 14
0
        def test(act, normalization):
            x = rs.randn(2, Rs, normalization=normalization)
            ac = S2Activation(Rs,
                              act,
                              120,
                              normalization=normalization,
                              lmax_out=6,
                              random_rot=True)

            a, b, c = o3.rand_angles()
            y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, 1).T
            y2 = ac(x @ rs.rep(Rs, a, b, c, 1).T)
            self.assertLess((y1 - y2).abs().max(), 1e-10 * y1.abs().max())
Exemplo n.º 15
0
def test_equivariance():
    torch.set_default_dtype(torch.float64)

    n_edge = 3
    n_source = 4
    n_target = 2

    Rs_in = [(3, 0), (0, 1)]
    Rs_mid1 = [(5, 0), (1, 1)]
    Rs_mid2 = [(5, 0), (1, 1), (1, 2)]
    Rs_out = [(5, 1), (3, 2)]

    convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel))
    convolution_groups = lambda Rs_in, Rs_out: Convolution(
        GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups))
    groups = 4
    mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution)
    mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups)

    features = rs.randn(n_target, Rs_in)

    r_source = torch.randn(n_source, 3)
    r_target = torch.randn(n_target, 3)

    edge_index = torch.stack([
        torch.randint(n_source, size=(n_edge,)),
        torch.randint(n_target, size=(n_edge,)),
    ])
    size = (n_target, n_source)

    if n_edge == 0:
        edge_r = torch.zeros(0, 3)
    else:
        edge_r = torch.stack([
            r_target[j] - r_source[i]
            for i, j in edge_index.T
        ])
    print(features.shape, edge_index.shape, edge_r.shape, size)
    out1 = mp(features, edge_index, edge_r, size=size)
    out1_groups = mp_groups(features, edge_index, edge_r, size=size)

    angles = o3.rand_angles()
    D_in = rs.rep(Rs_in, *angles)
    D_out = rs.rep(Rs_out, *angles)
    R = o3.rot(*angles)

    out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out
    out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out

    assert (out1 - out2).abs().max() < 1e-10
    assert (out1_groups - out2_groups).abs().max() < 1e-10
Exemplo n.º 16
0
    def parity_kernel(self, K):
        """Test parity equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)]
            Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)]

            k = K(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            D_in = rs.rep(Rs_in, 0, 0, 0, 1)
            D_out = rs.rep(Rs_out, 0, 0, 0, 1)

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(-r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
Exemplo n.º 17
0
    def test_tensor_square_equivariance(self):
        with o3.torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (2, 1), (5, 2)]

            sq = TensorSquare(Rs_in, o3.selection_rule)

            x = rs.randn(Rs_in)

            abc = o3.rand_angles()
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(sq.Rs_out, *abc)

            y1 = sq(D_in @ x)
            y2 = D_out @ sq(x)

            self.assertLess((y1 - y2).abs().max(), 1e-7 * y1.abs().max())
Exemplo n.º 18
0
    def rotation_kernel(self, K):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = K(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
Exemplo n.º 19
0
    def find_peaks(self, res=100):
        x1, f1 = self.signal_on_grid(res)

        abc = pi / 2, pi / 2, pi / 2
        R = o3.rot(*abc)
        D = rs.rep(self.Rs, *abc)

        rtensor = SphericalTensor(D @ self.signal)
        rx2, f2 = rtensor.signal_on_grid(res)
        x2 = torch.einsum('ij,baj->bai', R.T, rx2)

        ij = _find_peaks_2d(f1)
        x1p = torch.stack([x1[i, j] for i, j in ij])
        f1p = torch.stack([f1[i, j] for i, j in ij])

        ij = _find_peaks_2d(f2)
        x2p = torch.stack([x2[i, j] for i, j in ij])
        f2p = torch.stack([f2[i, j] for i, j in ij])

        # Union of the results
        mask = torch.cdist(x1p, x2p) < 2 * pi / res
        x = torch.cat([x1p[mask.sum(1) == 0], x2p])
        f = torch.cat([f1p[mask.sum(1) == 0], f2p])

        return x, f
Exemplo n.º 20
0
    def test_equivariance_s2network(self):
        with torch_default_dtype(torch.float64):
            mul = 3
            Rs_in = [(mul, l) for l in range(3 + 1)]
            Rs_out = [(mul, l) for l in range(3 + 1)]

            net = S2Network(Rs_in, mul, lmax=4, Rs_out=Rs_out)

            abc = o3.rand_angles()
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(10, rs.dim(Rs_in))

            x1 = torch.einsum("ij,zj->zi", D_out, net(fea))
            x2 = net(torch.einsum("ij,zj->zi", D_in, fea))
            self.assertLess((x1 - x2).norm(), 1e-3 * x1.norm())
Exemplo n.º 21
0
def test_equivariance_s2parity_network():
    torch.set_default_dtype(torch.float64)
    mul = 3
    Rs_in = [(mul, l, -1) for l in range(3 + 1)]
    Rs_out = [(mul, l, 1) for l in range(3 + 1)]

    net = S2ParityNetwork(Rs_in, mul, lmax=3, Rs_out=Rs_out)

    abc = o3.rand_angles()
    D_in = rs.rep(Rs_in, *abc, 1)
    D_out = rs.rep(Rs_out, *abc, 1)

    fea = rs.randn(10, Rs_in)

    x1 = torch.einsum("ij,zj->zi", D_out, net(fea))
    x2 = net(torch.einsum("ij,zj->zi", D_in, fea))
    assert (x1 - x2).norm() < 1e-3 * x1.norm()
Exemplo n.º 22
0
Arquivo: s2.py Projeto: zizai/e3nn
    def forward(self, features):
        '''
        :param features: [..., l * m]
        '''
        if self.random_rot:
            abc = o3.rand_angles()
            features = torch.einsum('ij,...j->...i', rs.rep(self.Rs_in, *abc),
                                    features)

        features = self.to_s2(features)  # [..., beta, alpha]
        features = self.act(features)
        features = self.from_s2(features)

        if self.random_rot:
            features = torch.einsum('ij,...j->...i',
                                    rs.rep(self.Rs_out, *abc).T, features)
        return features
Exemplo n.º 23
0
    def parity_rotation_linear(self, L):
        """Test parity and rotation equivariance on Linear."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]
            Rs_out = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            lin = L(Rs_in, Rs_out)

            abc = torch.randn(3)
            D_in = rs.rep(lin.Rs_in, *abc, 1)
            D_out = rs.rep(lin.Rs_out, *abc, 1)

            fea = torch.randn(rs.dim(Rs_in))

            x1 = torch.einsum("ij,j->i", D_out, lin(fea))
            x2 = lin(torch.einsum("ij,j->i", D_in, fea))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemplo n.º 24
0
def test_weighted_tensor_product():
    torch.set_default_dtype(torch.float64)

    Rs_in1 = rs.simplify([1] * 20 + [2] * 4)
    Rs_in2 = rs.simplify([0] * 10 + [1] * 10 + [2] * 5)
    Rs_out = rs.simplify([0] * 3 + [1] * 4)

    tp = WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=2)

    x1 = rs.randn(20, Rs_in1)
    x2 = rs.randn(20, Rs_in2)

    angles = o3.rand_angles()

    z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T
    z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T)

    z1.sum().backward()

    assert torch.allclose(z1, z2)
Exemplo n.º 25
0
    def test_equivariance_gatedconvnetwork(self):
        with torch_default_dtype(torch.float64):
            mul = 3
            Rs_in = [(mul, l) for l in range(3 + 1)]
            Rs_out = [(mul, l) for l in range(3 + 1)]

            net = GatedConvNetwork(Rs_in, [(10, 0), (1, 1), (1, 2), (1, 3)],
                                   Rs_out)

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc)
            D_out = rs.rep(Rs_out, *abc)

            fea = torch.randn(1, 10, rs.dim(Rs_in))
            geo = torch.randn(1, 10, 3)

            x1 = torch.einsum("ij,zaj->zai", D_out, net(fea, geo))
            x2 = net(torch.einsum("ij,zaj->zai", D_in, fea),
                     torch.einsum("ij,zaj->zai", rot_geo, geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemplo n.º 26
0
def test_norm_activation(Rs, normalization, dtype):
    with o3.torch_default_dtype(dtype):
        m = NormActivation(Rs, swish, normalization=normalization)

        D = rs.rep(Rs, *o3.rand_angles())
        x = rs.randn(2, Rs, normalization=normalization)

        y1 = m(x)
        y1 = torch.einsum('ij,zj->zi', D, y1)

        x2 = torch.einsum('ij,zj->zi', D, x)
        y2 = m(x2)

        assert (y1 - y2).abs().max() < {
            torch.float32: 1e-5,
            torch.float64: 1e-10
        }[dtype]
Exemplo n.º 27
0
def test_s2conv_network():
    torch.set_default_dtype(torch.float64)

    lmax = 3
    Rs = [(1, l, 1) for l in range(lmax + 1)]
    model = S2ConvNetwork(Rs, 4, Rs, lmax)

    features = rs.randn(1, 4, Rs)
    geometry = torch.randn(1, 4, 3)

    output = model(features, geometry)

    angles = o3.rand_angles()
    D = rs.rep(Rs, *angles, 1)
    R = -o3.rot(*angles)
    ein = torch.einsum
    output2 = ein('ij,zaj->zai', D.T, model(ein('ij,zaj->zai', D, features), ein('ij,zaj->zai', R, geometry)))

    assert (output - output2).abs().max() < 1e-10 * output.abs().max()
Exemplo n.º 28
0
def _rep_zx(Rs, dtype, device):
    o = torch.zeros((), dtype=dtype, device=device)
    return rs.rep(Rs, o, -math.pi / 2, o)