Esempio n. 1
0
def _legendre_genjit(ls):
    ls = list(ls)
    fill = ""
    i = 0
    for l in ls:
        for m in range(l + 1):
            p = poly_legendre(l, m)
            formula = " + ".join("{:.25f} * z**{} * y**{}".format(c, zn, yn) for (zn, yn), c in p.items())
            fill += "    l{} = {}\n".format(m, formula)

        for m in range(-l, l + 1):
            fill += "    out[..., {}] = l{}\n".format(i, abs(m))
            i += 1

    code = _legendre_code
    code = code.replace("lsize", str(sum(2 * l + 1 for l in ls)))
    code = code.replace("# fill out", fill)
    return eval_code(code).main
Esempio n. 2
0
    def __init__(self,
                 Rs_in1: rs.TY_RS_LOOSE,
                 Rs_in2: rs.TY_RS_LOOSE,
                 Rs_out: rs.TY_RS_LOOSE,
                 instr: List[Tuple[int, int, int, str]],
                 normalization: str = 'component',
                 own_weight: bool = True):
        """
        Create a Tensor Product operation that has each of his path weighted by a parameter.
        `instr` is a list of instructions.
        An instruction if of the form (i_1, i_2, i_out, mode)
        it means "Put `Rs_in1[i_1] otimes Rs_in2[i_2] into Rs_out[i_out]"
        `mode` determines the way the multiplicities are treated.
        The default mode should be 'uvw', meaning that all paths are created.
        """

        super().__init__()

        self.Rs_in1 = rs.convention(Rs_in1)
        self.Rs_in2 = rs.convention(Rs_in2)
        self.Rs_out = rs.convention(Rs_out)

        code = ""

        index_w = 0
        wigners = set()
        count = [0 for _ in range(rs.dim(self.Rs_out))]

        instr = sorted(instr)  # for optimization

        last_s1, last_s2, last_ss = None, None, None
        for i_1, i_2, i_out, mode in instr:
            mul_1, l_1, p_1 = self.Rs_in1[i_1]
            mul_2, l_2, p_2 = self.Rs_in2[i_2]
            mul_out, l_out, p_out = self.Rs_out[i_out]
            dim_1 = mul_1 * (2 * l_1 + 1)
            dim_2 = mul_2 * (2 * l_2 + 1)
            dim_out = mul_out * (2 * l_out + 1)
            index_1 = rs.dim(self.Rs_in1[:i_1])
            index_2 = rs.dim(self.Rs_in2[:i_2])
            index_out = rs.dim(self.Rs_out[:i_out])

            assert p_1 * p_2 == p_out
            assert abs(l_1 - l_2) <= l_out <= l_1 + l_2

            if dim_1 == 0 or dim_2 == 0 or dim_out == 0:
                continue

            if last_s1 != i_1:
                code += f"    s1 = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n"
                last_s1 = i_1

            if last_s2 != i_2:
                code += f"    s2 = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n"
                last_s2 = i_2

            assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']

            if last_ss != (i_1, i_2, mode[:2]):
                if mode[:2] == 'uv':
                    code += f"    ss = ein('zui,zvj->zuvij', s1, s2)\n"
                if mode[:2] == 'uu':
                    code += f"    ss = ein('zui,zuj->zuij', s1, s2)\n"
                last_ss = (i_1, i_2, mode[:2])

            wigners.add((l_1, l_2, l_out))

            if mode == 'uvw':
                dim_w = mul_1 * mul_2 * mul_out
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2}, {mul_out})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuvw,ijk,zuvij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1 * mul_2

            if mode == 'uvu':
                assert mul_1 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_2

            if mode == 'uvv':
                assert mul_2 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1

            if mode == 'uuw':
                assert mul_1 == mul_2
                dim_w = mul_1 * mul_out
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_out})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuw,ijk,zuij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += mul_1

            if mode == 'uuu':
                assert mul_1 == mul_2 == mul_out
                dim_w = mul_1
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zu,ijk,zuij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += 1

            if mode == 'uvuv':
                assert mul_1 * mul_2 == mul_out
                dim_w = mul_1 * mul_2
                code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n"
                code += f"    out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n"

                for pos in range(index_out, index_out + dim_out):
                    count[pos] += 1

            index_w += dim_w
            code += "\n"

        ilast = 0
        clast = count[0]
        for i, c in enumerate(count):
            if clast != c:
                if clast > 1:
                    code += f"    out[:, {ilast}:{i}].div_({clast ** 0.5})\n"
                clast = c
                ilast = i
        if clast > 1:
            code += f"    out[:, {ilast}:].div_({clast ** 0.5})\n"

        wigners = sorted(wigners)
        self.wigners_names = [
            f"C{l_1}_{l_2}_{l_3}" for l_1, l_2, l_3 in wigners
        ]
        args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names)

        for arg, (l_1, l_2, l_out) in zip(self.wigners_names, wigners):
            wig = o3.wigner_3j(l_1, l_2, l_out)

            if normalization == 'component':
                wig *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(arg, wig)

        x = _tensor_product_code
        x = x.replace("DIM", f"{rs.dim(self.Rs_out)}")
        x = x.replace("ARGS", args)
        x = x.replace("CODE", code)

        self.code = x
        self.main = eval_code(x).main
        self.nweight = index_w
        if own_weight:
            self.weight = torch.nn.Parameter(torch.randn(self.nweight))
