def test_getitem(): irreps = o3.Irreps("16x1e + 3e + 2e + 5o") assert irreps[0] == (16, o3.Irrep("1e")) assert irreps[3] == (1, o3.Irrep("5o")) assert irreps[-1] == (1, o3.Irrep("5o")) sliced = irreps[2:] assert isinstance(sliced, o3.Irreps) assert sliced == o3.Irreps("2e + 5o")
def test_wigner_3j(float_tolerance): abc = o3.rand_angles(10) l1, l2, l3 = 1, 2, 3 C = o3.wigner_3j(l1, l2, l3) D1 = o3.Irrep(l1, 1).D_from_angles(*abc) D2 = o3.Irrep(l2, 1).D_from_angles(*abc) D3 = o3.Irrep(l3, 1).D_from_angles(*abc) C2 = torch.einsum("ijk,zil,zjm,zkn->zlmn", C, D1, D2, D3) assert (C - C2).abs().max() < float_tolerance
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}" instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) # inspired by https://arxiv.org/pdf/2002.10444.pdf self.alpha = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e") with torch.no_grad(): self.alpha.weight.zero_() assert self.alpha.output_mask[ 0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
def tp_path_exists(irreps_in1, irreps_in2, ir_out): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() ir_out = o3.Irrep(ir_out) for _, ir1 in irreps_in1: for _, ir2 in irreps_in2: if ir_out in ir1 * ir2: return True return False
def _wigner_nj(*irrepss, normalization='component', filter_ir_mid=None, dtype=None, device=None): irrepss = [o3.Irreps(irreps) for irreps in irrepss] if filter_ir_mid is not None: filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] if len(irrepss) == 1: irreps, = irrepss ret = [] e = torch.eye(irreps.dim, dtype=dtype, device=device) i = 0 for mul, ir in irreps: for _ in range(mul): sl = slice(i, i + ir.dim) ret += [ (ir, _INPUT(0, sl.start, sl.stop), e[sl]) ] i += ir.dim return ret *irrepss_left, irreps_right = irrepss ret = [] for ir_left, path_left, C_left in _wigner_nj(*irrepss_left, normalization=normalization, filter_ir_mid=filter_ir_mid, dtype=dtype, device=device): i = 0 for mul, ir in irreps_right: for ir_out in ir_left * ir: if filter_ir_mid is not None and ir_out not in filter_ir_mid: continue C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype, device=device) if normalization == 'component': C *= ir_out.dim**0.5 if normalization == 'norm': C *= ir_left.dim**0.5 * ir.dim**0.5 C = torch.einsum('jk,ijl->ikl', C_left.flatten(1), C) C = C.reshape(ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim) for u in range(mul): E = torch.zeros(ir_out.dim, *(irreps.dim for irreps in irrepss_left), irreps_right.dim, dtype=dtype, device=device) sl = slice(i + u * ir.dim, i + (u+1) * ir.dim) E[..., sl] = C ret += [ ( ir_out, _TP( op=(ir_left, ir, ir_out), args=(path_left, _INPUT(len(irrepss_left), sl.start, sl.stop)) ), E ) ] i += mul * ir.dim return sorted(ret, key=lambda x: x[0])
def test_arithmetic(): assert 3 * o3.Irrep("6o") == o3.Irreps("3x6o") products = list(o3.Irrep("1o") * o3.Irrep("2e")) assert products == [o3.Irrep("1o"), o3.Irrep("2o"), o3.Irrep("3o")] assert o3.Irrep("4o") + o3.Irrep("7e") == o3.Irreps("4o + 7e") assert 2 * o3.Irreps("2x2e + 4x1o") == o3.Irreps( "2x2e + 4x1o + 2x2e + 4x1o") assert o3.Irreps("2x2e + 4x1o") * 2 == o3.Irreps( "2x2e + 4x1o + 2x2e + 4x1o") assert o3.Irreps("1o + 4o") + o3.Irreps("1o + 7e") == o3.Irreps( "1o + 4o + 1o + 7e")
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) self.lin3 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e")
def test_properties(): irrep = o3.Irrep("3e") assert irrep.l == 3 assert irrep.p == 1 assert irrep.dim == 7 assert o3.Irrep(repr(irrep)) == irrep l, p = o3.Irrep("5o") assert l == 5 assert p == -1 iterator = o3.Irrep.iterator(5) assert len(list(iterator)) == 12 iterator = o3.Irrep.iterator() for x in range(100): irrep = next(iterator) assert irrep.l == x // 2 assert irrep.p in (-1, 1) assert irrep.dim == 2 * (x // 2) + 1 irreps = o3.Irreps("4x1e + 6x2e + 12x2o") assert o3.Irreps(repr(irreps)) == irreps
def __init__( self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs ): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() if filter_ir_out is not None: filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out] assert irreps_in1.num_irreps == irreps_in2.num_irreps irreps_in1 = list(irreps_in1) irreps_in2 = list(irreps_in2) i = 0 while i < len(irreps_in1): mul_1, ir_1 = irreps_in1[i] mul_2, ir_2 = irreps_in2[i] if mul_1 < mul_2: irreps_in2[i] = (mul_1, ir_2) irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) if mul_2 < mul_1: irreps_in1[i] = (mul_2, ir_1) irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) i += 1 out = [] instr = [] for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): assert mul == mul_2 for ir in ir_1 * ir_2: if filter_ir_out is not None and ir not in filter_ir_out: continue i_out = len(out) out.append((mul, ir)) instr += [ (i, i, i_out, 'uuu', False) ] super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
def test_creation(): o3.Irrep(3, 1) ir = o3.Irrep("3e") o3.Irrep(ir) assert o3.Irrep('10o') == o3.Irrep(10, -1) assert o3.Irrep("1y") == o3.Irrep("1o") irreps = o3.Irreps(ir) o3.Irreps(irreps) o3.Irreps([(32, (4, -1))]) o3.Irreps("11e") assert o3.Irreps("16x1e + 32 x 2o") == o3.Irreps([(16, (1, 1)), (32, (2, -1))]) o3.Irreps(["1e", '2o']) o3.Irreps([(16, "3e"), '1e']) o3.Irreps([(16, "3e"), '1e', (256, (1, -1))])
def __init__(self, irreps_in, ir): r"""Extract ``ir`` from irreps Parameters ---------- irreps_in : `e3nn.o3.Irreps` representation of the input ir : `e3nn.o3.Irrep` representation to extract """ ir = o3.Irrep(ir) irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps( [mul_ir for mul_ir in irreps_in if mul_ir.ir == ir]) instructions = [ tuple(i for i, mul_ir in enumerate(irreps_in) if mul_ir.ir == ir) ] super().__init__(irreps_in, [self.irreps_out], instructions, squeeze_out=True)
def __init__( self, irreps_in1: o3.Irreps, irreps_in2: o3.Irreps, filter_ir_out: Iterator[o3.Irrep] = None, **kwargs ): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() if filter_ir_out is not None: filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out] out = [] instr = [] for i_1, (mul_1, ir_1) in enumerate(irreps_in1): for i_2, (mul_2, ir_2) in enumerate(irreps_in2): for ir_out in ir_1 * ir_2: if filter_ir_out is not None and ir_out not in filter_ir_out: continue i_out = len(out) out.append((mul_1 * mul_2, ir_out)) instr += [ (i_1, i_2, i_out, 'uvuv', False) ] out = o3.Irreps(out) out, p, _ = out.sort() instr = [ (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr ] super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
def test_contains(): assert o3.Irrep("2e") in o3.Irreps("3x0e + 2x2e + 1x3o") assert o3.Irrep("2o") not in o3.Irreps("3x0e + 2x2e + 1x3o")
def __init__(self, formula, filter_ir_out=None, filter_ir_mid=None, eps=1e-9, **irreps): super().__init__(self, fx.Graph()) if filter_ir_out is not None: filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out] f0, formulas = germinate_formulas(formula) irreps = {i: o3.Irreps(irs) for i, irs in irreps.items()} for i in irreps: if len(i) != 1: raise TypeError(f"got an unexpected keyword argument '{i}'") for _sign, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in irreps and j in irreps and irreps[i] != irreps[j]: raise RuntimeError(f'irreps of {i} and {j} should be the same') if i in irreps: irreps[j] = irreps[i] if j in irreps: irreps[i] = irreps[j] for i in f0: if i not in irreps: raise RuntimeError(f'index {i} has no irreps associated to it') for i in irreps: if i not in f0: raise RuntimeError(f'index {i} has an irreps but does not appear in the fomula') base_perm, _ = reduce_permutation( f0, formulas, dtype=torch.float64, **{i: irs.dim for i, irs in irreps.items()} ) Ps = collections.defaultdict(list) for ir, path, base_o3 in _wigner_nj(*[irreps[i] for i in f0], filter_ir_mid=filter_ir_mid, dtype=torch.float64): if filter_ir_out is None or ir in filter_ir_out: P = base_o3.flatten(1) @ base_perm.flatten(1).T if P.norm() > eps: # if this Irrep is present in the premutation basis we keep it Ps[ir].append((path, base_o3)) outputs = [] change_of_basis = [] irreps_out = [] for ir in Ps: mul = len(Ps[ir]) paths = [path for path, _ in Ps[ir]] base_o3 = torch.stack([R for _, R in Ps[ir]]) R = base_o3.flatten(2) # [multiplicity, ir, input basis] (u,j,omega) P = base_perm.flatten(1) # [permutation basis, input basis] (a,omega) Xs = [] for j in range(ir.dim): RR = R[:, j] @ R[:, j].T # (u,u) PP = P @ P.T # (a,a) RP = R[:, j] @ P.T # (u,a) prob = torch.cat([ torch.cat([RR, -RP], dim=1), torch.cat([-RP.T, PP], dim=1) ], dim=0) eigenvalues, eigenvectors = torch.linalg.eigh(prob) X = eigenvectors[:, eigenvalues < eps][:mul].T # [solutions, multiplicity] X = torch.linalg.qr(X, mode='r').R for i, x in enumerate(X): for j in range(i, mul): if x[j] < eps: x.neg_() if x[j] > eps: break X[X.abs() < eps] = 0 X = sorted([[x.item() for x in line] for line in X]) X = torch.tensor(X, dtype=torch.float64) Xs.append(X) for X in Xs: assert (X - Xs[0]).abs().max() < eps X = Xs[0] for x in X: C = torch.einsum("u,ui...->i...", x, base_o3) correction = (ir.dim / C.pow(2).sum())**0.5 C = correction * C outputs.append([((correction * v).item(), p) for v, p in zip(x, paths) if v.abs() > eps]) change_of_basis.append(C) irreps_out.append((1, ir)) dtype, _ = explicit_default_types(None, None) self.register_buffer('change_of_basis', torch.cat(change_of_basis).to(dtype=dtype)) tps = set() for vp_list in outputs: for v, p in vp_list: for op in _get_ops(p): tps.add(op) tps = list(tps) for i, op in enumerate(tps): tp = o3.TensorProduct(op[0], op[1], op[2], [(0, 0, 0, 'uuu', False)]) setattr(self, f'tp{i}', tp) graph = fx.Graph() inputs = [ fx.Proxy(graph.placeholder(f"x{i}", torch.Tensor)) for i in f0 ] self.irreps_in = [irreps[i] for i in f0] self.irreps_out = o3.Irreps(irreps_out).simplify() values = dict() def evaluate(path): if path in values: return values[path] if isinstance(path, _INPUT): out = inputs[path.tensor] if (path.start, path.stop) != (0, self.irreps_in[path.tensor].dim): out = out.narrow(-1, path.start, path.stop - path.start) if isinstance(path, _TP): x1 = evaluate(path.args[0]).node x2 = evaluate(path.args[1]).node out = fx.Proxy(graph.call_module(f'tp{tps.index(path.op)}', (x1, x2))) values[path] = out return out outs = [] for vp_list in outputs: v, p = vp_list[0] out = evaluate(p) if abs(v - 1.0) > eps: out = v * out for v, p in vp_list[1:]: t = evaluate(p) if abs(v - 1.0) > eps: t = v * t out = out + t outs.append(out) out = torch.cat(outs, dim=-1) graph.output(out.node) self.graph = graph self.recompile()