Beispiel #1
0
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, selection_rule,
                           set_of_l_filters):
    """
    :param F: tensor [batch, b, l_in * mul_in * m_in]
    :param Y: tensor [l_filter * m_filter, batch, a, b]
    :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in]
    :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    batch, a, b, _ = Y.shape
    n_out = rs.dim(Rs_out)

    kernel_conv = Y.new_zeros(batch, a, n_out)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = selection_rule(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[:, :, :, begin_R:begin_R + n].reshape(
                batch, a, b, mul_out, mul_in,
                -1)  # [batch, a, b, mul_out, mul_in, l_filter]
            begin_R += n

            K = 0
            for k, l_filter in enumerate(l_filters):
                offset = sum(2 * l + 1 for l in set_of_l_filters
                             if l < l_filter)
                sub_Y = Y[...,
                          offset:offset + 2 * l_filter + 1]  # [batch, a, b, m]

                C = o3.wigner_3j(l_out,
                                 l_in,
                                 l_filter,
                                 cached=True,
                                 like=kernel_conv)  # [m_out, m_in, m]

                K += norm_coef[i, j] * torch.einsum(
                    "ijk,zabk,zabuv,zbvj->zaui", C,
                    sub_Y, sub_R[..., k], F[..., s_in].reshape(
                        batch, b, mul_in, -1))  # [batch, a, mul_out, m_out]

            if not isinstance(K, int):
                kernel_conv[:, :, s_out] += K.reshape(batch, a, -1)

    return kernel_conv
Beispiel #2
0
def kernel_fn_forward(Y, R, norm_coef, Rs_in, Rs_out, selection_rule,
                      set_of_l_filters):
    """
    :param Y: tensor [batch, l_filter * m_filter]
    :param R: tensor [batch, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in]
    :return: tensor [batch, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    batch = Y.shape[0]
    n_in = rs.dim(Rs_in)
    n_out = rs.dim(Rs_out)

    kernel = Y.new_zeros(batch, n_out, n_in)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = selection_rule(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[:, begin_R:begin_R + n].reshape(
                batch, mul_out, mul_in,
                len(l_filters))  # [batch, mul_out, mul_in, l_filter]
            begin_R += n

            # note: I don't know if we can vectorize this for loop because [l_filter * m_filter] cannot be put into [l_filter, m_filter]
            K = 0
            for k, l_filter in enumerate(l_filters):
                tmp = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter)
                sub_Y = Y[:, tmp:tmp + 2 * l_filter + 1]  # [batch, m]

                C = o3.wigner_3j(l_out,
                                 l_in,
                                 l_filter,
                                 cached=True,
                                 like=kernel)  # [m_out, m_in, m]

                # note: The multiplication with `sub_R` could also be done outside of the for loop
                K += norm_coef[i, j] * torch.einsum(
                    "ijk,zk,zuv->zuivj",
                    (C, sub_Y,
                     sub_R[..., k]))  # [batch, mul_out, m_out, mul_in, m_in]

            if not isinstance(K, int):
                kernel[:, s_out, s_in] = K.reshape_as(kernel[:, s_out, s_in])
    return kernel
def _generate_spherical_harmonics(lmax, device=None):  # pragma: no cover
    r"""code used to generate the code above

    based on `wigner_3j`
    """
    torch.set_default_dtype(torch.float64)

    print("sh_0_0 = torch.ones(x.shape, dtype=x.dtype, device=x.device)")
    print("if lmax == 0:")
    print("    return sh_0_0")
    print()

    y, z, x = sympy.symbols('y z x')
    polynomials = [y, z, x]
    polynormz = [0, 1, 0]

    for l in range(1, lmax+1):
        names = sympy.symbols(" ".join(f'sh_{l}_{m}' for m in range(2 * l + 1)))

        for n, p in zip(names, polynomials):
            p = sympy.simplify(p)
            p = sympy.N(p, n=20)
            print(f"{n} = {pycode(p)}")

        print(f"if lmax == {l}:")
        u = ",\n        ".join(", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1))
        print(f"    return torch.stack([\n        {u}\n    ], dim=-1)")
        print()

        if l == lmax:
            break

        polynomials = [
            sum(
                c.item() * v * p
                for cj, v in zip(cij, [y, z, x])
                for c, p in zip(cj, names)
            )
            for cij in o3.wigner_3j(l+1, 1, l, device=device)
        ]

        def sub(p, names, polynormz):
            p = p.subs(x, 0).subs(y, 0).subs(z, 1)
            for n, c in zip(names, polynormz):
                p = p.subs(n, c)
            return p

        polynormz = [
            sub(p, names, polynormz)
            for p in polynomials
        ]
        norm = sum(p ** 2 for p in polynormz) ** 0.5
        polynomials = [p / norm for p in polynomials]
        polynormz = [p / norm for p in polynormz]

        polynomials = [
            sympy.nsimplify(p, full=True)
            for p in polynomials
        ]
Beispiel #4
0
def test_wigner_3j_sh_norm():
    with o3.torch_default_dtype(torch.float64):
        for l_out in range(3 + 1):
            for l_in in range(l_out, 4 + 1):
                for l_f in range(abs(l_out - l_in), l_out + l_in + 1):
                    Q = o3.wigner_3j(l_out, l_in, l_f)
                    Y = rsh.spherical_harmonics_xyz([l_f], torch.randn(3))
                    QY = math.sqrt(4 * math.pi) * Q @ Y
                    assert abs(QY.norm() - 1) < 1e-10
Beispiel #5
0
 def test_wigner_3j_orthogonal(self):
     with o3.torch_default_dtype(torch.float64):
         for l_out in range(3 + 1):
             for l_in in range(l_out, 4 + 1):
                 for l_f in range(abs(l_out - l_in), l_out + l_in + 1):
                     Q = o3.wigner_3j(l_f, l_in, l_out).reshape(2 * l_f + 1, -1)
                     e = (2 * l_f + 1) * Q @ Q.t()
                     d = e - torch.eye(2 * l_f + 1)
                     self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
Beispiel #6
0
def test_wigner_3j(float_tolerance):
    abc = o3.rand_angles(10)

    l1, l2, l3 = 1, 2, 3
    C = o3.wigner_3j(l1, l2, l3)
    D1 = o3.Irrep(l1, 1).D_from_angles(*abc)
    D2 = o3.Irrep(l2, 1).D_from_angles(*abc)
    D3 = o3.Irrep(l3, 1).D_from_angles(*abc)

    C2 = torch.einsum("ijk,zil,zjm,zkn->zlmn", C, D1, D2, D3)
    assert (C - C2).abs().max() < float_tolerance
Beispiel #7
0
def _wigner_nj(*irrepss, normalization='component', filter_ir_mid=None, dtype=None, device=None):
    irrepss = [o3.Irreps(irreps) for irreps in irrepss]
    if filter_ir_mid is not None:
        filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]

    if len(irrepss) == 1:
        irreps, = irrepss
        ret = []
        e = torch.eye(irreps.dim, dtype=dtype, device=device)
        i = 0
        for mul, ir in irreps:
            for _ in range(mul):
                sl = slice(i, i + ir.dim)
                ret += [
                    (ir, _INPUT(0, sl.start, sl.stop), e[sl])
                ]
                i += ir.dim
        return ret

    *irrepss_left, irreps_right = irrepss
    ret = []
    for ir_left, path_left, C_left in _wigner_nj(*irrepss_left, normalization=normalization, filter_ir_mid=filter_ir_mid, dtype=dtype, device=device):
        i = 0
        for mul, ir in irreps_right:
            for ir_out in ir_left * ir:
                if filter_ir_mid is not None and ir_out not in filter_ir_mid:
                    continue

                C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype, device=device)
                if normalization == 'component':
                    C *= ir_out.dim**0.5
                if normalization == 'norm':
                    C *= ir_left.dim**0.5 * ir.dim**0.5

                C = torch.einsum('jk,ijl->ikl', C_left.flatten(1), C)
                C = C.reshape(ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim)
                for u in range(mul):
                    E = torch.zeros(ir_out.dim, *(irreps.dim for irreps in irrepss_left), irreps_right.dim, dtype=dtype, device=device)
                    sl = slice(i + u * ir.dim, i + (u+1) * ir.dim)
                    E[..., sl] = C
                    ret += [
                        (
                            ir_out,
                            _TP(
                                op=(ir_left, ir, ir_out),
                                args=(path_left, _INPUT(len(irrepss_left), sl.start, sl.stop))
                            ),
                            E
                        )
                    ]
            i += mul * ir.dim

    return sorted(ret, key=lambda x: x[0])
