def _is_representation(D, eps): I = D(0, 0, 0) if not torch.allclose(I, I @ I): return False g1 = o3.rand_angles() g2 = o3.rand_angles() g12 = o3.compose(*g1, *g2) D12 = D(*g12) D1D2 = D(*g1) @ D(*g2) return (D12 - D1D2).abs().max().item() < eps * D12.abs().max().item()
def _test_is_representation(self, R): """ R(Z(a1) Y(b1) Z(c1) Z(a2) Y(b2) Z(c2)) = R(Z(a1) Y(b1) Z(c1)) R(Z(a2) Y(b2) Z(c2)) """ with o3.torch_default_dtype(torch.float64): g1 = o3.rand_angles() g2 = o3.rand_angles() g12 = o3.compose(*g1, *g2) D12 = R(*g12) D1D2 = R(*g1) @ R(*g2) self.assertLess((D12 - D1D2).abs().max(), 1e-10 * D12.abs().max())
def test_tensor_square_norm(self): for Rs_in in [[(1, 0), (2, 1), (4, 3)]]: with o3.torch_default_dtype(torch.float64): Rs_out, Q = rs.tensor_square(Rs_in, o3.selection_rule, normalization='component', sorted=True) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) Q1 = torch.einsum("ijk,il->ljk", (Q, D_out)) Q2 = torch.einsum("li,mj,kij->klm", (D_in, D_in, Q)) d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt() self.assertLess(d, 1e-10) n = Q.size(0) M = Q.reshape(n, -1) I = torch.eye(n) d = ((M @ M.t()) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10)
def _is_representation(D, eps, with_parity=False): e = (0, 0, 0, 0) if with_parity else (0, 0, 0) I = D(*e) if not torch.allclose(I, I @ I): return False g1 = o3.rand_angles() + (0, ) if with_parity else o3.rand_angles() g2 = o3.rand_angles() + (0, ) if with_parity else o3.rand_angles() g12 = o3.compose_with_parity(*g1, *g2) if with_parity else o3.compose( *g1, *g2) D12 = D(*g12) D1D2 = D(*g1) @ D(*g2) return (D12 - D1D2).abs().max().item() < eps * D12.abs().max().item()
def test_custom_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = [(20, 1), (4, 2)] Rs_in2 = [(10, 0), (10, 1), (4, 2)] Rs_out = [(3, 0), (4, 1)] instr = [ (0, 1, 0, 'uvw'), (1, 2, 1, 'uuu'), (0, 1, 1, 'uvw'), ] tp = CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def test_sh_is_in_irrep(float_tolerance): for l in range(4 + 1): a, b, _ = o3.rand_angles() Y = o3.spherical_harmonics_alpha_beta(l, a, b) * math.sqrt( 4 * math.pi) / math.sqrt(2 * l + 1) D = o3.wigner_D(l, a, b, torch.zeros(())) assert (Y - D[:, l]).abs().max() < float_tolerance
def test_equivariance_wtp(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = WTPConv(Rs_in, Rs_out, 3, ConstantRadialModel) features = rs.randn(n_target, Rs_in) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) edge_r = torch.randn(n_edge, 3) if n_edge > 1: edge_r[0] = 0 out1 = mp(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10
def forward(self, features): r'''evaluate Parameters ---------- features : `torch.Tensor` tensor :math:`\{A^l\}_l` of shape ``(..., self.irreps_in.dim)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.irreps_out.dim)`` ''' assert features.shape[-1] == self.irreps_in.dim if self.random_rot: abc = o3.rand_angles(dtype=features.dtype, device=features.device) features = torch.einsum('ij,...j->...i', self.irreps_in.D_from_angles(*abc), features) features = self.to_s2(features) # [..., beta, alpha] features = self.act(features) features = self.from_s2(features) if self.random_rot: features = torch.einsum('ij,...j->...i', self.irreps_out.D_from_angles(*abc).T, features) return features
def test_reduce_tensor_Levi_Civita_symbol(): Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 1)]) assert Rs == [(1, 0, 0)] r = o3.rand_angles() D = o3.irr_repr(1, *r) Q = Q.reshape(3, 3, 3) Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q) assert (Q1 - Q).abs().max() < 1e-10
def _generate_wigner_3j(l1, l2, l3, dtype=None, device=None): # pragma: no cover r"""Computes the 3-j symbol """ # these three propositions are equivalent assert abs(l2 - l3) <= l1 <= l2 + l3 assert abs(l3 - l1) <= l2 <= l3 + l1 assert abs(l1 - l2) <= l3 <= l1 + l2 n = (2 * l1 + 1) * (2 * l2 + 1) * (2 * l3 + 1 ) # dimension of the 3-j symbol def _DxDxD(a, b, c): D1 = wigner_D(l1, a, b, c) D2 = wigner_D(l2, a, b, c) D3 = wigner_D(l3, a, b, c) return torch.einsum('il,jm,kn->ijklmn', D1, D2, D3).reshape(n, n) random_angles = torch.tensor([ [4.41301023, 5.56684102, 4.59384642], [4.93325116, 6.12697327, 4.14574096], [0.53878964, 4.09050444, 5.36539036], [2.16017393, 3.48835314, 5.55174441], [2.52385107, 0.29089583, 3.90040975], ], dtype=dtype, device=device) B = random_angles.new_zeros((n, n)) for abc in random_angles: D = _DxDxD(*abc) - torch.eye(n, dtype=dtype, device=device) B += D.T @ D eigenvalues, eigenvectors = torch.linalg.eigh(B) assert eigenvalues[0] < 1e-10 Q = eigenvectors[:, 0] assert (B @ Q).norm() < 1e-10 Q = Q.reshape(2 * l1 + 1, 2 * l2 + 1, 2 * l3 + 1) Q[Q.abs() < 1e-14] = 0 if Q[l1, l2, l3] != 0: if Q[l1, l2, l3] < 0: Q.neg_() else: if next(x for x in Q.flatten() if x != 0) < 0: Q.neg_() abc = o3.rand_angles(100, dtype=dtype, device=device) Q2 = torch.einsum("zil,zjm,zkn,lmn->zijk", wigner_D(l1, *abc), wigner_D(l2, *abc), wigner_D(l3, *abc), Q) assert (Q - Q2).norm() < 1e-10 assert abs(Q.norm() - 1) < 1e-10 return Q
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) groups = 4 mp_group = Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) features = rs.randn(n_target, Rs_in) features2 = rs.randn(n_target, Rs_in * groups) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack( [r_target[j] - r_source[i] for i, j in edge_index.T]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups) out1_kernel_groups = mp_group(features2, edge_index, edge_r, size=size, groups=groups) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) D_in_groups = rs.rep(Rs_in * groups, *angles) D_out_groups = rs.rep(Rs_out * groups, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups out2_kernel_groups = mp_group(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10 assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
def test_sh_equivariance1(float_tolerance): r"""test - compose - spherical_harmonics_alpha_beta - irrep """ for l in range(7 + 1): a, b, _ = o3.rand_angles() alpha, beta, gamma = o3.rand_angles() ra, rb, _ = o3.compose_angles(alpha, beta, gamma, a, b, torch.tensor(0.0)) Yrx = o3.spherical_harmonics_alpha_beta(l, ra, rb) Y = o3.spherical_harmonics_alpha_beta(l, a, b) DrY = o3.wigner_D(l, alpha, beta, gamma) @ Y assert (Yrx - DrY).abs().max() < float_tolerance * Y.abs().max()
def test_equivariance(float_tolerance): lmax = 5 irreps = o3.Irreps.spherical_harmonics(lmax) x = torch.randn(2, 3) abc = o3.rand_angles() y1 = o3.spherical_harmonics(irreps, x @ o3.angles_to_matrix(*abc).T, False) y2 = o3.spherical_harmonics(irreps, x, False) @ irreps.D_from_angles(*abc).T assert (y1 - y2).abs().max() < 10 * float_tolerance
def test_wigner_3j(float_tolerance): abc = o3.rand_angles(10) l1, l2, l3 = 1, 2, 3 C = o3.wigner_3j(l1, l2, l3) D1 = o3.Irrep(l1, 1).D_from_angles(*abc) D2 = o3.Irrep(l2, 1).D_from_angles(*abc) D3 = o3.Irrep(l3, 1).D_from_angles(*abc) C2 = torch.einsum("ijk,zil,zjm,zkn->zlmn", C, D1, D2, D3) assert (C - C2).abs().max() < float_tolerance
def test_irrep_closure_1(self): rots = [o3.rand_angles() for _ in range(10000)] Us = [torch.stack([o3.irr_repr(l, *abc) for abc in rots]) for l in range(3 + 1)] for l1, U1 in enumerate(Us): for l2, U2 in enumerate(Us): m = torch.einsum('zij,zkl->zijkl', U1, U2).mean(0).reshape((2 * l1 + 1)**2, (2 * l2 + 1)**2) if l1 == l2: i = torch.eye((2 * l1 + 1)**2) self.assertLess((m.mul(2 * l1 + 1) - i).abs().max(), 0.1) else: self.assertLess(m.abs().max(), 0.1)
def test_derivative_irr_repr(self): with o3.torch_default_dtype(torch.float64): l = 4 angles = o3.rand_angles() da, db, dc = o3.derivative_irr_repr(l, *angles) h = 1e-7 da1 = (o3.irr_repr(l, angles[0] + h, angles[1], angles[2]) - o3.irr_repr(l, *angles)) / h db1 = (o3.irr_repr(l, angles[0], angles[1] + h, angles[2]) - o3.irr_repr(l, *angles)) / h dc1 = (o3.irr_repr(l, angles[0], angles[1], angles[2] + h) - o3.irr_repr(l, *angles)) / h self.assertLess((da1 - da).abs().max(), 1e-5) self.assertLess((db1 - db).abs().max(), 1e-5) self.assertLess((dc1 - dc).abs().max(), 1e-5)
def test_inverse_angles(float_tolerance): a = o3.rand_angles() b = o3.inverse_angles(*a) c = o3.compose_angles(*a, *b) e = o3.identity_angles(requires_grad=True) rc = o3.angles_to_matrix(*c) re = o3.angles_to_matrix(*e) assert (rc - re).abs().max() < float_tolerance # test `requires_grad` re.sum().backward() assert e[0].grad is not None
def test(act, normalization): x = rs.randn(2, Rs, normalization=normalization) ac = S2Activation(Rs, act, 120, normalization=normalization, lmax_out=6, random_rot=True) a, b, c = o3.rand_angles() y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, 1).T y2 = ac(x @ rs.rep(Rs, a, b, c, 1).T) self.assertLess((y1 - y2).abs().max(), 1e-10 * y1.abs().max())
def test_equivariance(): torch.set_default_dtype(torch.float64) n_edge = 3 n_source = 4 n_target = 2 Rs_in = [(3, 0), (0, 1)] Rs_mid1 = [(5, 0), (1, 1)] Rs_mid2 = [(5, 0), (1, 1), (1, 2)] Rs_out = [(5, 1), (3, 2)] convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) convolution_groups = lambda Rs_in, Rs_out: Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) groups = 4 mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution) mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups) features = rs.randn(n_target, Rs_in) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge,)), torch.randint(n_target, size=(n_edge,)), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack([ r_target[j] - r_source[i] for i, j in edge_index.T ]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp_groups(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10
def test_reduce_tensor_antisymmetric_L2(): Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 2)]) assert Rs[0] == (1, 1, 0) q = Q[:3].reshape(3, 5, 5, 5) r = o3.rand_angles() D1 = o3.irr_repr(1, *r) D2 = o3.irr_repr(2, *r) Q1 = torch.einsum('il,jm,kn,zijk->zlmn', D2, D2, D2, q) Q2 = torch.einsum('yz,zijk->yijk', D1, q) assert (Q1 - Q2).abs().max() < 1e-10 assert (q + q.transpose(1, 2)).abs().max() < 1e-10 assert (q + q.transpose(1, 3)).abs().max() < 1e-10 assert (q + q.transpose(3, 2)).abs().max() < 1e-10
def test_tensor_square_equivariance(self): with o3.torch_default_dtype(torch.float64): Rs_in = [(3, 0), (2, 1), (5, 2)] sq = TensorSquare(Rs_in, o3.selection_rule) x = rs.randn(Rs_in) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(sq.Rs_out, *abc) y1 = sq(D_in @ x) y2 = D_out @ sq(x) self.assertLess((y1 - y2).abs().max(), 1e-7 * y1.abs().max())
def test_norm_activation(Rs, normalization, dtype): with o3.torch_default_dtype(dtype): m = NormActivation(Rs, swish, normalization=normalization) D = rs.rep(Rs, *o3.rand_angles()) x = rs.randn(2, Rs, normalization=normalization) y1 = m(x) y1 = torch.einsum('ij,zj->zi', D, y1) x2 = torch.einsum('ij,zj->zi', D, x) y2 = m(x2) assert (y1 - y2).abs().max() < { torch.float32: 1e-5, torch.float64: 1e-10 }[dtype]
def test_equivariance_s2network(self): with torch_default_dtype(torch.float64): mul = 3 Rs_in = [(mul, l) for l in range(3 + 1)] Rs_out = [(mul, l) for l in range(3 + 1)] net = S2Network(Rs_in, mul, lmax=4, Rs_out=Rs_out) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(10, rs.dim(Rs_in)) x1 = torch.einsum("ij,zj->zi", D_out, net(fea)) x2 = net(torch.einsum("ij,zj->zi", D_in, fea)) self.assertLess((x1 - x2).norm(), 1e-3 * x1.norm())
def test_equivariance_s2parity_network(): torch.set_default_dtype(torch.float64) mul = 3 Rs_in = [(mul, l, -1) for l in range(3 + 1)] Rs_out = [(mul, l, 1) for l in range(3 + 1)] net = S2ParityNetwork(Rs_in, mul, lmax=3, Rs_out=Rs_out) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(Rs_out, *abc, 1) fea = rs.randn(10, Rs_in) x1 = torch.einsum("ij,zj->zi", D_out, net(fea)) x2 = net(torch.einsum("ij,zj->zi", D_in, fea)) assert (x1 - x2).norm() < 1e-3 * x1.norm()
def forward(self, features): ''' :param features: [..., l * m] ''' if self.random_rot: abc = o3.rand_angles() features = torch.einsum('ij,...j->...i', rs.rep(self.Rs_in, *abc), features) features = self.to_s2(features) # [..., beta, alpha] features = self.act(features) features = self.from_s2(features) if self.random_rot: features = torch.einsum('ij,...j->...i', rs.rep(self.Rs_out, *abc).T, features) return features
def test_s2conv_network(): torch.set_default_dtype(torch.float64) lmax = 3 Rs = [(1, l, 1) for l in range(lmax + 1)] model = S2ConvNetwork(Rs, 4, Rs, lmax) features = rs.randn(1, 4, Rs) geometry = torch.randn(1, 4, 3) output = model(features, geometry) angles = o3.rand_angles() D = rs.rep(Rs, *angles, 1) R = -o3.rot(*angles) ein = torch.einsum output2 = ein('ij,zaj->zai', D.T, model(ein('ij,zaj->zai', D, features), ein('ij,zaj->zai', R, geometry))) assert (output - output2).abs().max() < 1e-10 * output.abs().max()
def test_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = rs.simplify([1] * 20 + [2] * 4) Rs_in2 = rs.simplify([0] * 10 + [1] * 10 + [2] * 5) Rs_out = rs.simplify([0] * 3 + [1] * 4) tp = WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=2) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def test_cartesian(float_tolerance): abc = o3.rand_angles(10) R = o3.angles_to_matrix(*abc) D = o3.wigner_D(1, *abc) assert (R - D).abs().max() < float_tolerance
from e3nn.o3 import irr_repr data = np.random.randn(36).reshape((6,6)) et1 = ElasticTensor(data, 'voigt') # back and forth transformation to cartesian ensures that voigt is symmetric _ = et1.voigt_to_cartesian() _ = et1.cartesian_to_voigt() _ = et1.voigt_to_cartesian() _ = et1.cartesian_to_covariant() _ = et1.covariant_to_spherical() et2 = ElasticTensor(et1.covariant.copy(), 'covariant') a, b, c = rand_angles() D0 = irr_repr(0, a, b, c).numpy() D1 = irr_repr(1, a, b, c).numpy() D2 = irr_repr(2, a, b, c).numpy() D4 = irr_repr(4, a, b, c).numpy() # forward rotation et2.covariant = np.einsum('ijkn,ai,bj,ck,dn->abcd', et2.covariant, D1, D1, D1, D1) _ = et2.covariant_to_spherical() # inverse rotation et2.spherical[0:1] = D0.T @ et2.spherical[0:1] et2.spherical[1:6] = D2.T @ et2.spherical[1:6] et2.spherical[6:7] = D0.T @ et2.spherical[6:7] et2.spherical[7:12] = D2.T @ et2.spherical[7:12]