Ejemplo n.º 1
0
def test_getitem():
    irreps = o3.Irreps("16x1e + 3e + 2e + 5o")
    assert irreps[0] == (16, o3.Irrep("1e"))
    assert irreps[3] == (1, o3.Irrep("5o"))
    assert irreps[-1] == (1, o3.Irrep("5o"))

    sliced = irreps[2:]
    assert isinstance(sliced, o3.Irreps)
    assert sliced == o3.Irreps("2e + 5o")
Ejemplo n.º 2
0
def test_wigner_3j(float_tolerance):
    abc = o3.rand_angles(10)

    l1, l2, l3 = 1, 2, 3
    C = o3.wigner_3j(l1, l2, l3)
    D1 = o3.Irrep(l1, 1).D_from_angles(*abc)
    D2 = o3.Irrep(l2, 1).D_from_angles(*abc)
    D3 = o3.Irrep(l3, 1).D_from_angles(*abc)

    C2 = torch.einsum("ijk,zil,zjm,zkn->zlmn", C, D1, D2, D3)
    assert (C - C2).abs().max() < float_tolerance
Ejemplo n.º 3
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}"

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)

        # inspired by https://arxiv.org/pdf/2002.10444.pdf
        self.alpha = FullyConnectedTensorProduct(irreps_mid,
                                                 self.irreps_node_attr, "0e")
        with torch.no_grad():
            self.alpha.weight.zero_()
        assert self.alpha.output_mask[
            0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
Ejemplo n.º 4
0
def tp_path_exists(irreps_in1, irreps_in2, ir_out):
    irreps_in1 = o3.Irreps(irreps_in1).simplify()
    irreps_in2 = o3.Irreps(irreps_in2).simplify()
    ir_out = o3.Irrep(ir_out)

    for _, ir1 in irreps_in1:
        for _, ir2 in irreps_in2:
            if ir_out in ir1 * ir2:
                return True
    return False
Ejemplo n.º 5
0
def _wigner_nj(*irrepss, normalization='component', filter_ir_mid=None, dtype=None, device=None):
    irrepss = [o3.Irreps(irreps) for irreps in irrepss]
    if filter_ir_mid is not None:
        filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]

    if len(irrepss) == 1:
        irreps, = irrepss
        ret = []
        e = torch.eye(irreps.dim, dtype=dtype, device=device)
        i = 0
        for mul, ir in irreps:
            for _ in range(mul):
                sl = slice(i, i + ir.dim)
                ret += [
                    (ir, _INPUT(0, sl.start, sl.stop), e[sl])
                ]
                i += ir.dim
        return ret

    *irrepss_left, irreps_right = irrepss
    ret = []
    for ir_left, path_left, C_left in _wigner_nj(*irrepss_left, normalization=normalization, filter_ir_mid=filter_ir_mid, dtype=dtype, device=device):
        i = 0
        for mul, ir in irreps_right:
            for ir_out in ir_left * ir:
                if filter_ir_mid is not None and ir_out not in filter_ir_mid:
                    continue

                C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype, device=device)
                if normalization == 'component':
                    C *= ir_out.dim**0.5
                if normalization == 'norm':
                    C *= ir_left.dim**0.5 * ir.dim**0.5

                C = torch.einsum('jk,ijl->ikl', C_left.flatten(1), C)
                C = C.reshape(ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim)
                for u in range(mul):
                    E = torch.zeros(ir_out.dim, *(irreps.dim for irreps in irrepss_left), irreps_right.dim, dtype=dtype, device=device)
                    sl = slice(i + u * ir.dim, i + (u+1) * ir.dim)
                    E[..., sl] = C
                    ret += [
                        (
                            ir_out,
                            _TP(
                                op=(ir_left, ir, ir_out),
                                args=(path_left, _INPUT(len(irrepss_left), sl.start, sl.stop))
                            ),
                            E
                        )
                    ]
            i += mul * ir.dim

    return sorted(ret, key=lambda x: x[0])
