コード例 #1
0
    def __init__(self, Rs_in, mul, lmax, Rs_out, size=5, layers=3):
        super().__init__()

        Rs = rs.simplify(Rs_in)
        Rs_out = rs.simplify(Rs_out)
        Rs_act = list(range(lmax + 1))

        self.mul = mul
        self.layers = []

        for _ in range(layers):
            conv = ImageConvolution(Rs,
                                    mul * Rs_act,
                                    size,
                                    lmax=lmax,
                                    fuzzy_pixels=True,
                                    padding=size // 2)

            # s2 nonlinearity
            act = S2Activation(Rs_act, swish, res=60)
            Rs = mul * act.Rs_out

            pool = LowPassFilter(scale=2.0, stride=2)

            self.layers += [torch.nn.ModuleList([conv, act, pool])]

        self.layers = torch.nn.ModuleList(self.layers)
        self.tail = LearnableTensorSquare(Rs, Rs_out)
コード例 #2
0
    def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3):
        super().__init__()

        Rs = self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)
        self.act = S2Activation(list(range(lmax + 1)),
                                swish,
                                res=20 * (lmax + 1))

        self.layers = []

        for _ in range(layers):
            lin = LearnableTensorSquare(Rs,
                                        mul * self.act.Rs_in,
                                        linear=True,
                                        allow_zero_outputs=True)

            # s2 nonlinearity
            Rs = mul * self.act.Rs_out

            self.layers += [lin]

        self.layers = torch.nn.ModuleList(self.layers)

        self.tail = LearnableTensorSquare(Rs, self.Rs_out)
コード例 #3
0
    def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3):
        super().__init__()

        Rs = self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        def make_act(p_val, p_arg, act):
            Rs = [(1, l, p_val * p_arg**l) for l in range(lmax + 1)]
            return S2Activation(Rs, act, res=20 * (lmax + 1))

        self.act1, self.act2 = make_act(1, -1, swish), make_act(-1, -1, tanh)
        self.mul = mul

        self.layers = []

        for _ in range(layers):
            Rs_out = mul * (self.act1.Rs_in + self.act2.Rs_in)
            lin = LearnableTensorSquare(Rs,
                                        Rs_out,
                                        linear=True,
                                        allow_zero_outputs=True)

            # s2 nonlinearity
            Rs = mul * (self.act1.Rs_out + self.act2.Rs_out)

            self.layers += [lin]

        self.layers = torch.nn.ModuleList(self.layers)

        self.tail = LearnableTensorSquare(Rs, self.Rs_out)
コード例 #4
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 linear=True,
                 allow_change_output=False,
                 allow_zero_outputs=False):
        super().__init__()

        self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        ls = [l for _, l, _ in self.Rs_out]
        selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls)

        if linear:
            Rs_in = [(1, 0, 1)] + self.Rs_in
        else:
            Rs_in = self.Rs_in
        self.linear = linear

        Rs_ts, T = rs.tensor_square(Rs_in, selection_rule)
        register_sparse_buffer(self, 'T', T)  # [out, in1 * in2]

        ls = [l for _, l, _ in Rs_ts]
        if allow_change_output:
            self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls]
        elif not allow_zero_outputs:
            assert all(l in ls for _, l, _ in self.Rs_out)

        self.kernel = KernelLinear(Rs_ts, self.Rs_out)  # [out, in, w]
コード例 #5
0
ファイル: tensor_product.py プロジェクト: zizai/e3nn
    def __init__(self, Rs_1, Rs_2, selection_rule=o3.selection_rule):
        super().__init__()

        self.Rs_1 = rs.simplify(Rs_1)
        self.Rs_2 = rs.simplify(Rs_2)

        Rs_out, mixing_matrix = rs.tensor_product(Rs_1, Rs_2, selection_rule)
        self.Rs_out = rs.simplify(Rs_out)
        self.register_buffer('mixing_matrix', mixing_matrix)
コード例 #6
0
ファイル: tensor_product.py プロジェクト: zizai/e3nn
    def __init__(self, Rs_1, Rs_2, selection_rule=o3.selection_rule):
        super().__init__()

        Rs_1 = rs.simplify(Rs_1)
        Rs_2 = rs.simplify(Rs_2)
        assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2)

        Rs_out, mixing_matrix = rs.elementwise_tensor_product(Rs_1, Rs_2, selection_rule)
        self.register_buffer("mixing_matrix", mixing_matrix)
        self.Rs_out = rs.simplify(Rs_out)
