コード例 #1
0
def test_gwtp():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = GroupedWeightedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    print(m)
    m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim))
コード例 #2
0
def GroupedWeightedTensorProduct(irreps_in1,
                                 irreps_in2,
                                 irreps_out,
                                 groups=math.inf,
                                 normalization='component',
                                 internal_weights=True,
                                 shared_weights=True):
    irreps_in1 = o3.Irreps(irreps_in1)
    irreps_in2 = o3.Irreps(irreps_in2)
    irreps_out = o3.Irreps(irreps_out)
    groups = min(groups, min(mul for mul, _ in irreps_in1),
                 min(mul for mul, _ in irreps_out))

    irreps_in1 = [(mul // groups + (g < mul % groups), (l, p))
                  for mul, (l, p) in irreps_in1 for g in range(groups)]
    irreps_out = [(mul // groups + (g < mul % groups), (l, p))
                  for mul, (l, p) in irreps_out for g in range(groups)]

    in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1]
    in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2]
    out = [(mul, ir, 1.0) for mul, ir in irreps_out]

    instr = [(i_1, i_2, i_out, 'uvw', True, 1.0)
             for i_1, (_, (l_1, p_1)) in enumerate(irreps_in1)
             for i_2, (_, (l_2, p_2)) in enumerate(irreps_in2)
             for i_out, (_, (l_out, p_out)) in enumerate(irreps_out)
             if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out
             if i_1 % groups == i_out % groups]
    return WeightedTensorProduct(in1, in2, out, instr, normalization,
                                 internal_weights, shared_weights)
コード例 #3
0
def reduce_tensor(formula, eps=1e-9, has_parity=True, **kw_irreps):
    """reduce a tensor with symmetries into irreducible representations
    Usage
    irreps, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)])
    irreps = 0,2,4
    Q = tensor of shape [15, 3, 3, 3, 3]
    """
    gr = group.O3() if has_parity else group.SO3()

    kw_representations = {}

    for i in kw_irreps:
        if callable(kw_irreps[i]):
            kw_representations[i] = lambda g: kw_irreps[i](*g)
        else:
            kw_representations[i] = lambda g: o3.Irreps(kw_irreps[i]).D(*g)

    irreps, Q = group.reduce_tensor(gr, formula, eps, **kw_representations)

    if has_parity:
        irreps = o3.Irreps(irreps)
    else:
        irreps = o3.Irreps([(mul, l, 1) for mul, l in irreps])

    return irreps, Q
コード例 #4
0
    def __init__(self, irreps_in, irreps_out):
        super().__init__()

        self.irreps_in = o3.Irreps(irreps_in).simplify()
        self.irreps_out = o3.Irreps(irreps_out).simplify()

        assert self.irreps_in == self.irreps_out

        output_mask = torch.cat([
            torch.ones(mul * (2 * l + 1)) for mul, (l, _p) in self.irreps_out
        ])
        self.register_buffer('output_mask', output_mask)
コード例 #5
0
def test():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1]
    in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2]
    out = [(mul, ir, 1.0) for mul, ir in irreps_out]
    instr = [
        (1, 1, 1, 'uvw', True, 1.0),
    ]
    m = WeightedTensorProduct(in1, in2, out, instr)
    x1 = torch.randn(irreps_in1.dim)
    x2 = torch.randn(irreps_in2.dim)
    m(x1, x2)
