def __init__(self, irreps: o3.Irreps, act, res, normalization='component', lmax_out=None, random_rot=False): super().__init__() irreps = o3.Irreps(irreps).simplify() _, (_, p_val) = irreps[0] _, (lmax, _) = irreps[-1] assert all(mul == 1 for mul, _ in irreps) assert irreps.ls == list(range(lmax + 1)) if all(p == p_val for _, (l, p) in irreps): p_arg = 1 elif all(p == p_val * (-1)**l for _, (l, p) in irreps): p_arg = -1 else: assert False, "the parity of the input is not well defined" self.irreps_in = irreps # the input transforms as : A_l ---> p_val * (p_arg)^l * A_l # the sphere signal transforms as : f(r) ---> p_val * f(p_arg * r) if lmax_out is None: lmax_out = lmax if p_val in (0, +1): self.irreps_out = o3.Irreps([(1, (l, p_val * p_arg**l)) for l in range(lmax_out + 1)]) if p_val == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.irreps_out = o3.Irreps([(1, (l, p_arg**l)) for l in range(lmax_out + 1)]) elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.irreps_out = o3.Irreps([(1, (l, -p_arg**l)) for l in range(lmax_out + 1)]) else: # p_act = 0 raise ValueError("warning! the parity is violated") self.to_s2 = o3.ToS2Grid(lmax, res, normalization=normalization) self.from_s2 = o3.FromS2Grid(res, lmax_out, normalization=normalization, lmax_in=lmax) self.act = normalize2mom(act) self.random_rot = random_rot
def __init__(self, irreps_in, acts): super().__init__() irreps_in = o3.Irreps(irreps_in) assert len(irreps_in) == len(acts), (irreps_in, acts) # normalize the second moment acts = [ normalize2mom(act) if act is not None else None for act in acts ] from e3nn.util._argtools import _get_device irreps_out = [] for (mul, (l_in, p_in)), act in zip(irreps_in, acts): if act is not None: if l_in != 0: raise ValueError( "Activation: cannot apply an activation function to a non-scalar input." ) x = torch.linspace(0, 10, 256, device=_get_device(act)) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < 1e-5: p_act = 1 elif (a1 + a2).abs().max() < 1e-5: p_act = -1 else: p_act = 0 p_out = p_act if p_in == -1 else p_in irreps_out.append((mul, (0, p_out))) if p_out == 0: raise ValueError( "Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd." ) else: irreps_out.append((mul, (l_in, p_in))) self.irreps_in = irreps_in self.irreps_out = o3.Irreps(irreps_out) self.acts = torch.nn.ModuleList(acts) assert len(self.irreps_in) == len(self.acts)
def __init__(self, lmax_in, lmax_out, act, resolution, *, normalization='component', aspect_ratio=2): super().__init__() self.grid_in = SO3Grid(lmax_in, resolution, normalization=normalization, aspect_ratio=aspect_ratio) self.grid_out = SO3Grid(lmax_out, resolution, normalization=normalization, aspect_ratio=aspect_ratio) self.act = normalize2mom(act) self.lmax_in = lmax_in self.lmax_out = lmax_out
def __init__(self, hs, act=None, variance_in=1, variance_out=1, out_act=False): super().__init__() self.hs = list(hs) if act is not None: act = normalize2mom(act) var_in = variance_in for i, (h1, h2) in enumerate(zip(self.hs, self.hs[1:])): if i == len(self.hs) - 2: var_out = variance_out a = act if out_act else None else: var_out = 1 a = act layer = _Layer(h1, h2, a, var_in, var_out) setattr(self, f"layer{i}", layer) var_in = var_out
def test_identity(): act1 = normalize2mom(torch.relu) act2 = normalize2mom(act1) x = torch.randn(10) assert (act1(x) == act2(x)).all()
def test_device(): act = torch.nn.ReLU() act = normalize2mom(act)
def test_deterministic(): act1 = normalize2mom(torch.tanh) act2 = normalize2mom(torch.tanh) x = torch.randn(10) assert (act1(x) == act2(x)).all()