Exemplo n.º 1
0
    def test_reduce_tensor_product(self):
        for Rs_i, Rs_j in [([(1, 0)], [(2, 0)]),
                           ([(3, 1), (2, 2)], [(2, 0), (1, 1), (1, 3)])]:
            with o3.torch_default_dtype(torch.float64):
                Rs, Q = rs.tensor_product(Rs_i, Rs_j)

                abc = torch.rand(3, dtype=torch.float64)

                D_i = o3.direct_sum(*[
                    o3.irr_repr(l, *abc) for mul, l in Rs_i for _ in range(mul)
                ])
                D_j = o3.direct_sum(*[
                    o3.irr_repr(l, *abc) for mul, l in Rs_j for _ in range(mul)
                ])
                D = o3.direct_sum(*[
                    o3.irr_repr(l, *abc) for mul, l, _ in Rs
                    for _ in range(mul)
                ])

                Q1 = torch.einsum("ijk,il->ljk", (Q, D))
                Q2 = torch.einsum("li,mj,kij->klm", (D_i, D_j, Q))

                d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                n = Q.size(0)
                M = Q.view(n, n)
                I = torch.eye(n, dtype=M.dtype)

                d = ((M @ M.t()) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)

                d = ((M.t() @ M) - I).pow(2).mean().sqrt()
                self.assertLess(d, 1e-10)
Exemplo n.º 2
0
    def test3(self):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            act = GatedBlock(Rs_out,
                             scalar_activation=sigmoid,
                             gate_activation=sigmoid)
            conv = Convolution(K, Rs_in, act.Rs_in)

            abc = torch.randn(3)
            rot_geo = o3.rot(*abc)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(
                conv(torch.einsum("ij,zaj->zai", (D_in, fea)),
                     torch.einsum("ij,zaj->zai", rot_geo, geo)))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemplo n.º 3
0
    def test1(self):
        with torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)]
            Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)]

            f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5),
                           rescaled_act.sigmoid)
            c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel))

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 5, 3)

            rx = torch.einsum("ij,zaj->zai", (D_in, x))
            rgeo = geo @ o3.rot(*abc).t()

            y = f(c(x, geo), dim=2)
            ry = torch.einsum("ij,zaj->zai", (D_out, y))

            self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
