예제 #1
0
def elementwise_tensor_product(Rs_1, Rs_2, get_l_output=o3.selection_rule):
    """
    :return: Rs_out, matrix

    m_kij A_i B_j
    """
    Rs_1 = simplify(Rs_1)
    Rs_2 = simplify(Rs_2)

    assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2)

    i = 0
    while i < len(Rs_1):
        mul_1, l_1, p_1 = Rs_1[i]
        mul_2, l_2, p_2 = Rs_2[i]

        if mul_1 < mul_2:
            Rs_2[i] = (mul_1, l_2, p_2)
            Rs_2.insert(i + 1, (mul_2 - mul_1, l_2, p_2))

        if mul_2 < mul_1:
            Rs_1[i] = (mul_2, l_1, p_1)
            Rs_1.insert(i + 1, (mul_1 - mul_2, l_1, p_1))
        i += 1

    Rs_out = []
    for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2):
        assert mul == mul_2
        for l in get_l_output(l_1, l_2):
            Rs_out.append((mul, l, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    clebsch_gordan_tensor = torch.zeros(dim(Rs_out), dim(Rs_1), dim(Rs_2))

    index_out = 0
    index_1 = 0
    index_2 = 0
    for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2):
        assert mul == mul_2
        dim_1 = mul * (2 * l_1 + 1)
        dim_2 = mul * (2 * l_2 + 1)

        for l in get_l_output(l_1, l_2):
            dim_out = mul * (2 * l + 1)
            C = o3.clebsch_gordan(l, l_1, l_2, cached=True) * (2 * l + 1)**0.5
            I = torch.einsum("uv,wu->wuv", torch.eye(mul), torch.eye(mul))
            m = torch.einsum("wuv,kij->wkuivj", I,
                             C).view(dim_out, dim_1, dim_2)
            clebsch_gordan_tensor[index_out:index_out + dim_out,
                                  index_1:index_1 + dim_1,
                                  index_2:index_2 + dim_2] = m
            index_out += dim_out

        index_1 += dim_1
        index_2 += dim_2

    return Rs_out, clebsch_gordan_tensor
예제 #2
0
def tensor_product(Rs_1, Rs_2, get_l_output=o3.selection_rule, paths=False):
    """
    Compute the orthonormal change of basis Q
    from Rs_out to Rs_1 tensor product with Rs_2
    where Rs_out is a direct sum of irreducible representations

    :return: Rs_out, Q

    example:
    _, Q = tensor_product(Rs1, Rs2)
    torch.einsum('kij,i,j->k', Q, A, B)
    """
    Rs_1 = simplify(Rs_1)
    Rs_2 = simplify(Rs_2)

    Rs_out = []

    if paths:
        path_list = []
    for mul_1, l_1, p_1 in Rs_1:
        for mul_2, l_2, p_2 in Rs_2:
            for l in get_l_output(l_1, l_2):
                if paths:
                    path_list.extend([[l_1, l_2, l]] * (mul_1 * mul_2))
                Rs_out.append((mul_1 * mul_2, l, p_1 * p_2))

    Rs_out = simplify(Rs_out)

    clebsch_gordan_tensor = torch.zeros(dim(Rs_out), dim(Rs_1), dim(Rs_2))

    index_out = 0

    index_1 = 0
    for mul_1, l_1, _p_1 in Rs_1:
        dim_1 = mul_1 * (2 * l_1 + 1)

        index_2 = 0
        for mul_2, l_2, _p_2 in Rs_2:
            dim_2 = mul_2 * (2 * l_2 + 1)
            for l in get_l_output(l_1, l_2):
                dim_out = mul_1 * mul_2 * (2 * l + 1)
                C = o3.clebsch_gordan(l, l_1, l_2,
                                      cached=True) * (2 * l + 1)**0.5
                I = torch.eye(mul_1 * mul_2).view(mul_1 * mul_2, mul_1, mul_2)
                m = torch.einsum("wuv,kij->wkuivj", I,
                                 C).view(dim_out, dim_1, dim_2)
                clebsch_gordan_tensor[index_out:index_out + dim_out,
                                      index_1:index_1 + dim_1,
                                      index_2:index_2 + dim_2] = m
                index_out += dim_out

            index_2 += dim_2
        index_1 += dim_1

    if paths:
        return Rs_out, clebsch_gordan_tensor, path_list
    return Rs_out, clebsch_gordan_tensor
예제 #3
0
 def test_clebsch_gordan_orthogonal(self):
     with o3.torch_default_dtype(torch.float64):
         for l_out in range(6):
             for l_in in range(6):
                 for l_f in range(abs(l_out - l_in), l_out + l_in + 1):
                     Q = o3.clebsch_gordan(l_f, l_in,
                                           l_out).view(2 * l_f + 1, -1)
                     e = (2 * l_f + 1) * Q @ Q.t()
                     d = e - torch.eye(2 * l_f + 1)
                     self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
예제 #4
0
 def test_clebsch_gordan_sh_norm(self):
     with o3.torch_default_dtype(torch.float64):
         for l_out in range(6):
             for l_in in range(6):
                 for l_f in range(abs(l_out - l_in), l_out + l_in + 1):
                     Q = o3.clebsch_gordan(l_out, l_in, l_f)
                     Y = o3.spherical_harmonics_xyz(l_f, torch.randn(
                         1, 3)).view(2 * l_f + 1)
                     QY = math.sqrt(4 * math.pi) * Q @ Y
                     self.assertLess(abs(QY.norm() - 1), 1e-10)
예제 #5
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--l_in", type=int, required=True)
    parser.add_argument("--l_out", type=int, required=True)
    parser.add_argument("--n", type=int, default=30, help="size of the SOFT grid")
    parser.add_argument("--dpi", type=float, default=100)
    parser.add_argument("--sep", type=float, default=0.5, help="space between matrices")

    args = parser.parse_args()

    torch.set_default_dtype(torch.float64)
    x, y, z, alpha, beta = spherical_surface(args.n)

    out = []
    for l in range(abs(args.l_out - args.l_in), args.l_out + args.l_in + 1):
        C = o3.clebsch_gordan(args.l_out, args.l_in, l)
        Y = o3.spherical_harmonics(l, alpha, beta)
        out.append(torch.einsum("ijk,k...->ij...", (C, Y)))
    f = torch.stack(out)

    nf, dim_out, dim_in, *_ = f.size()

    f = 0.5 + 0.5 * f / f.abs().max()

    fig = plt.figure(figsize=(nf * dim_in + (nf - 1) * args.sep, dim_out), dpi=args.dpi)

    for index in range(nf):
        for i in range(dim_out):
            for j in range(dim_in):
                width = 1 / (nf * dim_in + (nf - 1) * args.sep)
                height = 1 / dim_out
                rect = [
                    (index * (dim_in + args.sep) + j) * width,
                    (dim_out - i - 1) * height,
                    width,
                    height
                ]
                ax = fig.add_axes(rect, projection='3d')

                fc = plt.get_cmap("bwr")(f[index, i, j].detach().cpu().numpy())

                ax.plot_surface(x.numpy(), y.numpy(), z.numpy(), rstride=1, cstride=1, facecolors=fc)
                ax.set_axis_off()

                a = 0.6
                ax.set_xlim3d(-a, a)
                ax.set_ylim3d(-a, a)
                ax.set_zlim3d(-a, a)

                ax.view_init(90, 0)

    plt.savefig("kernels{}to{}.png".format(args.l_in, args.l_out), transparent=True)
예제 #6
0
    def test1(self):
        """Test irr_repr and clebsch_gordan equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = o3.clebsch_gordan(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = o3.irr_repr(l_in, *abc)
                D_out = o3.irr_repr(l_out, *abc)

                Y = o3.spherical_harmonics_xyz(l_f, r @ o3.rot(*abc).t())
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = o3.spherical_harmonics_xyz(l_f, r)
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
예제 #7
0
def kernel_conv_fn_forward(F, Y, R, norm_coef, Rs_in, Rs_out, get_l_filters,
                           set_of_l_filters):
    """
    :param F: tensor [batch, b, l_in * mul_in * m_in]
    :param Y: tensor [l_filter * m_filter, batch, a, b]
    :param R: tensor [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]
    :param norm_coef: tensor [l_out, l_in, batch, a, b]
    :return: tensor [batch, a, l_out * mul_out * m_out, l_in * mul_in * m_in]
    """
    batch, a, b = Y.shape[1:]
    n_in = rs.dim(Rs_in)
    n_out = rs.dim(Rs_out)

    kernel_conv = Y.new_zeros(batch, a, n_out)

    # note: for the normalization we assume that the variance of R[i] is one
    begin_R = 0

    begin_out = 0
    for i, (mul_out, l_out, p_out) in enumerate(Rs_out):
        s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
        begin_out += mul_out * (2 * l_out + 1)

        begin_in = 0
        for j, (mul_in, l_in, p_in) in enumerate(Rs_in):
            s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
            begin_in += mul_in * (2 * l_in + 1)

            l_filters = get_l_filters(l_in, p_in, l_out, p_out)
            if not l_filters:
                continue

            # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
            n = mul_out * mul_in * len(l_filters)
            sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view(
                batch, a, b, mul_out, mul_in,
                -1)  # [batch, a, b, mul_out, mul_in, l_filter]
            begin_R += n

            sub_norm_coef = norm_coef[i, j]  # [batch]

            K = 0
            for k, l_filter in enumerate(l_filters):
                offset = sum(2 * l + 1 for l in set_of_l_filters
                             if l < l_filter)
                sub_Y = Y[offset:offset + 2 * l_filter + 1,
                          ...]  # [m, batch, a, b]

                C = o3.clebsch_gordan(l_out,
                                      l_in,
                                      l_filter,
                                      cached=True,
                                      like=kernel_conv)  # [m_out, m_in, m]

                K += torch.einsum("ijk,kzab,zabuv,zab,zbvj->zaui", C, sub_Y,
                                  sub_R[...,
                                        k], sub_norm_coef, F[..., s_in].view(
                                            batch, b, mul_in,
                                            -1))  # [batch, a, mul_out, m_out]

            if K is not 0:
                kernel_conv[:, :, s_out] += K.view(batch, a, -1)

    return kernel_conv
예제 #8
0
    def backward(ctx, grad_kernel):
        F, Y, R, norm_coef = ctx.saved_tensors
        batch, a, b = ctx.batch, ctx.a, ctx.b

        grad_F = grad_Y = grad_R = None

        if ctx.needs_input_grad[0]:
            grad_F = grad_kernel.new_zeros(
                *ctx.F_shape)  # [batch, b, l_in * mul_in * m_in]
        if ctx.needs_input_grad[1]:
            grad_Y = grad_kernel.new_zeros(
                *ctx.Y_shape)  # [l_filter * m_filter, batch, a, b]
        if ctx.needs_input_grad[2]:
            grad_R = grad_kernel.new_zeros(
                *ctx.R_shape
            )  # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter]

        begin_R = 0

        begin_out = 0
        for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out):
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            begin_in = 0
            for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in):
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out)
                if not l_filters:
                    continue

                n = mul_out * mul_in * len(l_filters)
                if (grad_Y is not None) or (grad_F is not None):
                    sub_R = R[:, :, :, begin_R:begin_R + n].contiguous().view(
                        batch, a, b, mul_out, mul_in,
                        -1)  # [batch, a, b, mul_out, mul_in, l_filter]
                if grad_R is not None:
                    sub_grad_R = grad_R[:, :, :, begin_R:begin_R + n].contiguous(
                    ).view(batch, a, b, mul_out, mul_in,
                           -1)  # [batch, a, b, mul_out, mul_in, l_filter]

                if grad_F is not None:
                    sub_grad_F = grad_F[:, :, s_in].contiguous().view(
                        batch, b, mul_in,
                        2 * l_in + 1)  # [batch, b, mul_in, 2 * l_in + 1]
                if (grad_Y is not None) or (grad_R is not None):
                    sub_F = F[..., s_in].view(batch, b, mul_in, 2 * l_in + 1)

                grad_K = grad_kernel[:, :, s_out].view(batch, a, mul_out,
                                                       2 * l_out + 1)

                sub_norm_coef = norm_coef[i, j]  # [batch, a, b]

                for k, l_filter in enumerate(l_filters):
                    tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters
                              if l < l_filter)
                    C = o3.clebsch_gordan(l_out,
                                          l_in,
                                          l_filter,
                                          cached=True,
                                          like=grad_kernel)  # [m_out, m_in, m]

                    if (grad_F is not None) or (grad_R is not None):
                        sub_Y = Y[tmp:tmp + 2 * l_filter + 1,
                                  ...]  # [m, batch, a, b]

                    if grad_F is not None:
                        sub_grad_F += torch.einsum(
                            "zaui,ijk,kzab,zabuv,zab->zbvj", grad_K, C, sub_Y,
                            sub_R[..., k],
                            sub_norm_coef)  # [batch, b, mul_in, 2 * l_in + 1
                    if grad_Y is not None:
                        grad_Y[tmp:tmp + 2 * l_filter + 1,
                               ...] += torch.einsum(
                                   "zaui,ijk,zabuv,zab,zbvj->kzab", grad_K, C,
                                   sub_R[..., k], sub_norm_coef,
                                   sub_F)  # [m, batch, a, b]
                    if grad_R is not None:
                        sub_grad_R[..., k] = torch.einsum(
                            "zaui,ijk,kzab,zab,zbvj->zabuv", grad_K, C, sub_Y,
                            sub_norm_coef,
                            sub_F)  # [batch, a, b, mul_out, mul_in]
                if grad_F is not None:
                    grad_F[:, :,
                           s_in] = sub_grad_F.view(batch, b,
                                                   mul_in * (2 * l_in + 1))
                if grad_R is not None:
                    grad_R[..., begin_R:begin_R + n] += sub_grad_R.view(
                        batch, a, b, -1)
                begin_R += n

        return grad_F, grad_Y, grad_R, None, None, None, None, None