示例#1
0
文件: s2_tests.py 项目: Nanco-L/e3nn
        def test(Rs, act):
            x = torch.randn(55, sum(2 * l + 1 for _, l, _ in Rs))
            ac = S2Activation(Rs, act, 1000)

            y1 = ac(x, dim=-1) @ SO3.rep(ac.Rs_out, 0, 0, 0, -1).T
            y2 = ac(x @ SO3.rep(Rs, 0, 0, 0, -1).T, dim=-1)
            self.assertLess((y1 - y2).abs().max(), 1e-10)
示例#2
0
    def precompute(self, R):
        a = torch.linspace(0, 2 * math.pi, 2 * self.n)
        b = torch.linspace(0, math.pi, self.n)[2:-2]
        a, b = torch.meshgrid(a, b)

        xyz = torch.stack(SO3.angles_to_xyz(a, b), dim=-1) @ R.t()
        a, b = SO3.xyz_to_angles(xyz)

        proj = SphericalHarmonicsProject(a, b, self.lmax)
        return xyz, proj
示例#3
0
    def __init__(self, n, lmax):
        super().__init__()
        self.n = n
        self.lmax = lmax

        R = SO3.rot(math.pi / 2, math.pi / 2, math.pi / 2)
        self.xyz1, self.proj1 = self.precompute(R)

        R = SO3.rot(0, 0, 0)
        self.xyz2, self.proj2 = self.precompute(R)
示例#4
0
文件: kernel.py 项目: Nanco-L/e3nn
def check_basis_equivariance(basis, order_in, order_out, alpha, beta, gamma):
    from e3nn import SO3
    from scipy.ndimage import affine_transform
    import numpy as np

    n = basis.size(0)
    dim_in = 2 * order_in + 1
    dim_out = 2 * order_out + 1
    size = basis.size(-1)
    assert basis.size() == (n, dim_out, dim_in, size, size, size), basis.size()

    basis = basis / basis.view(n, -1).norm(dim=1).view(-1, 1, 1, 1, 1, 1)

    x = basis.view(-1, size, size, size)
    y = torch.empty_like(x)

    invrot = SO3.rot(-gamma, -beta, -alpha).numpy()
    center = (np.array(x.size()[1:]) - 1) / 2

    for k in range(y.size(0)):
        y[k] = torch.tensor(affine_transform(x[k].numpy(), matrix=invrot, offset=center - np.dot(invrot, center)))

    y = y.view(*basis.size())

    y = torch.einsum(
        "ij,bjkxyz,kl->bilxyz",
        (
            irr_repr(order_out, alpha.item(), beta.item(), gamma.item(), dtype=y.dtype),
            y,
            irr_repr(order_in, -gamma.item(), -beta.item(), -alpha.item(), dtype=y.dtype)
        )
    )

    return torch.tensor([(basis[i] * y[i]).sum() for i in range(n)])
示例#5
0
def random_rotate_translate(positions, rotation=True, translation=1):
    while True:
        trans = torch.rand(3) * 2 - 1
        if trans.norm() <= 1:
            break
    rot = SO3.rot(*torch.rand(3) * 6.2832).type(torch.float32)
    return [rot @ pos + translation * trans for pos in positions]
示例#6
0
    def __init__(self, Rs, normalization='norm'):
        super().__init__()

        Rs = SO3.normalizeRs(Rs)
        n = sum(mul for mul, _, _ in Rs)
        self.Rs_in = Rs
        self.Rs_out = [(n, 0, +1)]
        self.normalization = normalization