コード例 #6
0
ファイル: qm9.py プロジェクト: yimaverickxia/e3nn_little
    def __init__(self, irreps_in, irreps_out, irreps_sh, rad_features):
        super().__init__(aggr='add')
        self.irreps_in = irreps_in.simplify()
        self.irreps_out = irreps_out.simplify()
        self.irreps_sh = irreps_sh.simplify()

        self.si = Linear(self.irreps_in, self.irreps_out)
        self.lin1 = Linear(self.irreps_in, self.irreps_in)

        instr = []
        irreps = []
        for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in):
            for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh):
                for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1):
                    p_out = p_1 * p_2
                    if (l_out, p_out) in [(l, p) for _, (l, p) in self.irreps_out]:
                        r = (mul_1, l_out, p_out)
                        if r in irreps:
                            i_out = irreps.index(r)
                        else:
                            i_out = len(irreps)
                            irreps.append(r)
                        instr += [(i_1, i_2, i_out, 'uvu', True, 1.0)]
        irreps = o3.Irreps(irreps)
        in1 = [(mul, ir, 1.0) for mul, ir in self.irreps_in]
        in2 = [(mul, ir, 1.0) for mul, ir in self.irreps_sh]
        out = [(mul, ir, 1.0) for mul, ir in irreps]
        self.tp = WeightedTensorProduct(in1, in2, out, instr, internal_weights=False, shared_weights=False)
        self.ws = torch.nn.ModuleList([
            FC((rad_features, prod(shape)), variance_out=1 / prod(shape))
            for shape in self.tp.weight_shapes
        ])
        self.lin2 = Linear(irreps, self.irreps_out)
コード例 #7
0
ファイル: gate.py プロジェクト: yimaverickxia/e3nn_little
    def __init__(self, irreps, acts):
        '''
        Can be used only with scalar fields

        :param acts: list of tuple (multiplicity, activation)
        '''
        super().__init__()

        irreps = irreps.simplify()

        n1 = sum(mul for mul, _ in irreps)
        n2 = sum(mul for mul, _ in acts if mul > 0)

        # normalize the second moment
        acts = [(mul, normalize2mom(act)) for mul, act in acts]

        for i, (mul, act) in enumerate(acts):
            if mul == -1:
                acts[i] = (n1 - n2, act)
                assert n1 - n2 >= 0

        assert n1 == sum(mul for mul, _ in acts)

        irreps = list(irreps)
        i = 0
        while i < len(irreps):
            mul_r, (l, p_r) = irreps[i]
            mul_a, act = acts[i]

            if mul_r < mul_a:
                acts[i] = (mul_r, act)
                acts.insert(i + 1, (mul_a - mul_r, act))

            if mul_a < mul_r:
                irreps[i] = (mul_a, (l, p_r))
                irreps.insert(i + 1, (mul_r - mul_a, (l, p_r)))
            i += 1

        x = torch.linspace(0, 10, 256)

        irreps_out = []
        for (mul, (l, p_in)), (mul_a, act) in zip(irreps, acts):
            assert mul == mul_a

            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = 1
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = -1
            else:
                p_act = 0

            p = p_act if p_in == -1 else p_in
            irreps_out.append((mul, (0, p)))

            if p_in != 0 and p == 0:
                raise ValueError("warning! the parity is violated")

        self.irreps_out = o3.Irreps(irreps_out).simplify()
        self.acts = acts
コード例 #8
0
def ElementwiseTensorProduct(irreps_in1,
                             irreps_in2,
                             normalization='component'):
    irreps_in1 = o3.Irreps(irreps_in1).simplify()
    irreps_in2 = o3.Irreps(irreps_in2).simplify()

    assert irreps_in1.num_irreps == irreps_in2.num_irreps

    irreps_in1 = list(irreps_in1)
    irreps_in2 = list(irreps_in2)

    i = 0
    while i < len(irreps_in1):
        mul_1, (l_1, p_1) = irreps_in1[i]
        mul_2, (l_2, p_2) = irreps_in2[i]

        if mul_1 < mul_2:
            irreps_in2[i] = (mul_1, (l_2, p_2))
            irreps_in2.insert(i + 1, (mul_2 - mul_1, (l_2, p_2)))

        if mul_2 < mul_1:
            irreps_in1[i] = (mul_2, (l_1, p_1))
            irreps_in1.insert(i + 1, (mul_1 - mul_2, (l_1, p_1)))
        i += 1

    irreps_out = []
    instr = []
    for i, ((mul, (l_1, p_1)),
            (mul_2, (l_2, p_2))) in enumerate(zip(irreps_in1, irreps_in2)):
        assert mul == mul_2
        for l in list(range(abs(l_1 - l_2), l_1 + l_2 + 1)):
            i_out = len(irreps_out)
            irreps_out.append((mul, (l, p_1 * p_2)))
            instr += [(i, i, i_out, 'uuu', False, 1.0)]

    in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1]
    in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2]
    out = [(mul, ir, 1.0) for mul, ir in irreps_out]

    return WeightedTensorProduct(in1,
                                 in2,
                                 out,
                                 instr,
                                 normalization,
                                 internal_weights=False)
