Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()

    # required
    parser.add_argument("--order_in", type=int, required=True)
    parser.add_argument("--order_out", type=int, required=True)
    parser.add_argument("--n",
                        type=int,
                        default=50,
                        help="size of the SOFT grid")
    parser.add_argument("--scale",
                        type=float,
                        default=1.5,
                        help="plot size of a sphere")
    parser.add_argument("--sep",
                        type=float,
                        default=1,
                        help="plot separation size")
    parser.add_argument("--alpha", type=float, default=0)
    parser.add_argument("--beta", type=float, default=0)
    parser.add_argument("--gamma", type=float, default=0)

    args = parser.parse_args()

    f = _sample_sh_sphere(args.n, args.order_in, args.order_out, args.alpha,
                          args.beta, args.gamma)
    # f(r^-1 x)

    f = np.einsum(
        "ij,zjkba,kl->zilba",
        irr_repr(args.order_out, args.alpha, args.beta, args.gamma), f,
        irr_repr(args.order_in, -args.gamma, -args.beta, -args.alpha))
    # rho_out(r) f(r^-1 x) rho_in(r^-1)

    beta, alpha = beta_alpha(args.n)
    alpha = alpha - np.pi / (2 * args.n)

    nbase = f.shape[0]
    dim_out = f.shape[1]
    dim_in = f.shape[2]

    f = (f - np.min(f)) / (np.max(f) - np.min(f))

    fig = plt.figure(figsize=(args.scale * (nbase * dim_in +
                                            (nbase - 1) * args.sep),
                              args.scale * dim_out))

    for base in range(nbase):
        for i in range(dim_out):
            for j in range(dim_in):
                width = 1 / (nbase * dim_in + (nbase - 1) * args.sep)
                height = 1 / dim_out
                rect = [(base * (dim_in + args.sep) + j) * width,
                        (dim_out - i - 1) * height, width, height]
                ax = fig.add_axes(rect, projection='3d', aspect=1)
                ax.patch.set_visible(False)
                plot_sphere(beta, alpha, f[base, i, j])

    plt.savefig("kernels{}{}.png".format(args.order_in, args.order_out),
                transparent=True)
Exemplo n.º 2
0
def check_basis_equivariance(basis, order_in, order_out, alpha, beta, gamma):
    from se3cnn import SO3
    from scipy.ndimage import affine_transform
    import numpy as np

    n = basis.size(0)
    dim_in = 2 * order_in + 1
    dim_out = 2 * order_out + 1
    size = basis.size(-1)
    assert basis.size() == (n, dim_out, dim_in, size, size, size), basis.size()

    basis = basis / basis.view(n, -1).norm(dim=1).view(-1, 1, 1, 1, 1, 1)

    x = basis.view(-1, size, size, size)
    y = torch.empty_like(x)

    invrot = SO3.rot(-gamma, -beta, -alpha).numpy()
    center = (np.array(x.size()[1:]) - 1) / 2

    for k in range(y.size(0)):
        y[k] = torch.tensor(
            affine_transform(x[k].numpy(),
                             matrix=invrot,
                             offset=center - np.dot(invrot, center)))

    y = y.view(*basis.size())

    y = torch.einsum(
        "ij,bjkxyz,kl->bilxyz",
        (irr_repr(order_out, alpha, beta, gamma, dtype=y.dtype), y,
         irr_repr(order_in, -gamma, -beta, -alpha, dtype=y.dtype)))

    return torch.tensor([(basis[i] * y[i]).sum() for i in range(n)])
Exemplo n.º 3
0
def cube_basis_kernels(size, order_in, order_out, radial_window):
    '''
    Generate equivariant kernel basis mapping between capsules transforming under order_in and order_out
    :param size: side length of the filter kernel
    :param order_in: input representation order
    :param order_out: output representation order
    :param radial_window: callable for windowing out radial parts, taking mandatory parameters 'solutions', 'r_field' and 'order_irreps'
    :return: basis of equivariant kernels of shape (N_basis, 2 * order_out + 1, 2 * order_in + 1, size, size, size)
    '''
    basis = radial_window(*_sample_cube(size, order_in, order_out))
    if basis is None:
        return None

    # check that  rho_out(u) K(u^-1 x) rho_in(u^-1) = K(x) with u = rotation of +pi/2 around y axis
    tmp = basis.transpose(3, 5).flip(5)  # K(u^-1 x)
    tmp = torch.einsum(
        "ij,bjkxyz,kl->bilxyz",
        (
            irr_repr(order_out, 0, math.pi / 2, 0, dtype=basis.dtype),
            tmp,
            irr_repr(order_in, 0, -math.pi / 2, 0, dtype=basis.dtype)
        )
    )  # rho_out(u) K(u^-1 x) rho_in(u^-1)
    assert torch.allclose(tmp, basis)

    return basis
