def __init__(self, res=None, lmax=None, normalization='component', lmax_in=None): """ :param res: resolution of the input as a tuple (beta resolution, alpha resolution) :param lmax: maximum l of the output :param normalization: either 'norm', 'component', 'none' or custom :param lmax_in: maximum l of the input of ToS2Grid in order to be the inverse """ super().__init__() assert normalization in [ 'norm', 'component', 'none' ] or torch.is_tensor( normalization ), "normalization needs to be 'norm', 'component' or 'none'" if isinstance(res, int) or res is None: lmax, res_beta, res_alpha = complete_lmax_res(lmax, res, None) else: lmax, res_beta, res_alpha = complete_lmax_res(lmax, *res) if lmax_in is None: lmax_in = lmax betas, alphas, shb, sha = spherical_harmonics_s2_grid( lmax, res_beta, res_alpha) with torch_default_dtype(torch.float64): # normalize such that it is the inverse of ToS2Grid if normalization == 'component': n = math.sqrt(4 * math.pi) * torch.tensor( [math.sqrt(2 * l + 1) for l in range(lmax + 1)]) * math.sqrt(lmax_in + 1) if normalization == 'norm': n = math.sqrt( 4 * math.pi) * torch.ones(lmax + 1) * math.sqrt(lmax_in + 1) if normalization == 'none': n = 4 * math.pi * torch.ones(lmax + 1) if torch.is_tensor(normalization): n = normalization.to(dtype=torch.float64) m = rsh.spherical_harmonics_expand_matrix(range(lmax + 1)) # [l, m, i] assert res_beta % 2 == 0 qw = torch.tensor(S3.quadrature_weights( res_beta // 2)) * res_beta**2 / res_alpha # [b] shb = torch.einsum('lmj,bj,lmi,l,b->mbi', m, shb, m, n, qw) # [m, b, i] self.register_buffer('alphas', alphas) self.register_buffer('betas', betas) self.register_buffer('sha', sha) self.register_buffer('shb', shb) self.to(torch.get_default_dtype())
def __init__(self, lmax=None, res=None, normalization='component'): """ :param lmax: lmax of the input signal :param res: resolution of the output as a tuple (beta resolution, alpha resolution) :param normalization: either 'norm', 'component', 'none' or custom """ super().__init__() assert normalization in [ 'norm', 'component', 'none' ] or torch.is_tensor( normalization ), "normalization needs to be 'norm', 'component' or 'none'" if isinstance(res, int) or res is None: lmax, res_beta, res_alpha = complete_lmax_res(lmax, res, None) else: lmax, res_beta, res_alpha = complete_lmax_res(lmax, *res) betas, alphas, shb, sha = spherical_harmonics_s2_grid( lmax, res_beta, res_alpha) with torch_default_dtype(torch.float64): if normalization == 'component': # normalize such that all l has the same variance on the sphere # given that all componant has mean 0 and variance 1 n = math.sqrt(4 * math.pi) * torch.tensor( [1 / math.sqrt(2 * l + 1) for l in range(lmax + 1)]) / math.sqrt(lmax + 1) if normalization == 'norm': # normalize such that all l has the same variance on the sphere # given that all componant has mean 0 and variance 1/(2L+1) n = math.sqrt( 4 * math.pi) * torch.ones(lmax + 1) / math.sqrt(lmax + 1) if normalization == 'none': n = torch.ones(lmax + 1) if torch.is_tensor(normalization): n = normalization.to(dtype=torch.float64) m = rsh.spherical_harmonics_expand_matrix(range(lmax + 1)) # [l, m, i] shb = torch.einsum('lmj,bj,lmi,l->mbi', m, shb, m, n) # [m, b, i] self.register_buffer('alphas', alphas) self.register_buffer('betas', betas) self.register_buffer('sha', sha) self.register_buffer('shb', shb) self.to(torch.get_default_dtype())
def __init__(self, lmax, res=None, normalization='component'): """ :param lmax: lmax of the input signal :param res: resolution of the output as a tuple (beta resolution, alpha resolution) :param normalization: either 'norm' or 'component' """ super().__init__() assert normalization in [ 'norm', 'component' ], "normalization needs to be 'norm' or 'component'" if isinstance(res, int): res_beta, res_alpha = res, res elif res is None: res_beta = 2 * (lmax + 1) res_alpha = 2 * res_beta else: res_beta, res_alpha = res del res assert res_beta % 2 == 0 assert res_beta >= 2 * (lmax + 1) alphas, betas, sha, shb = spherical_harmonics_s2_grid( lmax, res_alpha, res_beta) with torch_default_dtype(torch.float64): # normalize such that all l has the same variance on the sphere if normalization == 'component': n = math.sqrt(4 * math.pi) * torch.tensor( [1 / math.sqrt(2 * l + 1) for l in range(lmax + 1)]) / math.sqrt(lmax + 1) if normalization == 'norm': n = math.sqrt( 4 * math.pi) * torch.ones(lmax + 1) / math.sqrt(lmax + 1) m = rsh.spherical_harmonics_expand_matrix(lmax) # [l, m, i] shb = torch.einsum('lmj,bj,lmi,l->mbi', m, shb, m, n) # [m, b, i] self.register_buffer('alphas', alphas) self.register_buffer('betas', betas) self.register_buffer('sha', sha) self.register_buffer('shb', shb) self.to(torch.get_default_dtype())