コード例 #9
0
ファイル: qm9.py プロジェクト: yimaverickxia/e3nn_little
    def __init__(self, muls=(128, 12, 0), lmax=1,
                 num_layers=3, cutoff=10.0, rad_gaussians=50,
                 rad_hs=(512, 512), num_neighbors=20,
                 readout='add', mean=None, std=None, scale=None,
                 atomref=None):
        super().__init__()

        assert readout in ['add', 'sum', 'mean']
        self.readout = readout
        self.cutoff = cutoff
        self.mean = mean
        self.std = std
        self.scale = scale
        self.num_neighbors = num_neighbors

        self.embedding = Embedding(100, muls[0])
        self.embedding.weight.requires_grad = False
        self.irreps_in = o3.Irreps([(muls[0], 0, 1)])

        self.radial = torch.nn.Sequential(
            GaussianBasis(rad_gaussians, cutoff),
            FC((rad_gaussians, ) + rad_hs, swish, variance_in=1 / rad_gaussians, out_act=True)
        )
        self.irreps_sh = o3.Irreps([(1, l, (-1)**l) for l in range(lmax + 1)])  # spherical harmonics representation

        irreps = self.irreps_in
        modules = []
        for _ in range(num_layers):
            act = make_gated_block(irreps, muls, self.irreps_sh)
            conv = Conv(irreps, act.irreps_in, self.irreps_sh, rad_hs[-1])
            irreps = act.irreps_out.simplify()

            modules += [torch.nn.ModuleList([conv, act])]

        self.layers = torch.nn.ModuleList(modules)

        self.irreps_out = o3.Irreps("0e + 0o")
        self.layers.append(Conv(irreps, self.irreps_out, self.irreps_sh, rad_hs[-1]))

        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
            self.atomref.weight.requires_grad = False
コード例 #10
0
def test_creation():
    o3.Irrep(3, 1)
    ir = o3.Irrep("3e")
    o3.Irrep(ir)
    assert o3.Irrep('10o') == o3.Irrep(10, -1)
    assert o3.Irrep("1y") == o3.Irrep("1o")

    irreps = o3.Irreps(ir)
    o3.Irreps(irreps)
    o3.Irreps([(32, (4, -1))])
    o3.Irreps("11e")
    assert o3.Irreps("16x1e + 32 x 2o") == o3.Irreps([(16, 1, 1), (32, 2, -1)])
    o3.Irreps(["1e", '2o'])
    o3.Irreps([(16, "3e"), '1e'])
    o3.Irreps([(16, "3e"), '1e', (256, 1, -1)])
コード例 #11
0
ファイル: gate.py プロジェクト: yimaverickxia/e3nn_little
    def __init__(self, *irreps_outs):
        super().__init__()
        self.irreps_outs = tuple(irreps.simplify() for irreps in irreps_outs)

        def key(mul_ir):
            _mul, (l, p) = mul_ir
            return (l, p)

        self.irreps_in = o3.Irreps(
            sorted((x for irreps in self.irreps_outs for x in irreps),
                   key=key)).simplify()
