def test_gwtp(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = GroupedWeightedTensorProduct(irreps_in1, irreps_in2, irreps_out) print(m) m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim))
def GroupedWeightedTensorProduct(irreps_in1, irreps_in2, irreps_out, groups=math.inf, normalization='component', internal_weights=True, shared_weights=True): irreps_in1 = o3.Irreps(irreps_in1) irreps_in2 = o3.Irreps(irreps_in2) irreps_out = o3.Irreps(irreps_out) groups = min(groups, min(mul for mul, _ in irreps_in1), min(mul for mul, _ in irreps_out)) irreps_in1 = [(mul // groups + (g < mul % groups), (l, p)) for mul, (l, p) in irreps_in1 for g in range(groups)] irreps_out = [(mul // groups + (g < mul % groups), (l, p)) for mul, (l, p) in irreps_out for g in range(groups)] in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1] in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2] out = [(mul, ir, 1.0) for mul, ir in irreps_out] instr = [(i_1, i_2, i_out, 'uvw', True, 1.0) for i_1, (_, (l_1, p_1)) in enumerate(irreps_in1) for i_2, (_, (l_2, p_2)) in enumerate(irreps_in2) for i_out, (_, (l_out, p_out)) in enumerate(irreps_out) if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out if i_1 % groups == i_out % groups] return WeightedTensorProduct(in1, in2, out, instr, normalization, internal_weights, shared_weights)
def reduce_tensor(formula, eps=1e-9, has_parity=True, **kw_irreps): """reduce a tensor with symmetries into irreducible representations Usage irreps, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)]) irreps = 0,2,4 Q = tensor of shape [15, 3, 3, 3, 3] """ gr = group.O3() if has_parity else group.SO3() kw_representations = {} for i in kw_irreps: if callable(kw_irreps[i]): kw_representations[i] = lambda g: kw_irreps[i](*g) else: kw_representations[i] = lambda g: o3.Irreps(kw_irreps[i]).D(*g) irreps, Q = group.reduce_tensor(gr, formula, eps, **kw_representations) if has_parity: irreps = o3.Irreps(irreps) else: irreps = o3.Irreps([(mul, l, 1) for mul, l in irreps]) return irreps, Q
def __init__(self, irreps_in, irreps_out): super().__init__() self.irreps_in = o3.Irreps(irreps_in).simplify() self.irreps_out = o3.Irreps(irreps_out).simplify() assert self.irreps_in == self.irreps_out output_mask = torch.cat([ torch.ones(mul * (2 * l + 1)) for mul, (l, _p) in self.irreps_out ]) self.register_buffer('output_mask', output_mask)
def test(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1] in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2] out = [(mul, ir, 1.0) for mul, ir in irreps_out] instr = [ (1, 1, 1, 'uvw', True, 1.0), ] m = WeightedTensorProduct(in1, in2, out, instr) x1 = torch.randn(irreps_in1.dim) x2 = torch.randn(irreps_in2.dim) m(x1, x2)
def __init__(self, irreps_in, irreps_out, irreps_sh, rad_features): super().__init__(aggr='add') self.irreps_in = irreps_in.simplify() self.irreps_out = irreps_out.simplify() self.irreps_sh = irreps_sh.simplify() self.si = Linear(self.irreps_in, self.irreps_out) self.lin1 = Linear(self.irreps_in, self.irreps_in) instr = [] irreps = [] for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in): for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh): for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1): p_out = p_1 * p_2 if (l_out, p_out) in [(l, p) for _, (l, p) in self.irreps_out]: r = (mul_1, l_out, p_out) if r in irreps: i_out = irreps.index(r) else: i_out = len(irreps) irreps.append(r) instr += [(i_1, i_2, i_out, 'uvu', True, 1.0)] irreps = o3.Irreps(irreps) in1 = [(mul, ir, 1.0) for mul, ir in self.irreps_in] in2 = [(mul, ir, 1.0) for mul, ir in self.irreps_sh] out = [(mul, ir, 1.0) for mul, ir in irreps] self.tp = WeightedTensorProduct(in1, in2, out, instr, internal_weights=False, shared_weights=False) self.ws = torch.nn.ModuleList([ FC((rad_features, prod(shape)), variance_out=1 / prod(shape)) for shape in self.tp.weight_shapes ]) self.lin2 = Linear(irreps, self.irreps_out)
def __init__(self, irreps, acts): ''' Can be used only with scalar fields :param acts: list of tuple (multiplicity, activation) ''' super().__init__() irreps = irreps.simplify() n1 = sum(mul for mul, _ in irreps) n2 = sum(mul for mul, _ in acts if mul > 0) # normalize the second moment acts = [(mul, normalize2mom(act)) for mul, act in acts] for i, (mul, act) in enumerate(acts): if mul == -1: acts[i] = (n1 - n2, act) assert n1 - n2 >= 0 assert n1 == sum(mul for mul, _ in acts) irreps = list(irreps) i = 0 while i < len(irreps): mul_r, (l, p_r) = irreps[i] mul_a, act = acts[i] if mul_r < mul_a: acts[i] = (mul_r, act) acts.insert(i + 1, (mul_a - mul_r, act)) if mul_a < mul_r: irreps[i] = (mul_a, (l, p_r)) irreps.insert(i + 1, (mul_r - mul_a, (l, p_r))) i += 1 x = torch.linspace(0, 10, 256) irreps_out = [] for (mul, (l, p_in)), (mul_a, act) in zip(irreps, acts): assert mul == mul_a a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: p_act = 1 elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: p_act = -1 else: p_act = 0 p = p_act if p_in == -1 else p_in irreps_out.append((mul, (0, p))) if p_in != 0 and p == 0: raise ValueError("warning! the parity is violated") self.irreps_out = o3.Irreps(irreps_out).simplify() self.acts = acts
def ElementwiseTensorProduct(irreps_in1, irreps_in2, normalization='component'): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() assert irreps_in1.num_irreps == irreps_in2.num_irreps irreps_in1 = list(irreps_in1) irreps_in2 = list(irreps_in2) i = 0 while i < len(irreps_in1): mul_1, (l_1, p_1) = irreps_in1[i] mul_2, (l_2, p_2) = irreps_in2[i] if mul_1 < mul_2: irreps_in2[i] = (mul_1, (l_2, p_2)) irreps_in2.insert(i + 1, (mul_2 - mul_1, (l_2, p_2))) if mul_2 < mul_1: irreps_in1[i] = (mul_2, (l_1, p_1)) irreps_in1.insert(i + 1, (mul_1 - mul_2, (l_1, p_1))) i += 1 irreps_out = [] instr = [] for i, ((mul, (l_1, p_1)), (mul_2, (l_2, p_2))) in enumerate(zip(irreps_in1, irreps_in2)): assert mul == mul_2 for l in list(range(abs(l_1 - l_2), l_1 + l_2 + 1)): i_out = len(irreps_out) irreps_out.append((mul, (l, p_1 * p_2))) instr += [(i, i, i_out, 'uuu', False, 1.0)] in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1] in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2] out = [(mul, ir, 1.0) for mul, ir in irreps_out] return WeightedTensorProduct(in1, in2, out, instr, normalization, internal_weights=False)
def __init__(self, muls=(128, 12, 0), lmax=1, num_layers=3, cutoff=10.0, rad_gaussians=50, rad_hs=(512, 512), num_neighbors=20, readout='add', mean=None, std=None, scale=None, atomref=None): super().__init__() assert readout in ['add', 'sum', 'mean'] self.readout = readout self.cutoff = cutoff self.mean = mean self.std = std self.scale = scale self.num_neighbors = num_neighbors self.embedding = Embedding(100, muls[0]) self.embedding.weight.requires_grad = False self.irreps_in = o3.Irreps([(muls[0], 0, 1)]) self.radial = torch.nn.Sequential( GaussianBasis(rad_gaussians, cutoff), FC((rad_gaussians, ) + rad_hs, swish, variance_in=1 / rad_gaussians, out_act=True) ) self.irreps_sh = o3.Irreps([(1, l, (-1)**l) for l in range(lmax + 1)]) # spherical harmonics representation irreps = self.irreps_in modules = [] for _ in range(num_layers): act = make_gated_block(irreps, muls, self.irreps_sh) conv = Conv(irreps, act.irreps_in, self.irreps_sh, rad_hs[-1]) irreps = act.irreps_out.simplify() modules += [torch.nn.ModuleList([conv, act])] self.layers = torch.nn.ModuleList(modules) self.irreps_out = o3.Irreps("0e + 0o") self.layers.append(Conv(irreps, self.irreps_out, self.irreps_sh, rad_hs[-1])) self.register_buffer('initial_atomref', atomref) self.atomref = None if atomref is not None: self.atomref = Embedding(100, 1) self.atomref.weight.data.copy_(atomref) self.atomref.weight.requires_grad = False
def test_creation(): o3.Irrep(3, 1) ir = o3.Irrep("3e") o3.Irrep(ir) assert o3.Irrep('10o') == o3.Irrep(10, -1) assert o3.Irrep("1y") == o3.Irrep("1o") irreps = o3.Irreps(ir) o3.Irreps(irreps) o3.Irreps([(32, (4, -1))]) o3.Irreps("11e") assert o3.Irreps("16x1e + 32 x 2o") == o3.Irreps([(16, 1, 1), (32, 2, -1)]) o3.Irreps(["1e", '2o']) o3.Irreps([(16, "3e"), '1e']) o3.Irreps([(16, "3e"), '1e', (256, 1, -1)])
def __init__(self, *irreps_outs): super().__init__() self.irreps_outs = tuple(irreps.simplify() for irreps in irreps_outs) def key(mul_ir): _mul, (l, p) = mul_ir return (l, p) self.irreps_in = o3.Irreps( sorted((x for irreps in self.irreps_outs for x in irreps), key=key)).simplify()
def FullyConnectedWeightedTensorProduct(irreps_in1, irreps_in2, irreps_out, normalization='component', internal_weights=True, shared_weights=True): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() irreps_out = o3.Irreps(irreps_out).simplify() in1 = [(mul, ir, 1.0) for mul, ir in irreps_in1] in2 = [(mul, ir, 1.0) for mul, ir in irreps_in2] out = [(mul, ir, 1.0) for mul, ir in irreps_out] instr = [(i_1, i_2, i_out, 'uvw', True, 1.0) for i_1, (_, (l_1, p_1)) in enumerate(irreps_in1) for i_2, (_, (l_2, p_2)) in enumerate(irreps_in2) for i_out, (_, (l_out, p_out)) in enumerate(irreps_out) if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out] return WeightedTensorProduct(in1, in2, out, instr, normalization, internal_weights, shared_weights)
def test_reduce_tensor_equivariance(): torch.set_default_dtype(torch.float64) ir = o3.Irreps('1e') irreps, Q = o3.reduce_tensor('ijkl=jikl=klij', i=ir) abc = o3.rand_angles() R = ir.D(*abc) D = irreps.D(*abc) q1 = torch.einsum('qmnop,mi,nj,ok,pl->qijkl', Q, R, R, R, R) q2 = torch.einsum('qa,aijkl->qijkl', D, Q) assert (q1 - q2).abs().max() < 1e-10
def make_gated_block(irreps_in, muls, irreps_sh): """ Make a Gate assuming many things """ irreps_available = [ (l, p_in * p_sh) for _, (l_in, p_in) in irreps_in.simplify() for _, (l_sh, p_sh) in irreps_sh for l in range(abs(l_in - l_sh), l_in + l_sh + 1) ] scalars = o3.Irreps([(muls[0], 0, p) for p in (1, -1) if (0, p) in irreps_available]) act_scalars = [(mul, swish if p == 1 else torch.tanh) for mul, (_, p) in scalars] nonscalars = o3.Irreps([(muls[l], l, p*(-1)**l) for l in range(1, len(muls)) for p in (1, -1) if (l, p*(-1)**l) in irreps_available]) if (0, +1) in irreps_available: gates = o3.Irreps([(nonscalars.num_irreps, 0, +1)]) act_gates = [(-1, torch.sigmoid)] else: gates = o3.Irreps([(nonscalars.num_irreps, 0, -1)]) act_gates = [(-1, torch.tanh)] return Gate(scalars, act_scalars, gates, act_gates, nonscalars)
def __init__(self) -> None: super().__init__() self.sh = [0, 1, 2, 3] irreps_sh = o3.Irreps([(1, l, (-1)**l) for l in self.sh]) irreps_mid = "64x0e + 24x1e + 24x1o + 16x2e + 16x2o" irreps_out = "0o + 6x0e" self.tp1 = FullyConnectedWeightedTensorProduct( irreps_in1=irreps_sh, irreps_in2=irreps_sh, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedWeightedTensorProduct( irreps_in1=irreps_mid, irreps_in2=irreps_sh, irreps_out=irreps_out, )
def test_sh_equivariance2(): """test - rot - rep - spherical_harmonics """ torch.set_default_dtype(torch.float64) irreps = o3.Irreps("0e + 1o + 2e + 3o + 4e") abc = o3.rand_angles() R = o3.rot(*abc) D = irreps.D(*abc) x = torch.randn(10, 3) y1 = o3.spherical_harmonics(irreps, x @ R.T) y2 = o3.spherical_harmonics(irreps, x) @ D.T assert (y1 - y2).abs().max() < 1e-10
def test_id(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = Identity(irreps_in, irreps_out) print(m) m(torch.randn(irreps_in.dim))
def test_slice(): irreps = o3.Irreps("16x1e + 3e + 2e + 5o") assert isinstance(irreps[2:], o3.Irreps)
def test_cat(): irreps = o3.Irreps("4x1e + 6x2e + 12x2o") + o3.Irreps( "1x1e + 2x2e + 12x2o") assert len(irreps) == 6 assert irreps.num_irreps == 4 + 6 + 12 + 1 + 2 + 12
def test_fail1(): o3.Irreps([(32, 1)])
def __init__( self, in1: List[Tuple[int, Any, float]], in2: List[Tuple[int, Any, float]], out: List[Tuple[int, Any, float]], instr: List[Tuple[int, int, int, str, bool, float]], normalization: str = 'component', internal_weights: bool = True, shared_weights: bool = True, _specialized_code=True, ): """Tensor Product with parametrizable paths Parameters ---------- in1 List of inputs (multiplicity, (l, p), variance). in2 List of inputs (multiplicity, (l, p), variance). out List of outputs (multiplicity, (l, p), variance). instr List of instructions (i_1, i_2, i_out, mode, train, path_weight) it means: Put `in1[i_1]` otimes `in2[i_2]` into `out[i_out]` - mode: determines the way the multiplicities are treated, "uvw" is fully connected - train: is this path trained? - path weight: how much this path should contribute to the output """ super().__init__() assert normalization in ['component', 'norm'], normalization self.irreps_in1 = o3.Irreps([(mul, ir) for mul, ir, _var in in1]) self.irreps_in2 = o3.Irreps([(mul, ir) for mul, ir, _var in in2]) self.irreps_out = o3.Irreps([(mul, ir) for mul, ir, _var in out]) in1_var = [var for _, _, var in in1] in2_var = [var for _, _, var in in2] out_var = [var for _, _, var in out] self.shared_weights = shared_weights z = '' if self.shared_weights else 'z' code = f""" from typing import List import torch @torch.jit.script def main(x1: torch.Tensor, x2: torch.Tensor, ws: List[torch.Tensor], w3j: List[torch.Tensor]) -> torch.Tensor: batch = x1.shape[0] out = x1.new_zeros((batch, {self.irreps_out.dim})) ein = torch.einsum """ wshapes = [] wigners = [] for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in1): index_1 = self.irreps_in1[:i_1].dim dim_1 = mul_1 * (2 * l_1 + 1) code += f" x1_{i_1} = x1[:, {index_1}:{index_1+dim_1}].reshape(batch, {mul_1}, {2 * l_1 + 1})\n" code += f"\n" for i_2, (mul_2, (l_2, p_2)) in enumerate(self.irreps_in2): index_2 = self.irreps_in2[:i_2].dim dim_2 = mul_2 * (2 * l_2 + 1) code += f" x2_{i_2} = x2[:, {index_2}:{index_2+dim_2}].reshape(batch, {mul_2}, {2 * l_2 + 1})\n" code += f"\n" last_ss = None for i_1, i_2, i_out, mode, weight, path_weight in instr: mul_1, (l_1, p_1) = self.irreps_in1[i_1] mul_2, (l_2, p_2) = self.irreps_in2[i_2] mul_out, (l_out, p_out) = self.irreps_out[i_out] dim_1 = mul_1 * (2 * l_1 + 1) dim_2 = mul_2 * (2 * l_2 + 1) dim_out = mul_out * (2 * l_out + 1) index_1 = self.irreps_in1[:i_1].dim index_2 = self.irreps_in2[:i_2].dim index_out = self.irreps_out[:i_out].dim assert p_1 * p_2 == p_out assert abs(l_1 - l_2) <= l_out <= l_1 + l_2 if dim_1 == 0 or dim_2 == 0 or dim_out == 0: continue alpha = out_var[i_out] / sum( path_weight_ * in1_var[i_1_] * in2_var[i_2_] for i_1_, i_2_, i_out_, _, _, path_weight_ in instr if i_out_ == i_out) code += ( f" with torch.autograd.profiler.record_function(" f"'{self.irreps_in1[i_1:i_1+1]} x {self.irreps_in2[i_2:i_2+1]} " f"= {self.irreps_out[i_out:i_out+1]} {mode} {weight}'):\n") code += f" s1 = x1_{i_1}\n" code += f" s2 = x2_{i_2}\n" assert mode in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'] c = sqrt( alpha * path_weight / { 'uvw': (mul_1 * mul_2), 'uvu': mul_2, 'uvv': mul_1, 'uuw': mul_1, 'uuu': 1, 'uvuv': 1, }[mode]) index_w = len(wshapes) if weight: wshapes.append({ 'uvw': (mul_1, mul_2, mul_out), 'uvu': (mul_1, mul_2), 'uvv': (mul_1, mul_2), 'uuw': (mul_1, mul_out), 'uuu': (mul_1, ), 'uvuv': (mul_1, mul_2), }[mode]) if _specialized_code: # optimized code for special cases: # 0 x 0 = 0 # 0 x L = L # L x 0 = L # L x L = 0 # 1 x 1 = 1 if (l_1, l_2, l_out) == (0, 0, 0) and mode in [ 'uvw', 'uvu' ] and normalization in ['component', 'norm'] and weight: code += f" s1 = s1.reshape(batch, {mul_1})\n" code += f" s2 = s2.reshape(batch, {mul_2})\n" if mode == 'uvw': code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zu,zv->zw', ws[{index_w}], s1, s2)\n\n" if mode == 'uvu': code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zu,zv->zu', ws[{index_w}], s1, s2)\n\n" continue if l_1 == 0 and l_2 == l_out and mode in [ 'uvw', 'uvu' ] and normalization == 'component' and weight: code += f" s1 = s1.reshape(batch, {mul_1})\n" if mode == 'uvw': code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zu,zvi->zwi', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n" if mode == 'uvu': code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zu,zvi->zui', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n" continue if l_2 == 0 and l_1 == l_out and mode in [ 'uvw', 'uvu' ] and normalization == 'component' and weight: code += f" s2 = s2.reshape(batch, {mul_2})\n" if mode == 'uvw': code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zui,zv->zwi', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n" if mode == 'uvu': code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zui,zv->zui', ws[{index_w}], s1, s2).reshape(batch, {dim_out})\n\n" continue if l_1 == l_2 and l_out == 0 and mode == 'uvw' and normalization == 'component' and weight: # Cl_l_0 = eye(3) / sqrt(2L+1) code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zui,zvi->zw', ws[{index_w}] / {sqrt(2 * l_1 + 1)}, s1, s2).reshape(batch, {dim_out})\n\n" continue if l_1 == l_2 and l_out == 0 and mode == 'uvu' and normalization == 'component' and weight: # Cl_l_0 = eye(3) / sqrt(2L+1) code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zui,zvi->zu', ws[{index_w}] / {sqrt(2 * l_1 + 1)}, s1, s2).reshape(batch, {dim_out})\n\n" continue if (l_1, l_2, l_out) == ( 1, 1, 1 ) and mode == 'uvw' and normalization == 'component' and weight: # C1_1_1 = levi-civita / sqrt(2) code += f" s1 = s1.reshape(batch, {mul_1}, 1, {2 * l_1 + 1})\n" code += f" s2 = s2.reshape(batch, 1, {mul_2}, {2 * l_2 + 1})\n" code += f" s1, s2 = torch.broadcast_tensors(s1, s2)\n" code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,zuvi->zwi', ws[{index_w}] / {sqrt(2)}, torch.cross(s1, s2, dim=3)).reshape(batch, {dim_out})\n\n" continue if (l_1, l_2, l_out) == ( 1, 1, 1 ) and mode == 'uvu' and normalization == 'component' and weight: # C1_1_1 = levi-civita / sqrt(2) code += f" s1 = s1.reshape(batch, {mul_1}, 1, {2 * l_1 + 1})\n" code += f" s2 = s2.reshape(batch, 1, {mul_2}, {2 * l_2 + 1})\n" code += f" s1, s2 = torch.broadcast_tensors(s1, s2)\n" code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,zuvi->zui', ws[{index_w}] / {sqrt(2)}, torch.cross(s1, s2, dim=3)).reshape(batch, {dim_out})\n\n" continue if last_ss != (i_1, i_2, mode[:2]): if mode[:2] == 'uv': code += f" ss = ein('zui,zvj->zuvij', s1, s2)\n" if mode[:2] == 'uu': code += f" ss = ein('zui,zuj->zuij', s1, s2)\n" last_ss = (i_1, i_2, mode[:2]) if (l_1, l_2, l_out) in wigners: index_w3j = wigners.index((l_1, l_2, l_out)) else: index_w3j = len(wigners) wigners += [(l_1, l_2, l_out)] if mode == 'uvw': assert weight code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uvw,ijk,zuvij->zwk', ws[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" if mode == 'uvu': assert mul_1 == mul_out if weight: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,ijk,zuvij->zuk', ws[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" else: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuvij->zuk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" if mode == 'uvv': assert mul_2 == mul_out if weight: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,ijk,zuvij->zvk', ws[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" else: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuvij->zvk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" if mode == 'uuw': assert mul_1 == mul_2 assert weight code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uw,ijk,zuij->zwk', sw[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" if mode == 'uuu': assert mul_1 == mul_2 == mul_out if weight: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}u,ijk,zuij->zuk', sw[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" else: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuij->zuk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" if mode == 'uvuv': assert mul_1 * mul_2 == mul_out if weight: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('{z}uv,ijk,zuvij->zuvk', sw[{index_w}], w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" else: code += f" out[:, {index_out}:{index_out+dim_out}] += {c} * ein('ijk,zuvij->zuvk', w3j[{index_w3j}], ss).reshape(batch, {dim_out})\n" code += "\n" code += f" return out" self.code = code self.main = eval_code(self.code).main # w3j self.wigners = wigners for i, (l_1, l_2, l_out) in enumerate(self.wigners): wig = o3.wigner_3j(l_1, l_2, l_out) if normalization == 'component': wig *= (2 * l_out + 1)**0.5 if normalization == 'norm': wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5 self.register_buffer(f"C{i}", wig) # weights self.weight_shapes = wshapes self.weight_numel = sum( math.prod(shape) for shape in self.weight_shapes) self.weight_infos = [ (i_1, i_2, i_out, mode, path_weight, shape) for (i_1, i_2, i_out, mode, path_weight), shape in zip( [(i_1, i_2, i_out, mode, path_weight) for i_1, i_2, i_out, mode, weight, path_weight in instr if weight], wshapes) ] if internal_weights: assert self.shared_weights, "Having internal weights impose shared weights" self.weight = torch.nn.ParameterDict() for i, (i_1, i_2, i_out, mode, path_weight, shape) in enumerate(self.weight_infos): mul_1, (l_1, p_1) = self.irreps_in1[i_1] mul_2, (l_2, p_2) = self.irreps_in2[i_2] mul_out, (l_out, p_out) = self.irreps_out[i_out] self.weight[ f'{i} l1={l_1} l2={l_2} lout={l_out}'] = torch.nn.Parameter( torch.randn(shape)) self.to(dtype=torch.get_default_dtype())