Beispiel #8
0
    def backward(ctx, grad_kernel):  # pragma: no cover
        Y, R, norm_coef = ctx.saved_tensors

        grad_Y = grad_R = None

        if ctx.needs_input_grad[0]:
            grad_Y = grad_kernel.new_zeros(*ctx.Y_shape)  # [batch, l_filter * m_filter]
        if ctx.needs_input_grad[1]:
            grad_R = grad_kernel.new_zeros(*ctx.R_shape)  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        begin_R = 0

        begin_out = 0
        for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out):
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            begin_in = 0
            for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in):
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                l_filters = ctx.selection_rule(l_in, p_in, l_out, p_out)
                if not l_filters:
                    continue

                n = mul_out * mul_in * len(l_filters)
                if grad_Y is not None:
                    sub_R = R[:, begin_R: begin_R + n].reshape(
                        -1, mul_out, mul_in, len(l_filters)
                    )  # [batch, mul_out, mul_in, l_filter]
                if grad_R is not None:
                    sub_grad_R = grad_R[:, begin_R: begin_R + n].reshape(
                        -1, mul_out, mul_in, len(l_filters)
                    )  # [batch, mul_out, mul_in, l_filter]
                begin_R += n

                grad_K = grad_kernel[:, s_out, s_in].reshape(-1, mul_out, 2 * l_out + 1, mul_in, 2 * l_in + 1)

                for k, l_filter in enumerate(l_filters):
                    tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters if l < l_filter)
                    C = o3.wigner_3j(l_out, l_in, l_filter, cached=True, like=grad_kernel)  # [m_out, m_in, m]

                    if grad_Y is not None:
                        grad_Y[:, tmp: tmp + 2 * l_filter + 1] += norm_coef[i, j] * torch.einsum("zuivj,ijk,zuv->zk", grad_K, C, sub_R[..., k])
                    if grad_R is not None:
                        sub_Y = Y[:, tmp: tmp + 2 * l_filter + 1]  # [batch, m]
                        sub_grad_R[..., k] = norm_coef[i, j] * torch.einsum("zuivj,ijk,zk->zuv", grad_K, C, sub_Y)

        del ctx
        return grad_Y, grad_R, None, None, None, None, None
def test_recurrence_relation(float_tolerance, l):
    if torch.get_default_dtype() != torch.float64 and l > 6:
        pytest.xfail('we expect this to fail for high l and single precision')

    x = torch.randn(3, requires_grad=True)

    a = o3.spherical_harmonics(l + 1, x, False)

    b = torch.einsum('ijk,j,k->i', o3.wigner_3j(l + 1, l, 1),
                     o3.spherical_harmonics(l, x, False), x)

    alpha = b.norm() / a.norm()

    assert (a / a.norm() - b / b.norm()).abs().max() < 10 * float_tolerance

    def f(x):
        return o3.spherical_harmonics(l + 1, x, False)

    a = torch.autograd.functional.jacobian(f, x)

    b = (l + 1) / alpha * torch.einsum('ijk,j->ik', o3.wigner_3j(l + 1, l, 1),
                                       o3.spherical_harmonics(l, x, False))

    assert (a - b).abs().max() < 100 * float_tolerance