コード例 #12
0
def FullyConnectedWeightedTensorProduct(irreps_in1,
                                        irreps_in2,
                                        irreps_out,
                                        normalization='component',
                                        internal_weights=True,
                                        shared_weights=True):
    irreps_in1 = o3.Irreps(irreps_in1).simplify()
    irreps_in2 = o3.Irreps(irreps_in2).simplify()
    irreps_out = o3.Irreps(irreps_out).simplify()

    in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1]
    in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2]
    out = [(mul, ir, 1.0) for mul, ir in irreps_out]

    instr = [(i_1, i_2, i_out, 'uvw', True, 1.0)
             for i_1, (_, (l_1, p_1)) in enumerate(irreps_in1)
             for i_2, (_, (l_2, p_2)) in enumerate(irreps_in2)
             for i_out, (_, (l_out, p_out)) in enumerate(irreps_out)
             if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out]
    return WeightedTensorProduct(in1, in2, out, instr, normalization,
                                 internal_weights, shared_weights)
コード例 #13
0
def test_reduce_tensor_equivariance():
    torch.set_default_dtype(torch.float64)

    ir = o3.Irreps('1e')
    irreps, Q = o3.reduce_tensor('ijkl=jikl=klij', i=ir)

    abc = o3.rand_angles()
    R = ir.D(*abc)
    D = irreps.D(*abc)

    q1 = torch.einsum('qmnop,mi,nj,ok,pl->qijkl', Q, R, R, R, R)
    q2 = torch.einsum('qa,aijkl->qijkl', D, Q)

    assert (q1 - q2).abs().max() < 1e-10
コード例 #14
0
ファイル: qm9.py プロジェクト: yimaverickxia/e3nn_little
def make_gated_block(irreps_in, muls, irreps_sh):
    """
    Make a Gate assuming many things
    """
    irreps_available = [
        (l, p_in * p_sh)
        for _, (l_in, p_in) in irreps_in.simplify()
        for _, (l_sh, p_sh) in irreps_sh
        for l in range(abs(l_in - l_sh), l_in + l_sh + 1)
    ]

    scalars = o3.Irreps([(muls[0], 0, p) for p in (1, -1) if (0, p) in irreps_available])
    act_scalars = [(mul, swish if p == 1 else torch.tanh) for mul, (_, p) in scalars]

    nonscalars = o3.Irreps([(muls[l], l, p*(-1)**l) for l in range(1, len(muls)) for p in (1, -1) if (l, p*(-1)**l) in irreps_available])
    if (0, +1) in irreps_available:
        gates = o3.Irreps([(nonscalars.num_irreps, 0, +1)])
        act_gates = [(-1, torch.sigmoid)]
    else:
        gates = o3.Irreps([(nonscalars.num_irreps, 0, -1)])
        act_gates = [(-1, torch.tanh)]

    return Gate(scalars, act_scalars, gates, act_gates, nonscalars)
コード例 #15
0
    def __init__(self) -> None:
        super().__init__()
        self.sh = [0, 1, 2, 3]
        irreps_sh = o3.Irreps([(1, l, (-1)**l) for l in self.sh])
        irreps_mid = "64x0e + 24x1e + 24x1o + 16x2e + 16x2o"
        irreps_out = "0o + 6x0e"

        self.tp1 = FullyConnectedWeightedTensorProduct(
            irreps_in1=irreps_sh,
            irreps_in2=irreps_sh,
            irreps_out=irreps_mid,
        )
        self.tp2 = FullyConnectedWeightedTensorProduct(
            irreps_in1=irreps_mid,
            irreps_in2=irreps_sh,
            irreps_out=irreps_out,
        )
コード例 #16
0
def test_sh_equivariance2():
    """test
    - rot
    - rep
    - spherical_harmonics
    """
    torch.set_default_dtype(torch.float64)
    irreps = o3.Irreps("0e + 1o + 2e + 3o + 4e")

    abc = o3.rand_angles()
    R = o3.rot(*abc)
    D = irreps.D(*abc)

    x = torch.randn(10, 3)

    y1 = o3.spherical_harmonics(irreps, x @ R.T)
    y2 = o3.spherical_harmonics(irreps, x) @ D.T

    assert (y1 - y2).abs().max() < 1e-10
コード例 #17
0
def test_id():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = Identity(irreps_in, irreps_out)
    print(m)
    m(torch.randn(irreps_in.dim))
