def setup_s2_local_ft(b, grid, cuda_device=None): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix # TODO: optionally get quadrature weights for the chosen grid and use them to weigh the D matrices below. # This is optional because we can also view the filter coefficients as having absorbed the weights already. # Sample the Wigner-D functions on the local grid n_spatial = len(grid) n_spectral = np.sum([(2 * l + 1) for l in range(b)]) F = np.zeros((n_spatial, n_spectral), dtype=complex) for i in range(n_spatial): Dmats = [(2 * b) * wigner_D_matrix(l, grid[i][0], grid[i][1], 0, field='complex', normalization='quantum', order='centered', condon_shortley='cs') .conj() for l in range(b)] F[i] = np.hstack([Dmats[l][:, l] for l in range(b)]) # F is a complex matrix of shape (n_spatial, n_spectral) # If we view it as float, we get a real matrix of shape (n_spatial, 2 * n_spectral) # In the so3_local_ft, we will multiply a batch of real (..., n_spatial) vectors x with this matrix F as xF. # The result is a (..., 2 * n_spectral) array that can be interpreted as a batch of complex vectors. F = F.view('float') # convert to torch Tensor F = torch.from_numpy(F.astype(np.float32)) if cuda_device is not None: F = F.cuda(cuda_device) return F
def __setup_so3_ft(b, grid): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix # Note: optionally get quadrature weights for the chosen grid and use them to weigh the D matrices below. # This is optional because we can also view the filter coefficients as having absorbed the weights already. # The weights depend on the spacing between the point of the grid # Only the coefficient sin(beta) can be added without requireing to know the spacings # Sample the Wigner-D functions on the local grid n_spatial = len(grid) n_spectral = np.sum([(2 * l + 1)**2 for l in range(b)]) F = np.zeros((n_spatial, n_spectral), dtype=complex) for i, (beta, alpha, gamma) in enumerate(grid): Dmats = [ wigner_D_matrix(l, alpha, beta, gamma, field='complex', normalization='quantum', order='centered', condon_shortley='cs').conj() for l in range(b) ] F[i] = np.hstack([Dl.flatten() for Dl in Dmats]) # F is a complex matrix of shape (n_spatial, n_spectral) # If we view it as float, we get a real matrix of shape (n_spatial, 2 * n_spectral) # In the so3_local_ft, we will multiply a batch of real (..., n_spatial) vectors x with this matrix F as xF. # The result is a (..., 2 * n_spectral) array that can be interpreted as a batch of complex vectors. F = F.view('float').reshape((-1, n_spectral, 2)) return F
def _setup_s2_ft(b, grid, device_type, device_index): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix # Note: optionally get quadrature weights for the chosen grid and use them to weigh the D matrices below. # This is optional because we can also view the filter coefficients as having absorbed the weights already. # Sample the Wigner-D functions on the local grid n_spatial = len(grid) n_spectral = np.sum([(2 * l + 1) for l in range(b)]) F = np.zeros((n_spatial, n_spectral), dtype=complex) for i, (beta, alpha) in enumerate(grid): Dmats = [(2 * b) * wigner_D_matrix(l, alpha, beta, 0, field='complex', normalization='quantum', order='centered', condon_shortley='cs') .conj() for l in range(b)] F[i] = np.hstack([Dmats[l][:, l] for l in range(b)]) # F is a complex matrix of shape (n_spatial, n_spectral) # If we view it as float, we get a real matrix of shape (n_spatial, 2 * n_spectral) # In the so3_local_ft, we will multiply a batch of real (..., n_spatial) vectors x with this matrix F as xF. # The result is a (..., 2 * n_spectral) array that can be interpreted as a batch of complex vectors. F = F.view('float').reshape((-1, n_spectral, 2)) # convert to torch Tensor F = torch.tensor(F.astype(np.float32), dtype=torch.float32, device=torch.device(device_type, device_index)) # pylint: disable=E1102 return F
def _setup_so3_rotation(b, alpha, beta, gamma, device_type, device_index): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix Us = [ wigner_D_matrix(l, alpha, beta, gamma, field='complex', normalization='quantum', order='centered', condon_shortley='cs') for l in range(b) ] # Us[l][m, n] = exp(i m alpha) d^l_mn(beta) exp(i n gamma) Us = [ Us[l].astype(np.complex64).view(np.float32).reshape( (2 * l + 1, 2 * l + 1, 2)) for l in range(b) ] # convert to torch Tensor Us = [ torch.tensor(U, dtype=torch.float32, device=torch.device(device_type, device_index)) for U in Us ] # pylint: disable=E1102 return Us
def __setup_so3_rotation(b, alpha, beta, gamma): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix Us = [wigner_D_matrix(l, alpha, beta, gamma, field='complex', normalization='quantum', order='centered', condon_shortley='cs') for l in range(b)] # Us[l][m, n] = exp(i m alpha) d^l_mn(beta) exp(i n gamma) Us = [Us[l].astype(np.complex64).view(np.float32).reshape((2 * l + 1, 2 * l + 1, 2)) for l in range(b)] return Us
def irr_repr(order, alpha, beta, gamma, dtype=None): """ irreducible representation of SO3 - compatible with compose and spherical_harmonics """ from lie_learn.representations.SO3.wigner_d import wigner_D_matrix # if order == 1: # # change of basis to have vector_field[x, y, z] = [vx, vy, vz] # A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) # return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T return torch.tensor(wigner_D_matrix(order, alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype)
def _test_change_basis_wigner_to_rot(): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix with torch_default_dtype(torch.float64): A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64) a, b, c = torch.rand(3) r1 = A.t() @ torch.tensor(wigner_D_matrix(1, a, b, c), dtype=torch.float64) @ A r2 = rot(a, b, c) d = (r1 - r2).abs().max() print(d.item()) assert d < 1e-10
def irr_repr(order, alpha, beta, gamma, dtype=None, device=None): """ irreducible representation of SO3 - compatible with compose and spherical_harmonics """ abc = [alpha, beta, gamma] for i, x in enumerate(abc): if torch.is_tensor(x): abc[i] = x.item() if dtype is None: dtype = x.dtype if device is None: device = x.device if dtype is None: dtype = torch.get_default_dtype() return torch.tensor(wigner_D_matrix(order, *abc), dtype=dtype, device=device)
def setup_so3_rotation(b, alpha, beta, gamma, cuda_device=None): from lie_learn.representations.SO3.wigner_d import wigner_D_matrix Us = [wigner_D_matrix(l, alpha, beta, gamma, field='complex', normalization='quantum', order='centered', condon_shortley='cs') for l in range(b)] # Us[l][m, n] = exp(i m alpha) d^l_mn(beta) exp(i n gamma) Us = [Us[l].astype(np.complex64).view(np.float32).reshape((2 * l + 1, 2 * l + 1, 2)) for l in range(b)] # convert to torch Tensor Us = [torch.from_numpy(U) for U in Us] if cuda_device is not None: Us = [U.cuda(cuda_device) for U in Us] return Us
def __init__(self, L_max, field='complex', normalization='quantum', order='centered', condon_shortley='cs'): super().__init__() # TODO allow user to specify the grid (now using SOFT implicitly) # Explicitly construct the Wigner-D matrices evaluated at each point in a grid in SO(3) self.D = [] b = L_max + 1 for l in range(b): self.D.append(np.zeros((2 * b, 2 * b, 2 * b, 2 * l + 1, 2 * l + 1), dtype=complex if field == 'complex' else float)) for j1 in range(2 * b): alpha = 2 * np.pi * j1 / (2. * b) for k in range(2 * b): beta = np.pi * (2 * k + 1) / (4. * b) for j2 in range(2 * b): gamma = 2 * np.pi * j2 / (2. * b) self.D[-1][j1, k, j2, :, :] = wigner_D_matrix(l, alpha, beta, gamma, field, normalization, order, condon_shortley) # Compute quadrature weights self.w = S3.quadrature_weights(b=b, grid_type='SOFT') # Stack D into a single Fourier matrix # The first axis corresponds to the spatial samples. # The spatial grid has shape (2b, 2b, 2b), so this axis has length (2b)^3. # The second axis of this matrix has length sum_{l=0}^L_max (2l+1)^2, # which corresponds to all the spectral coefficients flattened into a vector. # (normally these are stored as matrices D^l of shape (2l+1)x(2l+1)) self.F = np.hstack([self.D[l].reshape((2 * b) ** 3, (2 * l + 1) ** 2) for l in range(b)]) # For the IFFT / synthesis transform, we need to weight the order-l Fourier coefficients by (2l + 1) # Here we precompute these coefficients. ls = [[ls] * (2 * ls + 1) ** 2 for ls in range(b)] ls = np.array([ll for sublist in ls for ll in sublist]) # (0,) + 9 * (1,) + 25 * (2,), ... self.l_weights = 2 * ls + 1