Beispiel #10
0
    def test_irr_repr_wigner_3j(self):
        """Test irr_repr and wigner_3j equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.wigner_3j(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = rsh.spherical_harmonics_xyz([l_f], r)
                W = torch.einsum("ijk,zk->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
Beispiel #11
0
def elementwise_tensor_product(
        Rs_in1: TY_RS_LOOSE,
        Rs_in2: TY_RS_LOOSE,
        selection_rule: o3.TY_SELECTION_RULE = o3.selection_rule,
        normalization: str = 'component') -> Tuple[TY_RS_STRICT, SparseTensor]:
    """
    :return: Rs_out, matrix

    m_kij A_i B_j
    """
    assert normalization in [
        'norm', 'component'
    ], "normalization needs to be 'norm' or 'component'"

    Rs_in1 = simplify(Rs_in1)
    Rs_in2 = simplify(Rs_in2)

    assert sum(mul for mul, _, _ in Rs_in1) == sum(mul for mul, _, _ in Rs_in2)

    i = 0
    while i < len(Rs_in1):
        mul_1, l_1, p_1 = Rs_in1[i]
        mul_2, l_2, p_2 = Rs_in2[i]

        if mul_1 < mul_2:
            Rs_in2[i] = (mul_1, l_2, p_2)
            Rs_in2.insert(i + 1, (mul_2 - mul_1, l_2, p_2))

        if mul_2 < mul_1:
            Rs_in1[i] = (mul_2, l_1, p_1)
            Rs_in1.insert(i + 1, (mul_1 - mul_2, l_1, p_1))
        i += 1

    Rs_out = []
    for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_in1, Rs_in2):
        assert mul == mul_2
        for l in selection_rule(l_1, p_1, l_2, p_2):
            Rs_out.append((mul, l, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    dim_in2 = dim(Rs_in2)
    row = []
    col = []
    val = []

    index_out = 0
    index_1 = 0
    index_2 = 0
    for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_in1, Rs_in2):
        assert mul == mul_2
        dim_1 = mul * (2 * l_1 + 1)
        dim_2 = mul * (2 * l_2 + 1)

        for l_out in selection_rule(l_1, p_1, l_2, p_2):
            dim_out = mul * (2 * l_out + 1)
            C = o3.wigner_3j(l_out, l_1, l_2, cached=True)
            if normalization == 'component':
                C *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5
            I = torch.einsum("uv,wu->wuv", torch.eye(mul), torch.eye(mul))
            m = torch.einsum("wuv,kij->wkuivj", I,
                             C).reshape(dim_out, dim_1, dim_2)
            i_out, i_1, i_2 = m.nonzero(as_tuple=False).T
            i_out += index_out
            i_1 += index_1
            i_2 += index_2
            row.append(i_out)
            col.append(i_1 * dim_in2 + i_2)
            val.append(m[m != 0])

            index_out += dim_out

        index_1 += dim_1
        index_2 += dim_2

    wigner_3j_tensor = SparseTensor(
        row=torch.cat(row) if row else torch.zeros(0, dtype=torch.long),
        col=torch.cat(col) if col else torch.zeros(0, dtype=torch.long),
        value=torch.cat(val) if val else torch.zeros(0),
        sparse_sizes=(dim(Rs_out), dim(Rs_in1) * dim(Rs_in2)))

    return Rs_out, wigner_3j_tensor
Beispiel #12
0
def _tensor_product_in_out(Rs_in1, selection_rule, Rs_out, normalization,
                           sorted):
    """
    Compute the matrix Q
    from Rs_out to Rs_in1 tensor product with Rs_in2
    where Rs_in2 is a direct sum of irreducible representations

    For normalization='component',
    The set of "lines" { Q[i] }_i is orthonormal

    :return: Rs_in2, Q

    example:
    _, Q = tensor_product_in_out(Rs_in1, Rs_out)
    torch.einsum('kij,i,j->k', Q, A, B)
    """
    assert normalization in [
        'norm', 'component'
    ], "normalization needs to be 'norm' or 'component'"

    Rs_in1 = simplify(Rs_in1)
    Rs_out = simplify(Rs_out)

    Rs_in2 = []

    for mul_out, l_out, p_out in Rs_out:
        for mul_1, l_1, p_1 in Rs_in1:
            for l_2 in selection_rule(l_1, p_1, l_out, p_out):
                Rs_in2.append((mul_1 * mul_out, l_2, p_1 * p_out))

    Rs_in2 = simplify(Rs_in2)

    dim_in2 = dim(Rs_in2)
    row = []
    col = []
    val = []

    index_2 = 0

    index_out = 0
    for mul_out, l_out, p_out in Rs_out:
        dim_out = mul_out * (2 * l_out + 1)

        n_path = 0
        for mul_1, l_1, p_1 in Rs_in1:
            for l_2 in selection_rule(l_1, p_1, l_out, p_out):
                n_path += mul_1

        index_1 = 0
        for mul_1, l_1, p_1 in Rs_in1:
            dim_1 = mul_1 * (2 * l_1 + 1)
            for l_2 in selection_rule(l_1, p_1, l_out, p_out):
                if l_2 == 0:
                    assert l_out == l_1
                    l = l_1
                    dim_2 = mul_1 * mul_out
                    i_out = []
                    i_1 = []
                    i_2 = []
                    v = 0
                    for w in range(mul_out):
                        for u in range(mul_1):
                            i_out += [(2 * l + 1) * w + m
                                      for m in range(2 * l + 1)]
                            i_1 += [(2 * l + 1) * u + m
                                    for m in range(2 * l + 1)]
                            i_2 += (2 * l + 1) * [v]
                            v += 1
                    i_out = index_out + torch.tensor(i_out)
                    i_1 = index_1 + torch.tensor(i_1)
                    i_2 = index_2 + torch.tensor(i_2)
                    m = torch.ones((2 * l + 1) * dim_2) / n_path**0.5
                else:
                    dim_2 = mul_1 * mul_out * (2 * l_2 + 1)
                    C = o3.wigner_3j(l_out, l_1, l_2, cached=True)
                    if normalization == 'component':
                        C *= (2 * l_out + 1)**0.5
                    if normalization == 'norm':
                        C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5
                    I = torch.eye(mul_out * mul_1).reshape(
                        mul_out, mul_1, mul_out * mul_1) / n_path**0.5
                    m = torch.einsum("wuv,kij->wkuivj", I,
                                     C).reshape(dim_out, dim_1, dim_2)
                    i_out, i_1, i_2 = m.nonzero(as_tuple=True)  # slow part
                    m = m[(i_out, i_1, i_2)]
                    i_out += index_out
                    i_1 += index_1
                    i_2 += index_2

                row.append(i_out)
                col.append(i_1 * dim_in2 + i_2)
                val.append(m)

                index_2 += dim_2
            index_1 += dim_1
        index_out += dim_out

    wigner_3j_tensor = SparseTensor(
        row=torch.cat(row) if row else torch.zeros(0, dtype=torch.long),
        col=torch.cat(col) if col else torch.zeros(0, dtype=torch.long),
        value=torch.cat(val) if val else torch.zeros(0),
        sparse_sizes=(dim(Rs_out), dim(Rs_in1) * dim(Rs_in2)))

    if sorted:
        Rs_in2, perm_mat = sort(Rs_in2)
        Rs_in2 = simplify(Rs_in2)
        # sorted = perm_mat @ unsorted
        wigner_3j_tensor = wigner_3j_tensor.sparse_reshape(-1, dim(Rs_in2))
        wigner_3j_tensor = wigner_3j_tensor @ perm_mat.t()  # slow part
        wigner_3j_tensor = wigner_3j_tensor.sparse_reshape(
            -1,
            dim(Rs_in1) * dim(Rs_in2))

    return Rs_in2, wigner_3j_tensor
Beispiel #13
0
def tensor_square(Rs_in: TY_RS_LOOSE,
                  selection_rule: o3.TY_SELECTION_RULE = o3.selection_rule,
                  normalization: str = 'component',
                  sorted: bool = False) -> Tuple[TY_RS_STRICT, SparseTensor]:
    """
    Compute the matrix Q
    from Rs_out to Rs_in tensor product with Rs_in
    where Rs_out is a direct sum of irreducible representations

    For normalization='component',
    The set of "lines" { Q[i] }_i is orthonormal

    :return: Rs_out, Q

    example:
    _, Q = tensor_square(Rs_in)
    torch.einsum('kij,i,j->k', Q, A, A)
    """
    assert normalization in [
        'norm', 'component'
    ], "normalization needs to be 'norm' or 'component'"

    Rs_in = simplify(Rs_in)

    Rs_out = []

    for i, (mul_1, l_1, p_1) in enumerate(Rs_in):
        for l_out in selection_rule(l_1, p_1, l_1, p_1):
            if l_out % 2 == 0:
                Rs_out.append((mul_1 * (mul_1 + 1) // 2, l_out, p_1**2))
            else:
                Rs_out.append((mul_1 * (mul_1 - 1) // 2, l_out, p_1**2))

        for mul_2, l_2, p_2 in Rs_in[i + 1:]:
            for l_out in selection_rule(l_1, p_1, l_2, p_2):
                Rs_out.append((mul_1 * mul_2, l_out, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    dim_in = dim(Rs_in)
    row = []
    col = []
    val = []

    index_out = 0

    index_1 = 0
    for i, (mul_1, l_1, p_1) in enumerate(Rs_in):
        dim_1 = mul_1 * (2 * l_1 + 1)

        for l_out in selection_rule(l_1, p_1, l_1, p_1):
            I = torch.eye(mul_1**2).reshape(mul_1**2, mul_1, mul_1)
            uv = I.nonzero(as_tuple=False)[:, 1:]
            if l_out % 2 == 0:
                I = I[uv[:, 0] <= uv[:, 1]]
            else:
                I = I[uv[:, 0] < uv[:, 1]]

            if I.shape[0] == 0:
                continue

            C = o3.wigner_3j(l_out, l_1, l_1)
            if normalization == 'component':
                C *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1)**0.5 * (2 * l_1 + 1)**0.5
            dim_out = I.shape[0] * (2 * l_out + 1)
            m = torch.einsum("wuv,kij->wkuivj", I,
                             C).reshape(dim_out, dim_1, dim_1)
            i_out, i_1, i_2 = m.nonzero(as_tuple=False).T
            i_out += index_out
            i_1 += index_1
            i_2 += index_1
            row.append(i_out)
            col.append(i_1 * dim_in + i_2)
            val.append(m[m != 0])

            index_out += dim_out

        index_2 = index_1 + dim_1
        for mul_2, l_2, p_2 in Rs_in[i + 1:]:
            dim_2 = mul_2 * (2 * l_2 + 1)
            for l_out in selection_rule(l_1, p_1, l_2, p_2):
                I = torch.eye(mul_1 * mul_2).reshape(mul_1 * mul_2, mul_1,
                                                     mul_2)

                C = o3.wigner_3j(l_out, l_1, l_2)
                if normalization == 'component':
                    C *= (2 * l_out + 1)**0.5
                if normalization == 'norm':
                    C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5
                dim_out = I.shape[0] * (2 * l_out + 1)
                m = torch.einsum("wuv,kij->wkuivj", I,
                                 C).reshape(dim_out, dim_1, dim_2)
                i_out, i_1, i_2 = m.nonzero(as_tuple=False).T
                i_out += index_out
                i_1 += index_1
                i_2 += index_2
                row.append(i_out)
                col.append(i_1 * dim_in + i_2)
                val.append(m[m != 0])

                index_out += dim_out
            index_2 += dim_2
        index_1 += dim_1

    wigner_3j_tensor = SparseTensor(
        row=torch.cat(row) if row else torch.zeros(0, dtype=torch.long),
        col=torch.cat(col) if col else torch.zeros(0, dtype=torch.long),
        value=torch.cat(val) if val else torch.zeros(0),
        sparse_sizes=(dim(Rs_out), dim(Rs_in) * dim(Rs_in)))

    if sorted:
        Rs_out, perm_mat = sort(Rs_out)
        Rs_out = simplify(Rs_out)
        # sorted = perm_mat @ unsorted
        wigner_3j_tensor = perm_mat @ wigner_3j_tensor

    return Rs_out, wigner_3j_tensor
Beispiel #14
0
def test_wigner_3j_symmetry():
    assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(1, 3, 2).transpose(1, 2))
    assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(2, 1, 3).transpose(0, 1))
    assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(3, 2, 1).transpose(0, 2))
    assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(3, 1, 2).transpose(0, 1).transpose(1, 2))
    assert torch.allclose(o3.wigner_3j(1, 2, 3), o3.wigner_3j(2, 3, 1).transpose(0, 2).transpose(1, 2))
Beispiel #15
0
def _tensor_product_in_in(Rs_in1, Rs_in2, selection_rule, normalization,
                          sorted):
    """
    Compute the matrix Q
    from Rs_out to Rs_in1 tensor product with Rs_in2
    where Rs_out is a direct sum of irreducible representations

    For normalization='component',
    The set of "lines" { Q[i] }_i is orthonormal

    :return: Rs_out, Q

    example:
    _, Q = tensor_product_in_in(Rs_in1, Rs_in2)
    torch.einsum('kij,i,j->k', Q, A, B)
    """
    assert normalization in [
        'norm', 'component'
    ], "normalization needs to be 'norm' or 'component'"

    Rs_in1 = simplify(Rs_in1)
    Rs_in2 = simplify(Rs_in2)

    Rs_out = []

    for mul_1, l_1, p_1 in Rs_in1:
        for mul_2, l_2, p_2 in Rs_in2:
            for l_out in selection_rule(l_1, p_1, l_2, p_2):
                Rs_out.append((mul_1 * mul_2, l_out, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    dim_in2 = dim(Rs_in2)
    row = []
    col = []
    val = []

    index_out = 0

    index_1 = 0
    for mul_1, l_1, p_1 in Rs_in1:
        dim_1 = mul_1 * (2 * l_1 + 1)

        index_2 = 0
        for mul_2, l_2, p_2 in Rs_in2:
            dim_2 = mul_2 * (2 * l_2 + 1)
            for l_out in selection_rule(l_1, p_1, l_2, p_2):
                dim_out = mul_1 * mul_2 * (2 * l_out + 1)
                C = o3.wigner_3j(l_out, l_1, l_2, cached=True)
                if normalization == 'component':
                    C *= (2 * l_out + 1)**0.5
                if normalization == 'norm':
                    C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5
                I = torch.eye(mul_1 * mul_2).reshape(mul_1 * mul_2, mul_1,
                                                     mul_2)
                m = torch.einsum("wuv,kij->wkuivj", I,
                                 C).reshape(dim_out, dim_1, dim_2)
                i_out, i_1, i_2 = m.nonzero(as_tuple=False).T
                i_out += index_out
                i_1 += index_1
                i_2 += index_2
                row.append(i_out)
                col.append(i_1 * dim_in2 + i_2)
                val.append(m[m != 0])

                index_out += dim_out
            index_2 += dim_2
        index_1 += dim_1

    wigner_3j_tensor = SparseTensor(
        row=torch.cat(row) if row else torch.zeros(0, dtype=torch.long),
        col=torch.cat(col) if col else torch.zeros(0, dtype=torch.long),
        value=torch.cat(val) if val else torch.zeros(0),
        sparse_sizes=(dim(Rs_out), dim(Rs_in1) * dim(Rs_in2)))

    if sorted:
        Rs_out, perm_mat = sort(Rs_out)
        Rs_out = simplify(Rs_out)
        wigner_3j_tensor = perm_mat @ wigner_3j_tensor

    return Rs_out, wigner_3j_tensor
Beispiel #16
0
    def __init__(self,
                 Rs_in1: rs.TY_RS_LOOSE,
                 Rs_in2: rs.TY_RS_LOOSE,
                 Rs_out: rs.TY_RS_LOOSE,
                 instr: List[Tuple[int, int, int, str]],
                 normalization: str = 'component',
                 own_weight: bool = True):
        """
        Create a Tensor Product operation that has each of his path weighted by a parameter.
        `instr` is a list of instructions.
        An instruction if of the form (i_1, i_2, i_out, mode)
        it means "Put `Rs_in1[i_1] otimes Rs_in2[i_2] into Rs_out[i_out]"
        `mode` determines the way the multiplicities are treated.
        The default mode should be 'uvw', meaning that all paths are created.
        """

        super().__init__()

        self.Rs_in1 = rs.convention(Rs_in1)
        self.Rs_in2 = rs.convention(Rs_in2)
        self.Rs_out = rs.convention(Rs_out)

        code = ""

        index_w = 0
        wigners = set()
        count = [0 for _ in range(rs.dim(self.Rs_out))]

        instr = sorted(instr)  # for optimization

        last_s1, last_s2, last_ss = None, None, None
        for i_1, i_2, i_out, mode in instr:
            mul_1, l_1, p_1 = self.Rs_in1[i_1]
            mul_2, l_2, p_2 = self.Rs_in2[i_2]
            mul_out, l_out, p_out = self.Rs_out[i_out]
            dim_1 = mul_1 * (2 * l_1 + 1)
            dim_2 = mul_2 * (2 * l_2 + 1)
            dim_out = mul_out * (2 * l_out + 1)
            index_1 = rs.dim(self.Rs_in1[:i_1])
            index_2 = rs.dim(self.Rs_in2[:i_2])
            index_out = rs.dim(self.Rs_out[:i_out])

            assert p_1 * p_2 == p_out
            assert abs(l_1 - l_2) <= l_out <= l_1 + l_2

            if dim_1 == 0 or dim_2 == 0 or dim_out == 0:
                continue

            if last_s1 != i_1:
                code += f"    s1 = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n"
                last_s1 = i_1

            if last_s2 != i_2:
                code += f"    s2 = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n"
                last_s2 = i_2

            assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']

            if last_ss != (i_1, i_2, mode[:2]):
                if mode[:2] == 'uv':
                    code += f"    ss = ein('zui,zvj->zuvij', s1, s2)\n"
                if mode[:2] == 'uu':
                    code += f"    ss = ein('zui,zuj->zuij', s1, s2)\n"
                last_ss = (i_1, i_2, mode[:2])

            wigners.add((l_1, l_2, l_out))

            if mode == 'uvw':
                dim_w = mul_1 * mul_2 * mul_out
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2}, {mul_out})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuvw,ijk,zuvij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1 * mul_2

            if mode == 'uvu':
                assert mul_1 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_2

            if mode == 'uvv':
                assert mul_2 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1

            if mode == 'uuw':
                assert mul_1 == mul_2
                dim_w = mul_1 * mul_out
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_out})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuw,ijk,zuij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1

            if mode == 'uuu':
                assert mul_1 == mul_2 == mul_out
                dim_w = mul_1
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zu,ijk,zuij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += 1

            if mode == 'uvuv':
                assert mul_1 * mul_2 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += 1

            index_w += dim_w
            code += "\n"

        ilast = 0
        clast = count[0]
        for i, c in enumerate(count):
            if clast != c:
                if clast > 1:
                    code += f"    out[:, {ilast}:{i}].div_({clast ** 0.5})\n"
                clast = c
                ilast = i
        if clast > 1:
            code += f"    out[:, {ilast}:].div_({clast ** 0.5})\n"

        wigners = sorted(wigners)
        self.wigners_names = [
            f"C{l_1}_{l_2}_{l_3}" for l_1, l_2, l_3 in wigners
        ]
        args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names)

        for arg, (l_1, l_2, l_out) in zip(self.wigners_names, wigners):
            wig = o3.wigner_3j(l_1, l_2, l_out)

            if normalization == 'component':
                wig *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(arg, wig)

        x = _tensor_product_code
        x = x.replace("DIM", f"{rs.dim(self.Rs_out)}")
        x = x.replace("ARGS", args)
        x = x.replace("CODE", code)

        self.code = x
        self.main = eval_code(x).main
        self.nweight = index_w
        if own_weight:
            self.weight = torch.nn.Parameter(torch.randn(self.nweight))
Beispiel #17
0
Datei: rs.py Projekt: zizai/e3nn
def _tensor_product_in_out(Rs_in1, selection_rule, Rs_out, normalization, sorted):
    """
    Compute the matrix Q
    from Rs_out to Rs_in1 tensor product with Rs_in2
    where Rs_in2 is a direct sum of irreducible representations

    For normalization='component',
    The set of "lines" { Q[i] }_i is orthonormal

    :return: Rs_in2, Q

    example:
    _, Q = tensor_product_in_out(Rs_in1, Rs_out)
    torch.einsum('kij,i,j->k', Q, A, B)
    """
    assert normalization in ['norm', 'component'], "normalization needs to be 'norm' or 'component'"

    Rs_in1 = simplify(Rs_in1)
    Rs_out = simplify(Rs_out)

    Rs_in2 = []

    for mul_out, l_out, p_out in Rs_out:
        for mul_1, l_1, p_1 in Rs_in1:
            for l_2 in selection_rule(l_1, p_1, l_out, p_out):
                Rs_in2.append((mul_1 * mul_out, l_2, p_1 * p_out))

    Rs_in2 = simplify(Rs_in2)

    wigner_3j_tensor = torch.zeros(dim(Rs_out), dim(Rs_in1), dim(Rs_in2))

    index_2 = 0

    index_out = 0
    for mul_out, l_out, p_out in Rs_out:
        dim_out = mul_out * (2 * l_out + 1)

        n_path = 0
        for mul_1, l_1, p_1 in Rs_in1:
            for l_2 in selection_rule(l_1, p_1, l_out, p_out):
                n_path += mul_1

        index_1 = 0
        for mul_1, l_1, p_1 in Rs_in1:
            dim_1 = mul_1 * (2 * l_1 + 1)
            for l_2 in selection_rule(l_1, p_1, l_out, p_out):
                dim_2 = mul_1 * mul_out * (2 * l_2 + 1)
                C = o3.wigner_3j(l_out, l_1, l_2, cached=True)
                if normalization == 'component':
                    C *= (2 * l_out + 1) ** 0.5
                if normalization == 'norm':
                    C *= (2 * l_1 + 1) ** 0.5 * (2 * l_2 + 1) ** 0.5
                I = torch.eye(mul_out * mul_1).reshape(mul_out, mul_1, mul_out * mul_1) / n_path ** 0.5
                m = torch.einsum("wuv,kij->wkuivj", I, C).reshape(dim_out, dim_1, dim_2)
                wigner_3j_tensor[index_out:index_out + dim_out, index_1:index_1 + dim_1, index_2:index_2 + dim_2] = m

                index_2 += dim_2
            index_1 += dim_1
        index_out += dim_out

    if sorted:
        Rs_in2, perm = sort(Rs_in2)
        Rs_in2 = simplify(Rs_in2)
        wigner_3j_tensor = torch.einsum('jl,kil->kij', perm, wigner_3j_tensor)

    return Rs_in2, wigner_3j_tensor
Beispiel #18
0
Datei: rs.py Projekt: zizai/e3nn
def tensor_square(Rs_in, selection_rule=o3.selection_rule, normalization='component', sorted=False):
    """
    Compute the matrix Q
    from Rs_out to Rs_in tensor product with Rs_in
    where Rs_out is a direct sum of irreducible representations

    For normalization='component',
    The set of "lines" { Q[i] }_i is orthonormal

    :return: Rs_out, Q

    example:
    _, Q = tensor_square(Rs_in)
    torch.einsum('kij,i,j->k', Q, A, A)
    """
    assert normalization in ['norm', 'component'], "normalization needs to be 'norm' or 'component'"

    Rs_in = simplify(Rs_in)

    Rs_out = []

    for i, (mul_1, l_1, p_1) in enumerate(Rs_in):
        for l_out in selection_rule(l_1, p_1, l_1, p_1):
            if l_out % 2 == 0:
                Rs_out.append((mul_1 * (mul_1 + 1) // 2, l_out, p_1**2))
            else:
                Rs_out.append((mul_1 * (mul_1 - 1) // 2, l_out, p_1**2))

        for mul_2, l_2, p_2 in Rs_in[i + 1:]:
            for l_out in selection_rule(l_1, p_1, l_2, p_2):
                Rs_out.append((mul_1 * mul_2, l_out, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    wigner_3j_tensor = torch.zeros(dim(Rs_out), dim(Rs_in), dim(Rs_in))

    index_out = 0

    index_1 = 0
    for i, (mul_1, l_1, p_1) in enumerate(Rs_in):
        dim_1 = mul_1 * (2 * l_1 + 1)

        for l_out in selection_rule(l_1, p_1, l_1, p_1):
            I = torch.eye(mul_1**2).reshape(mul_1**2, mul_1, mul_1)
            uv = I.nonzero()[:, 1:]
            if l_out % 2 == 0:
                I = I[uv[:, 0] <= uv[:, 1]]
            else:
                I = I[uv[:, 0] < uv[:, 1]]

            if I.shape[0] == 0:
                continue

            C = o3.wigner_3j(l_out, l_1, l_1)
            if normalization == 'component':
                C *= (2 * l_out + 1) ** 0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1) ** 0.5 * (2 * l_1 + 1) ** 0.5
            dim_out = I.shape[0] * (2 * l_out + 1)
            m = torch.einsum("wuv,kij->wkuivj", I, C).reshape(dim_out, dim_1, dim_1)
            wigner_3j_tensor[index_out:index_out + dim_out, index_1:index_1 + dim_1, index_1:index_1 + dim_1] = m

            index_out += dim_out

        index_2 = index_1 + dim_1
        for mul_2, l_2, p_2 in Rs_in[i + 1:]:
            dim_2 = mul_2 * (2 * l_2 + 1)
            for l_out in selection_rule(l_1, p_1, l_2, p_2):
                I = torch.eye(mul_1 * mul_2).reshape(mul_1 * mul_2, mul_1, mul_2)

                C = o3.wigner_3j(l_out, l_1, l_2)
                if normalization == 'component':
                    C *= (2 * l_out + 1) ** 0.5
                if normalization == 'norm':
                    C *= (2 * l_1 + 1) ** 0.5 * (2 * l_2 + 1) ** 0.5
                dim_out = I.shape[0] * (2 * l_out + 1)
                m = torch.einsum("wuv,kij->wkuivj", I, C).reshape(dim_out, dim_1, dim_2)
                wigner_3j_tensor[index_out:index_out + dim_out, index_1:index_1 + dim_1, index_2:index_2 + dim_2] = m

                index_out += dim_out
            index_2 += dim_2
        index_1 += dim_1

    if sorted:
        Rs_out, perm = sort(Rs_out)
        Rs_out = simplify(Rs_out)
        wigner_3j_tensor = torch.einsum('ij,jkl->ikl', perm, wigner_3j_tensor)
    return Rs_out, wigner_3j_tensor
Beispiel #19
0
def codegen_tensor_product(
    irreps_in1: o3.Irreps,
    in1_var: List[float],
    irreps_in2: o3.Irreps,
    in2_var: List[float],
    irreps_out: o3.Irreps,
    out_var: List[float],
    instructions: List[Instruction],
    normalization: str = 'component',
    shared_weights: bool = False,
    specialized_code: bool = True,
    optimize_einsums: bool = True,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
    graph_out = fx.Graph()
    graph_right = fx.Graph()

    # = Function definitions =
    x1s_out = fx.Proxy(graph_out.placeholder('x1', torch.Tensor))
    x2s_out = fx.Proxy(graph_out.placeholder('x2', torch.Tensor))
    ws_out = fx.Proxy(graph_out.placeholder('w', torch.Tensor))

    x2s_right = fx.Proxy(graph_right.placeholder('x2', torch.Tensor))
    ws_right = fx.Proxy(graph_right.placeholder('w', torch.Tensor))

    empty_out = fx.Proxy(
        graph_out.call_function(torch.empty, ((), ), dict(device='cpu')))
    empty_right = fx.Proxy(
        graph_right.call_function(torch.empty, ((), ), dict(device='cpu')))
    if shared_weights:
        size_out = torch.broadcast_tensors(
            empty_out.expand(x1s_out.shape[:-1]),
            empty_out.expand(x2s_out.shape[:-1]))[0].shape
        size_right = x2s_right.shape[:-1]
    else:
        size_out = torch.broadcast_tensors(
            empty_out.expand(x1s_out.shape[:-1]),
            empty_out.expand(x2s_out.shape[:-1]),
            empty_out.expand(ws_out.shape[:-1]))[0].shape
        size_right = torch.broadcast_tensors(
            empty_right.expand(x2s_right.shape[:-1]),
            empty_right.expand(ws_right.shape[:-1]))[0].shape

    # = Short-circut for zero dimensional =
    # We produce no code for empty instructions
    instructions = [ins for ins in instructions if 0 not in ins.path_shape]

    if len(instructions) == 0:
        out_out = x1s_out.new_zeros(size_out + (irreps_out.dim, ))
        out_right = x2s_right.new_zeros(size_right + (
            irreps_in1.dim,
            irreps_out.dim,
        ))

        graph_out.output(out_out.node, torch.Tensor)
        graph_right.output(out_right.node, torch.Tensor)
        # Short circut
        return (fx.GraphModule({}, graph_out, "tp_forward"),
                fx.GraphModule({}, graph_right, "tp_right"))

    # = Broadcast inputs =
    if shared_weights:
        x1s_out, x2s_out = x1s_out.broadcast_to(
            size_out + (-1, )), x2s_out.broadcast_to(size_out + (-1, ))
    else:
        x1s_out, x2s_out, ws_out = x1s_out.broadcast_to(
            size_out + (-1, )), x2s_out.broadcast_to(
                size_out + (-1, )), ws_out.broadcast_to(size_out + (-1, ))
        x2s_right, ws_right = x2s_right.broadcast_to(
            size_right + (-1, )), ws_right.broadcast_to(size_right + (-1, ))

    outsize_out = size_out + (irreps_out.dim, )
    outsize_right = size_right + (
        irreps_in1.dim,
        irreps_out.dim,
    )

    x1s_out = x1s_out.reshape(-1, irreps_in1.dim)
    x2s_out = x2s_out.reshape(-1, irreps_in2.dim)
    x2s_right = x2s_right.reshape(-1, irreps_in2.dim)

    batch_out = x1s_out.shape[0]
    batch_right = x2s_right.shape[0]

    # = Determine number of weights and reshape weights ==
    weight_numel = sum(
        prod(ins.path_shape) for ins in instructions if ins.has_weight)
    if weight_numel > 0:
        ws_out = ws_out.reshape(-1, weight_numel)
        ws_right = ws_right.reshape(-1, weight_numel)
    del weight_numel

    # = book-keeping for wigners =
    w3j = []
    w3j_dict_out = dict()
    w3j_dict_right = dict()

    # = extract individual input irreps =
    # If only one input irrep, can avoid creating a view
    if len(irreps_in1) == 1:
        x1_list_out = [
            x1s_out.reshape(batch_out, irreps_in1[0].mul, irreps_in1[0].ir.dim)
        ]
    else:
        x1_list_out = [
            x1s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(irreps_in1.slices(), irreps_in1)
        ]

    x2_list_out = []
    x2_list_right = []
    # If only one input irrep, can avoid creating a view
    if len(irreps_in2) == 1:
        x2_list_out.append(
            x2s_out.reshape(batch_out, irreps_in2[0].mul,
                            irreps_in2[0].ir.dim))
        x2_list_right.append(
            x2s_right.reshape(batch_right, irreps_in2[0].mul,
                              irreps_in2[0].ir.dim))
    else:
        for i, mul_ir in zip(irreps_in2.slices(), irreps_in2):
            x2_list_out.append(x2s_out[:, i].reshape(batch_out, mul_ir.mul,
                                                     mul_ir.ir.dim))
            x2_list_right.append(x2s_right[:,
                                           i].reshape(batch_right, mul_ir.mul,
                                                      mul_ir.ir.dim))

    # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension
    z = '' if shared_weights else 'z'

    # Cache of input irrep pairs whose outer products (xx) have already been computed
    xx_dict = dict()

    # Current index in the flat weight tensor
    flat_weight_index = 0

    out_list_out = []
    out_list_right = []

    for ins in instructions:
        mul_ir_in1 = irreps_in1[ins.i_in1]
        mul_ir_in2 = irreps_in2[ins.i_in2]
        mul_ir_out = irreps_out[ins.i_out]

        assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p
        assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l
                   ) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l

        if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0:
            continue

        alpha = ins.path_weight * out_var[ins.i_out] / sum(
            in1_var[i.i_in1] * in2_var[i.i_in2]
            for i in instructions if i.i_out == ins.i_out)

        # Open the profiler block
        name = f"{mul_ir_in1} x {mul_ir_in2} = {mul_ir_out} {ins.connection_mode} {ins.has_weight}"
        handle_out = graph_out.call_function(
            torch.ops.profiler._record_function_enter, (name, ))
        handle_right = graph_right.call_function(
            torch.ops.profiler._record_function_enter, (name, ))

        x1_out = x1_list_out[ins.i_in1]
        x2_out = x2_list_out[ins.i_in2]
        x2_right = x2_list_right[ins.i_in2]

        e1_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in1.mul, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))
        e2_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in2.mul, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))
        i1_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in1.ir.dim, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))

        assert ins.connection_mode in [
            'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'
        ]

        alpha = sqrt(
            alpha / {
                'uvw': (mul_ir_in1.mul * mul_ir_in2.mul),
                'uvu': mul_ir_in2.mul,
                'uvv': mul_ir_in1.mul,
                'uuw': mul_ir_in1.mul,
                'uuu': 1,
                'uvuv': 1,
            }[ins.connection_mode])

        if ins.has_weight:
            # Extract the weight from the flattened weight tensor
            w_out = ws_out[:, flat_weight_index:flat_weight_index +
                           prod(ins.path_shape)].reshape((
                               () if shared_weights else (-1, )) +
                                                         tuple(ins.path_shape))
            w_right = ws_right[:, flat_weight_index:flat_weight_index +
                               prod(ins.path_shape)].reshape(
                                   (() if shared_weights else (-1, )) +
                                   tuple(ins.path_shape))
            flat_weight_index += prod(ins.path_shape)

        # Construct the general xx in case this instruction isn't specialized
        # If this isn't used, the dead code will get removed
        key = (ins.i_in1, ins.i_in2, ins.connection_mode[:2])
        if key not in xx_dict:
            if ins.connection_mode[:2] == 'uv':
                xx_dict[key] = torch.einsum('zui,zvj->zuvij', x1_out, x2_out)
            if ins.connection_mode[:2] == 'uu':
                xx_dict[key] = torch.einsum('zui,zuj->zuij', x1_out, x2_out)
        xx = xx_dict[key]

        # Create a proxy & request for the relevant wigner w3j
        # If not used (because of specialized code), will get removed later.
        key = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l)
        if key not in w3j:
            w3j_dict_out[key] = fx.Proxy(
                graph_out.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}"))
            w3j_dict_right[key] = fx.Proxy(
                graph_right.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}"))
            w3j.append(key)
        w3j_out = w3j_dict_out[key]
        w3j_right = w3j_dict_right[key]

        exp = {'component': 1, 'norm': -1}[normalization]

        if ins.connection_mode == 'uvw':
            assert ins.has_weight
            if specialized_code and key == (0, 0, 0):
                ein_out = torch.einsum(
                    f"{z}uvw,zu,zv->zw", w_out,
                    x1_out.reshape(batch_out, mul_ir_in1.dim),
                    x2_out.reshape(batch_out, mul_ir_in2.dim))
                ein_right = torch.einsum(
                    f"{z}uvw,zv->zuw", w_right,
                    x2_right.reshape(batch_right, mul_ir_in2.dim))
            elif specialized_code and mul_ir_in1.ir.l == 0:
                ein_out = torch.einsum(
                    f"{z}uvw,zu,zvj->zwj", w_out,
                    x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                ein_right = torch.einsum(f"{z}uvw,zvi->zuwi", w_right,
                                         x2_right)
            elif specialized_code and mul_ir_in2.ir.l == 0:
                ein_out = torch.einsum(
                    f"{z}uvw,zui,zv->zwi", w_out, x1_out,
                    x2_out.reshape(batch_out, mul_ir_in2.dim))
                ein_right = torch.einsum(
                    f"{z}uvw,ij,zv->zuiwj", w_right, i1_right,
                    x2_right.reshape(batch_right, mul_ir_in2.dim))
            elif specialized_code and mul_ir_out.ir.l == 0:
                ein_out = torch.einsum(f"{z}uvw,zui,zvi->zw", w_out, x1_out,
                                       x2_out) / sqrt(mul_ir_in1.ir.dim)**exp
                ein_right = torch.einsum(f"{z}uvw,zvi->zuiw", w_right,
                                         x2_right) / sqrt(
                                             mul_ir_in1.ir.dim)**exp
            else:
                ein_out = torch.einsum(f"{z}uvw,ijk,zuvij->zwk", w_out,
                                       w3j_out, xx)
                ein_right = torch.einsum(f"{z}uvw,ijk,zvj->zuiwk", w_right,
                                         w3j_right, x2_right)
        if ins.connection_mode == 'uvu':
            assert mul_ir_in1.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zv->zu", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,uw,zv->zuw", w_right, e1_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zvj->zuj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}uv,uw,zvi->zuwi", w_right,
                                             e1_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zui,zv->zui", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,ij,uw,zv->zuiwj", w_right, i1_right, e1_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uv,zui,zvi->zu", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}uv,uw,zvi->zuiw", w_right,
                                             e1_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
            else:
                # not so useful operation because v is summed
                ein_out = torch.einsum("ijk,zuvij->zuk", w3j_out, xx)
                ein_right = torch.einsum("ijk,uw,zvj->zuiwk", w3j_right,
                                         e1_right, x2_right)
        if ins.connection_mode == 'uvv':
            assert mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zv->zv", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,vw,zv->zuw", w_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zvj->zvj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}uv,vw,zvi->zuwi", w_right,
                                             e2_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zui,zv->zvi", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,ij,vw,zv->zuiwj", w_right, i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uv,zui,zvi->zv", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}uv,vw,zvi->zuiw", w_right,
                                             e2_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zvk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}uv,ijk,zvj->zuivk", w_right,
                                             w3j_right, x2_right)
            else:
                # not so useful operation because u is summed
                # only specialize out for this path
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        "zu,zv->zv", x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        "zu,zvj->zvj", x1_out.reshape(batch_out,
                                                      mul_ir_in1.dim), x2_out)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        "zui,zv->zvi", x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum("zui,zvi->zv", x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum("ijk,zuvij->zvk", w3j_out, xx)
                s2ones = fx.Proxy(
                    graph_right.call_function(
                        torch.ones, (mul_ir_in1.mul, ),
                        dict(device=x2_right.device.node,
                             dtype=x2_right.dtype.node)))
                ein_right = torch.einsum("u,ijk,zvj->zuivk", s2ones, w3j_right,
                                         x2_right)
        if ins.connection_mode == 'uuw':
            assert mul_ir_in1.mul == mul_ir_in2.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uw,zu,zu->zw", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uw,zu,zuj->zwj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uw,zui,zu->zwi", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uw,zui,zui->zw", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uw,ijk,zuij->zwk", w_out,
                                           w3j_out, xx)
                # TODO: specialize right()
                ein_right = torch.einsum(f"{z}uw,ijk,zuj->zuiwk", w_right,
                                         w3j_right, x2_right)
            else:
                # equivalent to tp(x, y, 'uuu').sum('u')
                assert mul_ir_out.mul == 1
                ein_out = torch.einsum("ijk,zuij->zk", w3j_out, xx)
                ein_right = torch.einsum("ijk,zuj->zuik", w3j_right, x2_right)
        if ins.connection_mode == 'uuu':
            assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}u,zu,zu->zu", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}u,uw,zu->zuw", w_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and key == (
                        1, 1, 1) and normalization == "component":
                    ein_out = torch.einsum(f"{z}u,zui->zui", w_out,
                                           torch.cross(x1_out, x2_out,
                                                       dim=2)) / sqrt(2)
                    # For cross product, use the general case right()
                    ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}u,zu,zuj->zuj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}u,uw,zui->zuwi", w_right,
                                             e2_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}u,zui,zu->zui", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}u,ij,uw,zu->zuiwj", w_right, i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}u,zui,zui->zu", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}u,uw,zui->zuiw", w_right,
                                             e2_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}u,ijk,zuij->zuk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
            else:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        "zu,zu->zu", x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        "uw,zu->zuw", e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and key == (
                        1, 1, 1) and normalization == "component":
                    ein_out = torch.cross(x1_out, x2_out,
                                          dim=2) * (1.0 / sqrt(2))
                    # For cross product, use the general case right()
                    ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right,
                                             e1_right, x2_right)
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        "zu,zuj->zuj", x1_out.reshape(batch_out,
                                                      mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum("uw,zui->zuwi", e2_right,
                                             x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        "zui,zu->zui", x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        "ij,uw,zu->zuiwj", i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum("zui,zui->zu", x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum("uw,zui->zuiw", e2_right,
                                             x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum("ijk,zuij->zuk", w3j_out, xx)
                    ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right,
                                             e1_right, x2_right)
        if ins.connection_mode == 'uvuv':
            assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                # TODO implement specialized code
                ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuvk", w_out,
                                       w3j_out, xx)
                ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwvk", w_right,
                                         w3j_right, e1_right, x2_right)
            else:
                # TODO implement specialized code
                ein_out = torch.einsum("ijk,zuvij->zuvk", w3j_out, xx)
                ein_right = torch.einsum("ijk,uw,zvj->zuiwvk", w3j_right,
                                         e1_right, x2_right)

        ein_out = alpha * ein_out
        ein_right = alpha * ein_right

        out_list_out += [ein_out.reshape(batch_out, mul_ir_out.dim)]
        out_list_right += [
            ein_right.reshape(batch_right, mul_ir_in1.dim, mul_ir_out.dim)
        ]

        # Close the profiler block
        graph_out.call_function(torch.ops.profiler._record_function_exit,
                                (handle_out, ))
        graph_right.call_function(torch.ops.profiler._record_function_exit,
                                  (handle_right, ))

        # Remove unused w3js:
        if len(w3j_out.node.users) == 0 and len(w3j_right.node.users) == 0:
            del w3j[-1]
            # The w3j nodes are reshapes, so we have to remove them from the graph
            # Although they are dead code, they try to reshape to dimensions that don't exist
            # (since the corresponding w3js are not in w3j)
            # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript.
            graph_out.erase_node(w3j_dict_out.pop(key).node)
            graph_right.erase_node(w3j_dict_right.pop(key).node)

    # = Return the result =
    out_out = [
        _sum_tensors([
            out for ins, out in zip(instructions, out_list_out)
            if ins.i_out == i_out
        ],
                     shape=(batch_out, mul_ir_out.dim),
                     like=x1s_out)
        for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0
    ]
    if len(out_out) > 1:
        out_out = torch.cat(out_out, dim=1)
    else:
        # Avoid an unnecessary copy in a size one torch.cat
        out_out = out_out[0]

    out_right = [
        torch.cat([
            _sum_tensors([
                out for ins, out in zip(instructions, out_list_right)
                if (ins.i_in1, ins.i_out) == (i_in1, i_out)
            ],
                         shape=(batch_right, mul_ir_in1.dim, mul_ir_out.dim),
                         like=x2s_right)
            for i_out, mul_ir_out in enumerate(irreps_out)
            if mul_ir_out.mul > 0
        ],
                  dim=2) for i_in1, mul_ir_in1 in enumerate(irreps_in1)
        if mul_ir_in1.mul > 0
    ]
    if len(out_right) > 1:
        out_right = torch.cat(out_right, dim=1)
    else:
        out_right = out_right[0]

    out_out = out_out.reshape(outsize_out)
    out_right = out_right.reshape(outsize_right)

    graph_out.output(out_out.node, torch.Tensor)
    graph_right.output(out_right.node, torch.Tensor)

    # check graphs
    graph_out.lint()
    graph_right.lint()

    # Make GraphModules
    wigner_mats = {}
    for l_1, l_2, l_out in w3j:
        wig = o3.wigner_3j(l_1, l_2, l_out)

        if normalization == 'component':
            wig *= (2 * l_out + 1)**0.5
        if normalization == 'norm':
            wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

        wigner_mats[f"_w3j_{l_1}_{l_2}_{l_out}"] = wig

    # By putting the constants in a Module rather than a dict,
    # we force FX to copy them as buffers instead of as attributes.
    #
    # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0.
    constants_root = torch.nn.Module()
    for wkey, wmat in wigner_mats.items():
        constants_root.register_buffer(wkey, wmat)
    graphmod_out = fx.GraphModule(constants_root,
                                  graph_out,
                                  class_name="tp_forward")
    graphmod_right = fx.GraphModule(constants_root,
                                    graph_right,
                                    class_name="tp_right")

    # == Optimize ==
    # TODO: when eliminate_dead_code() is in PyTorch stable, use that
    if optimize_einsums:
        # Note that for our einsums, we can optimize _once_ for _any_ batch dimension
        # and still get the right path for _all_ batch dimensions.
        # This is because our einsums are essentially of the form:
        #    zuvw,ijk,zuvij->zwk    OR     uvw,ijk,zuvij->zwk
        # In the first case, all but one operands have the batch dimension
        #    => The first contraction gains the batch dimension
        #    => All following contractions have batch dimension
        #    => All possible contraction paths have cost that scales linearly in batch size
        #    => The optimal path is the same for all batch sizes
        # For the second case, this logic follows as long as the first contraction is not between the first two operands. Since those two operands do not share any indexes, contracting them first is a rare pathological case. See
        # https://github.com/dgasmith/opt_einsum/issues/158
        # for more details.
        #
        # TODO: consider the impact maximum intermediate result size on this logic
        #         \- this is the `memory_limit` option in opt_einsum
        # TODO: allow user to choose opt_einsum parameters?
        #
        # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or dtypes.
        batchdim = 4
        example_inputs = (
            torch.zeros((batchdim, irreps_in1.dim)),
            torch.zeros((batchdim, irreps_in2.dim)),
            torch.zeros(
                1 if shared_weights else batchdim,
                flat_weight_index,
            ),
        )
        graphmod_out = jitable(
            optimize_einsums_full(graphmod_out, example_inputs))
        graphmod_right = jitable(
            optimize_einsums_full(graphmod_right, example_inputs[1:]))

    return graphmod_out, graphmod_right
Beispiel #20
0
Datei: rs.py Projekt: zizai/e3nn
def elementwise_tensor_product(Rs_1, Rs_2, selection_rule=o3.selection_rule, normalization='component'):
    """
    :return: Rs_out, matrix

    m_kij A_i B_j
    """
    assert normalization in ['norm', 'component'], "normalization needs to be 'norm' or 'component'"

    Rs_1 = simplify(Rs_1)
    Rs_2 = simplify(Rs_2)

    assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2)

    i = 0
    while i < len(Rs_1):
        mul_1, l_1, p_1 = Rs_1[i]
        mul_2, l_2, p_2 = Rs_2[i]

        if mul_1 < mul_2:
            Rs_2[i] = (mul_1, l_2, p_2)
            Rs_2.insert(i + 1, (mul_2 - mul_1, l_2, p_2))

        if mul_2 < mul_1:
            Rs_1[i] = (mul_2, l_1, p_1)
            Rs_1.insert(i + 1, (mul_1 - mul_2, l_1, p_1))
        i += 1

    Rs_out = []
    for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2):
        assert mul == mul_2
        for l in selection_rule(l_1, p_1, l_2, p_2):
            Rs_out.append((mul, l, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    wigner_3j_tensor = torch.zeros(dim(Rs_out), dim(Rs_1), dim(Rs_2))

    index_out = 0
    index_1 = 0
    index_2 = 0
    for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2):
        assert mul == mul_2
        dim_1 = mul * (2 * l_1 + 1)
        dim_2 = mul * (2 * l_2 + 1)

        for l_out in selection_rule(l_1, p_1, l_2, p_2):
            dim_out = mul * (2 * l_out + 1)
            C = o3.wigner_3j(l_out, l_1, l_2, cached=True)
            if normalization == 'component':
                C *= (2 * l_out + 1) ** 0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1) ** 0.5 * (2 * l_2 + 1) ** 0.5
            I = torch.einsum("uv,wu->wuv", torch.eye(mul), torch.eye(mul))
            m = torch.einsum("wuv,kij->wkuivj", I, C).reshape(dim_out, dim_1, dim_2)
            wigner_3j_tensor[index_out:index_out + dim_out, index_1:index_1 + dim_1, index_2:index_2 + dim_2] = m
            index_out += dim_out

        index_1 += dim_1
        index_2 += dim_2

    return Rs_out, wigner_3j_tensor
Beispiel #21
0
"""
Generate the .cache files
"""
from e3nn import o3

lmax = 10

for l1 in range(lmax + 1):
    for l2 in range(lmax + 1):
        for l3 in range(abs(l1 - l2), min(l1 + l2, lmax) + 1):
            print(l1, l2, l3)
            o3.wigner_3j(l1, l2, l3)
Beispiel #22
0
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
    """ Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
    return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64,
                        device=device).permute(2, 1, 0)