Esempio n. 3
0
    def __init__(self,
                 Rs_in1,
                 Rs_in2,
                 Rs_out,
                 selection_rule=o3.selection_rule,
                 normalization='component',
                 groups=1):
        super().__init__()

        self.Rs_in1 = rs.convention(Rs_in1)
        self.Rs_in2 = rs.convention(Rs_in2)
        self.Rs_out = rs.convention(Rs_out)

        code = ""

        index_w = 0
        wigners = set()
        count = [0 for _ in range(rs.dim(self.Rs_out))]

        index_1 = 0
        for mul_1, l_1, p_1 in self.Rs_in1:
            dim_1 = mul_1 * (2 * l_1 + 1)

            index_2 = 0
            for mul_2, l_2, p_2 in self.Rs_in2:
                dim_2 = mul_2 * (2 * l_2 + 1)

                gmul_1s = [
                    mul_1 // groups + (g < mul_1 % groups)
                    for g in range(groups)
                ]
                gmul_2s = [
                    mul_2 // groups + (g < mul_2 % groups)
                    for g in range(groups)
                ]

                for g in range(groups):
                    if gmul_1s[g] * gmul_2s[g] == 0:
                        continue

                    code += f"    s1 = x1[:, {index_1+sum(gmul_1s[:g])*(2*l_1+1)}:{index_1+sum(gmul_1s[:g+1])*(2*l_1+1)}].reshape(batch, {gmul_1s[g]}, {2 * l_1 + 1})\n"
                    code += f"    s2 = x2[:, {index_2+sum(gmul_2s[:g])*(2*l_2+1)}:{index_2+sum(gmul_2s[:g+1])*(2*l_2+1)}].reshape(batch, {gmul_2s[g]}, {2 * l_2 + 1})\n"
                    code += f"    ss = ein('zui,zvj->zuvij', s1, s2)\n"

                    index_out = 0
                    for mul_out, l_out, p_out in self.Rs_out:
                        dim_out = mul_out * (2 * l_out + 1)

                        if l_out in selection_rule(l_1, p_1, l_2,
                                                   p_2) and p_out == p_1 * p_2:
                            wigners.add((l_out, l_1, l_2))

                            gmul_outs = [
                                mul_out // groups + (g < mul_out % groups)
                                for g in range(groups)
                            ]
                            dim_w = gmul_outs[g] * gmul_1s[g] * gmul_2s[g]

                            if gmul_outs[g] == 0:
                                continue

                            code += f"    sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {gmul_outs[g]}, {gmul_1s[g]}, {gmul_2s[g]})\n"
                            i = index_out + sum(
                                gmul_outs[:g]) * (2 * l_out + 1)
                            j = index_out + sum(
                                gmul_outs[:g + 1]) * (2 * l_out + 1)
                            code += f"    out[:, {i}:{j}] += ein('zwuv,kij,zuvij->zwk', sw, C{l_out}_{l_1}_{l_2}, ss).reshape(batch, {gmul_outs[g]*(2*l_out+1)})\n"
                            code += "\n"

                            for k in range(i, j):
                                count[k] += gmul_1s[g] * gmul_2s[g]

                            index_w += dim_w

                        index_out += dim_out

                index_2 += dim_2
            index_1 += dim_1

        ilast = 0
        clast = count[0]
        for i, c in enumerate(count):
            if clast != c:
                if clast > 1:
                    code += f"    out[:, {ilast}:{i}].div_({clast ** 0.5})\n"
                clast = c
                ilast = i
        if clast > 1:
            code += f"    out[:, {ilast}:].div_({clast ** 0.5})\n"

        wigners = sorted(wigners)
        self.wigners_names = [
            f"C{l_out}_{l_1}_{l_2}" for l_out, l_1, l_2 in wigners
        ]
        args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names)

        for arg, (l_out, l_1, l_2) in zip(self.wigners_names, wigners):
            C = o3.wigner_3j(l_out, l_1, l_2)

            if normalization == 'component':
                C *= (2 * l_out + 1)**0.5
            if normalization == 'norm':
                C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

            self.register_buffer(arg, C)

        x = _tensor_product_code
        x = x.replace("DIM", f"{rs.dim(self.Rs_out)}")
        x = x.replace("ARGS", args)
        x = x.replace("CODE", code)

        self.main = eval_code(x).main
        self.nweight = index_w