Beispiel #1
0
    def __init__(self, Rs_in, Rs_out, RadialModel,
                 selection_rule=o3.selection_rule_in_out_sh,
                 normalization='component',
                 allow_unused_inputs=False,
                 allow_zero_outputs=False):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        if not allow_unused_inputs:
            self.check_input(selection_rule)
        if not allow_zero_outputs:
            self.check_output(selection_rule)

        self.normalization = normalization

        self.tp = rs.TensorProduct(self.Rs_in, selection_rule, Rs_out, normalization, sorted=True)
        self.Rs_f = self.tp.Rs_in2

        self.Ls = [l for _, l, _ in self.Rs_f]
        self.R = RadialModel(rs.mul_dim(self.Rs_f))

        self.linear = KernelLinear(self.Rs_in, self.Rs_out)
Beispiel #2
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 RadialModel,
                 selection_rule=o3.selection_rule_in_out_sh,
                 sh=rsh.spherical_harmonics_xyz,
                 normalization='component'):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        self.check_input_output(selection_rule)

        Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule,
                                   normalization)
        self.register_buffer('Q', Q)  # [out, in, Y, R]

        self.sh = sh
        self.Ls = [l for _, l, _ in Rs_f]
        self.R = RadialModel(rs.mul_dim(Rs_f))

        self.linear = KernelLinear(self.Rs_in, self.Rs_out)
Beispiel #3
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 RadialModel,
                 r,
                 r_eps=0,
                 selection_rule=o3.selection_rule_in_out_sh,
                 normalization='component'):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param tensor r: [..., 3]
        :param float r_eps: distance considered as zero
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        self.check_input_output(selection_rule)

        *self.size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)  # [batch, space]
        self.register_buffer('radii', r.norm(2, dim=1))  # [batch]
        self.r_eps = r_eps

        self.tp = rs.TensorProduct(self.Rs_in,
                                   selection_rule,
                                   self.Rs_out,
                                   normalization,
                                   sorted=True)
        self.Rs_f = self.tp.Rs_in2

        Y = rsh.spherical_harmonics_xyz(
            [(1, l, p) for _, l, p in self.Rs_f],
            r[self.radii > self.r_eps])  # [batch, l_filter * m_filter]

        # Normalize the spherical harmonics
        if normalization == 'component':
            Y.mul_(math.sqrt(4 * math.pi))
        if normalization == 'norm':
            diag = math.sqrt(4 * math.pi) * torch.cat([
                torch.ones(2 * l + 1) / math.sqrt(2 * l + 1)
                for _, l, _ in self.Rs_f
            ])
            Y.mul_(diag)

        self.register_buffer('Y', Y)
        self.R = RadialModel(rs.mul_dim(self.Rs_f))

        if (self.radii <= self.r_eps).any():
            self.linear = KernelLinear(self.Rs_in, self.Rs_out)
        else:
            self.linear = None