コード例 #7
0
ファイル: message_passing.py プロジェクト: wendazhou/e3nn
    def __init__(self, Rs_in, Rs_out, Rs_sh, RadialModel, groups=math.inf, normalization='component'):
        super().__init__(aggr='add', flow='target_to_source')
        self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        self.lin1 = Linear(Rs_in, Rs_out, allow_unused_inputs=True, allow_zero_outputs=True)
        self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, groups=groups, normalization=normalization, own_weight=False)
        self.rm = RadialModel(self.tp.nweight)
        self.lin2 = Linear(Rs_out, Rs_out)
        self.Rs_sh = Rs_sh
        self.normalization = normalization
コード例 #8
0
def WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, normalization='component', own_weight=True):
    Rs_in1 = rs.simplify(Rs_in1)
    Rs_in2 = rs.simplify(Rs_in2)
    Rs_out = rs.simplify(Rs_out)

    instr = [
        (i_1, i_2, i_out, 'uvw')
        for i_1, (_, l_1, p_1) in enumerate(Rs_in1)
        for i_2, (_, l_2, p_2) in enumerate(Rs_in2)
        for i_out, (_, l_out, p_out) in enumerate(Rs_out)
        if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out
    ]
    return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, normalization, own_weight)
コード例 #9
0
ファイル: message_passing.py プロジェクト: wendazhou/e3nn
    def __init__(self, Rs_in, Rs_out, Rs_sh, RadialModel, normalization='component'):
        """
        :param Rs_in:  input representation
        :param lmax:   spherical harmonic representation
        :param Rs_out: output representation
        :param RadialModel: model constructor
        """
        super().__init__(aggr='add', flow='target_to_source')
        self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        self.tp = WeightedTensorProduct(Rs_in, Rs_sh, Rs_out, normalization, own_weight=False)
        self.rm = RadialModel(self.tp.nweight)
        self.Rs_sh = Rs_sh
        self.normalization = normalization
コード例 #10
0
def spherical_harmonics_xyz(Rs, xyz):
    """
    spherical harmonics

    :param Rs: list of L's
    :param xyz: tensor of shape [..., 3]
    :return: tensor of shape [..., m]
    """
    Rs = rs.simplify(Rs)

    if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(Rs) <= 10:  # pragma: no cover
        try:
            return spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    *size, _ = xyz.shape
    xyz = xyz.reshape(-1, 3)
    xyz = xyz / torch.norm(xyz, 2, dim=1, keepdim=True)

    # if z > x, rotate x-axis with z-axis
    s = xyz[:, 2].abs() > xyz[:, 0].abs()
    xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])

    alpha = torch.atan2(xyz[:, 1], xyz[:, 0])
    z = xyz[:, 2]
    y = (xyz[:, 0].pow(2) + xyz[:, 1].pow(2)).sqrt()

    sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y)

    # rotate back
    sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device)
    return sh.reshape(*size, sh.shape[1])