Beispiel #23
0
    def backward(ctx, grad_kernel):
        F, Y, R, norm_coef = ctx.saved_tensors
        batch, a, b = ctx.batch, ctx.a, ctx.b

        grad_F = grad_Y = grad_R = None

        if ctx.needs_input_grad[0]:
            grad_F = grad_kernel.new_zeros(
                *ctx.F_shape)  # [batch, b, l_in * mul_in * m_in]
        if ctx.needs_input_grad[1]:
            grad_Y = grad_kernel.new_zeros(
                *ctx.Y_shape)  # [l_filter * m_filter, batch, a, b]
        if ctx.needs_input_grad[2]:
            grad_R = grad_kernel.new_zeros(
                *ctx.R_shape
            )  # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]

        begin_R = 0

        begin_out = 0
        for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out):
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            begin_in = 0
            for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in):
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                l_filters = ctx.selection_rule(l_in, p_in, l_out, p_out)
                if not l_filters:
                    continue

                n = mul_out * mul_in * len(l_filters)
                if (grad_Y is not None) or (grad_F is not None):
                    sub_R = R[:, :, :, begin_R:begin_R + n].reshape(
                        batch, a, b, mul_out, mul_in,
                        -1)  # [batch, a, b, mul_out, mul_in, l_filter]
                if grad_R is not None:
                    sub_grad_R = grad_R[:, :, :, begin_R:begin_R + n].clone(
                    ).reshape(batch, a, b, mul_out, mul_in,
                              -1)  # [batch, a, b, mul_out, mul_in, l_filter]

                if grad_F is not None:
                    sub_grad_F = grad_F[:, :, s_in].clone().reshape(
                        batch, b, mul_in,
                        2 * l_in + 1)  # [batch, b, mul_in, 2 * l_in + 1]
                if (grad_Y is not None) or (grad_R is not None):
                    sub_F = F[..., s_in].reshape(batch, b, mul_in,
                                                 2 * l_in + 1)

                grad_K = grad_kernel[:, :,
                                     s_out].reshape(batch, a, mul_out,
                                                    2 * l_out + 1)

                for k, l_filter in enumerate(l_filters):
                    tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters
                              if l < l_filter)
                    C = o3.wigner_3j(l_out,
                                     l_in,
                                     l_filter,
                                     cached=True,
                                     like=grad_kernel)  # [m_out, m_in, m]

                    if (grad_F is not None) or (grad_R is not None):
                        sub_Y = Y[:, :, :, tmp:tmp + 2 * l_filter +
                                  1]  # [batch, a, b, m]

                    if grad_F is not None:
                        sub_grad_F += norm_coef[i, j] * torch.einsum(
                            "zaui,ijk,zabk,zabuv->zbvj", grad_K, C, sub_Y,
                            sub_R[..., k])  # [batch, b, mul_in, 2 * l_in + 1
                    if grad_Y is not None:
                        grad_Y[..., tmp:tmp + 2 * l_filter +
                               1] += norm_coef[i, j] * torch.einsum(
                                   "zaui,ijk,zabuv,zbvj->zabk", grad_K, C,
                                   sub_R[..., k], sub_F)  # [m, batch, a, b]
                    if grad_R is not None:
                        sub_grad_R[..., k] = norm_coef[i, j] * torch.einsum(
                            "zaui,ijk,zabk,zbvj->zabuv", grad_K, C, sub_Y,
                            sub_F)  # [batch, a, b, mul_out, mul_in]
                if grad_F is not None:
                    grad_F[:, :,
                           s_in] = sub_grad_F.reshape(batch, b,
                                                      mul_in * (2 * l_in + 1))
                if grad_R is not None:
                    grad_R[..., begin_R:begin_R + n] += sub_grad_R.reshape(
                        batch, a, b, -1)
                begin_R += n

        return grad_F, grad_Y, grad_R, None, None, None, None, None
