def germinate_formulas(formula): formulas = [(-1 if f.startswith('-') else 1, f.replace('-', '')) for f in formula.split('=')] s0, f0 = formulas[0] assert s0 == 1 for _s, f in formulas: if len(set(f)) != len(f) or set(f) != set(f0): raise RuntimeError(f'{f} is not a permutation of {f0}') if len(f0) != len(f): raise RuntimeError( f'{f0} and {f} don\'t have the same number of indices') # `formulas` is a list of (sign, permutation of indices) # each formula can be viewed as a permutation of the original formula formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) # they can be composed, for instance if you have ijk=jik=ikj # you also have ijk=jki # applying all possible compositions creates an entire group while True: n = len(formulas) formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas]) formulas = formulas.union([(s1 * s2, perm.compose(p1, p2)) for s1, p1 in formulas for s2, p2 in formulas]) if len(formulas) == n: break # we break when the set is stable => it is now a group \o/ return f0, formulas
def sort(self): r"""Sort the representations. Returns ------- irreps : `e3nn.o3.Irreps` p : tuple of int inv : tuple of int Examples -------- >>> Irreps("1e + 0e + 1e").sort().irreps 1x0e+1x1e+1x1e >>> Irreps("2o + 1e + 0e + 1e").sort().p (3, 1, 0, 2) >>> Irreps("2o + 1e + 0e + 1e").sort().inv (2, 1, 3, 0) """ Ret = collections.namedtuple("sort", ["irreps", "p", "inv"]) out = [(ir, i, mul) for i, (mul, ir) in enumerate(self)] out = sorted(out) inv = tuple(i for _, i, _ in out) p = perm.inverse(inv) irreps = Irreps([(mul, ir) for ir, _, mul in out]) return Ret(irreps, p, inv)
def test_natural_representation(float_tolerance, n): p = perm.rand(n) a = torch.eye(n)[list(perm.inverse(p))] b = perm.natural_representation(p) assert torch.allclose(a, b, atol=float_tolerance) p = perm.rand(n) a = torch.eye(n)[:, list(p)] b = perm.natural_representation(p) assert torch.allclose(a, b, atol=float_tolerance) # orthogonal a = perm.natural_representation(perm.rand(n)) assert torch.allclose(a @ a.T, torch.eye(n), atol=float_tolerance)
def test_standard_representation(float_tolerance, n): # identity e = perm.standard_representation(perm.identity(n)) assert torch.allclose(e, torch.eye(n - 1), atol=float_tolerance) # inverse p = perm.rand(n) a = perm.standard_representation(p) b = perm.standard_representation(perm.inverse(p)) assert torch.allclose(a, torch.inverse(b), atol=float_tolerance) # compose p1, p2 = perm.rand(n), perm.rand(n) a = perm.standard_representation(p1) @ perm.standard_representation(p2) b = perm.standard_representation(perm.compose(p1, p2)) assert torch.allclose(a, b, atol=float_tolerance) # orthogonal a = perm.standard_representation(perm.rand(n)) assert torch.allclose(a @ a.T, torch.eye(n - 1), atol=float_tolerance)
def test_inverse(n): for p in perm.group(n): ip = perm.inverse(p) assert perm.compose(p, ip) == perm.identity(n) assert perm.compose(ip, p) == perm.identity(n)