Beispiel #4
0
def GroupedWeightedTensorProduct(Rs_in1,
                                 Rs_in2,
                                 Rs_out,
                                 groups=math.inf,
                                 normalization='component',
                                 own_weight=True):
    Rs_in1 = rs.convention(Rs_in1)
    Rs_in2 = rs.convention(Rs_in2)
    Rs_out = rs.convention(Rs_out)

    groups = min(groups, min(mul for mul, _, _ in Rs_in1),
                 min(mul for mul, _, _ in Rs_out))

    Rs_in1 = [(mul // groups + (g < mul % groups), l, p)
              for mul, l, p in Rs_in1 for g in range(groups)]
    Rs_out = [(mul // groups + (g < mul % groups), l, p)
              for mul, l, p in Rs_out for g in range(groups)]

    instr = [(i_1, i_2, i_out, 'uvw')
             for i_1, (_, l_1, p_1) in enumerate(Rs_in1)
             for i_2, (_, l_2, p_2) in enumerate(Rs_in2)
             for i_out, (_, l_out, p_out) in enumerate(Rs_out)
             if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out
             if i_1 % groups == i_out % groups]
    return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr,
                                       normalization, own_weight)
Beispiel #5
0
 def test_conventionRs(self):
     Rs = [(1, 0)]
     Rs_out = rs.convention(Rs)
     self.assertSequenceEqual(Rs_out, [(1, 0, 0)])
     Rs = [(1, 0), (2, 0)]
     Rs_out = rs.convention(Rs)
     self.assertSequenceEqual(Rs_out, [(1, 0, 0), (2, 0, 0)])
Beispiel #6
0
    def __init__(self, Rs_in, Rs_out):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()
        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        self.check_input_output()

        self.kernel = KernelLinear(self.Rs_in, self.Rs_out)
Beispiel #7
0
    def __init__(self, Rs, activation, normalization='component'):
        super().__init__()

        self.Rs = rs.convention(Rs)
        self.activation = activation
        self.norm = Norm(self.Rs, normalization)
        self.bias = torch.nn.Parameter(torch.zeros(rs.mul_dim(self.Rs)))
Beispiel #8
0
    def __init__(self, signal, mul, lmax, p_val=0, p_arg=0):
        """
        f: s2 x r -> R^N

        Rotations
        [D(g) f](x) = f(g^{-1} x)

        Parity
        [P f](x) = p_val f(p_arg x)

        f(x) = sum F^l . Y^l(x)

        This class contains the F^l

        Rotations
        [D(g) f](x) = sum [D^l(g) F^l] . Y^l(x)         (using equiv. of Y and orthogonality of D)

        Parity
        [P f](x) = sum [p_val p_arg^l F^l] . Y^l(x)     (using parity of Y)
        """
        if signal.shape[-1] != mul * (lmax + 1)**2:
            raise ValueError(
                "Last tensor dimension and Rs do not have same dimension.")

        self.signal = signal
        self.lmax = lmax
        self.mul = mul
        self.Rs = rs.convention([(mul, l, p_val * p_arg**l)
                                 for l in range(lmax + 1)])
        self.radial_model = None
Beispiel #9
0
    def __init__(self, signal: torch.Tensor, p_val: int = 0, p_arg: int = 0):
        """
        f: s2 -> R

        Rotations
        [D(g) f](x) = f(g^{-1} x)

        Parity
        [P f](x) = p_val f(p_arg x)

        f(x) = sum F^l . Y^l(x)

        This class contains the F^l

        Rotations
        [D(g) f](x) = sum [D^l(g) F^l] . Y^l(x)         (using equiv. of Y and orthogonality of D)

        Parity
        [P f](x) = sum [p_val p_arg^l F^l] . Y^l(x)     (using parity of Y)
        """
        lmax = round(math.sqrt(signal.shape[-1]) - 1)

        if signal.shape[-1] != (lmax + 1)**2:
            raise ValueError(
                "Last tensor dimension and Rs do not have same dimension.")

        self.signal = signal
        self.lmax = lmax
        self.Rs = rs.convention([(1, l, p_val * p_arg**l)
                                 for l in range(lmax + 1)])
        self.p_val = p_val
        self.p_arg = p_arg
Beispiel #10
0
def WeightedTensorProduct(Rs_in1,
                          Rs_in2,
                          Rs_out,
                          normalization='component',
                          own_weight=True):
    Rs_in1 = rs.convention(Rs_in1)
    Rs_in2 = rs.convention(Rs_in2)
    Rs_out = rs.convention(Rs_out)

    instr = [(i_1, i_2, i_out, 'uvw')
             for i_1, (_, l_1, p_1) in enumerate(Rs_in1)
             for i_2, (_, l_2, p_2) in enumerate(Rs_in2)
             for i_out, (_, l_out, p_out) in enumerate(Rs_out)
             if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out]
    return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr,
                                       normalization, own_weight)
