Exemple #1
0
    def __init__(self, num_classes):
        super().__init__()

        R = partial(CosineBasisModel, max_radius=3.0, number_of_basis=3, h=100, L=3, act=relu)
        K = partial(Kernel, RadialModel=R)
        C = partial(Convolution, K)

        mul = 7
        layers = []

        rs = [(1, 0, +1)]
        for i in range(3):
            scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if haspath(rs, l, p)]
            act_scalars = [(mul, relu if p == 1 else tanh) for mul, l, p in scalars]

            nonscalars = [(mul, l, p) for mul, l, p in [(mul, 1, +1), (mul, 1, -1)] if haspath(rs, l, p)]
            gates = [(sum(mul for mul, l, p in nonscalars), 0, +1)]
            act_gates = [(-1, sigmoid)]

            print("layer {}: from {} to {}".format(i, formatRs(rs), formatRs(scalars + nonscalars)))

            block = GatedBlockParity(C, rs, scalars, act_scalars, gates, act_gates, nonscalars)
            rs = block.Rs_out
            layers.append(block)

        layers.append(GatedBlockParity(C, rs, [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, tanh)], [], [], []))

        self.firstlayers = torch.nn.ModuleList(layers)

        # the last layer is not equivariant, it is allowed to mix even and odds scalars
        self.lastlayers = torch.nn.Sequential(AvgSpacial(), torch.nn.Linear(mul + mul, num_classes))
Exemple #2
0
    def parity_rotation_gated_block_parity(self, K):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]]

            K = partial(K, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K(Rs_in, act.Rs_in))

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = rs.rep(Rs_in, *abc, 1)
            D_out = rs.rep(act.Rs_out, *abc, 1)

            fea = torch.randn(1, 4, rs.dim(Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemple #3
0
def check_rotation_parity(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    C = partial(Convolution, K)
    Rs_in = [(1, 0, +1)]
    f = GatedBlockParity(Operation=C,
                         Rs_in=Rs_in,
                         Rs_scalars=[(4, 0, +1)],
                         act_scalars=[(-1, relu)],
                         Rs_gates=[(8, 0, +1)],
                         act_gates=[(-1, tanh)],
                         Rs_nonscalars=[(4, 1, -1), (4, 2, +1)])
    Rs_out = f.Rs_out

    # Setup the data. The geometry, input features, and output features must all rotate and observe parity.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = -rot(
        *abc
    )  # Negative because geometry has odd parity. i.e. improper rotation.
    D_in = rep(Rs_in, *abc, parity=1)
    D_out = rep(Rs_out, *abc, parity=1)

    c = sum([mul * (2 * l + 1) for mul, l, _ in Rs_in])
    feat = torch.randn(batch, n_atoms,
                       c)  # Transforms with wigner D matrix and parity.
    geo = torch.randn(batch, n_atoms,
                      3)  # Transforms with rotation matrix and parity.

    # Test equivariance.
    F = f(feat, geo)
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = f(feat @ D_in.t(), geo @ rot_geo.t())
    return (RF - FR).norm() < 10e-5 * RF.norm()
Exemple #4
0
    def test5(self):
        """Test parity equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)
            C = partial(Convolution, K)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            f = GatedBlockParity(C, Rs_in, *scalars, *gates, rs_nonscalars)

            D_in = direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in f.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo)))
            x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo)
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
    def test6(self):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = -o3.rot(*abc)
            D_in = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())