Beispiel #24
0
    def __init__(self,
                 Rs_in1,
                 Rs_in2,
                 Rs_out,
                 selection_rule=o3.selection_rule,
                 normalization='component',
                 groups=1):
        super().__init__()

        self.Rs_in1 = rs.convention(Rs_in1)
        self.Rs_in2 = rs.convention(Rs_in2)
        self.Rs_out = rs.convention(Rs_out)

        code = ""

        index_w = 0
        wigners = set()
        count = [0 for _ in range(rs.dim(self.Rs_out))]

        index_1 = 0
        for mul_1, l_1, p_1 in self.Rs_in1:
            dim_1 = mul_1 * (2 * l_1 + 1)

            index_2 = 0
            for mul_2, l_2, p_2 in self.Rs_in2:
                dim_2 = mul_2 * (2 * l_2 + 1)

                gmul_1s = [
                    mul_1 // groups + (g < mul_1 % groups)
                    for g in range(groups)
                ]
                gmul_2s = [
                    mul_2 // groups + (g < mul_2 % groups)
                    for g in range(groups)
                ]

                for g in range(groups):
                    if gmul_1s[g] * gmul_2s[g] == 0:
                        continue

                    code += f"    s1 = x1[:, {index_1+sum(gmul_1s[:g])*(2*l_1+1)}:{index_1+sum(gmul_1s[:g+1])*(2*l_1+1)}].reshape(batch, {gmul_1s[g]}, {2 * l_1 + 1})\n"
                    code += f"    s2 = x2[:, {index_2+sum(gmul_2s[:g])*(2*l_2+1)}:{index_2+sum(gmul_2s[:g+1])*(2*l_2+1)}].reshape(batch, {gmul_2s[g]}, {2 * l_2 + 1})\n"
                    code += f"    ss = ein('zui,zvj->zuvij', s1, s2)\n"

                    index_out = 0
                    for mul_out, l_out, p_out in self.Rs_out:
                        dim_out = mul_out * (2 * l_out + 1)

                        if l_out in selection_rule(l_1, p_1, l_2,
                                                   p_2) and p_out == p_1 * p_2:
                            wigners.add((l_out, l_1, l_2))

                            gmul_outs = [
                                mul_out // groups + (g < mul_out % groups)
                                for g in range(groups)
                            ]
                            dim_w = gmul_outs[g] * gmul_1s[g] * gmul_2s[g]

                            if gmul_outs[g] == 0:
                                continue

                            code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {gmul_outs[g]}, {gmul_1s[g]}, {gmul_2s[g]})\n"
                            i = index_out + sum(
                                gmul_outs[:g]) * (2 * l_out + 1)
                            j = index_out + sum(
                                gmul_outs[:g + 1]) * (2 * l_out + 1)
                            code += f"    out[:, {i}:{j}] += ein('zwuv,kij,zuvij->zwk', sw, C{l_out}_{l_1}_{l_2}, ss).reshape(batch, {gmul_outs[g]*(2*l_out+1)})\n"
                            code += "\n"

                            for k in range(i, j):
                                count[k] += gmul_1s[g] * gmul_2s[g]

                            index_w += dim_w

                        index_out += dim_out

                index_2 += dim_2
            index_1 += dim_1

        ilast = 0
        clast = count[0]
        for i, c in enumerate(count):
            if clast != c:
                if clast > 1:
                    code += f"    out[:, {ilast}:{i}].div_({clast ** 0.5})\n"
                clast = c
                ilast = i
        if clast > 1:
            code += f"    out[:, {ilast}:].div_({clast ** 0.5})\n"

        wigners = sorted(wigners)
        self.wigners_names = [
            f"C{l_out}_{l_1}_{l_2}" for l_out, l_1, l_2 in wigners
        ]
        args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names)

        for arg, (l_out, l_1, l_2) in zip(self.wigners_names, wigners):
            C = o3.wigner_3j(l_out, l_1, l_2)

            if normalization == 'component':
                C *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(arg, C)

        x = _tensor_product_code
        x = x.replace("DIM", f"{rs.dim(self.Rs_out)}")
        x = x.replace("ARGS", args)
        x = x.replace("CODE", code)

        self.main = eval_code(x).main
        self.nweight = index_w
