def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, normalization='component', allow_unused_inputs=False, allow_zero_outputs=False): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) if not allow_unused_inputs: self.check_input(selection_rule) if not allow_zero_outputs: self.check_output(selection_rule) self.normalization = normalization self.tp = rs.TensorProduct(self.Rs_in, selection_rule, Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 self.Ls = [l for _, l, _ in self.Rs_f] self.R = RadialModel(rs.mul_dim(self.Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, sh=rsh.spherical_harmonics_xyz, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule, normalization) self.register_buffer('Q', Q) # [out, in, Y, R] self.sh = sh self.Ls = [l for _, l, _ in Rs_f] self.R = RadialModel(rs.mul_dim(Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.register_buffer('radii', r.norm(2, dim=1)) # [batch] self.r_eps = r_eps self.tp = rs.TensorProduct(self.Rs_in, selection_rule, self.Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 Y = rsh.spherical_harmonics_xyz( [(1, l, p) for _, l, p in self.Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) self.register_buffer('Y', Y) self.R = RadialModel(rs.mul_dim(self.Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def GroupedWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=math.inf, normalization='component', own_weight=True): Rs_in1 = rs.convention(Rs_in1) Rs_in2 = rs.convention(Rs_in2) Rs_out = rs.convention(Rs_out) groups = min(groups, min(mul for mul, _, _ in Rs_in1), min(mul for mul, _, _ in Rs_out)) Rs_in1 = [(mul // groups + (g < mul % groups), l, p) for mul, l, p in Rs_in1 for g in range(groups)] Rs_out = [(mul // groups + (g < mul % groups), l, p) for mul, l, p in Rs_out for g in range(groups)] instr = [(i_1, i_2, i_out, 'uvw') for i_1, (_, l_1, p_1) in enumerate(Rs_in1) for i_2, (_, l_2, p_2) in enumerate(Rs_in2) for i_out, (_, l_out, p_out) in enumerate(Rs_out) if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out if i_1 % groups == i_out % groups] return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, normalization, own_weight)
def test_conventionRs(self): Rs = [(1, 0)] Rs_out = rs.convention(Rs) self.assertSequenceEqual(Rs_out, [(1, 0, 0)]) Rs = [(1, 0), (2, 0)] Rs_out = rs.convention(Rs) self.assertSequenceEqual(Rs_out, [(1, 0, 0), (2, 0, 0)])
def __init__(self, Rs_in, Rs_out): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output() self.kernel = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs, activation, normalization='component'): super().__init__() self.Rs = rs.convention(Rs) self.activation = activation self.norm = Norm(self.Rs, normalization) self.bias = torch.nn.Parameter(torch.zeros(rs.mul_dim(self.Rs)))
def __init__(self, signal, mul, lmax, p_val=0, p_arg=0): """ f: s2 x r -> R^N Rotations [D(g) f](x) = f(g^{-1} x) Parity [P f](x) = p_val f(p_arg x) f(x) = sum F^l . Y^l(x) This class contains the F^l Rotations [D(g) f](x) = sum [D^l(g) F^l] . Y^l(x) (using equiv. of Y and orthogonality of D) Parity [P f](x) = sum [p_val p_arg^l F^l] . Y^l(x) (using parity of Y) """ if signal.shape[-1] != mul * (lmax + 1)**2: raise ValueError( "Last tensor dimension and Rs do not have same dimension.") self.signal = signal self.lmax = lmax self.mul = mul self.Rs = rs.convention([(mul, l, p_val * p_arg**l) for l in range(lmax + 1)]) self.radial_model = None
def __init__(self, signal: torch.Tensor, p_val: int = 0, p_arg: int = 0): """ f: s2 -> R Rotations [D(g) f](x) = f(g^{-1} x) Parity [P f](x) = p_val f(p_arg x) f(x) = sum F^l . Y^l(x) This class contains the F^l Rotations [D(g) f](x) = sum [D^l(g) F^l] . Y^l(x) (using equiv. of Y and orthogonality of D) Parity [P f](x) = sum [p_val p_arg^l F^l] . Y^l(x) (using parity of Y) """ lmax = round(math.sqrt(signal.shape[-1]) - 1) if signal.shape[-1] != (lmax + 1)**2: raise ValueError( "Last tensor dimension and Rs do not have same dimension.") self.signal = signal self.lmax = lmax self.Rs = rs.convention([(1, l, p_val * p_arg**l) for l in range(lmax + 1)]) self.p_val = p_val self.p_arg = p_arg
def WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, normalization='component', own_weight=True): Rs_in1 = rs.convention(Rs_in1) Rs_in2 = rs.convention(Rs_in2) Rs_out = rs.convention(Rs_out) instr = [(i_1, i_2, i_out, 'uvw') for i_1, (_, l_1, p_1) in enumerate(Rs_in1) for i_2, (_, l_2, p_2) in enumerate(Rs_in2) for i_out, (_, l_out, p_out) in enumerate(Rs_out) if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out] return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, normalization, own_weight)
def __init__(self, tensor, Rs): Rs = rs.convention(Rs) if tensor.shape[-1] != rs.dim(Rs): raise ValueError( "Last tensor dimension and Rs do not have same dimension.") self.tensor = tensor self.Rs = Rs
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, sh=rsh.spherical_harmonics_xyz, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.radii = r.norm(2, dim=1) # [batch] self.r_eps = r_eps Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule, normalization) Y = sh([l for _, l, _ in Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] Q = torch.einsum('ijyw,zy->zijw', Q, Y) self.register_buffer('Q', Q) # [out, in, Y, R] self.R = RadialModel(rs.mul_dim(Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def __init__(self, Rs_in, mul, Rs_out, lmax, layers=3, max_radius=1.0, number_of_basis=3, radial_layers=3, kernel=Kernel, convolution=Convolution, min_radius=0.0): super().__init__() R = partial(GaussianRadialModel, max_radius=max_radius, number_of_basis=number_of_basis, h=100, L=radial_layers, act=swish, min_radius=min_radius) K = partial(kernel, RadialModel=R, selection_rule=partial(o3.selection_rule_in_out_sh, lmax=lmax)) modules = [] Rs = rs.convention(Rs_in) for _ in range(layers): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if rs.haslinearpath(Rs, l, p)] act_scalars = [(mul, swish if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for l in range(1, lmax + 1) for p in [+1, -1] if rs.haslinearpath(Rs, l, p)] if rs.haslinearpath(Rs, 0, +1): gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] else: gates = [(rs.mul_dim(nonscalars), 0, -1)] act_gates = [(-1, tanh)] act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = convolution(K(Rs, act.Rs_in)) Rs = act.Rs_out block = torch.nn.ModuleList([conv, act]) modules.append(block) self.layers = torch.nn.ModuleList(modules) K = partial(K, allow_unused_inputs=True) self.layers.append(convolution(K(Rs, Rs_out)))
def __init__(self, signal, mul, lmax): if signal.shape[-1] != mul * (lmax + 1)**2: raise ValueError( "Last tensor dimension and Rs do not have same dimension.") self.signal = signal self.lmax = lmax self.mul = mul self.Rs = rs.convention([(mul, l) for l in range(lmax + 1)]) self.radial_model = None
def __init__(self, Rs_in, Rs_out, layer, activation, layers=3): super().__init__() modules = [] Rs = rs.convention(Rs_in) for _ in range(layers): act = activation(Rs) lay = layer(Rs, act.Rs_in) Rs = act.Rs_out modules += [torch.nn.ModuleList([lay, act])] self.layers = torch.nn.ModuleList(modules) self.layers.append(layer(Rs, Rs_out))
def __init__(self, Rs_in, mul, Rs_out, lmax, size=5, layers=3): super().__init__() modules = [] Rs = rs.convention(Rs_in) for _ in range(layers): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if rs.haslinearpath(Rs, l, p)] act_scalars = [(mul, swish if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for l in range(1, lmax + 1) for p in [+1, -1] if rs.haslinearpath(Rs, l, p)] gates = [(rs.mul_dim(nonscalars), 0, +1)] if rs.haslinearpath(Rs, 0, +1): gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] else: gates = [(rs.mul_dim(nonscalars), 0, -1)] act_gates = [(-1, tanh)] act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = Convolution(Rs, act.Rs_in, size, lmax=lmax, fuzzy_pixels=True, padding=size // 2) Rs = act.Rs_out block = torch.nn.Sequential(conv, act) modules.append(block) modules += [ Convolution(Rs, Rs_out, size, lmax=lmax, fuzzy_pixels=True, padding=size // 2, allow_unused_inputs=True) ] self.layers = torch.nn.Sequential(*modules)
def __init__(self, Rs_in1: rs.TY_RS_LOOSE, Rs_in2: rs.TY_RS_LOOSE, Rs_out: rs.TY_RS_LOOSE, instr: List[Tuple[int, int, int, str]], normalization: str = 'component', own_weight: bool = True): """ Create a Tensor Product operation that has each of his path weighted by a parameter. `instr` is a list of instructions. An instruction if of the form (i_1, i_2, i_out, mode) it means "Put `Rs_in1[i_1] otimes Rs_in2[i_2] into Rs_out[i_out]" `mode` determines the way the multiplicities are treated. The default mode should be 'uvw', meaning that all paths are created. """ super().__init__() self.Rs_in1 = rs.convention(Rs_in1) self.Rs_in2 = rs.convention(Rs_in2) self.Rs_out = rs.convention(Rs_out) code = "" index_w = 0 wigners = set() count = [0 for _ in range(rs.dim(self.Rs_out))] instr = sorted(instr) # for optimization last_s1, last_s2, last_ss = None, None, None for i_1, i_2, i_out, mode in instr: mul_1, l_1, p_1 = self.Rs_in1[i_1] mul_2, l_2, p_2 = self.Rs_in2[i_2] mul_out, l_out, p_out = self.Rs_out[i_out] dim_1 = mul_1 * (2 * l_1 + 1) dim_2 = mul_2 * (2 * l_2 + 1) dim_out = mul_out * (2 * l_out + 1) index_1 = rs.dim(self.Rs_in1[:i_1]) index_2 = rs.dim(self.Rs_in2[:i_2]) index_out = rs.dim(self.Rs_out[:i_out]) assert p_1 * p_2 == p_out assert abs(l_1 - l_2) <= l_out <= l_1 + l_2 if dim_1 == 0 or dim_2 == 0 or dim_out == 0: continue if last_s1 != i_1: code += f" s1 = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n" last_s1 = i_1 if last_s2 != i_2: code += f" s2 = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n" last_s2 = i_2 assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'] if last_ss != (i_1, i_2, mode[:2]): if mode[:2] == 'uv': code += f" ss = ein('zui,zvj->zuvij', s1, s2)\n" if mode[:2] == 'uu': code += f" ss = ein('zui,zuj->zuij', s1, s2)\n" last_ss = (i_1, i_2, mode[:2]) wigners.add((l_1, l_2, l_out)) if mode == 'uvw': dim_w = mul_1 * mul_2 * mul_out code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2}, {mul_out})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuvw,ijk,zuvij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_1 * mul_2 if mode == 'uvu': assert mul_1 == mul_out dim_w = mul_1 * mul_2 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_2 if mode == 'uvv': assert mul_2 == mul_out dim_w = mul_1 * mul_2 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_1 if mode == 'uuw': assert mul_1 == mul_2 dim_w = mul_1 * mul_out code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_out})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuw,ijk,zuij->zwk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += mul_1 if mode == 'uuu': assert mul_1 == mul_2 == mul_out dim_w = mul_1 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zu,ijk,zuij->zuk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += 1 if mode == 'uvuv': assert mul_1 * mul_2 == mul_out dim_w = mul_1 * mul_2 code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {mul_1}, {mul_2})\n" code += f" out[:, {index_out}:{index_out+dim_out}] += ein('zuv,ijk,zuvij->zuvk', sw, C{l_1}_{l_2}_{l_out}, ss).reshape(batch, {dim_out})\n" for pos in range(index_out, index_out + dim_out): count[pos] += 1 index_w += dim_w code += "\n" ilast = 0 clast = count[0] for i, c in enumerate(count): if clast != c: if clast > 1: code += f" out[:, {ilast}:{i}].div_({clast ** 0.5})\n" clast = c ilast = i if clast > 1: code += f" out[:, {ilast}:].div_({clast ** 0.5})\n" wigners = sorted(wigners) self.wigners_names = [ f"C{l_1}_{l_2}_{l_3}" for l_1, l_2, l_3 in wigners ] args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names) for arg, (l_1, l_2, l_out) in zip(self.wigners_names, wigners): wig = o3.wigner_3j(l_1, l_2, l_out) if normalization == 'component': wig *= (2 * l_out + 1)**0.5 if normalization == 'norm': wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 self.register_buffer(arg, wig) x = _tensor_product_code x = x.replace("DIM", f"{rs.dim(self.Rs_out)}") x = x.replace("ARGS", args) x = x.replace("CODE", code) self.code = x self.main = eval_code(x).main self.nweight = index_w if own_weight: self.weight = torch.nn.Parameter(torch.randn(self.nweight))
def __init__(self, Rs_in1, Rs_in2, Rs_out, selection_rule=o3.selection_rule, normalization='component', groups=1): super().__init__() self.Rs_in1 = rs.convention(Rs_in1) self.Rs_in2 = rs.convention(Rs_in2) self.Rs_out = rs.convention(Rs_out) code = "" index_w = 0 wigners = set() count = [0 for _ in range(rs.dim(self.Rs_out))] index_1 = 0 for mul_1, l_1, p_1 in self.Rs_in1: dim_1 = mul_1 * (2 * l_1 + 1) index_2 = 0 for mul_2, l_2, p_2 in self.Rs_in2: dim_2 = mul_2 * (2 * l_2 + 1) gmul_1s = [ mul_1 // groups + (g < mul_1 % groups) for g in range(groups) ] gmul_2s = [ mul_2 // groups + (g < mul_2 % groups) for g in range(groups) ] for g in range(groups): if gmul_1s[g] * gmul_2s[g] == 0: continue code += f" s1 = x1[:, {index_1+sum(gmul_1s[:g])*(2*l_1+1)}:{index_1+sum(gmul_1s[:g+1])*(2*l_1+1)}].reshape(batch, {gmul_1s[g]}, {2 * l_1 + 1})\n" code += f" s2 = x2[:, {index_2+sum(gmul_2s[:g])*(2*l_2+1)}:{index_2+sum(gmul_2s[:g+1])*(2*l_2+1)}].reshape(batch, {gmul_2s[g]}, {2 * l_2 + 1})\n" code += f" ss = ein('zui,zvj->zuvij', s1, s2)\n" index_out = 0 for mul_out, l_out, p_out in self.Rs_out: dim_out = mul_out * (2 * l_out + 1) if l_out in selection_rule(l_1, p_1, l_2, p_2) and p_out == p_1 * p_2: wigners.add((l_out, l_1, l_2)) gmul_outs = [ mul_out // groups + (g < mul_out % groups) for g in range(groups) ] dim_w = gmul_outs[g] * gmul_1s[g] * gmul_2s[g] if gmul_outs[g] == 0: continue code += f" sw = w[:, {index_w}:{index_w+dim_w}].reshape(batch, {gmul_outs[g]}, {gmul_1s[g]}, {gmul_2s[g]})\n" i = index_out + sum( gmul_outs[:g]) * (2 * l_out + 1) j = index_out + sum( gmul_outs[:g + 1]) * (2 * l_out + 1) code += f" out[:, {i}:{j}] += ein('zwuv,kij,zuvij->zwk', sw, C{l_out}_{l_1}_{l_2}, ss).reshape(batch, {gmul_outs[g]*(2*l_out+1)})\n" code += "\n" for k in range(i, j): count[k] += gmul_1s[g] * gmul_2s[g] index_w += dim_w index_out += dim_out index_2 += dim_2 index_1 += dim_1 ilast = 0 clast = count[0] for i, c in enumerate(count): if clast != c: if clast > 1: code += f" out[:, {ilast}:{i}].div_({clast ** 0.5})\n" clast = c ilast = i if clast > 1: code += f" out[:, {ilast}:].div_({clast ** 0.5})\n" wigners = sorted(wigners) self.wigners_names = [ f"C{l_out}_{l_1}_{l_2}" for l_out, l_1, l_2 in wigners ] args = ", ".join(f"{arg}: torch.Tensor" for arg in self.wigners_names) for arg, (l_out, l_1, l_2) in zip(self.wigners_names, wigners): C = o3.wigner_3j(l_out, l_1, l_2) if normalization == 'component': C *= (2 * l_out + 1)**0.5 if normalization == 'norm': C *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 self.register_buffer(arg, C) x = _tensor_product_code x = x.replace("DIM", f"{rs.dim(self.Rs_out)}") x = x.replace("ARGS", args) x = x.replace("CODE", code) self.main = eval_code(x).main self.nweight = index_w
def __init__(self, Rs, p=0.5): super().__init__() self.Rs = rs.convention(Rs) self.p = p
def test_convention(): Rs = [0] assert rs.convention(Rs) == [(1, 0, 0)] Rs = [0, (2, 0)] assert rs.convention(Rs) == [(1, 0, 0), (2, 0, 0)]