예제 #1
0
파일: s2.py 프로젝트: zizai/e3nn
    def __init__(self,
                 Rs,
                 act,
                 res,
                 normalization='component',
                 lmax_out=None,
                 random_rot=False):
        '''
        map to the sphere, apply the non linearity point wise and project back
        the signal on the sphere is a quasiregular representation of O3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation of the form [(1, l, p0 * u^l) for l in [0, ..., lmax]]
        :param act: activation function
        :param res: resolution of the grid on the sphere (the higher the more accurate)
        :param normalization: either 'norm' or 'component'
        :param lmax_out: maximum l of the output
        :param random_rot: rotate randomly the grid
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        _, _, p0 = Rs[0]
        _, lmax, _ = Rs[-1]
        assert all(mul == 1 for mul, _, _ in Rs)
        assert [l for _, l, _ in Rs] == [l for l in range(lmax + 1)]
        if all(p == p0 for _, l, p in Rs):
            u = 1
        elif all(p == p0 * (-1)**l for _, l, p in Rs):
            u = -1
        else:
            assert False, "the parity of the input is not well defined"
        self.Rs_in = Rs
        # the input transforms as : A_l ---> p0 * u^l * A_l
        # the sphere signal transforms as : f(r) ---> p0 * f(u * r)
        if lmax_out is None:
            lmax_out = lmax

        if p0 == +1 or p0 == 0:
            self.Rs_out = [(1, l, p0 * u**l) for l in range(lmax_out + 1)]
        if p0 == -1:
            x = torch.linspace(0, 10, 256)
            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = 1
                self.Rs_out = [(1, l, u**l) for l in range(lmax_out + 1)]
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = -1
                self.Rs_out = [(1, l, -u**l) for l in range(lmax_out + 1)]
            else:
                # p_act = 0
                raise ValueError("warning! the parity is violated")

        self.to_s2 = s2grid.ToS2Grid(lmax, res, normalization=normalization)
        self.from_s2 = s2grid.FromS2Grid(res,
                                         lmax_out,
                                         normalization=normalization,
                                         lmax_in=lmax)
        self.act = act
        self.random_rot = random_rot
예제 #2
0
    def test_inverse(self):
        with o3.torch_default_dtype(torch.float64):
            lmax = 5
            res = (50, 75)

            for normalization in ['component', 'norm']:
                to = s2grid.ToS2Grid(lmax, res, normalization=normalization)
                fr = s2grid.FromS2Grid(res, lmax, normalization=normalization)

                sig = rs.randn(10, [(1, l) for l in range(lmax + 1)])
                self.assertLess((fr(to(sig)) - sig).abs().max(), 1e-5)

                s = to(sig)
                self.assertLess((to(fr(s)) - s).abs().max(), 1e-5)
예제 #3
0
    def test_inverse_different_ls(self):
        with o3.torch_default_dtype(torch.float64):
            lin = 5
            lout = 7
            res = (50, 60)

            for normalization in ['component', 'norm', 'none']:
                to = s2grid.ToS2Grid(lin, res, normalization=normalization)
                fr = s2grid.FromS2Grid(res,
                                       lout,
                                       lmax_in=lin,
                                       normalization=normalization)

                si = rs.randn(10, [(1, l) for l in range(lin + 1)])
                so = fr(to(si))
                so = so[:, :si.shape[1]]
                self.assertLess((so - si).abs().max(), 1e-5)