Beispiel #11
0
 def __init__(self, tensor, Rs):
     Rs = rs.convention(Rs)
     if tensor.shape[-1] != rs.dim(Rs):
         raise ValueError(
             "Last tensor dimension and Rs do not have same dimension.")
     self.tensor = tensor
     self.Rs = Rs
Beispiel #12
0
    def __init__(self,
                 Rs_in,
                 Rs_out,
                 RadialModel,
                 r,
                 r_eps=0,
                 selection_rule=o3.selection_rule_in_out_sh,
                 sh=rsh.spherical_harmonics_xyz,
                 normalization='component'):
        """
        :param Rs_in: list of triplet (multiplicity, representation order, parity)
        :param Rs_out: list of triplet (multiplicity, representation order, parity)
        :param RadialModel: Class(d), trainable model: R -> R^d
        :param tensor r: [..., 3]
        :param float r_eps: distance considered as zero
        :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter]
        :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...]
        :param normalization: either 'norm' or 'component'
        representation order = nonnegative integer
        parity = 0 (no parity), 1 (even), -1 (odd)
        """
        super().__init__()

        self.Rs_in = rs.convention(Rs_in)
        self.Rs_out = rs.convention(Rs_out)
        self.check_input_output(selection_rule)

        *self.size, xyz = r.size()
        assert xyz == 3
        r = r.reshape(-1, 3)  # [batch, space]
        self.radii = r.norm(2, dim=1)  # [batch]
        self.r_eps = r_eps

        Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule,
                                   normalization)
        Y = sh([l for _, l, _ in Rs_f],
               r[self.radii > self.r_eps])  # [batch, l_filter * m_filter]
        Q = torch.einsum('ijyw,zy->zijw', Q, Y)
        self.register_buffer('Q', Q)  # [out, in, Y, R]

        self.R = RadialModel(rs.mul_dim(Rs_f))

        if (self.radii <= self.r_eps).any():
            self.linear = KernelLinear(self.Rs_in, self.Rs_out)
        else:
            self.linear = None
Beispiel #13
0
    def __init__(self,
                 Rs_in,
                 mul,
                 Rs_out,
                 lmax,
                 layers=3,
                 max_radius=1.0,
                 number_of_basis=3,
                 radial_layers=3,
                 kernel=Kernel,
                 convolution=Convolution,
                 min_radius=0.0):
        super().__init__()

        R = partial(GaussianRadialModel,
                    max_radius=max_radius,
                    number_of_basis=number_of_basis,
                    h=100,
                    L=radial_layers,
                    act=swish,
                    min_radius=min_radius)
        K = partial(kernel,
                    RadialModel=R,
                    selection_rule=partial(o3.selection_rule_in_out_sh,
                                           lmax=lmax))

        modules = []

        Rs = rs.convention(Rs_in)
        for _ in range(layers):
            scalars = [(mul, l, p)
                       for mul, l, p in [(mul, 0, +1), (mul, 0, -1)]
                       if rs.haslinearpath(Rs, l, p)]
            act_scalars = [(mul, swish if p == 1 else tanh)
                           for mul, l, p in scalars]

            nonscalars = [(mul, l, p) for l in range(1, lmax + 1)
                          for p in [+1, -1] if rs.haslinearpath(Rs, l, p)]
            if rs.haslinearpath(Rs, 0, +1):
                gates = [(rs.mul_dim(nonscalars), 0, +1)]
                act_gates = [(-1, sigmoid)]
            else:
                gates = [(rs.mul_dim(nonscalars), 0, -1)]
                act_gates = [(-1, tanh)]

            act = GatedBlockParity(scalars, act_scalars, gates, act_gates,
                                   nonscalars)
            conv = convolution(K(Rs, act.Rs_in))

            Rs = act.Rs_out

            block = torch.nn.ModuleList([conv, act])
            modules.append(block)

        self.layers = torch.nn.ModuleList(modules)

        K = partial(K, allow_unused_inputs=True)
        self.layers.append(convolution(K(Rs, Rs_out)))