Ejemplo n.º 6
0
def test_arithmetic():
    assert 3 * o3.Irrep("6o") == o3.Irreps("3x6o")
    products = list(o3.Irrep("1o") * o3.Irrep("2e"))
    assert products == [o3.Irrep("1o"), o3.Irrep("2o"), o3.Irrep("3o")]

    assert o3.Irrep("4o") + o3.Irrep("7e") == o3.Irreps("4o + 7e")

    assert 2 * o3.Irreps("2x2e + 4x1o") == o3.Irreps(
        "2x2e + 4x1o + 2x2e + 4x1o")
    assert o3.Irreps("2x2e + 4x1o") * 2 == o3.Irreps(
        "2x2e + 4x1o + 2x2e + 4x1o")

    assert o3.Irreps("1o + 4o") + o3.Irreps("1o + 7e") == o3.Irreps(
        "1o + 4o + 1o + 7e")
Ejemplo n.º 7
0
    def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr,
                 irreps_node_output, fc_neurons, num_neighbors) -> None:
        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_node_input,
                                              self.irreps_node_attr,
                                              self.irreps_node_output)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input,
                                                self.irreps_node_attr,
                                                self.irreps_node_input)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output or ir_out == o3.Irrep(
                            0, 1):
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, 'uvu', True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        instructions = [(i_1, i_2, p[i_out], mode, train)
                        for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_node_input,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel],
                                    torch.nn.functional.silu)
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr,
                                                self.irreps_node_output)
        self.lin3 = FullyConnectedTensorProduct(irreps_mid,
                                                self.irreps_node_attr, "0e")
Ejemplo n.º 8
0
def test_properties():
    irrep = o3.Irrep("3e")
    assert irrep.l == 3
    assert irrep.p == 1
    assert irrep.dim == 7

    assert o3.Irrep(repr(irrep)) == irrep

    l, p = o3.Irrep("5o")
    assert l == 5
    assert p == -1

    iterator = o3.Irrep.iterator(5)
    assert len(list(iterator)) == 12

    iterator = o3.Irrep.iterator()
    for x in range(100):
        irrep = next(iterator)
        assert irrep.l == x // 2
        assert irrep.p in (-1, 1)
        assert irrep.dim == 2 * (x // 2) + 1

    irreps = o3.Irreps("4x1e + 6x2e + 12x2o")
    assert o3.Irreps(repr(irreps)) == irreps
Ejemplo n.º 9
0
    def __init__(
        self,
        irreps_in1,
        irreps_in2,
        filter_ir_out=None,
        **kwargs
    ):

        irreps_in1 = o3.Irreps(irreps_in1).simplify()
        irreps_in2 = o3.Irreps(irreps_in2).simplify()
        if filter_ir_out is not None:
            filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out]

        assert irreps_in1.num_irreps == irreps_in2.num_irreps

        irreps_in1 = list(irreps_in1)
        irreps_in2 = list(irreps_in2)

        i = 0
        while i < len(irreps_in1):
            mul_1, ir_1 = irreps_in1[i]
            mul_2, ir_2 = irreps_in2[i]

            if mul_1 < mul_2:
                irreps_in2[i] = (mul_1, ir_2)
                irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2))

            if mul_2 < mul_1:
                irreps_in1[i] = (mul_2, ir_1)
                irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1))
            i += 1

        out = []
        instr = []
        for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)):
            assert mul == mul_2
            for ir in ir_1 * ir_2:

                if filter_ir_out is not None and ir not in filter_ir_out:
                    continue

                i_out = len(out)
                out.append((mul, ir))
                instr += [
                    (i, i, i_out, 'uuu', False)
                ]

        super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