Exemplo n.º 4
0
def _sample_cube(size, order_in, order_out):
    '''
    :param size: side length of the kernel
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: sampled equivariant kernel basis of shape (N_basis, 2*order_out+1, 2*order_in+1, size, size, size)
    '''

    rng = torch.linspace(-((size - 1) / 2), (size - 1) / 2, steps=size, dtype=torch.float64)

    order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1))
    solutions = []
    for J in order_irreps:
        Y_J = _sample_sh_cube(size, J)  # [m, x, y, z]

        # compute basis transformation matrix Q_J
        Q_J = _basis_transformation_Q_J(J, order_in, order_out)  # [m_out * m_in, m]
        K_J = torch.einsum('mn,nxyz->mxyz', (Q_J, Y_J))  # [m_out * m_in, x, y, z]
        K_J = K_J.view(2 * order_out + 1, 2 * order_in + 1, size, size, size)  # [m_out, m_in, x, y, z]
        solutions.append(K_J)

        # check that  rho_out(u) K(u^-1 x) rho_in(u^-1) = K(x) with u = rotation of +pi/2 around y axis
        tmp = K_J.transpose(2, 4).flip(4)  # K(u^-1 x)
        tmp = torch.einsum(
            "ij,jkxyz,kl->ilxyz",
            (
                irr_repr(order_out, 0, math.pi / 2, 0, dtype=K_J.dtype),
                tmp,
                irr_repr(order_in, 0, -math.pi / 2, 0, dtype=K_J.dtype)
            )
        )  # rho_out(u) K(u^-1 x) rho_in(u^-1)
        assert torch.allclose(tmp, K_J)

    r_field = (rng.view(-1, 1, 1).pow(2) + rng.view(1, -1, 1).pow(2) + rng.view(1, 1, -1).pow(2)).sqrt()  # [x, y, z]
    return solutions, r_field, order_irreps
Exemplo n.º 5
0
def _basis_transformation_Q_J(J, order_in, order_out, version=3):  # pylint: disable=W0613
    """
    :param J: order of the spherical harmonics
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: one part of the Q^-1 matrix of the article
    """
    with torch_default_dtype(torch.float64):
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):
            ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
            R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
            R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
            return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
                kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]

        random_angles = [
            [4.41301023, 5.56684102, 4.59384642],
            [4.93325116, 6.12697327, 4.14574096],
            [0.53878964, 4.09050444, 5.36539036],
            [2.16017393, 3.48835314, 5.55174441],
            [2.52385107, 0.2908958, 3.90040975]
        ]
        null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
        assert null_space.size(0) == 1, null_space.size()  # unique subspace solution
        Q_J = null_space[0]  # [(m_out * m_in) * m]
        Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1)  # [m_out * m_in, m]
        assert all(torch.allclose(
            _R_tensor(a.item(), b.item(), c.item()) @ Q_J,
            Q_J @ irr_repr(J, a.item(), b.item(), c.item())) for a, b, c in torch.rand(4, 3)
        )

    assert Q_J.dtype == torch.float64
    return Q_J  # [m_out * m_in, m]
Exemplo n.º 6
0
 def transform(x):
     convert = Obj2OrientedVoxel(32)
     x, abc = convert(x)
     x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0)
     rot = irr_repr(1, *abc) @ xyz_vector_basis_to_spherical_basis()
     top = rot @ torch.tensor([0, 0, 1.])
     front = rot @ torch.tensor([0, 1., 0])
     return x, (top, front)
Exemplo n.º 7
0
def rotate(x, alpha, beta, gamma):
    t = time_logging.start()
    y = x.cpu().detach().numpy()
    R = rot(alpha, beta, gamma)
    if x.ndimension() == 4:
        for i in range(y.shape[0]):
            y[i] = rotate_scalar(y[i], R)
    else:
        y = rotate_scalar(y, R)
    x = x.new_tensor(y)
    if x.ndimension() == 4 and x.size(0) == 3:
        rep = irr_repr(1, alpha, beta, gamma, x.dtype).to(x.device)
        x = torch.einsum("ij,jxyz->ixyz", (rep, x))
    time_logging.end("rotate", t)
    return x
Exemplo n.º 8
0
 def _sylvester_submatrix(J, a, b, c):
     ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
     R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
     R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
     return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
         kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]
Exemplo n.º 9
0
 def _R_tensor(a, b, c):
     return kron(irr_repr(order_out, a, b, c),
                 irr_repr(order_in, a, b, c))
Exemplo n.º 10
0
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):