예제 #1
0
파일: s2grid.py 프로젝트: wendazhou/e3nn
    def __init__(self,
                 res=None,
                 lmax=None,
                 normalization='component',
                 lmax_in=None):
        """
        :param res: resolution of the input as a tuple (beta resolution, alpha resolution)
        :param lmax: maximum l of the output
        :param normalization: either 'norm', 'component', 'none' or custom
        :param lmax_in: maximum l of the input of ToS2Grid in order to be the inverse
        """
        super().__init__()

        assert normalization in [
            'norm', 'component', 'none'
        ] or torch.is_tensor(
            normalization
        ), "normalization needs to be 'norm', 'component' or 'none'"

        if isinstance(res, int) or res is None:
            lmax, res_beta, res_alpha = complete_lmax_res(lmax, res, None)
        else:
            lmax, res_beta, res_alpha = complete_lmax_res(lmax, *res)

        if lmax_in is None:
            lmax_in = lmax

        betas, alphas, shb, sha = spherical_harmonics_s2_grid(
            lmax, res_beta, res_alpha)

        with torch_default_dtype(torch.float64):
            # normalize such that it is the inverse of ToS2Grid
            if normalization == 'component':
                n = math.sqrt(4 * math.pi) * torch.tensor(
                    [math.sqrt(2 * l + 1)
                     for l in range(lmax + 1)]) * math.sqrt(lmax_in + 1)
            if normalization == 'norm':
                n = math.sqrt(
                    4 * math.pi) * torch.ones(lmax + 1) * math.sqrt(lmax_in +
                                                                    1)
            if normalization == 'none':
                n = 4 * math.pi * torch.ones(lmax + 1)
            if torch.is_tensor(normalization):
                n = normalization.to(dtype=torch.float64)
            m = rsh.spherical_harmonics_expand_matrix(range(lmax +
                                                            1))  # [l, m, i]
            assert res_beta % 2 == 0
            qw = torch.tensor(S3.quadrature_weights(
                res_beta // 2)) * res_beta**2 / res_alpha  # [b]
        shb = torch.einsum('lmj,bj,lmi,l,b->mbi', m, shb, m, n,
                           qw)  # [m, b, i]

        self.register_buffer('alphas', alphas)
        self.register_buffer('betas', betas)
        self.register_buffer('sha', sha)
        self.register_buffer('shb', shb)
        self.to(torch.get_default_dtype())
예제 #2
0
파일: s2grid.py 프로젝트: wendazhou/e3nn
    def __init__(self, lmax=None, res=None, normalization='component'):
        """
        :param lmax: lmax of the input signal
        :param res: resolution of the output as a tuple (beta resolution, alpha resolution)
        :param normalization: either 'norm', 'component', 'none' or custom
        """
        super().__init__()

        assert normalization in [
            'norm', 'component', 'none'
        ] or torch.is_tensor(
            normalization
        ), "normalization needs to be 'norm', 'component' or 'none'"

        if isinstance(res, int) or res is None:
            lmax, res_beta, res_alpha = complete_lmax_res(lmax, res, None)
        else:
            lmax, res_beta, res_alpha = complete_lmax_res(lmax, *res)

        betas, alphas, shb, sha = spherical_harmonics_s2_grid(
            lmax, res_beta, res_alpha)

        with torch_default_dtype(torch.float64):
            if normalization == 'component':
                # normalize such that all l has the same variance on the sphere
                # given that all componant has mean 0 and variance 1
                n = math.sqrt(4 * math.pi) * torch.tensor(
                    [1 / math.sqrt(2 * l + 1)
                     for l in range(lmax + 1)]) / math.sqrt(lmax + 1)
            if normalization == 'norm':
                # normalize such that all l has the same variance on the sphere
                # given that all componant has mean 0 and variance 1/(2L+1)
                n = math.sqrt(
                    4 * math.pi) * torch.ones(lmax + 1) / math.sqrt(lmax + 1)
            if normalization == 'none':
                n = torch.ones(lmax + 1)
            if torch.is_tensor(normalization):
                n = normalization.to(dtype=torch.float64)
            m = rsh.spherical_harmonics_expand_matrix(range(lmax +
                                                            1))  # [l, m, i]
        shb = torch.einsum('lmj,bj,lmi,l->mbi', m, shb, m, n)  # [m, b, i]

        self.register_buffer('alphas', alphas)
        self.register_buffer('betas', betas)
        self.register_buffer('sha', sha)
        self.register_buffer('shb', shb)
        self.to(torch.get_default_dtype())
예제 #3
0
    def __init__(self, lmax, res=None, normalization='component'):
        """
        :param lmax: lmax of the input signal
        :param res: resolution of the output as a tuple (beta resolution, alpha resolution)
        :param normalization: either 'norm' or 'component'
        """
        super().__init__()

        assert normalization in [
            'norm', 'component'
        ], "normalization needs to be 'norm' or 'component'"

        if isinstance(res, int):
            res_beta, res_alpha = res, res
        elif res is None:
            res_beta = 2 * (lmax + 1)
            res_alpha = 2 * res_beta
        else:
            res_beta, res_alpha = res
        del res
        assert res_beta % 2 == 0
        assert res_beta >= 2 * (lmax + 1)

        alphas, betas, sha, shb = spherical_harmonics_s2_grid(
            lmax, res_alpha, res_beta)

        with torch_default_dtype(torch.float64):
            # normalize such that all l has the same variance on the sphere
            if normalization == 'component':
                n = math.sqrt(4 * math.pi) * torch.tensor(
                    [1 / math.sqrt(2 * l + 1)
                     for l in range(lmax + 1)]) / math.sqrt(lmax + 1)
            if normalization == 'norm':
                n = math.sqrt(
                    4 * math.pi) * torch.ones(lmax + 1) / math.sqrt(lmax + 1)
            m = rsh.spherical_harmonics_expand_matrix(lmax)  # [l, m, i]
        shb = torch.einsum('lmj,bj,lmi,l->mbi', m, shb, m, n)  # [m, b, i]

        self.register_buffer('alphas', alphas)
        self.register_buffer('betas', betas)
        self.register_buffer('sha', sha)
        self.register_buffer('shb', shb)
        self.to(torch.get_default_dtype())