Ejemplo n.º 10
0
def test_creation():
    o3.Irrep(3, 1)
    ir = o3.Irrep("3e")
    o3.Irrep(ir)
    assert o3.Irrep('10o') == o3.Irrep(10, -1)
    assert o3.Irrep("1y") == o3.Irrep("1o")

    irreps = o3.Irreps(ir)
    o3.Irreps(irreps)
    o3.Irreps([(32, (4, -1))])
    o3.Irreps("11e")
    assert o3.Irreps("16x1e + 32 x 2o") == o3.Irreps([(16, (1, 1)),
                                                      (32, (2, -1))])
    o3.Irreps(["1e", '2o'])
    o3.Irreps([(16, "3e"), '1e'])
    o3.Irreps([(16, "3e"), '1e', (256, (1, -1))])
Ejemplo n.º 11
0
    def __init__(self, irreps_in, ir):
        r"""Extract ``ir`` from irreps

        Parameters
        ----------
        irreps_in : `e3nn.o3.Irreps`
            representation of the input

        ir : `e3nn.o3.Irrep`
            representation to extract
        """
        ir = o3.Irrep(ir)
        irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(
            [mul_ir for mul_ir in irreps_in if mul_ir.ir == ir])
        instructions = [
            tuple(i for i, mul_ir in enumerate(irreps_in) if mul_ir.ir == ir)
        ]

        super().__init__(irreps_in, [self.irreps_out],
                         instructions,
                         squeeze_out=True)
Ejemplo n.º 12
0
    def __init__(
        self,
        irreps_in1: o3.Irreps,
        irreps_in2: o3.Irreps,
        filter_ir_out: Iterator[o3.Irrep] = None,
        **kwargs
    ):

        irreps_in1 = o3.Irreps(irreps_in1).simplify()
        irreps_in2 = o3.Irreps(irreps_in2).simplify()
        if filter_ir_out is not None:
            filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out]

        out = []
        instr = []
        for i_1, (mul_1, ir_1) in enumerate(irreps_in1):
            for i_2, (mul_2, ir_2) in enumerate(irreps_in2):
                for ir_out in ir_1 * ir_2:

                    if filter_ir_out is not None and ir_out not in filter_ir_out:
                        continue

                    i_out = len(out)
                    out.append((mul_1 * mul_2, ir_out))
                    instr += [
                        (i_1, i_2, i_out, 'uvuv', False)
                    ]

        out = o3.Irreps(out)
        out, p, _ = out.sort()

        instr = [
            (i_1, i_2, p[i_out], mode, train)
            for i_1, i_2, i_out, mode, train in instr
        ]

        super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
Ejemplo n.º 13
0
def test_contains():
    assert o3.Irrep("2e") in o3.Irreps("3x0e + 2x2e + 1x3o")
    assert o3.Irrep("2o") not in o3.Irreps("3x0e + 2x2e + 1x3o")