示例#7
0
    def __init__(self, Rs, acts):
        '''
        Can be used only with scalar fields

        :param acts: list of tuple (multiplicity, activation)
        '''
        super().__init__()

        Rs = SO3.normalizeRs(Rs)
        acts = copy.deepcopy(acts)

        n1 = sum(mul for mul, _, _ in Rs)
        n2 = sum(mul for mul, _ in acts if mul > 0)

        for i, (mul, act) in enumerate(acts):
            if mul == -1:
                acts[i] = (n1 - n2, act)
                assert n1 - n2 >= 0

        assert n1 == sum(mul for mul, _ in acts)

        i = 0
        while i < len(Rs):
            mul_r, l, p_r = Rs[i]
            mul_a, act = acts[i]

            if mul_r < mul_a:
                acts[i] = (mul_r, act)
                acts.insert(i + 1, (mul_a - mul_r, act))

            if mul_a < mul_r:
                Rs[i] = (mul_a, l, p_r)
                Rs.insert(i + 1, (mul_r - mul_a, l, p_r))
            i += 1

        x = torch.linspace(0, 10, 256)

        Rs_out = []
        for (mul, l, p_in), (mul_a, act) in zip(Rs, acts):
            assert mul == mul_a

            a1, a2 = act(x), act(-x)
            if (a1 - a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = 1
            elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10:
                p_act = -1
            else:
                p_act = 0

            p = p_act if p_in == -1 else p_in
            Rs_out.append((mul, 0, p))

            if p_in != 0 and p == 0:
                raise ValueError("warning! the parity is violated")

        self.Rs_out = Rs_out
        self.acts = acts
示例#8
0
    def __init__(self, Rs_1, Rs_2):
        super().__init__()

        Rs_1 = SO3.normalizeRs(Rs_1)
        Rs_2 = SO3.normalizeRs(Rs_2)
        assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2)

        i = 0
        while i < len(Rs_1):
            mul_1, l_1, p_1 = Rs_1[i]
            mul_2, l_2, p_2 = Rs_2[i]

            if mul_1 < mul_2:
                Rs_2[i] = (mul_1, l_2, p_2)
                Rs_2.insert(i + 1, (mul_2 - mul_1, l_2, p_2))

            if mul_2 < mul_1:
                Rs_1[i] = (mul_2, l_1, p_1)
                Rs_1.insert(i + 1, (mul_1 - mul_2, l_1, p_1))
            i += 1

        self.Rs_1 = Rs_1
        self.Rs_2 = Rs_2

        Rs_out = []
        for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2):
            assert mul == mul_2
            for l in range(abs(l_1 - l_2), l_1 + l_2 + 1):
                Rs_out.append((mul, l, p_1 * p_2))

                C = SO3.clebsch_gordan(l, l_1, l_2).type(
                    torch.get_default_dtype()) * (2 * l + 1)**0.5
                if l_1 == 0 or l_2 == 0:
                    m = C.view(2 * l + 1, 2 * l + 1)
                    if C.dtype == torch.float:
                        assert (m - torch.eye(2 * l + 1, dtype=C.dtype)
                                ).abs().max() < 1e-7, m.numpy().round(3)
                    else:
                        assert (m - torch.eye(2 * l + 1, dtype=C.dtype)
                                ).abs().max() < 1e-10, m.numpy().round(3)
                else:
                    self.register_buffer("cg_{}_{}_{}".format(l, l_1, l_2), C)

        self.Rs_out = Rs_out
示例#9
0
    def __init__(self, Rs):
        super().__init__()

        self.Rs_in = SO3.normalizeRs(Rs)
        xs = []

        j = 0  # input offset
        for mul, l, p in self.Rs_in:
            d = mul * (2 * l + 1)
            xs.append((l, p, mul, j, d))
            j += d

        mixing_matrix = torch.zeros(j, j)

        Rs_out = []
        i = 0  # output offset
        for l, p, mul, j, d in sorted(xs):
            Rs_out.append((mul, l, p))
            mixing_matrix[i:i+d, j:j+d] = torch.eye(d)
            i += d

        self.Rs_out = SO3.normalizeRs(Rs_out)
        self.register_buffer('mixing_matrix', mixing_matrix)
