Example #1
0
    def __init__(self, repr_in, repr_out, radii, activation=(None, None)):
        super().__init__()

        if type(activation) is tuple:
            scalar_activation, gate_activation = activation
        else:
            scalar_activation, gate_activation = activation, activation

        self.repr_out = repr_out

        Rs_in = [(m, l) for l, m in enumerate(repr_in)]
        Rs_out_with_gate = [(m, l) for l, m in enumerate(repr_out)]

        if scalar_activation is not None and repr_out[0] > 0:
            self.scalar_act = ScalarActivation(
                [(repr_out[0], scalar_activation)], bias=False)
        else:
            self.scalar_act = None

        num_non_scalar = sum(repr_out[1:])
        if gate_activation is not None and num_non_scalar > 0:
            Rs_out_with_gate.append((num_non_scalar, 0))
            self.gate_act = ScalarActivation(
                [(num_non_scalar, gate_activation)], bias=False)
        else:
            self.gate_act = None

        with torch_default_dtype(torch.float64):
            self.conv = SE3PointConvolution(Rs_in,
                                            Rs_out_with_gate,
                                            radii=radii)
Example #2
0
def _sample_sh_cube(size, J, version=3):  # pylint: disable=W0613
    '''
    Sample spherical harmonics in a cube.
    No bandlimiting considered, aliased regions need to be cut by windowing!
    :param size: side length of the kernel
    :param J: order of the spherical harmonics
    '''
    with torch_default_dtype(torch.float64):
        rng = torch.linspace(-((size - 1) / 2), (size - 1) / 2, steps=size)

        Y_J = torch.zeros(2 * J + 1, size, size, size, dtype=torch.float64)
        for idx_x, x in enumerate(rng):
            for idx_y, y in enumerate(rng):
                for idx_z, z in enumerate(rng):
                    if x == y == z == 0:  # angles at origin are nan, special treatment
                        if J == 0:  # Y^0 is angularly independent, choose any angle
                            Y_J[:, idx_x, idx_y,
                                idx_z] = spherical_harmonics(0, 123,
                                                             321)  # [m]
                        else:  # insert zeros for Y^J with J!=0
                            Y_J[:, idx_x, idx_y, idx_z] = 0
                    else:  # not at the origin, sample spherical harmonic
                        alpha, beta = x_to_alpha_beta([x, y, z])
                        Y_J[:, idx_x, idx_y,
                            idx_z] = spherical_harmonics(J, alpha, beta)  # [m]

    assert Y_J.dtype == torch.float64
    return Y_J  # [m, x, y, z]
Example #3
0
def _basis_transformation_Q_J(J, order_in, order_out, version=3):  # pylint: disable=W0613
    """
    :param J: order of the spherical harmonics
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: one part of the Q^-1 matrix of the article
    """
    with torch_default_dtype(torch.float64):
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):
            ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
            R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
            R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
            return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
                kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]

        random_angles = [
            [4.41301023, 5.56684102, 4.59384642],
            [4.93325116, 6.12697327, 4.14574096],
            [0.53878964, 4.09050444, 5.36539036],
            [2.16017393, 3.48835314, 5.55174441],
            [2.52385107, 0.2908958, 3.90040975]
        ]
        null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
        assert null_space.size(0) == 1, null_space.size()  # unique subspace solution
        Q_J = null_space[0]  # [(m_out * m_in) * m]
        Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1)  # [m_out * m_in, m]
        assert all(torch.allclose(
            _R_tensor(a.item(), b.item(), c.item()) @ Q_J,
            Q_J @ irr_repr(J, a.item(), b.item(), c.item())) for a, b, c in torch.rand(4, 3)
        )

    assert Q_J.dtype == torch.float64
    return Q_J  # [m_out * m_in, m]
Example #4
0
def _test_basis_equivariance():
    from functools import partial
    with torch_default_dtype(torch.float64):
        basis = cube_basis_kernels(
            4 * 5, 2, 2,
            partial(gaussian_window, radii=[5], J_max_list=[999], sigma=2))
        overlaps = check_basis_equivariance(basis, 2, 2, *torch.rand(3))
        assert overlaps.gt(0.98).all(), overlaps