Exemplo n.º 4
0
    def test5(self):
        """Test parity equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            act = GatedBlockParity(*scalars, *gates, rs_nonscalars)
            conv = Convolution(K, Rs_in, act.Rs_in)

            D_in = o3.direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in act.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo))))
            x2 = act(conv(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Exemplo n.º 5
0
def rep(Rs, alpha, beta, gamma, parity=None):
    """
    Representation of O(3). Parity applied (-1)**parity times.
    """
    abc = [alpha, beta, gamma]
    if parity is None:
        return o3.direct_sum(*[o3.irr_repr(l, *abc) for mul, l, _ in simplify(Rs) for _ in range(mul)])
    else:
        assert all(parity != 0 for _, _, parity in simplify(Rs))
        return o3.direct_sum(*[(p ** parity) * o3.irr_repr(l, *abc) for mul, l, p in simplify(Rs) for _ in range(mul)])
Exemplo n.º 6
0
 def D(*angles):
     d = o3.direct_sum(
         o3.irr_repr(1, *angles),
         o3.irr_repr(1, *angles),
         o3.irr_repr(2, *angles),
     )
     return A @ d @ torch.inverse(A)
Exemplo n.º 7
0
    def test2(self):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = o3.direct_sum(
                *
                [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = o3.direct_sum(*[
                o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(o3.rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
Exemplo n.º 8
0
    def test4(self):
        """Test parity equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)]
            Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            D_in = o3.direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = o3.direct_sum(*[
                p * torch.eye(2 * l + 1) for mul, l, p in Rs_out
                for _ in range(mul)
            ])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(-r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
Exemplo n.º 9
0
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs):
    """
    Usage
    Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)])
    Rs = 0,2,4
    Q = tensor of shape [15, 81]
    """
    dtype = torch.get_default_dtype()
    with torch_default_dtype(torch.float64):
        # reformat `formulas` and make checks
        formulas = [(-1 if f.startswith('-') else 1, f.replace('-', ''))
                    for f in formula.split('=')]
        s0, f0 = formulas[0]
        assert s0 == 1

        for _s, f in formulas:
            if len(set(f)) != len(f) or set(f) != set(f0):
                raise RuntimeError(f'{f} is not a permutation of {f0}')
            if len(f0) != len(f):
                raise RuntimeError(
                    f'{f0} and {f} don\'t have the same number of indices')

        # `formulas` is a list of (sign, permutation of indices)
        # each formula can be viewed as a permutation of the original formula
        formulas = {(s, tuple(f.index(i) for i in f0))
                    for s, f in formulas}  # set of generators (permutations)

        # they can be composed, for instance if you have ijk=jik=ikj
        # you also have ijk=jki
        # applying all possible compositions creates an entire group
        while True:
            n = len(formulas)
            formulas = formulas.union([(s, perm.inverse(p))
                                       for s, p in formulas])
            formulas = formulas.union([(s1 * s2, perm.compose(p1, p2))
                                       for s1, p1 in formulas
                                       for s2, p2 in formulas])
            if len(formulas) == n:
                break  # we break when the set is stable => it is now a group \o/

        # lets clean the `kw_Rs` before checking that they are compatible with the formulas
        for i in kw_Rs:
            if not callable(kw_Rs[i]):
                Rs = convention(kw_Rs[i])
                if has_parity is None:
                    has_parity = any(p != 0 for _, _, p in Rs)
                if not has_parity and not all(p == 0 for _, _, p in Rs):
                    raise RuntimeError(
                        f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere'
                    )
                if has_parity and any(p == 0 for _, _, p in Rs):
                    raise RuntimeError(
                        f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere'
                    )
                kw_Rs[i] = Rs

        if has_parity is None:
            raise RuntimeError(f'please specify the argument `has_parity`')

        # here we check that each index has one and only one representation
        for _s, p in formulas:
            f = "".join(f0[i] for i in p)
            for i, j in zip(f0, f):
                if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]:
                    raise RuntimeError(
                        f'Rs of {i} (Rs={format_Rs(kw_Rs[i])}) and {j} (Rs={format_Rs(kw_Rs[j])}) should be the same'
                    )
                if i in kw_Rs:
                    kw_Rs[j] = kw_Rs[i]
                if j in kw_Rs:
                    kw_Rs[i] = kw_Rs[j]

        for i in f0:
            if i not in kw_Rs:
                raise RuntimeError(f'index {i} has not Rs associated to it')

        e = (0, 0, 0, 0) if has_parity else (0, 0, 0)
        dims = {
            i: len(kw_Rs[i](*e)) if callable(kw_Rs[i]) else dim(kw_Rs[i])
            for i in f0
        }  # dimension of each index
        full_base = list(itertools.product(
            *(range(dims[i])
              for i in f0)))  # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3)
        # len(full_base) degrees of freedom in an unconstrained tensor

        # but there is constraints given by the group `formulas`
        # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on
        base = set()
        for x in full_base:
            # T[x] is a coefficient of the tensor T and is related to other coefficient T[y]
            # if x and y are related by a formula
            xs = {(s, tuple(x[i] for i in p)) for s, p in formulas}
            # s * T[x] are all equal for all (s, x) in xs
            # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom
            if not (-1, x) in xs:
                # the sign is arbitrary, put both possibilities
                base.add(
                    frozenset(
                        {frozenset(xs),
                         frozenset({(-s, x)
                                    for s, x in xs})}))

        # len(base) is the number of degrees of freedom in the tensor.
        # Now we want to decompose these degrees of freedom into irreps

        base = sorted([
            sorted([sorted(xs) for xs in x]) for x in base
        ])  # requested for python 3.7 but not for 3.8 (probably a bug in 3.7)

        # First we compute the change of basis (projection) between full_base and base
        d_sym = len(base)
        d = len(full_base)
        Q = torch.zeros(d_sym, d)

        for i, x in enumerate(base):
            x = max(x, key=lambda xs: sum(s for s, x in xs))
            for s, e in x:
                j = full_base.index(e)
                Q[i, j] = s / len(x)**0.5

        assert torch.allclose(Q @ Q.T, torch.eye(d_sym))

        if d_sym == 0:
            return [], torch.zeros(d_sym, d).to(dtype=dtype)

        # We project the representation on the basis `base`
        def representation(alpha, beta, gamma, parity=None):
            def re(r):
                if callable(r):
                    if has_parity:
                        return r(alpha, beta, gamma, parity)
                    return r(alpha, beta, gamma)
                return rep(r, alpha, beta, gamma, parity)

            m = o3.kron(*(re(kw_Rs[i]) for i in f0))
            return Q @ m @ Q.T

        # And check that after this projection it is still a representation
        assert _is_representation(representation, eps, has_parity)

        # The rest of the code simply extract the irreps present in this representation
        Rs_out = []
        A = Q.clone()
        for l in range(int((d_sym - 1) // 2) + 1):
            for p in [-1, 1] if has_parity else [0]:
                if 2 * l + 1 > d_sym - dim(Rs_out):
                    break

                mul, B, representation = o3.reduce(representation,
                                                   partial(rep, [(1, l, p)]),
                                                   eps, has_parity)
                A = o3.direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A
                A = _round_sqrt(A, eps)
                Rs_out += [(mul, l, p)]

                if dim(Rs_out) == d_sym:
                    break

        if dim(Rs_out) != d_sym:
            raise RuntimeError(
                f'unable to decompose into irreducible representations')

        return simplify(Rs_out), A.to(dtype=dtype)
Exemplo n.º 10
0
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs):
    """
    Usage
    Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)])
    Rs = 0,2,4
    Q = tensor of shape [15, 81]
    """
    with torch_default_dtype(torch.float64):
        formulas = [(-1 if f.startswith('-') else 1, f.replace('-', ''))
                    for f in formula.split('=')]
        s0, f0 = formulas[0]
        assert s0 == 1

        for _s, f in formulas:
            if len(set(f)) != len(f) or set(f) != set(f0):
                raise RuntimeError(f'{f} is not a permutation of {f0}')
            if len(f0) != len(f):
                raise RuntimeError(
                    f'{f0} and {f} don\'t have the same number of indices')

        formulas = {(s, tuple(f.index(i) for i in f0))
                    for s, f in formulas}  # set of generators (permutations)

        # create the entire group
        while True:
            n = len(formulas)
            formulas = formulas.union([(s, perm.inverse(p))
                                       for s, p in formulas])
            formulas = formulas.union([(s1 * s2, perm.compose(p1, p2))
                                       for s1, p1 in formulas
                                       for s2, p2 in formulas])
            if len(formulas) == n:
                break

        for i in kw_Rs:
            if not callable(kw_Rs[i]):
                Rs = convention(kw_Rs[i])
                if has_parity is None:
                    has_parity = any(p != 0 for _, _, p in Rs)
                if not has_parity and not all(p == 0 for _, _, p in Rs):
                    raise RuntimeError(
                        f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere'
                    )
                if has_parity and any(p == 0 for _, _, p in Rs):
                    raise RuntimeError(
                        f'{format_Rs(Rs)} parity has to be specified everywhere or nowhere'
                    )
                kw_Rs[i] = Rs

        if has_parity is None:
            raise RuntimeError(f'please specify the argument `has_parity`')

        for _s, p in formulas:
            f = "".join(f0[i] for i in p)
            for i, j in zip(f0, f):
                if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]:
                    raise RuntimeError(
                        f'Rs of {i} (Rs={format_Rs(kw_Rs[i])}) and {j} (Rs={format_Rs(kw_Rs[j])}) should be the same'
                    )
                if i in kw_Rs:
                    kw_Rs[j] = kw_Rs[i]
                if j in kw_Rs:
                    kw_Rs[i] = kw_Rs[j]

        for i in f0:
            if i not in kw_Rs:
                raise RuntimeError(f'index {i} has not Rs associated to it')

        e = (0, 0, 0, 0) if has_parity else (0, 0, 0)
        full_base = list(
            itertools.product(*(range(
                len(kw_Rs[i](*e)) if callable(kw_Rs[i]) else dim(kw_Rs[i]))
                                for i in f0)))

        base = set()
        for x in full_base:
            xs = {(s, tuple(x[i] for i in p)) for s, p in formulas}
            # s * T[x] all equal for (s, x) in xs
            if not (-1, x) in xs:
                # the sign is arbitrary, put both possibilities
                base.add(
                    frozenset(
                        {frozenset(xs),
                         frozenset({(-s, x)
                                    for s, x in xs})}))

        base = sorted([
            sorted([sorted(xs) for xs in x]) for x in base
        ])  # requested for python 3.7 but not for 3.8 (probably a bug in 3.7)

        d_sym = len(base)
        d = len(full_base)
        Q = torch.zeros(d_sym, d)

        for i, x in enumerate(base):
            x = max(x, key=lambda xs: sum(s for s, x in xs))
            for s, e in x:
                j = full_base.index(e)
                Q[i, j] = s / len(x)**0.5

        assert torch.allclose(Q @ Q.T, torch.eye(d_sym))

        if d_sym == 0:
            return [], torch.zeros(d_sym, d)

        def representation(alpha, beta, gamma, parity=None):
            def re(r):
                if callable(r):
                    if has_parity:
                        return r(alpha, beta, gamma, parity)
                    return r(alpha, beta, gamma)
                return rep(r, alpha, beta, gamma, parity)

            m = o3.kron(*(re(kw_Rs[i]) for i in f0))
            return Q @ m @ Q.T

        assert _is_representation(representation, eps, has_parity)

        Rs_out = []
        A = Q.clone()
        for l in range(int((d_sym - 1) // 2) + 1):
            for p in [-1, 1] if has_parity else [0]:
                if 2 * l + 1 > d_sym - dim(Rs_out):
                    break

                mul, B, representation = o3.reduce(representation,
                                                   partial(rep, [(1, l, p)]),
                                                   eps, has_parity)
                A = o3.direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A
                A = _round_sqrt(A, eps)
                Rs_out += [(mul, l, p)]

                if dim(Rs_out) == d_sym:
                    break

        if dim(Rs_out) != d_sym:
            raise RuntimeError(
                f'unable to decompose into irreducible representations')
        return simplify(Rs_out), A