Beispiel #25
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--l_in", type=int, required=True)
    parser.add_argument("--l_out", type=int, required=True)
    parser.add_argument("--n",
                        type=int,
                        default=30,
                        help="size of the SOFT grid")
    parser.add_argument("--dpi", type=float, default=100)
    parser.add_argument("--sep",
                        type=float,
                        default=0.5,
                        help="space between matrices")

    args = parser.parse_args()

    torch.set_default_dtype(torch.float64)
    x, y, z, alpha, beta = spherical_surface(args.n)

    out = []
    for l in range(abs(args.l_out - args.l_in), args.l_out + args.l_in + 1):
        C = o3.wigner_3j(args.l_out, args.l_in, l)
        Y = rsh.spherical_harmonics(l, alpha, beta)
        out.append(torch.einsum("ijk,k...->ij...", (C, Y)))
    f = torch.stack(out)

    nf, dim_out, dim_in, *_ = f.size()

    f = 0.5 + 0.5 * f / f.abs().max()

    fig = plt.figure(figsize=(nf * dim_in + (nf - 1) * args.sep, dim_out),
                     dpi=args.dpi)

    for index in range(nf):
        for i in range(dim_out):
            for j in range(dim_in):
                width = 1 / (nf * dim_in + (nf - 1) * args.sep)
                height = 1 / dim_out
                rect = [(index * (dim_in + args.sep) + j) * width,
                        (dim_out - i - 1) * height, width, height]
                ax = fig.add_axes(rect, projection='3d')

                fc = plt.get_cmap("bwr")(f[index, i, j].detach().cpu().numpy())

                ax.plot_surface(x.numpy(),
                                y.numpy(),
                                z.numpy(),
                                rstride=1,
                                cstride=1,
                                facecolors=fc)
                ax.set_axis_off()

                a = 0.6
                ax.set_xlim3d(-a, a)
                ax.set_ylim3d(-a, a)
                ax.set_zlim3d(-a, a)

                ax.view_init(90, 0)

    plt.savefig("kernels{}to{}.png".format(args.l_in, args.l_out),
                transparent=True)
