def test_reduce_tensor_product(self): for Rs_i, Rs_j in [([(1, 0)], [(2, 0)]), ([(3, 1), (2, 2)], [(2, 0), (1, 1), (1, 3)])]: with o3.torch_default_dtype(torch.float64): Rs, Q = rs.tensor_product(Rs_i, Rs_j) abc = torch.rand(3, dtype=torch.float64) D_i = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_i for _ in range(mul) ]) D_j = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_j for _ in range(mul) ]) D = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l, _ in Rs for _ in range(mul) ]) Q1 = torch.einsum("ijk,il->ljk", (Q, D)) Q2 = torch.einsum("li,mj,kij->klm", (D_i, D_j, Q)) d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt() self.assertLess(d, 1e-10) n = Q.size(0) M = Q.view(n, n) I = torch.eye(n, dtype=M.dtype) d = ((M @ M.t()) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10) d = ((M.t() @ M) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10)
def test3(self): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(Kernel, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel)) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ o3.rot(*abc).t() y = f(c(x, geo), dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
def test5(self): """Test parity equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K, Rs_in, act.Rs_in) D_in = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in act.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act(conv(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def rep(Rs, alpha, beta, gamma, parity=None): """ Representation of O(3). Parity applied (-1)**parity times. """ abc = [alpha, beta, gamma] if parity is None: return o3.direct_sum(*[o3.irr_repr(l, *abc) for mul, l, _ in simplify(Rs) for _ in range(mul)]) else: assert all(parity != 0 for _, _, parity in simplify(Rs)) return o3.direct_sum(*[(p ** parity) * o3.irr_repr(l, *abc) for mul, l, p in simplify(Rs) for _ in range(mul)])
def D(*angles): d = o3.direct_sum( o3.irr_repr(1, *angles), o3.irr_repr(1, *angles), o3.irr_repr(2, *angles), ) return A @ d @ torch.inverse(A)
def test2(self): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test4(self): """Test parity equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)] Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) D_in = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(-r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs): """ Usage Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)]) Rs = 0,2,4 Q = tensor of shape [15, 81] """ dtype = torch.get_default_dtype() with torch_default_dtype(torch.float64): # reformat `formulas` and make checks formulas = [(-1 if f.startswith('-') else 1, f.replace('-', '')) for f in formula.split('=')] s0, f0 = formulas[0] assert s0 == 1 for _s, f in formulas: if len(set(f)) != len(f) or set(f) != set(f0): raise RuntimeError(f'{f} is not a permutation of {f0}') if len(f0) != len(f): raise RuntimeError( f'{f0} and {f} don\'t have the same number of indices') # `formulas` is a list of (sign, permutation of indices) # each formula can be viewed as a permutation of the original formula formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) # they can be composed, for instance if you have ijk=jik=ikj # you also have ijk=jki # applying all possible compositions creates an entire group while True: n = len(formulas) formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas]) formulas = formulas.union([(s1 * s2, perm.compose(p1, p2)) for s1, p1 in formulas for s2, p2 in formulas]) if len(formulas) == n: break # we break when the set is stable => it is now a group \o/ # lets clean the `kw_Rs` before checking that they are compatible with the formulas for i in kw_Rs: if not callable(kw_Rs[i]): Rs = convention(kw_Rs[i]) if has_parity is None: has_parity = any(p != 0 for _, _, p in Rs) if not has_parity and not all(p == 0 for _, _, p in Rs): raise RuntimeError( f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere' ) if has_parity and any(p == 0 for _, _, p in Rs): raise RuntimeError( f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere' ) kw_Rs[i] = Rs if has_parity is None: raise RuntimeError(f'please specify the argument `has_parity`') # here we check that each index has one and only one representation for _s, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]: raise RuntimeError( f'Rs of {i} (Rs={format_Rs(kw_Rs[i])}) and {j} (Rs={format_Rs(kw_Rs[j])}) should be the same' ) if i in kw_Rs: kw_Rs[j] = kw_Rs[i] if j in kw_Rs: kw_Rs[i] = kw_Rs[j] for i in f0: if i not in kw_Rs: raise RuntimeError(f'index {i} has not Rs associated to it') e = (0, 0, 0, 0) if has_parity else (0, 0, 0) dims = { i: len(kw_Rs[i](*e)) if callable(kw_Rs[i]) else dim(kw_Rs[i]) for i in f0 } # dimension of each index full_base = list(itertools.product( *(range(dims[i]) for i in f0))) # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3) # len(full_base) degrees of freedom in an unconstrained tensor # but there is constraints given by the group `formulas` # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on base = set() for x in full_base: # T[x] is a coefficient of the tensor T and is related to other coefficient T[y] # if x and y are related by a formula xs = {(s, tuple(x[i] for i in p)) for s, p in formulas} # s * T[x] are all equal for all (s, x) in xs # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom if not (-1, x) in xs: # the sign is arbitrary, put both possibilities base.add( frozenset( {frozenset(xs), frozenset({(-s, x) for s, x in xs})})) # len(base) is the number of degrees of freedom in the tensor. # Now we want to decompose these degrees of freedom into irreps base = sorted([ sorted([sorted(xs) for xs in x]) for x in base ]) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) # First we compute the change of basis (projection) between full_base and base d_sym = len(base) d = len(full_base) Q = torch.zeros(d_sym, d) for i, x in enumerate(base): x = max(x, key=lambda xs: sum(s for s, x in xs)) for s, e in x: j = full_base.index(e) Q[i, j] = s / len(x)**0.5 assert torch.allclose(Q @ Q.T, torch.eye(d_sym)) if d_sym == 0: return [], torch.zeros(d_sym, d).to(dtype=dtype) # We project the representation on the basis `base` def representation(alpha, beta, gamma, parity=None): def re(r): if callable(r): if has_parity: return r(alpha, beta, gamma, parity) return r(alpha, beta, gamma) return rep(r, alpha, beta, gamma, parity) m = o3.kron(*(re(kw_Rs[i]) for i in f0)) return Q @ m @ Q.T # And check that after this projection it is still a representation assert _is_representation(representation, eps, has_parity) # The rest of the code simply extract the irreps present in this representation Rs_out = [] A = Q.clone() for l in range(int((d_sym - 1) // 2) + 1): for p in [-1, 1] if has_parity else [0]: if 2 * l + 1 > d_sym - dim(Rs_out): break mul, B, representation = o3.reduce(representation, partial(rep, [(1, l, p)]), eps, has_parity) A = o3.direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A A = _round_sqrt(A, eps) Rs_out += [(mul, l, p)] if dim(Rs_out) == d_sym: break if dim(Rs_out) != d_sym: raise RuntimeError( f'unable to decompose into irreducible representations') return simplify(Rs_out), A.to(dtype=dtype)
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs): """ Usage Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)]) Rs = 0,2,4 Q = tensor of shape [15, 81] """ with torch_default_dtype(torch.float64): formulas = [(-1 if f.startswith('-') else 1, f.replace('-', '')) for f in formula.split('=')] s0, f0 = formulas[0] assert s0 == 1 for _s, f in formulas: if len(set(f)) != len(f) or set(f) != set(f0): raise RuntimeError(f'{f} is not a permutation of {f0}') if len(f0) != len(f): raise RuntimeError( f'{f0} and {f} don\'t have the same number of indices') formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) # create the entire group while True: n = len(formulas) formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas]) formulas = formulas.union([(s1 * s2, perm.compose(p1, p2)) for s1, p1 in formulas for s2, p2 in formulas]) if len(formulas) == n: break for i in kw_Rs: if not callable(kw_Rs[i]): Rs = convention(kw_Rs[i]) if has_parity is None: has_parity = any(p != 0 for _, _, p in Rs) if not has_parity and not all(p == 0 for _, _, p in Rs): raise RuntimeError( f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere' ) if has_parity and any(p == 0 for _, _, p in Rs): raise RuntimeError( f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere' ) kw_Rs[i] = Rs if has_parity is None: raise RuntimeError(f'please specify the argument `has_parity`') for _s, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]: raise RuntimeError( f'Rs of {i} (Rs={format_Rs(kw_Rs[i])}) and {j} (Rs={format_Rs(kw_Rs[j])}) should be the same' ) if i in kw_Rs: kw_Rs[j] = kw_Rs[i] if j in kw_Rs: kw_Rs[i] = kw_Rs[j] for i in f0: if i not in kw_Rs: raise RuntimeError(f'index {i} has not Rs associated to it') e = (0, 0, 0, 0) if has_parity else (0, 0, 0) full_base = list( itertools.product(*(range( len(kw_Rs[i](*e)) if callable(kw_Rs[i]) else dim(kw_Rs[i])) for i in f0))) base = set() for x in full_base: xs = {(s, tuple(x[i] for i in p)) for s, p in formulas} # s * T[x] all equal for (s, x) in xs if not (-1, x) in xs: # the sign is arbitrary, put both possibilities base.add( frozenset( {frozenset(xs), frozenset({(-s, x) for s, x in xs})})) base = sorted([ sorted([sorted(xs) for xs in x]) for x in base ]) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) d_sym = len(base) d = len(full_base) Q = torch.zeros(d_sym, d) for i, x in enumerate(base): x = max(x, key=lambda xs: sum(s for s, x in xs)) for s, e in x: j = full_base.index(e) Q[i, j] = s / len(x)**0.5 assert torch.allclose(Q @ Q.T, torch.eye(d_sym)) if d_sym == 0: return [], torch.zeros(d_sym, d) def representation(alpha, beta, gamma, parity=None): def re(r): if callable(r): if has_parity: return r(alpha, beta, gamma, parity) return r(alpha, beta, gamma) return rep(r, alpha, beta, gamma, parity) m = o3.kron(*(re(kw_Rs[i]) for i in f0)) return Q @ m @ Q.T assert _is_representation(representation, eps, has_parity) Rs_out = [] A = Q.clone() for l in range(int((d_sym - 1) // 2) + 1): for p in [-1, 1] if has_parity else [0]: if 2 * l + 1 > d_sym - dim(Rs_out): break mul, B, representation = o3.reduce(representation, partial(rep, [(1, l, p)]), eps, has_parity) A = o3.direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A A = _round_sqrt(A, eps) Rs_out += [(mul, l, p)] if dim(Rs_out) == d_sym: break if dim(Rs_out) != d_sym: raise RuntimeError( f'unable to decompose into irreducible representations') return simplify(Rs_out), A