def __init__(self, repr_in, repr_out, radii, activation=(None, None)): super().__init__() if type(activation) is tuple: scalar_activation, gate_activation = activation else: scalar_activation, gate_activation = activation, activation self.repr_out = repr_out Rs_in = [(m, l) for l, m in enumerate(repr_in)] Rs_out_with_gate = [(m, l) for l, m in enumerate(repr_out)] if scalar_activation is not None and repr_out[0] > 0: self.scalar_act = ScalarActivation( [(repr_out[0], scalar_activation)], bias=False) else: self.scalar_act = None num_non_scalar = sum(repr_out[1:]) if gate_activation is not None and num_non_scalar > 0: Rs_out_with_gate.append((num_non_scalar, 0)) self.gate_act = ScalarActivation( [(num_non_scalar, gate_activation)], bias=False) else: self.gate_act = None with torch_default_dtype(torch.float64): self.conv = SE3PointConvolution(Rs_in, Rs_out_with_gate, radii=radii)
def _sample_sh_cube(size, J, version=3): # pylint: disable=W0613 ''' Sample spherical harmonics in a cube. No bandlimiting considered, aliased regions need to be cut by windowing! :param size: side length of the kernel :param J: order of the spherical harmonics ''' with torch_default_dtype(torch.float64): rng = torch.linspace(-((size - 1) / 2), (size - 1) / 2, steps=size) Y_J = torch.zeros(2 * J + 1, size, size, size, dtype=torch.float64) for idx_x, x in enumerate(rng): for idx_y, y in enumerate(rng): for idx_z, z in enumerate(rng): if x == y == z == 0: # angles at origin are nan, special treatment if J == 0: # Y^0 is angularly independent, choose any angle Y_J[:, idx_x, idx_y, idx_z] = spherical_harmonics(0, 123, 321) # [m] else: # insert zeros for Y^J with J!=0 Y_J[:, idx_x, idx_y, idx_z] = 0 else: # not at the origin, sample spherical harmonic alpha, beta = x_to_alpha_beta([x, y, z]) Y_J[:, idx_x, idx_y, idx_z] = spherical_harmonics(J, alpha, beta) # [m] assert Y_J.dtype == torch.float64 return Y_J # [m, x, y, z]
def _basis_transformation_Q_J(J, order_in, order_out, version=3): # pylint: disable=W0613 """ :param J: order of the spherical harmonics :param order_in: order of the input representation :param order_out: order of the output representation :return: one part of the Q^-1 matrix of the article """ with torch_default_dtype(torch.float64): def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)) def _sylvester_submatrix(J, a, b, c): ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J ''' R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in] R_irrep_J = irr_repr(J, a, b, c) # [m, m] return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \ kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m] random_angles = [ [4.41301023, 5.56684102, 4.59384642], [4.93325116, 6.12697327, 4.14574096], [0.53878964, 4.09050444, 5.36539036], [2.16017393, 3.48835314, 5.55174441], [2.52385107, 0.2908958, 3.90040975] ] null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles]) assert null_space.size(0) == 1, null_space.size() # unique subspace solution Q_J = null_space[0] # [(m_out * m_in) * m] Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1) # [m_out * m_in, m] assert all(torch.allclose( _R_tensor(a.item(), b.item(), c.item()) @ Q_J, Q_J @ irr_repr(J, a.item(), b.item(), c.item())) for a, b, c in torch.rand(4, 3) ) assert Q_J.dtype == torch.float64 return Q_J # [m_out * m_in, m]
def _test_basis_equivariance(): from functools import partial with torch_default_dtype(torch.float64): basis = cube_basis_kernels( 4 * 5, 2, 2, partial(gaussian_window, radii=[5], J_max_list=[999], sigma=2)) overlaps = check_basis_equivariance(basis, 2, 2, *torch.rand(3)) assert overlaps.gt(0.98).all(), overlaps