Beispiel #26
0
def kernel_conv_fn_forward(F, edge_index, Y, R, norm_coef, Rs_in, Rs_out, selection_rule, set_of_l_filters):
    """
    :param F: tensor [b, l_in * mul_in * m_in]
    :param Y: tensor [n_edges, l_filter * m_filter]
    :param R: tensor [n_edges, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in]
    :return: tensor [a, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    n_edges = Y.shape[-2]
    n_atoms = F.shape[-2]
    n_out = rs.dim(Rs_out)

    kernel_conv = Y.new_zeros(n_atoms, n_out)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = selection_rule(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[..., begin_R: begin_R + n].reshape(
                n_edges, mul_out, mul_in, -1
            )  # [n_edges, mul_out, mul_in, l_filter]
            begin_R += n

            K = 0
            for k, l_filter in enumerate(l_filters):
                offset = sum(2 * l + 1 for l in set_of_l_filters if l < l_filter)
                sub_Y = Y[..., offset: offset + 2 * l_filter + 1]  # [n_edges, m]

                C = o3.wigner_3j(l_out, l_in, l_filter, cached=True, like=kernel_conv)  # [m_out, m_in, m]

                # i - tensor product index for output
                # j - tensor product index for feature (SUMMED)
                # k - tensor product index for edge spherical harmonic Y (SUMMED)
                # u - multiplicity output index
                # v - multiplicity input index (SUMMED)
                # a - atom ~ edge[0]
                # b - atom ~ edge[1] (SUMMED SPARSELY)
                
                EF = F[edge_index[1], s_in].reshape(n_edges, mul_in, -1) # [num_edges, mul_in, J]
                D = norm_coef[i, j] * torch.einsum("ijk,ek,euv,evj->eui",
                        C, sub_Y, sub_R[..., k], EF) # [num_edges, mul_out, I]
                K += scatter_add(D, edge_index[0], 0, n_atoms) # [n_atoms, mul_out, I]

            if not isinstance(K, int):
                kernel_conv[..., s_out] += K.reshape(n_atoms, -1)

    return kernel_conv