예제 #1
0
    def __init__(self,
                 irreps: o3.Irreps,
                 act,
                 res,
                 normalization='component',
                 lmax_out=None,
                 random_rot=False):
        super().__init__()

        irreps = o3.Irreps(irreps).simplify()
        _, (_, p_val) = irreps[0]
        _, (lmax, _) = irreps[-1]
        assert all(mul == 1 for mul, _ in irreps)
        assert irreps.ls == list(range(lmax + 1))
        if all(p == p_val for _, (l, p) in irreps):
            p_arg = 1
        elif all(p == p_val * (-1)**l for _, (l, p) in irreps):
            p_arg = -1
        else:
            assert False, "the parity of the input is not well defined"
        self.irreps_in = irreps
        # the input transforms as : A_l ---> p_val * (p_arg)^l * A_l
        # the sphere signal transforms as : f(r) ---> p_val * f(p_arg * r)
        if lmax_out is None:
            lmax_out = lmax

        if p_val in (0, +1):
            self.irreps_out = o3.Irreps([(1, (l, p_val * p_arg**l))
                                         for l in range(lmax_out + 1)])
        if p_val == -1:
            x = torch.linspace(0, 10, 256)
            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = 1
                self.irreps_out = o3.Irreps([(1, (l, p_arg**l))
                                             for l in range(lmax_out + 1)])
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = -1
                self.irreps_out = o3.Irreps([(1, (l, -p_arg**l))
                                             for l in range(lmax_out + 1)])
            else:
                # p_act = 0
                raise ValueError("warning! the parity is violated")

        self.to_s2 = o3.ToS2Grid(lmax, res, normalization=normalization)
        self.from_s2 = o3.FromS2Grid(res,
                                     lmax_out,
                                     normalization=normalization,
                                     lmax_in=lmax)
        self.act = normalize2mom(act)
        self.random_rot = random_rot
예제 #2
0
    def __init__(self, irreps_in, acts):
        super().__init__()
        irreps_in = o3.Irreps(irreps_in)
        assert len(irreps_in) == len(acts), (irreps_in, acts)

        # normalize the second moment
        acts = [
            normalize2mom(act) if act is not None else None for act in acts
        ]

        from e3nn.util._argtools import _get_device

        irreps_out = []
        for (mul, (l_in, p_in)), act in zip(irreps_in, acts):
            if act is not None:
                if l_in != 0:
                    raise ValueError(
                        "Activation: cannot apply an activation function to a non-scalar input."
                    )

                x = torch.linspace(0, 10, 256, device=_get_device(act))

                a1, a2 = act(x), act(-x)
                if (a1 - a2).abs().max() < 1e-5:
                    p_act = 1
                elif (a1 + a2).abs().max() < 1e-5:
                    p_act = -1
                else:
                    p_act = 0

                p_out = p_act if p_in == -1 else p_in
                irreps_out.append((mul, (0, p_out)))

                if p_out == 0:
                    raise ValueError(
                        "Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd."
                    )
            else:
                irreps_out.append((mul, (l_in, p_in)))

        self.irreps_in = irreps_in
        self.irreps_out = o3.Irreps(irreps_out)
        self.acts = torch.nn.ModuleList(acts)
        assert len(self.irreps_in) == len(self.acts)
예제 #3
0
    def __init__(self,
                 lmax_in,
                 lmax_out,
                 act,
                 resolution,
                 *,
                 normalization='component',
                 aspect_ratio=2):
        super().__init__()

        self.grid_in = SO3Grid(lmax_in,
                               resolution,
                               normalization=normalization,
                               aspect_ratio=aspect_ratio)
        self.grid_out = SO3Grid(lmax_out,
                                resolution,
                                normalization=normalization,
                                aspect_ratio=aspect_ratio)
        self.act = normalize2mom(act)

        self.lmax_in = lmax_in
        self.lmax_out = lmax_out
예제 #4
0
파일: _fc.py 프로젝트: Linux-cpp-lisp/e3nn
    def __init__(self,
                 hs,
                 act=None,
                 variance_in=1,
                 variance_out=1,
                 out_act=False):
        super().__init__()
        self.hs = list(hs)
        if act is not None:
            act = normalize2mom(act)
        var_in = variance_in

        for i, (h1, h2) in enumerate(zip(self.hs, self.hs[1:])):
            if i == len(self.hs) - 2:
                var_out = variance_out
                a = act if out_act else None
            else:
                var_out = 1
                a = act

            layer = _Layer(h1, h2, a, var_in, var_out)
            setattr(self, f"layer{i}", layer)

            var_in = var_out
예제 #5
0
def test_identity():
    act1 = normalize2mom(torch.relu)
    act2 = normalize2mom(act1)

    x = torch.randn(10)
    assert (act1(x) == act2(x)).all()
예제 #6
0
def test_device():
    act = torch.nn.ReLU()
    act = normalize2mom(act)
예제 #7
0
def test_deterministic():
    act1 = normalize2mom(torch.tanh)
    act2 = normalize2mom(torch.tanh)

    x = torch.randn(10)
    assert (act1(x) == act2(x)).all()