def __init__(self, num_classes): super().__init__() R = partial(CosineBasisModel, max_radius=3.0, number_of_basis=3, h=100, L=3, act=relu) K = partial(Kernel, RadialModel=R) C = partial(Convolution, K) mul = 7 layers = [] rs = [(1, 0, +1)] for i in range(3): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if haspath(rs, l, p)] act_scalars = [(mul, relu if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for mul, l, p in [(mul, 1, +1), (mul, 1, -1)] if haspath(rs, l, p)] gates = [(sum(mul for mul, l, p in nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] print("layer {}: from {} to {}".format(i, formatRs(rs), formatRs(scalars + nonscalars))) block = GatedBlockParity(C, rs, scalars, act_scalars, gates, act_gates, nonscalars) rs = block.Rs_out layers.append(block) layers.append(GatedBlockParity(C, rs, [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, tanh)], [], [], [])) self.firstlayers = torch.nn.ModuleList(layers) # the last layer is not equivariant, it is allowed to mix even and odds scalars self.lastlayers = torch.nn.Sequential(AvgSpacial(), torch.nn.Linear(mul + mul, num_classes))
def parity_rotation_gated_block_parity(self, K): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(act.Rs_out, *abc, 1) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def check_rotation_parity(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) Rs_in = [(1, 0, +1)] f = GatedBlockParity(Operation=C, Rs_in=Rs_in, Rs_scalars=[(4, 0, +1)], act_scalars=[(-1, relu)], Rs_gates=[(8, 0, +1)], act_gates=[(-1, tanh)], Rs_nonscalars=[(4, 1, -1), (4, 2, +1)]) Rs_out = f.Rs_out # Setup the data. The geometry, input features, and output features must all rotate and observe parity. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = -rot( *abc ) # Negative because geometry has odd parity. i.e. improper rotation. D_in = rep(Rs_in, *abc, parity=1) D_out = rep(Rs_out, *abc, parity=1) c = sum([mul * (2 * l + 1) for mul, l, _ in Rs_in]) feat = torch.randn(batch, n_atoms, c) # Transforms with wigner D matrix and parity. geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix and parity. # Test equivariance. F = f(feat, geo) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = f(feat @ D_in.t(), geo @ rot_geo.t()) return (RF - FR).norm() < 10e-5 * RF.norm()
def test5(self): """Test parity equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] f = GatedBlockParity(C, Rs_in, *scalars, *gates, rs_nonscalars) D_in = direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in f.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo))) x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test6(self): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())