def tensor3x3_repr_basis_to_spherical_basis(): """ to convert a 3x3 tensor transforming with tensor3x3_repr(a, b, c) into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c) see assert for usage """ with torch_default_dtype(torch.float64): to1 = torch.tensor([ [1, 0, 0, 0, 1, 0, 0, 0, 1], ], dtype=torch.get_default_dtype()) assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) to3 = torch.tensor([ [0, 0, -1, 0, 0, 0, 1, 0, 0], [0, 1, 0, -1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, -1, 0], ], dtype=torch.get_default_dtype()) assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) to5 = torch.tensor([ [0, 1, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 1, 0], [-3**.5/3, 0, 0, 0, -3**.5/3, 0, 0, 0, 12**.5/3], [0, 0, 1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, -1, 0, 0, 0, 0] ], dtype=torch.get_default_dtype()) assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype())
def basis_transformation_Q_J(J, order_in, order_out, version=3): # pylint: disable=W0613 """ :param J: order of the spherical harmonics :param order_in: order of the input representation :param order_out: order of the output representation :return: one part of the Q^-1 matrix of the article """ with torch_default_dtype(torch.float64): def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)) def _sylvester_submatrix(J, a, b, c): ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J ''' R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in] R_irrep_J = irr_repr(J, a, b, c) # [m, m] return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \ kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m] random_angles = [ [4.41301023, 5.56684102, 4.59384642], [4.93325116, 6.12697327, 4.14574096], [0.53878964, 4.09050444, 5.36539036], [2.16017393, 3.48835314, 5.55174441], [2.52385107, 0.2908958, 3.90040975] ] null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles]) assert null_space.size(0) == 1, null_space.size() # unique subspace solution Q_J = null_space[0] # [(m_out * m_in) * m] Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1) # [m_out * m_in, m] assert all(torch.allclose( _R_tensor(a.item(), b.item(), c.item()) @ Q_J, Q_J @ irr_repr(J, a.item(), b.item(), c.item())) for a, b, c in torch.rand(4, 3) ) assert Q_J.dtype == torch.float64 return Q_J # [m_out * m_in, m]
def __init__(self, num_classes, num_radial=size // 2 + 1, max_radius=size // 2): super(SE3Net, self).__init__() features = [(1, ), (2, 2, 2, 1), (4, 4, 4, 4), (6, 4, 4, 0), (64, )] self.num_features = len(features) kwargs = { 'radii': torch.linspace(0, max_radius, steps=num_radial, dtype=torch.float64), } self.layers = torch.nn.ModuleList([]) for i in range(len(features) - 1): Rs_in = list(zip(features[i], range(len(features[i])))) Rs_out = list(zip(features[i + 1], range(len(features[i + 1])))) self.layers.append(SE3PointConvolution(Rs_in, Rs_out, **kwargs)) with torch_default_dtype(torch.float64): self.layers.extend([ AvgSpacial(), torch.nn.Dropout(p=0.2), torch.nn.Linear(64, num_classes) ])
def _test_basis_equivariance(): from functools import partial with torch_default_dtype(torch.float64): basis = cube_basis_kernels( 4 * 5, 2, 2, partial(gaussian_window, radii=[5], J_max_list=[999], sigma=2)) overlaps = check_basis_equivariance(basis, 2, 2, *torch.rand(3)) assert overlaps.gt(0.98).all(), overlaps
def xyz_vector_basis_to_spherical_basis(): """ to convert a vector [x, y, z] transforming with rot(a, b, c) into a vector transforming with irr_repr(1, a, b, c) see assert for usage """ with torch_default_dtype(torch.float64): A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64) assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3)) return A.type(torch.get_default_dtype())
def __init__(self, num_classes, num_radial=size // 2 + 1, max_radius=size // 2): super(SE3Net, self).__init__() features = [(1,), (2, 2, 2, 1), (4, 4, 4, 4), (6, 4, 4, 0), (64,)] self.num_features = len(features) kwargs = { 'radii': torch.linspace(0, max_radius, steps=num_radial, dtype=torch.float64), 'activation': (torch.nn.functional.relu, torch.sigmoid) } self.layers = torch.nn.ModuleList([PointGatedBlock(features[i], features[i+1], **kwargs) for i in range(len(features) - 1)]) with torch_default_dtype(torch.float64): self.layers.extend([AvgSpacial(), torch.nn.Dropout(p=0.2), torch.nn.Linear(64, num_classes)])
def _test_change_basis_wigner_to_rot(): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix with torch_default_dtype(torch.float64): A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64) a, b, c = torch.rand(3) r1 = A.t() @ torch.tensor(wigner_D_matrix(1, a, b, c), dtype=torch.float64) @ A r2 = rot(a, b, c) d = (r1 - r2).abs().max() print(d.item()) assert d < 1e-10
def test_is_representation(rep): """ rep(Z(a1) Y(b1) Z(c1) Z(a2) Y(b2) Z(c2)) = rep(Z(a1) Y(b1) Z(c1)) rep(Z(a2) Y(b2) Z(c2)) """ with torch_default_dtype(torch.float64): a1, b1, c1, a2, b2, c2 = torch.rand(6) r1 = rep(a1, b1, c1) r2 = rep(a2, b2, c2) a, b, c = compose(a1, b1, c1, a2, b2, c2) r = rep(a, b, c) r_ = r1 @ r2 d, r = (r - r_).abs().max(), r.abs().max() print(d.item(), r.item()) assert d < 1e-10 * r, d / r
def spherical_harmonics_xyz(order, xyz): """ spherical harmonics :param order: int or list :param xyz: tensor of shape [A, 3] :return: tensor of shape [m, A] """ if not isinstance(order, list): order = [order] with torch_default_dtype(torch.float64): alpha, beta = x_to_alpha_beta(xyz) # two tensors of shape [A] out = spherical_harmonics(order, alpha, beta) # [m, A] # fix values when xyz = 0 if (xyz.norm(2, -1) == 0).nonzero().numel() > 0: # this `if` is not needed with version 1.0 of pytorch val = torch.cat([spherical_harmonics(0, 123, 321) if J == 0 else torch.zeros(2 * J + 1) for J in order]) # [m] out[:, xyz.norm(2, -1) == 0] = val.view(-1, 1) return out
def _test_spherical_harmonics(order): """ This test tests that - irr_repr - compose - spherical_harmonics are compatible Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x) with x = Z(a) Y(b) eta """ with torch_default_dtype(torch.float64): a, b = torch.rand(2) alpha, beta, gamma = torch.rand(3) ra, rb, _ = compose(alpha, beta, gamma, a, b, 0) Yrx = spherical_harmonics(order, ra, rb) Y = spherical_harmonics(order, a, b) DrY = irr_repr(order, alpha, beta, gamma) @ Y d, r = (Yrx - DrY).abs().max(), Y.abs().max() print(d.item(), r.item()) assert d < 1e-10 * r, d / r