def __init__(self, Rs, act, res, normalization='component', lmax_out=None, random_rot=False): ''' map to the sphere, apply the non linearity point wise and project back the signal on the sphere is a quasiregular representation of O3 and we can apply a pointwise operation on these representations :param Rs: input representation of the form [(1, l, p0 * u^l) for l in [0, ..., lmax]] :param act: activation function :param res: resolution of the grid on the sphere (the higher the more accurate) :param normalization: either 'norm' or 'component' :param lmax_out: maximum l of the output :param random_rot: rotate randomly the grid ''' super().__init__() Rs = rs.simplify(Rs) _, _, p0 = Rs[0] _, lmax, _ = Rs[-1] assert all(mul == 1 for mul, _, _ in Rs) assert [l for _, l, _ in Rs] == [l for l in range(lmax + 1)] if all(p == p0 for _, l, p in Rs): u = 1 elif all(p == p0 * (-1)**l for _, l, p in Rs): u = -1 else: assert False, "the parity of the input is not well defined" self.Rs_in = Rs # the input transforms as : A_l ---> p0 * u^l * A_l # the sphere signal transforms as : f(r) ---> p0 * f(u * r) if lmax_out is None: lmax_out = lmax if p0 == +1 or p0 == 0: self.Rs_out = [(1, l, p0 * u**l) for l in range(lmax_out + 1)] if p0 == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.Rs_out = [(1, l, u**l) for l in range(lmax_out + 1)] elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.Rs_out = [(1, l, -u**l) for l in range(lmax_out + 1)] else: # p_act = 0 raise ValueError("warning! the parity is violated") self.to_s2 = s2grid.ToS2Grid(lmax, res, normalization=normalization) self.from_s2 = s2grid.FromS2Grid(res, lmax_out, normalization=normalization, lmax_in=lmax) self.act = act self.random_rot = random_rot
def test_normalization(self): with o3.torch_default_dtype(torch.float64): lmax = 5 res = (20, 30) for normalization in ['component', 'norm']: to = s2grid.ToS2Grid(lmax, res, normalization=normalization) x = rs.randn(50, [(1, l) for l in range(lmax + 1)], normalization=normalization) y = to(x) self.assertAlmostEqual(y.var().item(), 1, delta=0.2)
def test_inverse(self): with o3.torch_default_dtype(torch.float64): lmax = 5 res = (50, 75) for normalization in ['component', 'norm']: to = s2grid.ToS2Grid(lmax, res, normalization=normalization) fr = s2grid.FromS2Grid(res, lmax, normalization=normalization) sig = rs.randn(10, [(1, l) for l in range(lmax + 1)]) self.assertLess((fr(to(sig)) - sig).abs().max(), 1e-5) s = to(sig) self.assertLess((to(fr(s)) - s).abs().max(), 1e-5)
def test_inverse_different_ls(self): with o3.torch_default_dtype(torch.float64): lin = 5 lout = 7 res = (50, 60) for normalization in ['component', 'norm', 'none']: to = s2grid.ToS2Grid(lin, res, normalization=normalization) fr = s2grid.FromS2Grid(res, lout, lmax_in=lin, normalization=normalization) si = rs.randn(10, [(1, l) for l in range(lin + 1)]) so = fr(to(si)) so = so[:, :si.shape[1]] self.assertLess((so - si).abs().max(), 1e-5)