示例#10
0
文件: kernel.py 项目: Nanco-L/e3nn
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 RadialModel,
                 get_l_filters=None,
                 sh=SO3.spherical_harmonics_xyz,
                 normalization='norm'):
        '''
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param get_l_filters: function of signature (l_in, l_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        '''
        super().__init__()

        self.Rs_in = SO3.normalizeRs(Rs_in)
        self.Rs_out = SO3.normalizeRs(Rs_out)

        def filters_with_parity(l_in, p_in, l_out, p_out):
            def filters(l_in, l_out):
                return list(range(abs(l_in - l_out), l_in + l_out + 1))

            nonlocal get_l_filters
            fn = filters if get_l_filters is None else get_l_filters
            return [
                l for l in fn(l_in, l_out)
                if p_out == 0 or p_in * (-1)**l == p_out
            ]

        self.get_l_filters = filters_with_parity
        self.check_input_output()
        self.sh = sh

        assert isinstance(
            normalization,
            str), "normalization should be passed as a string value"
        assert normalization in [
            'norm', 'component'
        ], "normalization needs to be 'norm' or 'component'"
        self.normalization = normalization

        def lm_normalization(l_out, l_in):
            # put 2l_in+1 to keep the norm of the m vector constant
            # put 2l_ou+1 to keep the variance of each m component constant
            # sum_m Y_m^2 = (2l+1)/(4pi)  and  norm(Q) = 1  implies that norm(QY) = sqrt(1/4pi)
            lm_norm = None
            if normalization == 'norm':
                lm_norm = math.sqrt(2 * l_in + 1) * math.sqrt(4 * math.pi)
            elif normalization == 'component':
                lm_norm = math.sqrt(2 * l_out + 1) * math.sqrt(4 * math.pi)
            return lm_norm

        norm_coef = torch.zeros((len(self.Rs_out), len(self.Rs_in), 2))

        n_path = 0
        set_of_l_filters = set()

        for i, (mul_out, l_out, p_out) in enumerate(self.Rs_out):
            # consider that we sum a bunch of [lambda_(m_out)] vectors
            # we need to count how many of them we sum in order to normalize the network
            num_summed_elements = 0
            for mul_in, l_in, p_in in self.Rs_in:
                l_filters = self.get_l_filters(l_in, p_in, l_out, p_out)
                num_summed_elements += mul_in * len(l_filters)

            for j, (mul_in, l_in, p_in) in enumerate(self.Rs_in):
                # normalization assuming that each terms are of order 1 and uncorrelated
                norm_coef[i, j, 0] = lm_normalization(
                    l_out, l_in) / math.sqrt(num_summed_elements)
                norm_coef[i, j, 1] = lm_normalization(l_out,
                                                      l_in) / math.sqrt(mul_in)

                l_filters = self.get_l_filters(l_in, p_in, l_out, p_out)
                assert l_filters == sorted(
                    set(l_filters)
                ), "get_l_filters must return a sorted list of unique values"

                # compute the number of degrees of freedom
                n_path += mul_out * mul_in * len(l_filters)

                # create the set of all spherical harmonics orders needed
                set_of_l_filters = set_of_l_filters.union(l_filters)

        # create the radial model: R+ -> R^n_path
        # it contains the learned parameters
        self.R = RadialModel(n_path)
        self.set_of_l_filters = sorted(set_of_l_filters)
        self.register_buffer('norm_coef', norm_coef)
示例#11
0
文件: kernel.py 项目: Nanco-L/e3nn
    def backward(ctx, grad_kernel):
        Y, R, norm_coef = ctx.saved_tensors

        grad_Y = grad_R = None

        if ctx.needs_input_grad[0]:
            grad_Y = grad_kernel.new_zeros(
                *ctx.Y_shape)  # [l_filter * m_filter, batch]
        if ctx.needs_input_grad[1]:
            grad_R = grad_kernel.new_zeros(
                *ctx.R_shape
            )  # [batch, l_out * l_in * mul_out * mul_in * l_filter]

        begin_R = 0

        begin_out = 0
        for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out):
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            begin_in = 0
            for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in):
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out)
                if not l_filters:
                    continue

                n = mul_out * mul_in * len(l_filters)
                if grad_Y is not None:
                    sub_R = R[:, begin_R:begin_R + n].view(
                        -1, mul_out, mul_in,
                        len(l_filters))  # [batch, mul_out, mul_in, l_filter]
                if grad_R is not None:
                    sub_grad_R = grad_R[:, begin_R:begin_R + n].view(
                        -1, mul_out, mul_in,
                        len(l_filters))  # [batch, mul_out, mul_in, l_filter]
                begin_R += n

                grad_K = grad_kernel[:, s_out,
                                     s_in].view(-1, mul_out, 2 * l_out + 1,
                                                mul_in, 2 * l_in + 1)

                sub_norm_coef = norm_coef[i, j]  # [batch]

                for k, l_filter in enumerate(l_filters):
                    tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters
                              if l < l_filter)
                    C = SO3.clebsch_gordan(
                        l_out, l_in, l_filter, cached=True,
                        like=grad_kernel)  # [m_out, m_in, m]

                    if grad_Y is not None:
                        grad_Y[tmp:tmp + 2 * l_filter + 1] += torch.einsum(
                            "zuivj,ijk,zuv,z->kz", grad_K, C, sub_R[..., k],
                            sub_norm_coef)
                    if grad_R is not None:
                        sub_Y = Y[tmp:tmp + 2 * l_filter + 1]  # [m, batch]
                        sub_grad_R[...,
                                   k] = torch.einsum("zuivj,ijk,kz,z->zuv",
                                                     grad_K, C, sub_Y,
                                                     sub_norm_coef)

        del ctx
        return grad_Y, grad_R, None, None, None, None, None
