def test3(self): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(Kernel, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel)) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ o3.rot(*abc).t() y = f(c(x, geo), dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
def D(*angles): d = o3.direct_sum( o3.irr_repr(1, *angles), o3.irr_repr(1, *angles), o3.irr_repr(2, *angles), ) return A @ d @ torch.inverse(A)
def test_irrep_closure_2(self): r1, r2 = (0, 0.2, 0), (0.1, 0.4, 1.5) # two random rotations a = sum((o3.irr_repr(l, *r1) * o3.irr_repr(l, *r2)).sum() for l in range(12 + 1)) b = sum((o3.irr_repr(l, *r1) * o3.irr_repr(l, *r1)).sum() for l in range(12 + 1)) self.assertLess(a, b / 100)
def test_reduce_tensor_product(self): for Rs_i, Rs_j in [([(1, 0)], [(2, 0)]), ([(3, 1), (2, 2)], [(2, 0), (1, 1), (1, 3)])]: with o3.torch_default_dtype(torch.float64): Rs, Q = rs.tensor_product(Rs_i, Rs_j) abc = torch.rand(3, dtype=torch.float64) D_i = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_i for _ in range(mul) ]) D_j = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_j for _ in range(mul) ]) D = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l, _ in Rs for _ in range(mul) ]) Q1 = torch.einsum("ijk,il->ljk", (Q, D)) Q2 = torch.einsum("li,mj,kij->klm", (D_i, D_j, Q)) d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt() self.assertLess(d, 1e-10) n = Q.size(0) M = Q.view(n, n) I = torch.eye(n, dtype=M.dtype) d = ((M @ M.t()) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10) d = ((M.t() @ M) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10)
def rep(Rs, alpha, beta, gamma, parity=None): """ Representation of O(3). Parity applied (-1)**parity times. """ abc = [alpha, beta, gamma] if parity is None: return o3.direct_sum(*[o3.irr_repr(l, *abc) for mul, l, _ in simplify(Rs) for _ in range(mul)]) else: assert all(parity != 0 for _, _, parity in simplify(Rs)) return o3.direct_sum(*[(p ** parity) * o3.irr_repr(l, *abc) for mul, l, p in simplify(Rs) for _ in range(mul)])
def test_reduce_tensor_antisymmetric_L2(): Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 2)]) assert Rs[0] == (1, 1, 0) q = Q[:3].reshape(3, 5, 5, 5) r = o3.rand_angles() D1 = o3.irr_repr(1, *r) D2 = o3.irr_repr(2, *r) Q1 = torch.einsum('il,jm,kn,zijk->zlmn', D2, D2, D2, q) Q2 = torch.einsum('yz,zijk->yijk', D1, q) assert (Q1 - Q2).abs().max() < 1e-10 assert (q + q.transpose(1, 2)).abs().max() < 1e-10 assert (q + q.transpose(1, 3)).abs().max() < 1e-10 assert (q + q.transpose(3, 2)).abs().max() < 1e-10
def test_sh_is_in_irrep(): with o3.torch_default_dtype(torch.float64): for l in range(4 + 1): a, b = 3.14 * torch.rand(2) # works only for beta in [0, pi] Y = rsh.spherical_harmonics_alpha_beta([l], a, b) * math.sqrt(4 * math.pi) / math.sqrt(2 * l + 1) * (-1) ** l D = o3.irr_repr(l, a, b, 0) assert (Y - D[:, l]).norm() < 1e-10
def __init__(self, Rs, act, n): ''' map to a signal on SO3, apply the non linearity point wise and project back the signal on SO3 is the regular representation of SO3 and we can apply a pointwise operation on these representations :param Rs: input representation :param act: activation function :param n: number of point on the sphere (the higher the more accurate) ''' super().__init__() Rs = rs.simplify(Rs) mul0, _, _ = Rs[0] assert all(mul0 * (2 * l + 1) == mul for mul, l, _ in Rs) assert [l for _, l, _ in Rs] == list(range(len(Rs))) assert all(p == 0 for _, l, p in Rs) self.Rs_out = Rs x = [o3.rand_rot() for _ in range(n)] Z = torch.stack([ torch.cat([ o3.irr_repr(l, *o3.rot_to_abc(R)).flatten() * (2 * l + 1)**0.5 for l in range(len(Rs)) ]) for R in x ]) # [z, lmn] Z.div_(Z.shape[1]**0.5) self.register_buffer('Z', Z) self.act = act
def rsh_surface(l, m, scale, tr, rot): n = 50 a = torch.linspace(0, 2 * math.pi, 2 * n) b = torch.linspace(0, math.pi, n) a, b = torch.meshgrid(a, b) f = rsh.spherical_harmonics_alpha_beta([l], a, b) f = torch.einsum('ij,...j->...i', o3.irr_repr(l, *rot), f) f = f[..., l + m] x, y, z = o3.angles_to_xyz(a, b) r = f.abs() x = scale * r * x + tr[0] y = scale * r * y + tr[1] z = scale * r * z + tr[2] max_value = 0.5 return go.Surface( x=x.numpy(), y=y.numpy(), z=z.numpy(), surfacecolor=f.numpy(), showscale=False, cmin=-max_value, cmax=max_value, colorscale=[[0, 'rgb(0,50,255)'], [0.5, 'rgb(200,200,200)'], [1, 'rgb(255,50,0)']], )
def test_reduce_tensor_Levi_Civita_symbol(): Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 1)]) assert Rs == [(1, 0, 0)] r = o3.rand_angles() D = o3.irr_repr(1, *r) Q = Q.reshape(3, 3, 3) Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q) assert (Q1 - Q).abs().max() < 1e-10
def test_xyz_to_irreducible_basis(self, ): with o3.torch_default_dtype(torch.float64): A = o3.xyz_to_irreducible_basis() a, b, c = torch.rand(3) r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A r2 = o3.rot(a, b, c) assert torch.allclose(r1, r2)
def test_xyz_vector_basis_to_spherical_basis(self, ): with o3.torch_default_dtype(torch.float64): A = o3.xyz_vector_basis_to_spherical_basis() a, b, c = torch.rand(3) r1 = A.t() @ o3.irr_repr(1, a, b, c) @ A r2 = o3.rot(a, b, c) self.assertLess((r1 - r2).abs().max(), 1e-10)
def test_irrep_closure_1(self): rots = [o3.rand_angles() for _ in range(10000)] Us = [torch.stack([o3.irr_repr(l, *abc) for abc in rots]) for l in range(3 + 1)] for l1, U1 in enumerate(Us): for l2, U2 in enumerate(Us): m = torch.einsum('zij,zkl->zijkl', U1, U2).mean(0).reshape((2 * l1 + 1)**2, (2 * l2 + 1)**2) if l1 == l2: i = torch.eye((2 * l1 + 1)**2) self.assertLess((m.mul(2 * l1 + 1) - i).abs().max(), 0.1) else: self.assertLess(m.abs().max(), 0.1)
def test2(self): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test6(self): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_irr_repr_wigner_3j(self): """Test irr_repr and wigner_3j equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = o3.wigner_3j(l_out, l_in, l_f) abc = torch.randn(3) D_in = o3.irr_repr(l_in, *abc) D_out = o3.irr_repr(l_out, *abc) Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t()) W = torch.einsum("ijk,zk->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = rsh.spherical_harmonics_xyz([l_f], r) W = torch.einsum("ijk,zk->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def test_derivative_irr_repr(self): with o3.torch_default_dtype(torch.float64): l = 4 angles = o3.rand_angles() da, db, dc = o3.derivative_irr_repr(l, *angles) h = 1e-7 da1 = (o3.irr_repr(l, angles[0] + h, angles[1], angles[2]) - o3.irr_repr(l, *angles)) / h db1 = (o3.irr_repr(l, angles[0], angles[1] + h, angles[2]) - o3.irr_repr(l, *angles)) / h dc1 = (o3.irr_repr(l, angles[0], angles[1], angles[2] + h) - o3.irr_repr(l, *angles)) / h self.assertLess((da1 - da).abs().max(), 1e-5) self.assertLess((db1 - db).abs().max(), 1e-5) self.assertLess((dc1 - dc).abs().max(), 1e-5)
def test_sh_equivariance(): """ This test tests that - irr_repr - compose - spherical_harmonics are compatible Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x) with x = Z(a) Y(b) eta """ for l in range(7): with o3.torch_default_dtype(torch.float64): a, b = torch.rand(2) alpha, beta, gamma = torch.rand(3) ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0) Yrx = rsh.spherical_harmonics_alpha_beta([l], ra, rb) Y = rsh.spherical_harmonics_alpha_beta([l], a, b) DrY = o3.irr_repr(l, alpha, beta, gamma) @ Y assert (Yrx - DrY).abs().max() < 1e-10 * Y.abs().max()
def test_spherical_harmonics(self): """ This test tests that - irr_repr - compose - spherical_harmonics are compatible Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x) with x = Z(a) Y(b) eta """ for order in range(7): with o3.torch_default_dtype(torch.float64): a, b = torch.rand(2) alpha, beta, gamma = torch.rand(3) ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0) Yrx = o3.spherical_harmonics(order, ra, rb) Y = o3.spherical_harmonics(order, a, b) DrY = o3.irr_repr(order, alpha, beta, gamma) @ Y self.assertLess((Yrx - DrY).abs().max(), 1e-10 * Y.abs().max())
data = np.random.randn(36).reshape((6,6)) et1 = ElasticTensor(data, 'voigt') # back and forth transformation to cartesian ensures that voigt is symmetric _ = et1.voigt_to_cartesian() _ = et1.cartesian_to_voigt() _ = et1.voigt_to_cartesian() _ = et1.cartesian_to_covariant() _ = et1.covariant_to_spherical() et2 = ElasticTensor(et1.covariant.copy(), 'covariant') a, b, c = rand_angles() D0 = irr_repr(0, a, b, c).numpy() D1 = irr_repr(1, a, b, c).numpy() D2 = irr_repr(2, a, b, c).numpy() D4 = irr_repr(4, a, b, c).numpy() # forward rotation et2.covariant = np.einsum('ijkn,ai,bj,ck,dn->abcd', et2.covariant, D1, D1, D1, D1) _ = et2.covariant_to_spherical() # inverse rotation et2.spherical[0:1] = D0.T @ et2.spherical[0:1] et2.spherical[1:6] = D2.T @ et2.spherical[1:6] et2.spherical[6:7] = D0.T @ et2.spherical[6:7] et2.spherical[7:12] = D2.T @ et2.spherical[7:12] et2.spherical[12:21] = D4.T @ et2.spherical[12:21]