コード例 #11
0
ファイル: s2.py プロジェクト: zizai/e3nn
    def __init__(self,
                 Rs,
                 act,
                 res,
                 normalization='component',
                 lmax_out=None,
                 random_rot=False):
        '''
        map to the sphere, apply the non linearity point wise and project back
        the signal on the sphere is a quasiregular representation of O3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation of the form [(1, l, p0 * u^l) for l in [0, ..., lmax]]
        :param act: activation function
        :param res: resolution of the grid on the sphere (the higher the more accurate)
        :param normalization: either 'norm' or 'component'
        :param lmax_out: maximum l of the output
        :param random_rot: rotate randomly the grid
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        _, _, p0 = Rs[0]
        _, lmax, _ = Rs[-1]
        assert all(mul == 1 for mul, _, _ in Rs)
        assert [l for _, l, _ in Rs] == [l for l in range(lmax + 1)]
        if all(p == p0 for _, l, p in Rs):
            u = 1
        elif all(p == p0 * (-1)**l for _, l, p in Rs):
            u = -1
        else:
            assert False, "the parity of the input is not well defined"
        self.Rs_in = Rs
        # the input transforms as : A_l ---> p0 * u^l * A_l
        # the sphere signal transforms as : f(r) ---> p0 * f(u * r)
        if lmax_out is None:
            lmax_out = lmax

        if p0 == +1 or p0 == 0:
            self.Rs_out = [(1, l, p0 * u**l) for l in range(lmax_out + 1)]
        if p0 == -1:
            x = torch.linspace(0, 10, 256)
            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = 1
                self.Rs_out = [(1, l, u**l) for l in range(lmax_out + 1)]
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = -1
                self.Rs_out = [(1, l, -u**l) for l in range(lmax_out + 1)]
            else:
                # p_act = 0
                raise ValueError("warning! the parity is violated")

        self.to_s2 = s2grid.ToS2Grid(lmax, res, normalization=normalization)
        self.from_s2 = s2grid.FromS2Grid(res,
                                         lmax_out,
                                         normalization=normalization,
                                         lmax_in=lmax)
        self.act = act
        self.random_rot = random_rot
コード例 #12
0
def spherical_harmonics_xyz_cuda(Rs, xyz):  # pragma: no cover
    """
    cuda version of spherical_harmonics_xyz
    """
    from e3nn import cuda_rsh  # pylint: disable=no-name-in-module, import-outside-toplevel

    Rs = rs.simplify(Rs)

    *size, _ = xyz.size()
    xyz = xyz.reshape(-1, 3)
    xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True)

    lmax = rs.lmax(Rs)
    out = xyz.new_empty(((lmax + 1)**2, xyz.size(0)))  # [ filters, batch_size]
    cuda_rsh.real_spherical_harmonics(out, xyz)

    # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined  # h - halved
    norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)]
    if lmax % 2 == 0:
        norm_coef.extend([1.] * (2 * lmax + 1))
    norm_coef = out.new_tensor(norm_coef).unsqueeze(1)
    out.mul_(norm_coef)

    if not rs.are_equal(Rs, list(range(lmax + 1))):
        out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)])

    return out.T.reshape(*size, out.shape[0])
コード例 #13
0
def kernel_geometric(Rs_in,
                     Rs_out,
                     selection_rule=o3.selection_rule_in_out_sh,
                     normalization='component'):
    # Compute Clebsh-Gordan coefficients
    Rs_f, Q = rs.tensor_product(Rs_in, selection_rule, Rs_out,
                                normalization)  # [out, in, Y]

    # Sort filters representation
    Rs_f, perm = rs.sort(Rs_f)
    Rs_f = rs.simplify(Rs_f)
    Q = torch.einsum('ijk,lk->ijl', Q, perm)
    del perm

    # Normalize the spherical harmonics
    if normalization == 'component':
        diag = torch.ones(rs.irrep_dim(Rs_f))
    if normalization == 'norm':
        diag = torch.cat(
            [torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in Rs_f])
    norm_Y = math.sqrt(4 * math.pi) * torch.diag(diag)  # [Y, Y]

    # Matrix to dispatch the spherical harmonics
    mat_Y = rs.map_irrep_to_Rs(Rs_f)  # [Rs_f, Y]
    mat_Y = mat_Y @ norm_Y

    # Create the radial model: R+ -> R^n_path
    mat_R = rs.map_mul_to_Rs(Rs_f)  # [Rs_f, R]

    mixing_matrix = torch.einsum('ijk,ky,kw->ijyw', Q, mat_Y,
                                 mat_R)  # [out, in, Y, R]
    return Rs_f, mixing_matrix
コード例 #14
0
ファイル: so3.py プロジェクト: truatpasteurdotfr/e3nn
    def __init__(self, Rs, act, n):
        '''
        map to a signal on SO3, apply the non linearity point wise and project back
        the signal on SO3 is the regular representation of SO3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation
        :param act: activation function
        :param n: number of point on the sphere (the higher the more accurate)
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        mul0, _, _ = Rs[0]
        assert all(mul0 * (2 * l + 1) == mul for mul, l, _ in Rs)
        assert [l for _, l, _ in Rs] == list(range(len(Rs)))
        assert all(p == 0 for _, l, p in Rs)

        self.Rs_out = Rs

        x = [o3.rand_rot() for _ in range(n)]
        Z = torch.stack([
            torch.cat([
                o3.irr_repr(l, *o3.rot_to_abc(R)).flatten() * (2 * l + 1)**0.5
                for l in range(len(Rs))
            ]) for R in x
        ])  # [z, lmn]
        Z.div_(Z.shape[1]**0.5)
        self.register_buffer('Z', Z)
        self.act = act