示例#12
0
文件: kernel.py 项目: Nanco-L/e3nn
    def forward(ctx, Y, R, norm_coef, Rs_in, Rs_out, get_l_filters,
                set_of_l_filters):
        """
        :param Y: tensor [l_filter * m_filter, batch]
        :param R: tensor [batch, l_out * l_in * mul_out * mul_in * l_filter]
        :param norm_coef: tensor [l_out, l_in, batch]
        :return: tensor [batch, l_out * mul_out * m_out, l_in * mul_in * m_in]
        """
        ctx.Rs_in = Rs_in
        ctx.Rs_out = Rs_out
        ctx.get_l_filters = get_l_filters
        ctx.set_of_l_filters = set_of_l_filters

        # save necessary tensors for backward
        saved_Y = saved_R = None
        if Y.requires_grad:
            ctx.Y_shape = Y.shape
            saved_R = R
        if R.requires_grad:
            ctx.R_shape = R.shape
            saved_Y = Y
        ctx.save_for_backward(saved_Y, saved_R, norm_coef)

        batch = Y.shape[1]
        n_in = sum(mul * (2 * l + 1) for mul, l, _ in ctx.Rs_in)
        n_out = sum(mul * (2 * l + 1) for mul, l, _ in ctx.Rs_out)

        kernel = Y.new_zeros(batch, n_out, n_in)

        # note: for the normalization we assume that the variance of R[i] is one
        begin_R = 0

        begin_out = 0
        for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out):
            s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1))
            begin_out += mul_out * (2 * l_out + 1)

            begin_in = 0
            for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in):
                s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))
                begin_in += mul_in * (2 * l_in + 1)

                l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out)
                if not l_filters:
                    continue

                # extract the subset of the `R` that corresponds to the couple (l_out, l_in)
                n = mul_out * mul_in * len(l_filters)
                sub_R = R[:, begin_R:begin_R + n].contiguous().view(
                    batch, mul_out, mul_in,
                    -1)  # [batch, mul_out, mul_in, l_filter]
                begin_R += n

                sub_norm_coef = norm_coef[i, j]  # [batch]

                # note: I don't know if we can vectorize this for loop because [l_filter * m_filter] cannot be put into [l_filter, m_filter]
                K = 0
                for k, l_filter in enumerate(l_filters):
                    tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters
                              if l < l_filter)
                    sub_Y = Y[tmp:tmp + 2 * l_filter + 1]  # [m, batch]

                    C = SO3.clebsch_gordan(l_out,
                                           l_in,
                                           l_filter,
                                           cached=True,
                                           like=kernel)  # [m_out, m_in, m]

                    # note: The multiplication with `sub_R` could also be done outside of the for loop
                    K += torch.einsum(
                        "ijk,kz,zuv,z->zuivj",
                        (C, sub_Y, sub_R[..., k], sub_norm_coef
                         ))  # [batch, mul_out, m_out, mul_in, m_in]

                if K is not 0:
                    kernel[:, s_out,
                           s_in] = K.contiguous().view_as(kernel[:, s_out,
                                                                 s_in])

        return kernel
示例#13
0
文件: kernel.py 项目: Nanco-L/e3nn
 def __repr__(self):
     return "{name} ({Rs_in} -> {Rs_out})".format(
         name=self.__class__.__name__,
         Rs_in=SO3.formatRs(self.Rs_in),
         Rs_out=SO3.formatRs(self.Rs_out),
     )
示例#14
0
 def __init__(self, alpha, beta, lmax):
     super().__init__()
     sh = torch.cat([SO3.spherical_harmonics(l, alpha, beta) for l in range(lmax + 1)])
     self.register_buffer("sh", sh)