コード例 #18
0
def test_slice():
    irreps = o3.Irreps("16x1e + 3e + 2e + 5o")
    assert isinstance(irreps[2:], o3.Irreps)
コード例 #19
0
def test_cat():
    irreps = o3.Irreps("4x1e + 6x2e + 12x2o") + o3.Irreps(
        "1x1e + 2x2e + 12x2o")
    assert len(irreps) == 6
    assert irreps.num_irreps == 4 + 6 + 12 + 1 + 2 + 12
コード例 #20
0
def test_fail1():
    o3.Irreps([(32, 1)])
コード例 #21
0
    def __init__(
        self,
        in1: List[Tuple[int, Any, float]],
        in2: List[Tuple[int, Any, float]],
        out: List[Tuple[int, Any, float]],
        instr: List[Tuple[int, int, int, str, bool, float]],
        normalization: str = 'component',
        internal_weights: bool = True,
        shared_weights: bool = True,
        _specialized_code=True,
    ):
        """Tensor Product with parametrizable paths

        Parameters
        ----------
        in1
            List of inputs (multiplicity, (l, p), variance).
        in2
            List of inputs (multiplicity, (l, p), variance).
        out
            List of outputs (multiplicity, (l, p), variance).
        instr
            List of instructions (i_1, i_2, i_out, mode, train, path_weight)
            it means: Put `in1[i_1]` otimes `in2[i_2]` into `out[i_out]`
            - mode: determines the way the multiplicities are treated, "uvw" is fully connected
            - train: is this path trained?
            - path weight: how much this path should contribute to the output
        """

        super().__init__()

        assert normalization in ['component', 'norm'], normalization
        self.irreps_in1 = o3.Irreps([(mul, ir) for mul, ir, _var in in1])
        self.irreps_in2 = o3.Irreps([(mul, ir) for mul, ir, _var in in2])
        self.irreps_out = o3.Irreps([(mul, ir) for mul, ir, _var in out])

        in1_var = [var for _, _, var in in1]
        in2_var = [var for _, _, var in in2]
        out_var = [var for _, _, var in out]

        self.shared_weights = shared_weights
        z = '' if self.shared_weights else 'z'

        code = f"""
from typing import List

import torch

@torch.jit.script
def main(x1: torch.Tensor, x2: torch.Tensor, ws: List[torch.Tensor], w3j: List[torch.Tensor]) -> torch.Tensor:
    batch = x1.shape[0]
    out = x1.new_zeros((batch, {self.irreps_out.dim}))
    ein = torch.einsum
"""

        wshapes = []
        wigners = []

        for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in1):
            index_1 = self.irreps_in1[:i_1].dim
            dim_1 = mul_1 * (2 * l_1 + 1)
            code += f"    x1_{i_1} = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n"
        code += f"\n"

        for i_2, (mul_2, (l_2, p_2)) in enumerate(self.irreps_in2):
            index_2 = self.irreps_in2[:i_2].dim
            dim_2 = mul_2 * (2 * l_2 + 1)
            code += f"    x2_{i_2} = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n"
        code += f"\n"

        last_ss = None

        for i_1, i_2, i_out, mode, weight, path_weight in instr:
            mul_1, (l_1, p_1) = self.irreps_in1[i_1]
            mul_2, (l_2, p_2) = self.irreps_in2[i_2]
            mul_out, (l_out, p_out) = self.irreps_out[i_out]
            dim_1 = mul_1 * (2 * l_1 + 1)
            dim_2 = mul_2 * (2 * l_2 + 1)
            dim_out = mul_out * (2 * l_out + 1)
            index_1 = self.irreps_in1[:i_1].dim
            index_2 = self.irreps_in2[:i_2].dim
            index_out = self.irreps_out[:i_out].dim

            assert p_1 * p_2 == p_out
            assert abs(l_1 - l_2) <= l_out <= l_1 + l_2

            if dim_1 == 0 or dim_2 == 0 or dim_out == 0:
                continue

            alpha = out_var[i_out] / sum(
                path_weight_ * in1_var[i_1_] * in2_var[i_2_]
                for i_1_, i_2_, i_out_, _, _, path_weight_ in instr
                if i_out_ == i_out)

            code += (
                f"    with torch.autograd.profiler.record_function("
                f"'{self.irreps_in1[i_1:i_1+1]} x {self.irreps_in2[i_2:i_2+1]} "
                f"= {self.irreps_out[i_out:i_out+1]} {mode} {weight}'):\n")
            code += f"        s1 = x1_{i_1}\n"
            code += f"        s2 = x2_{i_2}\n"

            assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']

            c = sqrt(
                alpha * path_weight / {
                    'uvw': (mul_1 * mul_2),
                    'uvu': mul_2,
                    'uvv': mul_1,
                    'uuw': mul_1,
                    'uuu': 1,
                    'uvuv': 1,
                }[mode])

            index_w = len(wshapes)
            if weight:
                wshapes.append({
                    'uvw': (mul_1, mul_2, mul_out),
                    'uvu': (mul_1, mul_2),
                    'uvv': (mul_1, mul_2),
                    'uuw': (mul_1, mul_out),
                    'uuu': (mul_1, ),
                    'uvuv': (mul_1, mul_2),
                }[mode])

            if _specialized_code:
                # optimized code for special cases:
                # 0 x 0 = 0
                # 0 x L = L
                # L x 0 = L
                # L x L = 0
                # 1 x 1 = 1

                if (l_1, l_2, l_out) == (0, 0, 0) and mode in [
                        'uvw', 'uvu'
                ] and normalization in ['component', 'norm'] and weight:
                    code += f"        s1 = s1.reshape(batch, {mul_1})\n"
                    code += f"        s2 = s2.reshape(batch, {mul_2})\n"

                    if mode == 'uvw':
                        code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zu,zv->zw', ws[{index_w}], s1, s2)\n\n"
                    if mode == 'uvu':
                        code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zu,zv->zu', ws[{index_w}], s1, s2)\n\n"

                    continue

                if l_1 == 0 and l_2 == l_out and mode in [
                        'uvw', 'uvu'
                ] and normalization == 'component' and weight:
                    code += f"        s1 = s1.reshape(batch, {mul_1})\n"

                    if mode == 'uvw':
                        code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zu,zvi->zwi', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n"
                    if mode == 'uvu':
                        code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zu,zvi->zui', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n"

                    continue

                if l_2 == 0 and l_1 == l_out and mode in [
                        'uvw', 'uvu'
                ] and normalization == 'component' and weight:
                    code += f"        s2 = s2.reshape(batch, {mul_2})\n"

                    if mode == 'uvw':
                        code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zui,zv->zwi', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n"
                    if mode == 'uvu':
                        code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zui,zv->zui', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n"

                    continue

                if l_1 == l_2 and l_out == 0 and mode == 'uvw' and normalization == 'component' and weight:
                    # Cl_l_0 = eye(3) / sqrt(2L+1)
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zui,zvi->zw', ws[{index_w}] / {sqrt(2 * l_1 + 1)}, s1, s2).reshape(batch, {dim_out})\n\n"
                    continue

                if l_1 == l_2 and l_out == 0 and mode == 'uvu' and normalization == 'component' and weight:
                    # Cl_l_0 = eye(3) / sqrt(2L+1)
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zui,zvi->zu', ws[{index_w}] / {sqrt(2 * l_1 + 1)}, s1, s2).reshape(batch, {dim_out})\n\n"
                    continue

                if (l_1, l_2, l_out) == (
                        1, 1, 1
                ) and mode == 'uvw' and normalization == 'component' and weight:
                    # C1_1_1 = levi-civita / sqrt(2)
                    code += f"        s1 = s1.reshape(batch, {mul_1}, 1, {2 * l_1 + 1})\n"
                    code += f"        s2 = s2.reshape(batch, 1, {mul_2}, {2 * l_2 + 1})\n"
                    code += f"        s1, s2 = torch.broadcast_tensors(s1, s2)\n"
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zuvi->zwi', ws[{index_w}] / {sqrt(2)}, torch.cross(s1, s2, dim=3)).reshape(batch, {dim_out})\n\n"
                    continue

                if (l_1, l_2, l_out) == (
                        1, 1, 1
                ) and mode == 'uvu' and normalization == 'component' and weight:
                    # C1_1_1 = levi-civita / sqrt(2)
                    code += f"        s1 = s1.reshape(batch, {mul_1}, 1, {2 * l_1 + 1})\n"
                    code += f"        s2 = s2.reshape(batch, 1, {mul_2}, {2 * l_2 + 1})\n"
                    code += f"        s1, s2 = torch.broadcast_tensors(s1, s2)\n"
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zuvi->zui', ws[{index_w}] / {sqrt(2)}, torch.cross(s1, s2, dim=3)).reshape(batch, {dim_out})\n\n"
                    continue

            if last_ss != (i_1, i_2, mode[:2]):
                if mode[:2] == 'uv':
                    code += f"        ss = ein('zui,zvj->zuvij', s1, s2)\n"
                if mode[:2] == 'uu':
                    code += f"        ss = ein('zui,zuj->zuij', s1, s2)\n"
                last_ss = (i_1, i_2, mode[:2])

            if (l_1, l_2, l_out) in wigners:
                index_w3j = wigners.index((l_1, l_2, l_out))
            else:
                index_w3j = len(wigners)
                wigners += [(l_1, l_2, l_out)]

            if mode == 'uvw':
                assert weight
                code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,ijk,zuvij->zwk', ws[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
            if mode == 'uvu':
                assert mul_1 == mul_out
                if weight:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,ijk,zuvij->zuk', ws[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
                else:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuvij->zuk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
            if mode == 'uvv':
                assert mul_2 == mul_out
                if weight:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,ijk,zuvij->zvk', ws[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
                else:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuvij->zvk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
            if mode == 'uuw':
                assert mul_1 == mul_2
                assert weight
                code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uw,ijk,zuij->zwk', sw[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
            if mode == 'uuu':
                assert mul_1 == mul_2 == mul_out
                if weight:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}u,ijk,zuij->zuk', sw[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
                else:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuij->zuk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
            if mode == 'uvuv':
                assert mul_1 * mul_2 == mul_out
                if weight:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,ijk,zuvij->zuvk', sw[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
                else:
                    code += f"        out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuvij->zuvk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n"
            code += "\n"

        code += f"    return out"

        self.code = code
        self.main = eval_code(self.code).main

        # w3j
        self.wigners = wigners
        for i, (l_1, l_2, l_out) in enumerate(self.wigners):
            wig = o3.wigner_3j(l_1, l_2, l_out)

            if normalization == 'component':
                wig *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(f"C{i}", wig)

        # weights
        self.weight_shapes = wshapes
        self.weight_numel = sum(
            math.prod(shape) for shape in self.weight_shapes)
        self.weight_infos = [
            (i_1, i_2, i_out, mode, path_weight, shape)
            for (i_1, i_2, i_out, mode, path_weight), shape in zip(
                [(i_1, i_2, i_out, mode, path_weight)
                 for i_1, i_2, i_out, mode, weight, path_weight in instr
                 if weight], wshapes)
        ]

        if internal_weights:
            assert self.shared_weights, "Having internal weights impose shared weights"
            self.weight = torch.nn.ParameterDict()
            for i, (i_1, i_2, i_out, mode, path_weight,
                    shape) in enumerate(self.weight_infos):
                mul_1, (l_1, p_1) = self.irreps_in1[i_1]
                mul_2, (l_2, p_2) = self.irreps_in2[i_2]
                mul_out, (l_out, p_out) = self.irreps_out[i_out]
                self.weight[
                    f'{i} l1={l_1} l2={l_2} lout={l_out}'] = torch.nn.Parameter(
                        torch.randn(shape))

        self.to(dtype=torch.get_default_dtype())