Beispiel #14
0
    def __init__(self, signal, mul, lmax):
        if signal.shape[-1] != mul * (lmax + 1)**2:
            raise ValueError(
                "Last tensor dimension and Rs do not have same dimension.")

        self.signal = signal
        self.lmax = lmax
        self.mul = mul
        self.Rs = rs.convention([(mul, l) for l in range(lmax + 1)])
        self.radial_model = None
Beispiel #15
0
    def __init__(self, Rs_in, Rs_out, layer, activation, layers=3):
        super().__init__()

        modules = []

        Rs = rs.convention(Rs_in)
        for _ in range(layers):
            act = activation(Rs)
            lay = layer(Rs, act.Rs_in)

            Rs = act.Rs_out

            modules += [torch.nn.ModuleList([lay, act])]

        self.layers = torch.nn.ModuleList(modules)
        self.layers.append(layer(Rs, Rs_out))
Beispiel #16
0
    def __init__(self, Rs_in, mul, Rs_out, lmax, size=5, layers=3):
        super().__init__()

        modules = []

        Rs = rs.convention(Rs_in)
        for _ in range(layers):
            scalars = [(mul, l, p)
                       for mul, l, p in [(mul, 0, +1), (mul, 0, -1)]
                       if rs.haslinearpath(Rs, l, p)]
            act_scalars = [(mul, swish if p == 1 else tanh)
                           for mul, l, p in scalars]

            nonscalars = [(mul, l, p) for l in range(1, lmax + 1)
                          for p in [+1, -1] if rs.haslinearpath(Rs, l, p)]
            gates = [(rs.mul_dim(nonscalars), 0, +1)]
            if rs.haslinearpath(Rs, 0, +1):
                gates = [(rs.mul_dim(nonscalars), 0, +1)]
                act_gates = [(-1, sigmoid)]
            else:
                gates = [(rs.mul_dim(nonscalars), 0, -1)]
                act_gates = [(-1, tanh)]

            act = GatedBlockParity(scalars, act_scalars, gates, act_gates,
                                   nonscalars)
            conv = Convolution(Rs,
                               act.Rs_in,
                               size,
                               lmax=lmax,
                               fuzzy_pixels=True,
                               padding=size // 2)

            Rs = act.Rs_out

            block = torch.nn.Sequential(conv, act)
            modules.append(block)

        modules += [
            Convolution(Rs,
                        Rs_out,
                        size,
                        lmax=lmax,
                        fuzzy_pixels=True,
                        padding=size // 2,
                        allow_unused_inputs=True)
        ]
        self.layers = torch.nn.Sequential(*modules)
Beispiel #17
0
    def __init__(self,
                 Rs_in1: rs.TY_RS_LOOSE,
                 Rs_in2: rs.TY_RS_LOOSE,
                 Rs_out: rs.TY_RS_LOOSE,
                 instr: List[Tuple[int, int, int, str]],
                 normalization: str = 'component',
                 own_weight: bool = True):
        """
        Create a Tensor Product operation that has each of his path weighted by a parameter.
        `instr` is a list of instructions.
        An instruction if of the form (i_1, i_2, i_out, mode)
        it means "Put `Rs_in1[i_1] otimes Rs_in2[i_2] into Rs_out[i_out]"
        `mode` determines the way the multiplicities are treated.
        The default mode should be 'uvw', meaning that all paths are created.
        """

        super().__init__()

        self.Rs_in1 = rs.convention(Rs_in1)
        self.Rs_in2 = rs.convention(Rs_in2)
        self.Rs_out = rs.convention(Rs_out)

        code = ""

        index_w = 0
        wigners = set()
        count = [0 for _ in range(rs.dim(self.Rs_out))]

        instr = sorted(instr)  # for optimization

        last_s1, last_s2, last_ss = None, None, None
        for i_1, i_2, i_out, mode in instr:
            mul_1, l_1, p_1 = self.Rs_in1[i_1]
            mul_2, l_2, p_2 = self.Rs_in2[i_2]
            mul_out, l_out, p_out = self.Rs_out[i_out]
            dim_1 = mul_1 * (2 * l_1 + 1)
            dim_2 = mul_2 * (2 * l_2 + 1)
            dim_out = mul_out * (2 * l_out + 1)
            index_1 = rs.dim(self.Rs_in1[:i_1])
            index_2 = rs.dim(self.Rs_in2[:i_2])
            index_out = rs.dim(self.Rs_out[:i_out])

            assert p_1 * p_2 == p_out
            assert abs(l_1 - l_2) <= l_out <= l_1 + l_2

            if dim_1 == 0 or dim_2 == 0 or dim_out == 0:
                continue

            if last_s1 != i_1:
                code += f"    s1 = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n"
                last_s1 = i_1

            if last_s2 != i_2:
                code += f"    s2 = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n"
                last_s2 = i_2

            assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']

            if last_ss != (i_1, i_2, mode[:2]):
                if mode[:2] == 'uv':
                    code += f"    ss = ein('zui,zvj->zuvij', s1, s2)\n"
                if mode[:2] == 'uu':
                    code += f"    ss = ein('zui,zuj->zuij', s1, s2)\n"
                last_ss = (i_1, i_2, mode[:2])

            wigners.add((l_1, l_2, l_out))

            if mode == 'uvw':
                dim_w = mul_1 * mul_2 * mul_out
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2}, {mul_out})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuvw,ijk,zuvij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1 * mul_2

            if mode == 'uvu':
                assert mul_1 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_2

            if mode == 'uvv':
                assert mul_2 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1

            if mode == 'uuw':
                assert mul_1 == mul_2
                dim_w = mul_1 * mul_out
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_out})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuw,ijk,zuij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1

            if mode == 'uuu':
                assert mul_1 == mul_2 == mul_out
                dim_w = mul_1
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zu,ijk,zuij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += 1

            if mode == 'uvuv':
                assert mul_1 * mul_2 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += 1

            index_w += dim_w
            code += "\n"

        ilast = 0
        clast = count[0]
        for i, c in enumerate(count):
            if clast != c:
                if clast > 1:
                    code += f"    out[:, {ilast}:{i}].div_({clast ** 0.5})\n"
                clast = c
                ilast = i
        if clast > 1:
            code += f"    out[:, {ilast}:].div_({clast ** 0.5})\n"

        wigners = sorted(wigners)
        self.wigners_names = [
            f"C{l_1}_{l_2}_{l_3}" for l_1, l_2, l_3 in wigners
        ]
        args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names)

        for arg, (l_1, l_2, l_out) in zip(self.wigners_names, wigners):
            wig = o3.wigner_3j(l_1, l_2, l_out)

            if normalization == 'component':
                wig *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(arg, wig)

        x = _tensor_product_code
        x = x.replace("DIM", f"{rs.dim(self.Rs_out)}")
        x = x.replace("ARGS", args)
        x = x.replace("CODE", code)

        self.code = x
        self.main = eval_code(x).main
        self.nweight = index_w
        if own_weight:
            self.weight = torch.nn.Parameter(torch.randn(self.nweight))
