def _legendre_genjit(ls): ls = list(ls) fill = "" i = 0 for l in ls: for m in range(l + 1): p = poly_legendre(l, m) formula = " + ".join("{:.25f} * z**{} * y**{}".format(c, zn, yn) for (zn, yn), c in p.items()) fill += " l{} = {}\n".format(m, formula) for m in range(-l, l + 1): fill += " out[..., {}] = l{}\n".format(i, abs(m)) i += 1 code = _legendre_code code = code.replace("lsize", str(sum(2 * l + 1 for l in ls))) code = code.replace("# fill out", fill) return eval_code(code).main
def __init__(self, Rs_in1: rs.TY_RS_LOOSE, Rs_in2: rs.TY_RS_LOOSE, Rs_out: rs.TY_RS_LOOSE, instr: List[Tuple[int, int, int, str]], normalization: str = 'component', own_weight: bool = True): """ Create a Tensor Product operation that has each of his path weighted by a parameter. `instr` is a list of instructions. An instruction if of the form (i_1, i_2, i_out, mode) it means "Put `Rs_in1[i_1] otimes Rs_in2[i_2] into Rs_out[i_out]" `mode` determines the way the multiplicities are treated. The default mode should be 'uvw', meaning that all paths are created. """ super().__init__() self.Rs_in1 = rs.convention(Rs_in1) self.Rs_in2 = rs.convention(Rs_in2) self.Rs_out = rs.convention(Rs_out) code = "" index_w = 0 wigners = set() count = [0 for _ in range(rs.dim(self.Rs_out))] instr = sorted(instr) # for optimization last_s1, last_s2, last_ss = None, None, None for i_1, i_2, i_out, mode in instr: mul_1, l_1, p_1 = self.Rs_in1[i_1] mul_2, l_2, p_2 = self.Rs_in2[i_2] mul_out, l_out, p_out = self.Rs_out[i_out] dim_1 = mul_1 * (2 * l_1 + 1) dim_2 = mul_2 * (2 * l_2 + 1) dim_out = mul_out * (2 * l_out + 1) index_1 = rs.dim(self.Rs_in1[:i_1]) index_2 = rs.dim(self.Rs_in2[:i_2]) index_out = rs.dim(self.Rs_out[:i_out]) assert p_1 * p_2 == p_out assert abs(l_1 - l_2) <= l_out <= l_1 + l_2 if dim_1 == 0 or dim_2 == 0 or dim_out == 0: continue if last_s1 != i_1: code += f" s1 = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n" last_s1 = i_1 if last_s2 != i_2: code += f" s2 = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n" last_s2 = i_2 assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'] if last_ss != (i_1, i_2, mode[:2]): if mode[:2] == 'uv': code += f" ss = ein('zui,zvj->zuvij', s1, s2)\n" if mode[:2] == 'uu': code += f" ss = ein('zui,zuj->zuij', s1, s2)\n" last_ss = (i_1, i_2, mode[:2]) wigners.add((l_1, l_2, l_out)) if mode == 'uvw': dim_w = mul_1 * mul_2 * mul_out code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2}, {mul_out})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuvw,ijk,zuvij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_1 * mul_2 if mode == 'uvu': assert mul_1 == mul_out dim_w = mul_1 * mul_2 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_2 if mode == 'uvv': assert mul_2 == mul_out dim_w = mul_1 * mul_2 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_1 if mode == 'uuw': assert mul_1 == mul_2 dim_w = mul_1 * mul_out code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_out})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuw,ijk,zuij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_1 if mode == 'uuu': assert mul_1 == mul_2 == mul_out dim_w = mul_1 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zu,ijk,zuij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += 1 if mode == 'uvuv': assert mul_1 * mul_2 == mul_out dim_w = mul_1 * mul_2 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += 1 index_w += dim_w code += "\n" ilast = 0 clast = count[0] for i, c in enumerate(count): if clast != c: if clast > 1: code += f" out[:, {ilast}:{i}].div_({clast ** 0.5})\n" clast = c ilast = i if clast > 1: code += f" out[:, {ilast}:].div_({clast ** 0.5})\n" wigners = sorted(wigners) self.wigners_names = [ f"C{l_1}_{l_2}_{l_3}" for l_1, l_2, l_3 in wigners ] args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names) for arg, (l_1, l_2, l_out) in zip(self.wigners_names, wigners): wig = o3.wigner_3j(l_1, l_2, l_out) if normalization == 'component': wig *= (2 * l_out + 1)**0.5 if normalization == 'norm': wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 self.register_buffer(arg, wig) x = _tensor_product_code x = x.replace("DIM", f"{rs.dim(self.Rs_out)}") x = x.replace("ARGS", args) x = x.replace("CODE", code) self.code = x self.main = eval_code(x).main self.nweight = index_w if own_weight: self.weight = torch.nn.Parameter(torch.randn(self.nweight))
def __init__(self, Rs_in1, Rs_in2, Rs_out, selection_rule=o3.selection_rule, normalization='component', groups=1): super().__init__() self.Rs_in1 = rs.convention(Rs_in1) self.Rs_in2 = rs.convention(Rs_in2) self.Rs_out = rs.convention(Rs_out) code = "" index_w = 0 wigners = set() count = [0 for _ in range(rs.dim(self.Rs_out))] index_1 = 0 for mul_1, l_1, p_1 in self.Rs_in1: dim_1 = mul_1 * (2 * l_1 + 1) index_2 = 0 for mul_2, l_2, p_2 in self.Rs_in2: dim_2 = mul_2 * (2 * l_2 + 1) gmul_1s = [ mul_1 // groups + (g < mul_1 % groups) for g in range(groups) ] gmul_2s = [ mul_2 // groups + (g < mul_2 % groups) for g in range(groups) ] for g in range(groups): if gmul_1s[g] * gmul_2s[g] == 0: continue code += f" s1 = x1[:, {index_1+sum(gmul_1s[:g])*(2*l_1+1)}:{index_1+sum(gmul_1s[:g+1])*(2*l_1+1)}].reshape(batch, {gmul_1s[g]}, {2 * l_1 + 1})\n" code += f" s2 = x2[:, {index_2+sum(gmul_2s[:g])*(2*l_2+1)}:{index_2+sum(gmul_2s[:g+1])*(2*l_2+1)}].reshape(batch, {gmul_2s[g]}, {2 * l_2 + 1})\n" code += f" ss = ein('zui,zvj->zuvij', s1, s2)\n" index_out = 0 for mul_out, l_out, p_out in self.Rs_out: dim_out = mul_out * (2 * l_out + 1) if l_out in selection_rule(l_1, p_1, l_2, p_2) and p_out == p_1 * p_2: wigners.add((l_out, l_1, l_2)) gmul_outs = [ mul_out // groups + (g < mul_out % groups) for g in range(groups) ] dim_w = gmul_outs[g] * gmul_1s[g] * gmul_2s[g] if gmul_outs[g] == 0: continue code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {gmul_outs[g]}, {gmul_1s[g]}, {gmul_2s[g]})\n" i = index_out + sum( gmul_outs[:g]) * (2 * l_out + 1) j = index_out + sum( gmul_outs[:g + 1]) * (2 * l_out + 1) code += f" out[:, {i}:{j}] += ein('zwuv,kij,zuvij->zwk', sw, C{l_out}_{l_1}_{l_2}, ss).reshape(batch, {gmul_outs[g]*(2*l_out+1)})\n" code += "\n" for k in range(i, j): count[k] += gmul_1s[g] * gmul_2s[g] index_w += dim_w index_out += dim_out index_2 += dim_2 index_1 += dim_1 ilast = 0 clast = count[0] for i, c in enumerate(count): if clast != c: if clast > 1: code += f" out[:, {ilast}:{i}].div_({clast ** 0.5})\n" clast = c ilast = i if clast > 1: code += f" out[:, {ilast}:].div_({clast ** 0.5})\n" wigners = sorted(wigners) self.wigners_names = [ f"C{l_out}_{l_1}_{l_2}" for l_out, l_1, l_2 in wigners ] args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names) for arg, (l_out, l_1, l_2) in zip(self.wigners_names, wigners): C = o3.wigner_3j(l_out, l_1, l_2) if normalization == 'component': C *= (2 * l_out + 1)**0.5 if normalization == 'norm': C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 self.register_buffer(arg, C) x = _tensor_product_code x = x.replace("DIM", f"{rs.dim(self.Rs_out)}") x = x.replace("ARGS", args) x = x.replace("CODE", code) self.main = eval_code(x).main self.nweight = index_w