Ejemplo n.º 14
0
    def __init__(self, formula, filter_ir_out=None, filter_ir_mid=None, eps=1e-9, **irreps):
        super().__init__(self, fx.Graph())

        if filter_ir_out is not None:
            filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out]

        f0, formulas = germinate_formulas(formula)

        irreps = {i: o3.Irreps(irs) for i, irs in irreps.items()}

        for i in irreps:
            if len(i) != 1:
                raise TypeError(f"got an unexpected keyword argument '{i}'")

        for _sign, p in formulas:
            f = "".join(f0[i] for i in p)
            for i, j in zip(f0, f):
                if i in irreps and j in irreps and irreps[i] != irreps[j]:
                    raise RuntimeError(f'irreps of {i} and {j} should be the same')
                if i in irreps:
                    irreps[j] = irreps[i]
                if j in irreps:
                    irreps[i] = irreps[j]

        for i in f0:
            if i not in irreps:
                raise RuntimeError(f'index {i} has no irreps associated to it')

        for i in irreps:
            if i not in f0:
                raise RuntimeError(f'index {i} has an irreps but does not appear in the fomula')

        base_perm, _ = reduce_permutation(
            f0,
            formulas,
            dtype=torch.float64,
            **{i: irs.dim for i, irs in irreps.items()}
        )

        Ps = collections.defaultdict(list)

        for ir, path, base_o3 in _wigner_nj(*[irreps[i] for i in f0], filter_ir_mid=filter_ir_mid, dtype=torch.float64):
            if filter_ir_out is None or ir in filter_ir_out:
                P = base_o3.flatten(1) @ base_perm.flatten(1).T
                if P.norm() > eps:  # if this Irrep is present in the premutation basis we keep it
                    Ps[ir].append((path, base_o3))

        outputs = []
        change_of_basis = []
        irreps_out = []

        for ir in Ps:
            mul = len(Ps[ir])
            paths = [path for path, _ in Ps[ir]]
            base_o3 = torch.stack([R for _, R in Ps[ir]])

            R = base_o3.flatten(2)  # [multiplicity, ir, input basis] (u,j,omega)
            P = base_perm.flatten(1)  # [permutation basis, input basis] (a,omega)

            Xs = []
            for j in range(ir.dim):
                RR = R[:, j] @ R[:, j].T  # (u,u)
                PP = P @ P.T  # (a,a)
                RP = R[:, j] @ P.T  # (u,a)

                prob = torch.cat([
                    torch.cat([RR, -RP], dim=1),
                    torch.cat([-RP.T, PP], dim=1)
                ], dim=0)
                eigenvalues, eigenvectors = torch.linalg.eigh(prob)
                X = eigenvectors[:, eigenvalues < eps][:mul].T  # [solutions, multiplicity]
                X = torch.linalg.qr(X, mode='r').R
                for i, x in enumerate(X):
                    for j in range(i, mul):
                        if x[j] < eps:
                            x.neg_()
                        if x[j] > eps:
                            break

                X[X.abs() < eps] = 0
                X = sorted([[x.item() for x in line] for line in X])
                X = torch.tensor(X, dtype=torch.float64)

                Xs.append(X)

            for X in Xs:
                assert (X - Xs[0]).abs().max() < eps

            X = Xs[0]
            for x in X:
                C = torch.einsum("u,ui...->i...", x, base_o3)
                correction = (ir.dim / C.pow(2).sum())**0.5
                C = correction * C

                outputs.append([((correction * v).item(), p) for v, p in zip(x, paths) if v.abs() > eps])
                change_of_basis.append(C)
                irreps_out.append((1, ir))

        dtype, _ = explicit_default_types(None, None)
        self.register_buffer('change_of_basis', torch.cat(change_of_basis).to(dtype=dtype))

        tps = set()
        for vp_list in outputs:
            for v, p in vp_list:
                for op in _get_ops(p):
                    tps.add(op)

        tps = list(tps)
        for i, op in enumerate(tps):
            tp = o3.TensorProduct(op[0], op[1], op[2], [(0, 0, 0, 'uuu', False)])
            setattr(self, f'tp{i}', tp)

        graph = fx.Graph()
        inputs = [
            fx.Proxy(graph.placeholder(f"x{i}", torch.Tensor))
            for i in f0
        ]

        self.irreps_in = [irreps[i] for i in f0]
        self.irreps_out = o3.Irreps(irreps_out).simplify()

        values = dict()

        def evaluate(path):
            if path in values:
                return values[path]

            if isinstance(path, _INPUT):
                out = inputs[path.tensor]
                if (path.start, path.stop) != (0, self.irreps_in[path.tensor].dim):
                    out = out.narrow(-1, path.start, path.stop - path.start)
            if isinstance(path, _TP):
                x1 = evaluate(path.args[0]).node
                x2 = evaluate(path.args[1]).node
                out = fx.Proxy(graph.call_module(f'tp{tps.index(path.op)}', (x1, x2)))
            values[path] = out
            return out

        outs = []
        for vp_list in outputs:
            v, p = vp_list[0]
            out = evaluate(p)
            if abs(v - 1.0) > eps:
                out = v * out
            for v, p in vp_list[1:]:
                t = evaluate(p)
                if abs(v - 1.0) > eps:
                    t = v * t
                out = out + t
            outs.append(out)

        out = torch.cat(outs, dim=-1)
        graph.output(out.node)

        self.graph = graph
        self.recompile()