Beispiel #18
0
    def __init__(self,
                 Rs_in1,
                 Rs_in2,
                 Rs_out,
                 selection_rule=o3.selection_rule,
                 normalization='component',
                 groups=1):
        super().__init__()

        self.Rs_in1 = rs.convention(Rs_in1)
        self.Rs_in2 = rs.convention(Rs_in2)
        self.Rs_out = rs.convention(Rs_out)

        code = ""

        index_w = 0
        wigners = set()
        count = [0 for _ in range(rs.dim(self.Rs_out))]

        index_1 = 0
        for mul_1, l_1, p_1 in self.Rs_in1:
            dim_1 = mul_1 * (2 * l_1 + 1)

            index_2 = 0
            for mul_2, l_2, p_2 in self.Rs_in2:
                dim_2 = mul_2 * (2 * l_2 + 1)

                gmul_1s = [
                    mul_1 // groups + (g < mul_1 % groups)
                    for g in range(groups)
                ]
                gmul_2s = [
                    mul_2 // groups + (g < mul_2 % groups)
                    for g in range(groups)
                ]

                for g in range(groups):
                    if gmul_1s[g] * gmul_2s[g] == 0:
                        continue

                    code += f"    s1 = x1[:, {index_1+sum(gmul_1s[:g])*(2*l_1+1)}:{index_1+sum(gmul_1s[:g+1])*(2*l_1+1)}].reshape(batch, {gmul_1s[g]}, {2 * l_1 + 1})\n"
                    code += f"    s2 = x2[:, {index_2+sum(gmul_2s[:g])*(2*l_2+1)}:{index_2+sum(gmul_2s[:g+1])*(2*l_2+1)}].reshape(batch, {gmul_2s[g]}, {2 * l_2 + 1})\n"
                    code += f"    ss = ein('zui,zvj->zuvij', s1, s2)\n"

                    index_out = 0
                    for mul_out, l_out, p_out in self.Rs_out:
                        dim_out = mul_out * (2 * l_out + 1)

                        if l_out in selection_rule(l_1, p_1, l_2,
                                                   p_2) and p_out == p_1 * p_2:
                            wigners.add((l_out, l_1, l_2))

                            gmul_outs = [
                                mul_out // groups + (g < mul_out % groups)
                                for g in range(groups)
                            ]
                            dim_w = gmul_outs[g] * gmul_1s[g] * gmul_2s[g]

                            if gmul_outs[g] == 0:
                                continue

                            code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {gmul_outs[g]}, {gmul_1s[g]}, {gmul_2s[g]})\n"
                            i = index_out + sum(
                                gmul_outs[:g]) * (2 * l_out + 1)
                            j = index_out + sum(
                                gmul_outs[:g + 1]) * (2 * l_out + 1)
                            code += f"    out[:, {i}:{j}] += ein('zwuv,kij,zuvij->zwk', sw, C{l_out}_{l_1}_{l_2}, ss).reshape(batch, {gmul_outs[g]*(2*l_out+1)})\n"
                            code += "\n"

                            for k in range(i, j):
                                count[k] += gmul_1s[g] * gmul_2s[g]

                            index_w += dim_w

                        index_out += dim_out

                index_2 += dim_2
            index_1 += dim_1

        ilast = 0
        clast = count[0]
        for i, c in enumerate(count):
            if clast != c:
                if clast > 1:
                    code += f"    out[:, {ilast}:{i}].div_({clast ** 0.5})\n"
                clast = c
                ilast = i
        if clast > 1:
            code += f"    out[:, {ilast}:].div_({clast ** 0.5})\n"

        wigners = sorted(wigners)
        self.wigners_names = [
            f"C{l_out}_{l_1}_{l_2}" for l_out, l_1, l_2 in wigners
        ]
        args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names)

        for arg, (l_out, l_1, l_2) in zip(self.wigners_names, wigners):
            C = o3.wigner_3j(l_out, l_1, l_2)

            if normalization == 'component':
                C *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(arg, C)

        x = _tensor_product_code
        x = x.replace("DIM", f"{rs.dim(self.Rs_out)}")
        x = x.replace("ARGS", args)
        x = x.replace("CODE", code)

        self.main = eval_code(x).main
        self.nweight = index_w
Beispiel #19
0
 def __init__(self, Rs, p=0.5):
     super().__init__()
     self.Rs = rs.convention(Rs)
     self.p = p
Beispiel #20
0
def test_convention():
    Rs = [0]
    assert rs.convention(Rs) == [(1, 0, 0)]
    Rs = [0, (2, 0)]
    assert rs.convention(Rs) == [(1, 0, 0), (2, 0, 0)]