コード例 #15
0
ファイル: tensor_product.py プロジェクト: zizai/e3nn
    def __init__(self, Rs_in, selection_rule=o3.selection_rule):
        super().__init__()

        self.Rs_in = rs.simplify(Rs_in)

        self.Rs_out, mixing_matrix = rs.tensor_square(Rs_in, selection_rule, sorted=True)
        self.register_buffer('mixing_matrix', mixing_matrix)
コード例 #16
0
    def __init__(self, Rs, normalization='component'):
        super().__init__()

        Rs = rs.simplify(Rs)
        n = sum(mul for mul, _, _ in Rs)
        self.Rs_in = Rs
        self.Rs_out = [(n, 0, +1)]
        self.normalization = normalization
コード例 #17
0
def spherical_harmonics_alpha_z_y(Rs, alpha, z, y):
    """
    cpu version of spherical_harmonics_alpha_beta
    """
    Rs = rs.simplify(Rs)
    sha = spherical_harmonics_alpha(rs.lmax(Rs), alpha.flatten())  # [z, m]
    shz = spherical_harmonics_z(Rs, z.flatten(), y.flatten())  # [z, l * m]
    out = mul_m_lm(Rs, sha, shz)
    return out.reshape(alpha.shape + (shz.shape[1],))
コード例 #18
0
    def __init__(self, Rs, acts):
        '''
        Can be used only with scalar fields

        :param acts: list of tuple (multiplicity, activation)
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        acts = copy.deepcopy(acts)

        n1 = sum(mul for mul, _, _ in Rs)
        n2 = sum(mul for mul, _ in acts if mul > 0)

        for i, (mul, act) in enumerate(acts):
            if mul == -1:
                acts[i] = (n1 - n2, act)
                assert n1 - n2 >= 0

        assert n1 == sum(mul for mul, _ in acts)

        i = 0
        while i < len(Rs):
            mul_r, l, p_r = Rs[i]
            mul_a, act = acts[i]

            if mul_r < mul_a:
                acts[i] = (mul_r, act)
                acts.insert(i + 1, (mul_a - mul_r, act))

            if mul_a < mul_r:
                Rs[i] = (mul_a, l, p_r)
                Rs.insert(i + 1, (mul_r - mul_a, l, p_r))
            i += 1

        x = torch.linspace(0, 10, 256)

        Rs_out = []
        for (mul, l, p_in), (mul_a, act) in zip(Rs, acts):
            assert mul == mul_a

            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = 1
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = -1
            else:
                p_act = 0

            p = p_act if p_in == -1 else p_in
            Rs_out.append((mul, 0, p))

            if p_in != 0 and p == 0:
                raise ValueError("warning! the parity is violated")

        self.Rs_out = Rs_out
        self.acts = acts
コード例 #19
0
def spherical_harmonics_xyz(Rs, xyz, normalization='none'):
    """
    spherical harmonics

    :param Rs: list of L's
    :param xyz: tensor of shape [..., 3]
    :return: tensor of shape [..., m]
    """
    Rs = rs.simplify(Rs)
    *size, _ = xyz.shape
    xyz = xyz.reshape(-1, 3)
    d = torch.norm(xyz, 2, dim=1)
    xyz = xyz[d > 0]
    xyz = xyz / d[d > 0, None]

    sh = None

    if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(
            Rs) <= 10:  # pragma: no cover
        try:
            sh = _spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    if sh is None:
        # if z > x, rotate x-axis with z-axis
        s = xyz[:, 2].abs() > xyz[:, 0].abs()
        xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])

        alpha = torch.atan2(xyz[:, 1], xyz[:, 0])
        z = xyz[:, 2]
        y = xyz[:, :2].norm(dim=1)

        sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y)

        # rotate back
        sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device)

    if len(d) > len(sh):
        out = sh.new_zeros(len(d), sh.shape[1])
        out[d == 0] = math.sqrt(1 / (4 * math.pi)) * torch.cat([
            sh.new_ones(1) if l == 0 else sh.new_zeros(2 * l + 1)
            for mul, l, p in Rs for _ in range(mul)
        ])
        out[d > 0] = sh
        sh = out

    if normalization == 'component':
        sh.mul_(math.sqrt(4 * math.pi))
    if normalization == 'norm':
        sh.mul_(
            torch.cat([
                math.sqrt(4 * math.pi / (2 * l + 1)) * sh.new_ones(2 * l + 1)
                for mul, l, p in Rs for _ in range(mul)
            ]))
    return sh.reshape(*size, sh.shape[1])
コード例 #20
0
    def __init__(self, Rs_in1, Rs_in2, Rs_out, allow_change_output=False):
        super().__init__()

        self.Rs_in1 = rs.simplify(Rs_in1)
        self.Rs_in2 = rs.simplify(Rs_in2)
        self.Rs_out = rs.simplify(Rs_out)

        ls = [l for _, l, _ in self.Rs_out]
        selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls)

        Rs_ts, T = rs.tensor_product(self.Rs_in1, self.Rs_in2, selection_rule)
        register_sparse_buffer(self, 'T', T)  # [out, in1 * in2]

        ls = [l for _, l, _ in Rs_ts]
        if allow_change_output:
            self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls]
        else:
            assert all(l in ls for _, l, _ in self.Rs_out)

        self.kernel = KernelLinear(Rs_ts, self.Rs_out)  # [out, in, w]
コード例 #21
0
ファイル: tensor_product_test.py プロジェクト: sophiaas/e3nn
def test_weighted_tensor_product():
    torch.set_default_dtype(torch.float64)

    Rs_in1 = rs.simplify([1] * 20 + [2] * 4)
    Rs_in2 = rs.simplify([0] * 10 + [1] * 10 + [2] * 5)
    Rs_out = rs.simplify([0] * 3 + [1] * 4)

    tp = WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=2)

    x1 = rs.randn(20, Rs_in1)
    x2 = rs.randn(20, Rs_in2)

    angles = o3.rand_angles()

    z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T
    z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T)

    z1.sum().backward()

    assert torch.allclose(z1, z2)
コード例 #22
0
ファイル: linear.py プロジェクト: zizai/e3nn
    def __init__(self, Rs_in, Rs_out):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()
        self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        n_path = 0

        for mul_out, l_out, p_out in self.Rs_out:
            for mul_in, l_in, p_in in self.Rs_in:
                if (l_out, p_out) == (l_in, p_in):
                    # compute the number of degrees of freedom
                    n_path += mul_out * mul_in

        self.weight = torch.nn.Parameter(torch.randn(n_path))
コード例 #23
0
def spherical_harmonics_z(Rs, z, y=None):
    """
    the z component of the spherical harmonics
    (useful to perform fourier transform)

    :param z: tensor of shape [...]
    :return: tensor of shape [..., l * m]
    """
    Rs = rs.simplify(Rs)
    assert all(p in [0, (-1)**l] for _, l, p in Rs)
    ls = [l for mul, l, _ in Rs for _ in range(mul)]
    return legendre(ls, z, y)  # [..., l * m]
コード例 #24
0
    def __init__(self, Rs_in, Rs_out, lmax=3):
        super().__init__(aggr='add', flow='target_to_source')
        RadialModel = partial(
            GaussianRadialModel,
            max_radius=1.2,
            min_radius=0.0,
            number_of_basis=3,
            h=100,
            L=2,
            act=swish
        )

        Rs_sh = [(1, l, (-1)**l) for l in range(0, lmax + 1)]

        self.Rs_in = rs.simplify(Rs_in)
        self.Rs_out = rs.simplify(Rs_out)

        self.lin1 = Linear(Rs_in, Rs_out, allow_unused_inputs=True, allow_zero_outputs=True)
        self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, own_weight=False)
        self.rm = RadialModel(self.tp.nweight)
        self.lin2 = Linear(Rs_out, Rs_out)
        self.Rs_sh = Rs_sh
コード例 #25
0
ファイル: linear_mod.py プロジェクト: zizai/e3nn
def kernel_linear(Rs_in, Rs_out):
    # Compute Clebsh-Gordan coefficients
    def selection_rule(l_in, p_in, l_out, p_out):
        if l_in == l_out and p_out in [0, p_in]:
            return [0]
        return []

    Rs_f, Q = rs.tensor_product(Rs_in, selection_rule, Rs_out)  # [out, in, w]
    Rs_f = rs.simplify(Rs_f)
    [(_n_path, l, p)] = Rs_f
    assert l == 0 and p in [0, 1]

    return Q
コード例 #26
0
ファイル: networks.py プロジェクト: zizai/e3nn
    def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3):
        super().__init__()

        Rs = rs.simplify(Rs_in)
        Rs_out = rs.simplify(Rs_out)

        self.layers = []

        for _ in range(layers):
            # tensor product: nonlinear and mixes the l's
            tp = TensorSquare(Rs,
                              selection_rule=partial(o3.selection_rule,
                                                     lmax=lmax))

            # direct sum
            Rs = Rs + tp.Rs_out

            # linear: learned but don't mix l's
            Rs_act = [(1, l) for l in range(lmax + 1)]
            lin = Linear(Rs, mul * Rs_act, allow_unused_inputs=True)

            # s2 nonlinearity
            act = S2Activation(Rs_act, swish, res=20 * (lmax + 1))
            Rs = mul * act.Rs_out

            self.layers += [torch.nn.ModuleList([tp, lin, act])]

        self.layers = torch.nn.ModuleList(self.layers)

        def lfilter(l):
            return l in [j for _, j, _ in Rs_out]

        tp = TensorSquare(Rs,
                          selection_rule=partial(o3.selection_rule,
                                                 lfilter=lfilter))
        Rs = Rs + tp.Rs_out
        lin = Linear(Rs, Rs_out, allow_unused_inputs=True)
        self.tail = torch.nn.ModuleList([tp, lin])
コード例 #27
0
    def __init__(self, Rs_out, scalar_activation, gate_activation):
        """
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param scalar_activation: nonlinear function applied on l=0 channels
        :param gate_activation: nonlinear function applied on the gates
        """
        super().__init__()

        Rs_out = rs.simplify(Rs_out)

        self.scalar_act = scalar_activation
        self.gate_act = gate_activation

        Rs = []
        Rs_gates = []
        for mul, l, p in Rs_out:
            if p != 0:
                raise ValueError("use GatedBlockParity instead")
            Rs.append((mul, l))
            if l != 0:
                Rs_gates.append((mul, 0))

        self.Rs = Rs
        self.Rs_in = rs.simplify(Rs + Rs_gates)
コード例 #28
0
 def from_irrep_tensor(cls, irrep_tensor):
     Rs_remove_p = [(mul, L) for mul, L, p in irrep_tensor.Rs]
     Rs, perm = rs.sort(Rs_remove_p)
     Rs = rs.simplify(Rs)
     mul, Ls, _ = zip(*Rs)
     if max(mul) > 1:
         raise ValueError(
             "Cannot have multiplicity greater than 1 for any L. This tensor has a simplified Rs of {}".format(Rs)
         )
     Lmax = max(Ls)
     sorted_tensor = torch.einsum('ij,...j->...i', perm.to_dense(), irrep_tensor.tensor)
     signal = torch.zeros((Lmax + 1)**2)
     Rs_idx = 0
     for L in range(Lmax + 1):
         if Rs[Rs_idx][1] == L:
             ten_slice = slice(rs.dim(Rs[:Rs_idx]), rs.dim(Rs[:Rs_idx + 1]))
             signal[L ** 2: (L + 1) ** 2] = sorted_tensor[ten_slice]
             Rs_idx += 1
     return cls(signal)
コード例 #29
0
    def __init__(self, Rs, act, n):
        '''
        map to the sphere, apply the non linearity point wise and project back
        the signal on the sphere is a quasiregular representation of O3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation
        :param act: activation function
        :param n: number of point on the sphere (the higher the more accurate)
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        mul0, _, p0 = Rs[0]
        assert all(mul0 == mul for mul, _, _ in Rs)
        assert [l for _, l, _ in Rs] == list(range(len(Rs)))
        assert all(p == p0 for _, l, p in Rs) or all(p == p0 * (-1)**l
                                                     for _, l, p in Rs)

        if p0 == +1 or p0 == 0:
            self.Rs_out = Rs
        if p0 == -1:
            x = torch.linspace(0, 10, 256)
            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = 1
                self.Rs_out = [(mul, l, -p) for mul, l, p in Rs]
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                # p_act = -1
                self.Rs_out = Rs
            else:
                # p_act = 0
                raise ValueError("warning! the parity is violated")

        x = torch.randn(n, 3)
        x = torch.cat([x, -x])
        Y = o3.spherical_harmonics_xyz(list(range(len(Rs))), x)  # [lm, z]
        self.register_buffer('Y', Y)
        self.act = act
コード例 #30
0
ファイル: rs_test.py プロジェクト: soupwaylee/e3nn
def test_simplify():
    Rs = [(1, 0), 0, (1, 0)]
    assert rs.simplify(Rs) == [(3, 0, 0)]