Пример #1
0
def tensor3x3_repr_basis_to_spherical_basis():
    """
    to convert a 3x3 tensor transforming with tensor3x3_repr(a, b, c)
    into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c)
    see assert for usage
    """
    with torch_default_dtype(torch.float64):
        to1 = torch.tensor([
            [1, 0, 0, 0, 1, 0, 0, 0, 1],
        ], dtype=torch.get_default_dtype())
        assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))

        to3 = torch.tensor([
            [0, 0, -1, 0, 0, 0, 1, 0, 0],
            [0, 1, 0, -1, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, -1, 0],
        ], dtype=torch.get_default_dtype())
        assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))

        to5 = torch.tensor([
            [0, 1, 0, 1, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, 1, 0],
            [-3**.5/3, 0, 0, 0, -3**.5/3, 0, 0, 0, 12**.5/3],
            [0, 0, 1, 0, 0, 0, 1, 0, 0],
            [1, 0, 0, 0, -1, 0, 0, 0, 0]
        ], dtype=torch.get_default_dtype())
        assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3))

    return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype())
Пример #2
0
def basis_transformation_Q_J(J, order_in, order_out, version=3):  # pylint: disable=W0613
    """
    :param J: order of the spherical harmonics
    :param order_in: order of the input representation
    :param order_out: order of the output representation
    :return: one part of the Q^-1 matrix of the article
    """
    with torch_default_dtype(torch.float64):
        def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c))

        def _sylvester_submatrix(J, a, b, c):
            ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J '''
            R_tensor = _R_tensor(a, b, c)  # [m_out * m_in, m_out * m_in]
            R_irrep_J = irr_repr(J, a, b, c)  # [m, m]
            return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \
                kron(torch.eye(R_tensor.size(0)), R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]

        random_angles = [
            [4.41301023, 5.56684102, 4.59384642],
            [4.93325116, 6.12697327, 4.14574096],
            [0.53878964, 4.09050444, 5.36539036],
            [2.16017393, 3.48835314, 5.55174441],
            [2.52385107, 0.2908958, 3.90040975]
        ]
        null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles])
        assert null_space.size(0) == 1, null_space.size()  # unique subspace solution
        Q_J = null_space[0]  # [(m_out * m_in) * m]
        Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1)  # [m_out * m_in, m]
        assert all(torch.allclose(
            _R_tensor(a.item(), b.item(), c.item()) @ Q_J,
            Q_J @ irr_repr(J, a.item(), b.item(), c.item())) for a, b, c in torch.rand(4, 3)
        )

    assert Q_J.dtype == torch.float64
    return Q_J  # [m_out * m_in, m]
Пример #3
0
    def __init__(self,
                 num_classes,
                 num_radial=size // 2 + 1,
                 max_radius=size // 2):
        super(SE3Net, self).__init__()

        features = [(1, ), (2, 2, 2, 1), (4, 4, 4, 4), (6, 4, 4, 0), (64, )]
        self.num_features = len(features)

        kwargs = {
            'radii':
            torch.linspace(0,
                           max_radius,
                           steps=num_radial,
                           dtype=torch.float64),
        }

        self.layers = torch.nn.ModuleList([])
        for i in range(len(features) - 1):
            Rs_in = list(zip(features[i], range(len(features[i]))))
            Rs_out = list(zip(features[i + 1], range(len(features[i + 1]))))
            self.layers.append(SE3PointConvolution(Rs_in, Rs_out, **kwargs))
        with torch_default_dtype(torch.float64):
            self.layers.extend([
                AvgSpacial(),
                torch.nn.Dropout(p=0.2),
                torch.nn.Linear(64, num_classes)
            ])
Пример #4
0
def _test_basis_equivariance():
    from functools import partial
    with torch_default_dtype(torch.float64):
        basis = cube_basis_kernels(
            4 * 5, 2, 2,
            partial(gaussian_window, radii=[5], J_max_list=[999], sigma=2))
        overlaps = check_basis_equivariance(basis, 2, 2, *torch.rand(3))
        assert overlaps.gt(0.98).all(), overlaps
Пример #5
0
def xyz_vector_basis_to_spherical_basis():
    """
    to convert a vector [x, y, z] transforming with rot(a, b, c)
    into a vector transforming with irr_repr(1, a, b, c)
    see assert for usage
    """
    with torch_default_dtype(torch.float64):
        A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64)
        assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3))
    return A.type(torch.get_default_dtype())
Пример #6
0
    def __init__(self, num_classes, num_radial=size // 2 + 1, max_radius=size // 2):
        super(SE3Net, self).__init__()

        features = [(1,), (2, 2, 2, 1), (4, 4, 4, 4), (6, 4, 4, 0), (64,)]
        self.num_features = len(features)

        kwargs = {
            'radii': torch.linspace(0, max_radius, steps=num_radial, dtype=torch.float64),
            'activation': (torch.nn.functional.relu, torch.sigmoid) }

        self.layers = torch.nn.ModuleList([PointGatedBlock(features[i], features[i+1], **kwargs) for i in range(len(features) - 1)])
        with torch_default_dtype(torch.float64):
            self.layers.extend([AvgSpacial(), torch.nn.Dropout(p=0.2), torch.nn.Linear(64, num_classes)])
Пример #7
0
def _test_change_basis_wigner_to_rot():
    from lie_learn.representations.SO3.wigner_d import wigner_D_matrix

    with torch_default_dtype(torch.float64):
        A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]],
                         dtype=torch.float64)

        a, b, c = torch.rand(3)

        r1 = A.t() @ torch.tensor(wigner_D_matrix(1, a, b, c),
                                  dtype=torch.float64) @ A
        r2 = rot(a, b, c)

        d = (r1 - r2).abs().max()
        print(d.item())
        assert d < 1e-10
Пример #8
0
def test_is_representation(rep):
    """
    rep(Z(a1) Y(b1) Z(c1) Z(a2) Y(b2) Z(c2)) = rep(Z(a1) Y(b1) Z(c1)) rep(Z(a2) Y(b2) Z(c2))
    """
    with torch_default_dtype(torch.float64):
        a1, b1, c1, a2, b2, c2 = torch.rand(6)

        r1 = rep(a1, b1, c1)
        r2 = rep(a2, b2, c2)

        a, b, c = compose(a1, b1, c1, a2, b2, c2)
        r = rep(a, b, c)

        r_ = r1 @ r2

        d, r = (r - r_).abs().max(), r.abs().max()
        print(d.item(), r.item())
        assert d < 1e-10 * r, d / r
Пример #9
0
def spherical_harmonics_xyz(order, xyz):
    """
    spherical harmonics

    :param order: int or list
    :param xyz: tensor of shape [A, 3]
    :return: tensor of shape [m, A]
    """
    if not isinstance(order, list):
        order = [order]

    with torch_default_dtype(torch.float64):
        alpha, beta = x_to_alpha_beta(xyz)  # two tensors of shape [A]
        out = spherical_harmonics(order, alpha, beta)  # [m, A]

        # fix values when xyz = 0
        if (xyz.norm(2, -1) == 0).nonzero().numel() > 0:  # this `if` is not needed with version 1.0 of pytorch
            val = torch.cat([spherical_harmonics(0, 123, 321) if J == 0 else torch.zeros(2 * J + 1) for J in order])  # [m]
            out[:, xyz.norm(2, -1) == 0] = val.view(-1, 1)
        return out
Пример #10
0
def _test_spherical_harmonics(order):
    """
    This test tests that
    - irr_repr
    - compose
    - spherical_harmonics
    are compatible

    Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
    with x = Z(a) Y(b) eta
    """
    with torch_default_dtype(torch.float64):
        a, b = torch.rand(2)
        alpha, beta, gamma = torch.rand(3)

        ra, rb, _ = compose(alpha, beta, gamma, a, b, 0)
        Yrx = spherical_harmonics(order, ra, rb)

        Y = spherical_harmonics(order, a, b)
        DrY = irr_repr(order, alpha, beta, gamma) @ Y

        d, r = (Yrx - DrY).abs().max(), Y.abs().max()
        print(d.item(), r.item())
        assert d < 1e-10 * r, d / r