def germinate_formulas(formula):
    formulas = [(-1 if f.startswith('-') else 1, f.replace('-', ''))
                for f in formula.split('=')]
    s0, f0 = formulas[0]
    assert s0 == 1

    for _s, f in formulas:
        if len(set(f)) != len(f) or set(f) != set(f0):
            raise RuntimeError(f'{f} is not a permutation of {f0}')
        if len(f0) != len(f):
            raise RuntimeError(
                f'{f0} and {f} don\'t have the same number of indices')

    # `formulas` is a list of (sign, permutation of indices)
    # each formula can be viewed as a permutation of the original formula
    formulas = {(s, tuple(f.index(i) for i in f0))
                for s, f in formulas}  # set of generators (permutations)

    # they can be composed, for instance if you have ijk=jik=ikj
    # you also have ijk=jki
    # applying all possible compositions creates an entire group
    while True:
        n = len(formulas)
        formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas])
        formulas = formulas.union([(s1 * s2, perm.compose(p1, p2))
                                   for s1, p1 in formulas
                                   for s2, p2 in formulas])
        if len(formulas) == n:
            break  # we break when the set is stable => it is now a group \o/

    return f0, formulas
    def sort(self):
        r"""Sort the representations.

        irreps : `e3nn.o3.Irreps`
        p : tuple of int
        inv : tuple of int


        >>> Irreps("1e + 0e + 1e").sort().irreps

        >>> Irreps("2o + 1e + 0e + 1e").sort().p
        (3, 1, 0, 2)

        >>> Irreps("2o + 1e + 0e + 1e").sort().inv
        (2, 1, 3, 0)
        Ret = collections.namedtuple("sort", ["irreps", "p", "inv"])
        out = [(ir, i, mul) for i, (mul, ir) in enumerate(self)]
        out = sorted(out)
        inv = tuple(i for _, i, _ in out)
        p = perm.inverse(inv)
        irreps = Irreps([(mul, ir) for ir, _, mul in out])
        return Ret(irreps, p, inv)
def test_natural_representation(float_tolerance, n):
    p = perm.rand(n)
    a = torch.eye(n)[list(perm.inverse(p))]
    b = perm.natural_representation(p)
    assert torch.allclose(a, b, atol=float_tolerance)

    p = perm.rand(n)
    a = torch.eye(n)[:, list(p)]
    b = perm.natural_representation(p)
    assert torch.allclose(a, b, atol=float_tolerance)

    # orthogonal
    a = perm.natural_representation(perm.rand(n))
    assert torch.allclose(a @ a.T, torch.eye(n), atol=float_tolerance)
def test_standard_representation(float_tolerance, n):
    # identity
    e = perm.standard_representation(perm.identity(n))
    assert torch.allclose(e, torch.eye(n - 1), atol=float_tolerance)

    # inverse
    p = perm.rand(n)
    a = perm.standard_representation(p)
    b = perm.standard_representation(perm.inverse(p))
    assert torch.allclose(a, torch.inverse(b), atol=float_tolerance)

    # compose
    p1, p2 = perm.rand(n), perm.rand(n)
    a = perm.standard_representation(p1) @ perm.standard_representation(p2)
    b = perm.standard_representation(perm.compose(p1, p2))
    assert torch.allclose(a, b, atol=float_tolerance)

    # orthogonal
    a = perm.standard_representation(perm.rand(n))
    assert torch.allclose(a @ a.T, torch.eye(n - 1), atol=float_tolerance)
def test_inverse(n):
    for p in perm.group(n):
        ip = perm.inverse(p)

        assert perm.compose(p, ip) == perm.identity(n)
        assert perm.compose(ip, p) == perm.identity(n)