def __init__(self, n, lmax): super().__init__() self.n = n self.lmax = lmax R = o3.rot(math.pi / 2, math.pi / 2, math.pi / 2) self.xyz1, self.proj1 = self.precompute(R) R = o3.rot(0, 0, 0) self.xyz2, self.proj2 = self.precompute(R)
def find_peaks(self, res=100): x1, f1 = self.signal_on_grid(res) abc = pi / 2, pi / 2, pi / 2 R = o3.rot(*abc) D = rs.rep(self.Rs, *abc) rtensor = SphericalTensor(D @ self.signal) rx2, f2 = rtensor.signal_on_grid(res) x2 = torch.einsum('ij,baj->bai', R.T, rx2) ij = _find_peaks_2d(f1) x1p = torch.stack([x1[i, j] for i, j in ij]) f1p = torch.stack([f1[i, j] for i, j in ij]) ij = _find_peaks_2d(f2) x2p = torch.stack([x2[i, j] for i, j in ij]) f2p = torch.stack([f2[i, j] for i, j in ij]) # Union of the results mask = torch.cdist(x1p, x2p) < 2 * pi / res x = torch.cat([x1p[mask.sum(1) == 0], x2p]) f = torch.cat([f1p[mask.sum(1) == 0], f2p]) return x, f
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel)) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ o3.rot(*abc).t() y = f(c(x, geo), dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
def test_equivariance_wtp(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = WTPConv(Rs_in, Rs_out, 3, ConstantRadialModel) features = rs.randn(n_target, Rs_in) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) edge_r = torch.randn(n_edge, 3) if n_edge > 1: edge_r[0] = 0 out1 = mp(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10
def check_rotation(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0), (1, 1)] Rs_out = [(1, 0), (1, 1), (1, 2)] act = GatedBlock( Rs_out, scalar_activation=sigmoid, gate_activation=absolute, ) conv = Convolution(K, Rs_in, act.Rs_in) # Setup the data. The geometry, input features, and output features must all rotate. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def check_rotation_parity(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0, +1)] act = GatedBlockParity( Rs_scalars=[(4, 0, +1)], act_scalars=[(-1, relu)], Rs_gates=[(8, 0, +1)], act_gates=[(-1, tanh)], Rs_nonscalars=[(4, 1, -1), (4, 2, +1)] ) conv = Convolution(K, Rs_in, act.Rs_in) Rs_out = act.Rs_out # Setup the data. The geometry, input features, and output features must all rotate and observe parity. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = -o3.rot(*abc) # Negative because geometry has odd parity. i.e. improper rotation. D_in = rs.rep(Rs_in, *abc, parity=1) D_out = rs.rep(Rs_out, *abc, parity=1) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix and parity. geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix and parity. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def test3(self): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(Kernel, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def rotation_gated_block(self, K): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(K, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_rot_to_abc(self): with o3.torch_default_dtype(torch.float64): R = o3.rand_rot() abc = o3.rot_to_abc(R) R2 = o3.rot(*abc) d = (R - R2).norm() / R.norm() self.assertTrue(d < 1e-10, d)
def parity_rotation_gated_block_parity(self, K): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(act.Rs_out, *abc, 1) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def random_rotate_translate(positions, rotation=True, translation=1): while True: trans = torch.rand(3) * 2 - 1 if trans.norm() <= 1: break rot = o3.rot(*torch.rand(3) * 6.2832).type(torch.float32) return [rot @ pos + translation * trans for pos in positions]
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) groups = 4 mp_group = Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) features = rs.randn(n_target, Rs_in) features2 = rs.randn(n_target, Rs_in * groups) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack( [r_target[j] - r_source[i] for i, j in edge_index.T]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups) out1_kernel_groups = mp_group(features2, edge_index, edge_r, size=size, groups=groups) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) D_in_groups = rs.rep(Rs_in * groups, *angles) D_out_groups = rs.rep(Rs_out * groups, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups out2_kernel_groups = mp_group(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10 assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
def test_xyz_vector_basis_to_spherical_basis(self, ): with o3.torch_default_dtype(torch.float64): A = o3.xyz_vector_basis_to_spherical_basis() a, b, c = torch.rand(3) r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A r2 = o3.rot(a, b, c) self.assertLess((r1 - r2).abs().max(), 1e-10)
def test_xyz_to_irreducible_basis(self, ): with o3.torch_default_dtype(torch.float64): A = o3.xyz_to_irreducible_basis() a, b, c = torch.rand(3) r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A r2 = o3.rot(a, b, c) assert torch.allclose(r1, r2)
def test_equivariance(): torch.set_default_dtype(torch.float64) n_edge = 3 n_source = 4 n_target = 2 Rs_in = [(3, 0), (0, 1)] Rs_mid1 = [(5, 0), (1, 1)] Rs_mid2 = [(5, 0), (1, 1), (1, 2)] Rs_out = [(5, 1), (3, 2)] convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) convolution_groups = lambda Rs_in, Rs_out: Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) groups = 4 mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution) mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups) features = rs.randn(n_target, Rs_in) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge,)), torch.randint(n_target, size=(n_edge,)), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack([ r_target[j] - r_source[i] for i, j in edge_index.T ]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp_groups(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10
def rotation_kernel(self, K): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = K(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test_s2conv_network(): torch.set_default_dtype(torch.float64) lmax = 3 Rs = [(1, l, 1) for l in range(lmax + 1)] model = S2ConvNetwork(Rs, 4, Rs, lmax) features = rs.randn(1, 4, Rs) geometry = torch.randn(1, 4, 3) output = model(features, geometry) angles = o3.rand_angles() D = rs.rep(Rs, *angles, 1) R = -o3.rot(*angles) ein = torch.einsum output2 = ein('ij,zaj->zai', D.T, model(ein('ij,zaj->zai', D, features), ein('ij,zaj->zai', R, geometry))) assert (output - output2).abs().max() < 1e-10 * output.abs().max()
def test2(self): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test_equivariance_gatedconvnetwork(self): with torch_default_dtype(torch.float64): mul = 3 Rs_in = [(mul, l) for l in range(3 + 1)] Rs_out = [(mul, l) for l in range(3 + 1)] net = GatedConvNetwork(Rs_in, [(10, 0), (1, 1), (1, 2), (1, 3)], Rs_out) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 10, rs.dim(Rs_in)) geo = torch.randn(1, 10, 3) x1 = torch.einsum("ij,zaj->zai", D_out, net(fea, geo)) x2 = net(torch.einsum("ij,zaj->zai", D_in, fea), torch.einsum("ij,zaj->zai", rot_geo, geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test6(self): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_irr_repr_wigner_3j(self): """Test irr_repr and wigner_3j equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = o3.wigner_3j(l_out, l_in, l_f) abc = torch.randn(3) D_in = o3.irr_repr(l_in, *abc) D_out = o3.irr_repr(l_out, *abc) Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t()) W = torch.einsum("ijk,zk->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = rsh.spherical_harmonics_xyz([l_f], r) W = torch.einsum("ijk,zk->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)