def test3(self): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) f = GatedBlock(partial(C, Rs_in), Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) abc = torch.randn(3) rot_geo = rot(*abc) D_in = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo))) x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def __clebsch_gordan(l1, l2, l3, _version=4): """ Computes the Clebsch–Gordan coefficients D(l1)_il D(l2)_jm D(l3)_kn Q_lmn == Q_ijk """ # these three propositions are equivalent assert abs(l2 - l3) <= l1 <= l2 + l3 assert abs(l3 - l1) <= l2 <= l3 + l1 assert abs(l1 - l2) <= l3 <= l1 + l2 with torch_default_dtype(torch.float64): null_space = _get_d_null_space(l1, l2, l3) assert null_space.size( 0) == 1, null_space.size() # unique subspace solution Q = null_space[0] Q = Q.view(2 * l1 + 1, 2 * l2 + 1, 2 * l3 + 1) if next(x for x in Q.flatten() if x.abs() > 1e-10 * Q.abs().max()) < 0: Q.neg_() abc = torch.rand(3) _Q = torch.einsum( "il,jm,kn,lmn", (irr_repr(l1, *abc), irr_repr(l2, *abc), irr_repr(l3, *abc), Q)) assert torch.allclose(Q, _Q) assert Q.dtype == torch.float64 return Q # [m1, m2, m3]
def test5(self): """Test parity equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) C = partial(Convolution, K) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] f = GatedBlockParity(C, Rs_in, *scalars, *gates, rs_nonscalars) D_in = direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in f.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo))) x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def parity_rotation_gated_block_parity(self, K): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(act.Rs_out, *abc, 1) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def xyz3x3_to_irreducible_basis(): """ to convert a 3x3 tensor transforming with xyz3x3_repr(a, b, c) into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c) see assert for usage """ with torch_default_dtype(torch.float64): to1 = torch.tensor([ [1, 0, 0, 0, 1, 0, 0, 0, 1], ], dtype=torch.get_default_dtype()) assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ xyz3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) to3 = torch.tensor([ [0, 0, -1, 0, 0, 0, 1, 0, 0], [0, 1, 0, -1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, -1, 0], ], dtype=torch.get_default_dtype()) assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ xyz3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) to5 = torch.tensor([ [0, 1, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 1, 0], [-3**.5 / 3, 0, 0, 0, -3**.5 / 3, 0, 0, 0, 12**.5 / 3], [0, 0, 1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, -1, 0, 0, 0, 0] ], dtype=torch.get_default_dtype()) assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ xyz3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype())
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel)) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ o3.rot(*abc).t() y = f(c(x, geo), dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
def rotation_gated_block(self, K): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(K, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def xyz_to_irreducible_basis(check=True): """ to convert a vector [x, y, z] transforming with rot(a, b, c) into a vector transforming with irr_repr(1, a, b, c) see assert for usage """ with torch_default_dtype(torch.float64): A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64) if check: assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3)) return A.type(torch.get_default_dtype())
def spherical_basis_vector_to_xyz_basis(check=True): """ to convert a vector transforming with irr_repr(1, a, b, c) into a vector [x, y, z] transforming with rot(a, b, c) see assert for usage Inverse of xyz_vector_basis_to_spherical_basis """ with torch_default_dtype(torch.float64): A = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float64) if check: assert all(torch.allclose(A @ irr_repr(1, a, b, c), rot(a, b, c) @ A) for a, b, c in torch.rand(10, 3)) return A.type(torch.get_default_dtype())
def parity_kernel(self, K): """Test parity equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)] Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)] k = K(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) D_in = rs.rep(Rs_in, 0, 0, 0, 1) D_out = rs.rep(Rs_out, 0, 0, 0, 1) W1 = D_out @ k(r) # [i, j] W2 = k(-r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def reduce_tensor_product(Rs_i, Rs_j): """ Compute the orthonormal change of basis Q from Rs_reduced to Rs_i tensor product with Rs_j where Rs_reduced is a direct sum of irreducible representations :return: Rs_reduced, Q """ with torch_default_dtype(torch.float64): Rs_i = normalizeRs(Rs_i) Rs_j = normalizeRs(Rs_j) n_i = sum(mul * (2 * l + 1) for mul, l, p in Rs_i) n_j = sum(mul * (2 * l + 1) for mul, l, p in Rs_j) out = torch.zeros(n_i, n_j, n_i * n_j, dtype=torch.float64) Rs_reduced = [] beg = 0 beg_i = 0 for mul_i, l_i, p_i in Rs_i: n_i = mul_i * (2 * l_i + 1) beg_j = 0 for mul_j, l_j, p_j in Rs_j: n_j = mul_j * (2 * l_j + 1) for l in range(abs(l_i - l_j), l_i + l_j + 1): Rs_reduced.append((mul_i * mul_j, l, p_i * p_j)) n = mul_i * mul_j * (2 * l + 1) # put sqrt(2l+1) to get an orthonormal output Q = math.sqrt(2 * l + 1) * clebsch_gordan( l_i, l_j, l) # [m_i, m_j, m] I = torch.eye(mul_i * mul_j).view( mul_i, mul_j, mul_i * mul_j) # [mul_i, mul_j, mul_i * mul_j] Q = torch.einsum("ijk,mno->imjnko", (I, Q)) view = out[beg_i:beg_i + n_i, beg_j:beg_j + n_j, beg:beg + n] view.add_(Q.view_as(view)) beg += n beg_j += n_j beg_i += n_i return Rs_reduced, out
def rotation_kernel(self, K): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = K(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test_equivariance_s2network(self): with torch_default_dtype(torch.float64): mul = 3 Rs_in = [(mul, l) for l in range(3 + 1)] Rs_out = [(mul, l) for l in range(3 + 1)] net = S2Network(Rs_in, mul, lmax=4, Rs_out=Rs_out) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(10, rs.dim(Rs_in)) x1 = torch.einsum("ij,zj->zi", D_out, net(fea)) x2 = net(torch.einsum("ij,zj->zi", D_in, fea)) self.assertLess((x1 - x2).norm(), 1e-3 * x1.norm())
def test2(self): with torch_default_dtype(torch.float64): mul = 100000 for l_in in range(4): Rs_in = [(mul, l_in)] for l_out in range(4): Rs_out = [(1, l_out)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel, normalization='norm') k = k(torch.randn(1, 3)) self.assertLess(k.mean().item(), 1e-3) self.assertAlmostEqual(k.var().item() * mul, 1 / (2 * l_out + 1), places=1)
def test2(self): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = direct_sum( *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)]) W1 = D_out @ k(r) # [i, j] W2 = k(rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def parity_rotation_linear(self, L): """Test parity and rotation equivariance on Linear.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] Rs_out = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] lin = L(Rs_in, Rs_out) abc = torch.randn(3) D_in = rs.rep(lin.Rs_in, *abc, 1) D_out = rs.rep(lin.Rs_out, *abc, 1) fea = torch.randn(rs.dim(Rs_in)) x1 = torch.einsum("ij,j->i", D_out, lin(fea)) x2 = lin(torch.einsum("ij,j->i", D_in, fea)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def __init__(self, lmax, res=None, normalization='component'): """ :param lmax: lmax of the input signal :param res: resolution of the output as a tuple (beta resolution, alpha resolution) :param normalization: either 'norm' or 'component' """ super().__init__() assert normalization in [ 'norm', 'component' ], "normalization needs to be 'norm' or 'component'" if isinstance(res, int): res_beta, res_alpha = res, res elif res is None: res_beta = 2 * (lmax + 1) res_alpha = 2 * res_beta else: res_beta, res_alpha = res del res assert res_beta % 2 == 0 assert res_beta >= 2 * (lmax + 1) alphas, betas, sha, shb = spherical_harmonics_s2_grid( lmax, res_alpha, res_beta) with torch_default_dtype(torch.float64): # normalize such that all l has the same variance on the sphere if normalization == 'component': n = math.sqrt(4 * math.pi) * torch.tensor( [1 / math.sqrt(2 * l + 1) for l in range(lmax + 1)]) / math.sqrt(lmax + 1) if normalization == 'norm': n = math.sqrt( 4 * math.pi) * torch.ones(lmax + 1) / math.sqrt(lmax + 1) m = rsh.spherical_harmonics_expand_matrix(lmax) # [l, m, i] shb = torch.einsum('lmj,bj,lmi,l->mbi', m, shb, m, n) # [m, b, i] self.register_buffer('alphas', alphas) self.register_buffer('betas', betas) self.register_buffer('sha', sha) self.register_buffer('shb', shb) self.to(torch.get_default_dtype())
def test_equivariance_gatedconvnetwork(self): with torch_default_dtype(torch.float64): mul = 3 Rs_in = [(mul, l) for l in range(3 + 1)] Rs_out = [(mul, l) for l in range(3 + 1)] net = GatedConvNetwork(Rs_in, [(10, 0), (1, 1), (1, 2), (1, 3)], Rs_out) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 10, rs.dim(Rs_in)) geo = torch.randn(1, 10, 3) x1 = torch.einsum("ij,zaj->zai", D_out, net(fea, geo)) x2 = net(torch.einsum("ij,zaj->zai", D_in, fea), torch.einsum("ij,zaj->zai", rot_geo, geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test4(self): """Test parity equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)] Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) D_in = direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(-r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test6(self): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * o3.irr_repr(l, *abc) for mul, l, p in act.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_irr_repr_wigner_3j(self): """Test irr_repr and wigner_3j equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = o3.wigner_3j(l_out, l_in, l_f) abc = torch.randn(3) D_in = o3.irr_repr(l_in, *abc) D_out = o3.irr_repr(l_out, *abc) Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t()) W = torch.einsum("ijk,zk->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = rsh.spherical_harmonics_xyz([l_f], r) W = torch.einsum("ijk,zk->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def test1(self): """Test irr_repr and clebsch_gordan equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = clebsch_gordan(l_out, l_in, l_f) abc = torch.randn(3) D_in = irr_repr(l_in, *abc) D_out = irr_repr(l_out, *abc) Y = spherical_harmonics_xyz(l_f, r @ rot(*abc).t()) W = torch.einsum("ijk,kz->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = spherical_harmonics_xyz(l_f, r) W = torch.einsum("ijk,kz->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def test_basis_equivariance(self): with torch_default_dtype(torch.float64): basis = cube_basis_kernels(4 * 5, 2, 2, partial(gaussian_window, radii=[5], J_max_list=[999], sigma=2)) overlaps = check_basis_equivariance(basis, 2, 2, *torch.rand(3)) self.assertTrue(overlaps.gt(0.98).all(), overlaps)
def spherical_harmonics_xyz(order, xyz, sph_last=False, dtype=None, device=None): """ spherical harmonics :param order: int or list :param xyz: tensor of shape [..., 3] :param sph_last: return the spherical harmonics in the last channel :param dtype: :param device: :return: tensor of shape [m, ...] (or [..., m] if sph_last) """ try: order = list(order) except TypeError: order = [order] if dtype is None and torch.is_tensor(xyz): dtype = xyz.dtype if dtype is None: dtype = torch.get_default_dtype() if device is None and torch.is_tensor(xyz): device = xyz.device if not torch.is_tensor(xyz): xyz = torch.tensor(xyz, dtype=torch.float64) with torch_default_dtype(torch.float64): if device.type == 'cuda' and max(order) <= 10: max_l = max(order) out = xyz.new_empty(((max_l + 1) * (max_l + 1), xyz.size(0))) # [ filters, batch_size] xyz_unit = torch.nn.functional.normalize(xyz, p=2, dim=-1) real_spherical_harmonics.rsh(out, xyz_unit) # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined # h - halved norm_coef = [ elem for lh in range((max_l + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3) ] if max_l % 2 == 0: norm_coef.extend([1.] * (2 * max_l + 1)) norm_coef = torch.tensor(norm_coef, device=device).unsqueeze(1) out.mul_(norm_coef) if order != list(range(max_l + 1)): keep_rows = torch.zeros(out.size(0), dtype=torch.bool) for l in order: keep_rows[(l * l):((l + 1) * (l + 1))].fill_(True) out = out[keep_rows.to(device)] else: alpha, beta = xyz_to_angles(xyz) # two tensors of shape [...] out = spherical_harmonics(order, alpha, beta) # [m, ...] # fix values when xyz = 0 val = xyz.new_tensor([1 / math.sqrt(4 * math.pi)]) val = torch.cat([ val if l == 0 else xyz.new_zeros(2 * l + 1) for l in order ]) # [m] out[:, xyz.norm(2, -1) == 0] = val.view(-1, 1) if sph_last: rank = len(out.shape) return out.to(dtype=dtype, device=device).permute(*range(1, rank), 0).contiguous() else: return out.to(dtype=dtype, device=device)
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs): """ Usage Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)]) Rs = 0,2,4 Q = tensor of shape [15, 81] """ with torch_default_dtype(torch.float64): formulas = [(-1 if f.startswith('-') else 1, f.replace('-', '')) for f in formula.split('=')] s0, f0 = formulas[0] assert s0 == 1 for _s, f in formulas: if len(set(f)) != len(f) or set(f) != set(f0): raise RuntimeError(f'{f} is not a permutation of {f0}') if len(f0) != len(f): raise RuntimeError( f'{f0} and {f} don\'t have the same number of indices') formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) # create the entire group while True: n = len(formulas) formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas]) formulas = formulas.union([(s1 * s2, perm.compose(p1, p2)) for s1, p1 in formulas for s2, p2 in formulas]) if len(formulas) == n: break for i in kw_Rs: if not callable(kw_Rs[i]): Rs = convention(kw_Rs[i]) if has_parity is None: has_parity = any(p != 0 for _, _, p in Rs) if not has_parity and not all(p == 0 for _, _, p in Rs): raise RuntimeError( f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere' ) if has_parity and any(p == 0 for _, _, p in Rs): raise RuntimeError( f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere' ) kw_Rs[i] = Rs if has_parity is None: raise RuntimeError(f'please specify the argument `has_parity`') for _s, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]: raise RuntimeError( f'Rs of {i} (Rs={format_Rs(kw_Rs[i])}) and {j} (Rs={format_Rs(kw_Rs[j])}) should be the same' ) if i in kw_Rs: kw_Rs[j] = kw_Rs[i] if j in kw_Rs: kw_Rs[i] = kw_Rs[j] for i in f0: if i not in kw_Rs: raise RuntimeError(f'index {i} has not Rs associated to it') e = (0, 0, 0, 0) if has_parity else (0, 0, 0) full_base = list( itertools.product(*(range( len(kw_Rs[i](*e)) if callable(kw_Rs[i]) else dim(kw_Rs[i])) for i in f0))) base = set() for x in full_base: xs = {(s, tuple(x[i] for i in p)) for s, p in formulas} # s * T[x] all equal for (s, x) in xs if not (-1, x) in xs: # the sign is arbitrary, put both possibilities base.add( frozenset( {frozenset(xs), frozenset({(-s, x) for s, x in xs})})) base = sorted([ sorted([sorted(xs) for xs in x]) for x in base ]) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) d_sym = len(base) d = len(full_base) Q = torch.zeros(d_sym, d) for i, x in enumerate(base): x = max(x, key=lambda xs: sum(s for s, x in xs)) for s, e in x: j = full_base.index(e) Q[i, j] = s / len(x)**0.5 assert torch.allclose(Q @ Q.T, torch.eye(d_sym)) if d_sym == 0: return [], torch.zeros(d_sym, d) def representation(alpha, beta, gamma, parity=None): def re(r): if callable(r): if has_parity: return r(alpha, beta, gamma, parity) return r(alpha, beta, gamma) return rep(r, alpha, beta, gamma, parity) m = o3.kron(*(re(kw_Rs[i]) for i in f0)) return Q @ m @ Q.T assert _is_representation(representation, eps, has_parity) Rs_out = [] A = Q.clone() for l in range(int((d_sym - 1) // 2) + 1): for p in [-1, 1] if has_parity else [0]: if 2 * l + 1 > d_sym - dim(Rs_out): break mul, B, representation = o3.reduce(representation, partial(rep, [(1, l, p)]), eps, has_parity) A = o3.direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A A = _round_sqrt(A, eps) Rs_out += [(mul, l, p)] if dim(Rs_out) == d_sym: break if dim(Rs_out) != d_sym: raise RuntimeError( f'unable to decompose into irreducible representations') return simplify(Rs_out), A