def __init__(self, Rs, acts): ''' Can be used only with scalar fields :param acts: list of tuple (multiplicity, activation) ''' super().__init__() Rs = o3.simplify(Rs) n1 = sum(mul for mul, _, _ in Rs) n2 = sum(mul for mul, _ in acts if mul > 0) # normalize the second moment acts = [(mul, normalize2mom(act)) for mul, act in acts] for i, (mul, act) in enumerate(acts): if mul == -1: acts[i] = (n1 - n2, act) assert n1 - n2 >= 0 assert n1 == sum(mul for mul, _ in acts) i = 0 while i < len(Rs): mul_r, l, p_r = Rs[i] mul_a, act = acts[i] if mul_r < mul_a: acts[i] = (mul_r, act) acts.insert(i + 1, (mul_a - mul_r, act)) if mul_a < mul_r: Rs[i] = (mul_a, l, p_r) Rs.insert(i + 1, (mul_r - mul_a, l, p_r)) i += 1 x = torch.linspace(0, 10, 256) Rs_out = [] for (mul, l, p_in), (mul_a, act) in zip(Rs, acts): assert mul == mul_a a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: p_act = 1 elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: p_act = -1 else: p_act = 0 p = p_act if p_in == -1 else p_in Rs_out.append((mul, 0, p)) if p_in != 0 and p == 0: raise ValueError("warning! the parity is violated") self.Rs_out = o3.simplify(Rs_out) self.acts = acts
def __repr__(self): return "{name}({Rs_in1} x {Rs_in2} -> {Rs_out} {nw} weights)".format( name=self.__class__.__name__, Rs_in1=o3.format_Rs(o3.simplify(self.Rs_in1)), Rs_in2=o3.format_Rs(o3.simplify(self.Rs_in2)), Rs_out=o3.format_Rs(o3.simplify(self.Rs_out)), nw=self.nweight, )
def __init__(self, *Rs_outs): super().__init__() self.Rs_outs = tuple(o3.simplify(Rs) for Rs in Rs_outs) def key(rs): _mul, l, p = rs return (l, p) self.Rs_in = o3.simplify( sorted((x for Rs in self.Rs_outs for x in Rs), key=key))
def __init__(self, Rs_in, Rs_out): super().__init__() self.Rs_in = o3.simplify(Rs_in) self.Rs_out = o3.simplify(Rs_out) assert self.Rs_in == self.Rs_out output_mask = torch.cat( [torch.ones(mul * (2 * l + 1)) for mul, l, p in self.Rs_out]) self.register_buffer('output_mask', output_mask)
def WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, normalization='component', own_weight=True, weight_batch=False): Rs_in1 = o3.simplify(Rs_in1) Rs_in2 = o3.simplify(Rs_in2) Rs_out = o3.simplify(Rs_out) instr = [(i_1, i_2, i_out, 'uvw', True) for i_1, (_, l_1, p_1) in enumerate(Rs_in1) for i_2, (_, l_2, p_2) in enumerate(Rs_in2) for i_out, (_, l_out, p_out) in enumerate(Rs_out) if abs(l_1 - l_2) <= l_out <= l_1 + l_2 and p_1 * p_2 == p_out] return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, normalization, own_weight, weight_batch)
def spherical_harmonics_alpha_z_y(Rs, alpha, z, y): """ spherical harmonics """ Rs = o3.simplify(Rs) sha = spherical_harmonics_alpha(o3.lmax(Rs), alpha.flatten()) # [z, m] shz = spherical_harmonics_z(Rs, z.flatten(), y.flatten()) # [z, l * m] out = mul_m_lm(Rs, sha, shz) return out.reshape(alpha.shape + (shz.shape[1], ))
def spherical_harmonics_z(Rs, z, y=None): """ the z component of the spherical harmonics (useful to perform fourier transform) :param z: tensor of shape [...] :return: tensor of shape [..., l * m] """ Rs = o3.simplify(Rs) for _, l, p in Rs: assert p in [0, (-1)**l] ls = [l for mul, l, _ in Rs] return legendre(ls, z, y) # [..., l * m]
def __init__(self, Rs_in, Rs_out, Rs_sh, rad_hs, normalization='component'): super().__init__(aggr='add') self.Rs_in = o3.simplify(Rs_in) self.Rs_out = o3.simplify(Rs_out) self.Rs_sh = o3.simplify(Rs_sh) self.si = Linear(self.Rs_in, self.Rs_out) self.lin1 = Linear(self.Rs_in, self.Rs_in) instr = [] Rs = [] for i_1, (mul_1, l_1, p_1) in enumerate(self.Rs_in): for i_2, (_, l_2, p_2) in enumerate(self.Rs_sh): for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1): p_out = p_1 * p_2 if (l_out, p_out) in [(l, p) for _, l, p in self.Rs_out]: r = (mul_1, l_out, p_out) if r in Rs: i_out = Rs.index(r) else: i_out = len(Rs) Rs.append(r) instr += [(i_1, i_2, i_out, 'uvu', True)] self.tp = CustomWeightedTensorProduct(self.Rs_in, self.Rs_sh, Rs, instr, own_weight=False, weight_batch=True) self.nn = FC(rad_hs + (self.tp.nweight, ), swish) self.lin2 = Linear(Rs, self.Rs_out) self.normalization = normalization
def rep(Rs, alpha, beta, gamma, parity=None): """ Representation of O(3). Parity applied (-1)**parity times. """ Rs = o3.simplify(Rs) if parity is None: return direct_sum(*[ irrep(l, alpha, beta, gamma) for mul, l, _ in Rs for _ in range(mul) ]) else: assert all(parity != 0 for _, _, parity in Rs) return direct_sum(*[(p**parity) * irrep(l, alpha, beta, gamma) for mul, l, p in Rs for _ in range(mul)])
def ElementwiseTensorProduct(Rs_in1, Rs_in2, normalization='component'): Rs_in1 = o3.simplify(Rs_in1) Rs_in2 = o3.simplify(Rs_in2) assert sum(mul for mul, _, _ in Rs_in1) == sum(mul for mul, _, _ in Rs_in2) i = 0 while i < len(Rs_in1): mul_1, l_1, p_1 = Rs_in1[i] mul_2, l_2, p_2 = Rs_in2[i] if mul_1 < mul_2: Rs_in2[i] = (mul_1, l_2, p_2) Rs_in2.insert(i + 1, (mul_2 - mul_1, l_2, p_2)) if mul_2 < mul_1: Rs_in1[i] = (mul_2, l_1, p_1) Rs_in1.insert(i + 1, (mul_1 - mul_2, l_1, p_1)) i += 1 Rs_out = [] instr = [] for i, ((mul, l_1, p_1), (mul_2, l_2, p_2)) in enumerate(zip(Rs_in1, Rs_in2)): assert mul == mul_2 for l in list(range(abs(l_1 - l_2), l_1 + l_2 + 1)): i_out = len(Rs_out) Rs_out.append((mul, l, p_1 * p_2)) instr += [(i, i, i_out, 'uuu', False)] return CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, normalization, own_weight=False)
def __init__(self, Rs_in, Rs_out, normalization: str = 'component'): super().__init__() self.Rs_in = o3.simplify(Rs_in) self.Rs_out = o3.simplify(Rs_out) instr = [(i_in, 0, i_out, 'uvw', True) for i_in, (_, l_in, p_in) in enumerate(self.Rs_in) for i_out, (_, l_out, p_out) in enumerate(self.Rs_out) if l_in == l_out and p_in == p_out] self.tp = CustomWeightedTensorProduct(self.Rs_in, [(1, 0, 1)], self.Rs_out, instr, normalization, own_weight=True) output_mask = torch.cat([ torch.ones(mul * (2 * l + 1)) if any( l_in == l and p_in == p for _, l_in, p_in in self.Rs_in) else torch.zeros(mul * (2 * l + 1)) for mul, l, p in self.Rs_out ]) self.register_buffer('output_mask', output_mask)
def __init__(self, Rs_scalars, act_scalars, Rs_gates, act_gates, Rs_nonscalars): super().__init__() self.sc = Sortcut(Rs_scalars, Rs_gates) self.Rs_scalars, self.Rs_gates = self.sc.Rs_outs self.Rs_nonscalars = o3.simplify(Rs_nonscalars) self.Rs_in = self.sc.Rs_in + self.Rs_nonscalars self.act_scalars = Activation(Rs_scalars, act_scalars) Rs_scalars = self.act_scalars.Rs_out self.act_gates = Activation(Rs_gates, act_gates) Rs_gates = self.act_gates.Rs_out self.mul = nn.ElementwiseTensorProduct(Rs_nonscalars, Rs_gates) Rs_nonscalars = self.mul.Rs_out self.Rs_out = Rs_scalars + Rs_nonscalars
def make_gated_block(Rs_in, muls, ps, Rs_sh): """ Make a `GatedBlockParity` assuming many things """ Rs_available = [(l, p_in * p_sh) for _, l_in, p_in in o3.simplify(Rs_in) for _, l_sh, p_sh in Rs_sh for l in range(abs(l_in - l_sh), l_in + l_sh + 1)] scalars = [(muls[0], 0, p) for p in ps if (0, p) in Rs_available] act_scalars = [(mul, swish if p == 1 else torch.tanh) for mul, l, p in scalars] nonscalars = [(muls[l], l, p * (-1)**l) for l in range(1, len(muls)) for p in ps if (l, p * (-1)**l) in Rs_available] if (0, +1) in Rs_available: gates = [(o3.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, torch.sigmoid)] else: gates = [(o3.mul_dim(nonscalars), 0, -1)] act_gates = [(-1, torch.tanh)] return GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars)
def __init__(self, muls=(128, 8, 0), ps=(1, -1), lmax=1, num_layers=6, cutoff=10.0, rad_gaussians=50, rad_hs=(128, ), num_neighbors=20, readout='add', dipole=False, mean=None, std=None, scale=None, atomref=None, options="res"): super().__init__() assert readout in ['add', 'sum', 'mean'] self.readout = readout self.cutoff = cutoff self.dipole = dipole self.mean = mean self.std = std self.scale = scale self.num_neighbors = num_neighbors self.options = options self.embedding = Embedding(100, muls[0]) self.Rs_in = [(muls[0], 0, 1)] self.radial = GaussianRadialModel(rad_gaussians, cutoff) self.Rs_sh = [(1, l, (-1)**l) for l in range(lmax + 1) ] # spherical harmonics representation Rs = self.Rs_in modules = [] for _ in range(num_layers): act = make_gated_block(Rs, muls, ps, self.Rs_sh) conv = Conv(Rs, act.Rs_in, self.Rs_sh, (rad_gaussians, ) + rad_hs) if 'res' in self.options: if Rs == act.Rs_out: shortcut = Identity(Rs, act.Rs_out) else: shortcut = Linear(Rs, act.Rs_out) else: shortcut = None Rs = o3.simplify(act.Rs_out) modules += [torch.nn.ModuleList([conv, act, shortcut])] self.layers = torch.nn.ModuleList(modules) self.Rs_out = [(1, 0, p) for p in ps] self.layers.append( Conv(Rs, self.Rs_out, self.Rs_sh, (rad_gaussians, ) + rad_hs)) self.register_buffer('initial_atomref', atomref) self.atomref = None if atomref is not None: self.atomref = Embedding(100, 1) self.atomref.weight.data.copy_(atomref)
def reduce_tensor(formula, eps=1e-9, has_parity=None, **kw_Rs): """ Usage Rs, Q = rs.reduce_tensor('ijkl=jikl=ikjl=ijlk', i=[(1, 1)]) Rs = 0,2,4 Q = tensor of shape [15, 81] """ dtype = torch.get_default_dtype() with torch_default_dtype(torch.float64): # reformat `formulas` and make checks formulas = [ (-1 if f.startswith('-') else 1, f.replace('-', '')) for f in formula.split('=') ] s0, f0 = formulas[0] assert s0 == 1 for _s, f in formulas: if len(set(f)) != len(f) or set(f) != set(f0): raise RuntimeError(f'{f} is not a permutation of {f0}') if len(f0) != len(f): raise RuntimeError(f'{f0} and {f} don\'t have the same number of indices') # `formulas` is a list of (sign, permutation of indices) # each formula can be viewed as a permutation of the original formula formulas = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) # they can be composed, for instance if you have ijk=jik=ikj # you also have ijk=jki # applying all possible compositions creates an entire group while True: n = len(formulas) formulas = formulas.union([(s, perm.inverse(p)) for s, p in formulas]) formulas = formulas.union([ (s1 * s2, perm.compose(p1, p2)) for s1, p1 in formulas for s2, p2 in formulas ]) if len(formulas) == n: break # we break when the set is stable => it is now a group \o/ # lets clean the `kw_Rs` before checking that they are compatible with the formulas for i in kw_Rs: if not callable(kw_Rs[i]): Rs = o3.convention(kw_Rs[i]) if has_parity is None: has_parity = any(p != 0 for _, _, p in Rs) if not has_parity and not all(p == 0 for _, _, p in Rs): raise RuntimeError(f'{o3.format_Rs(Rs)} parity has to be specified everywhere or nowhere') if has_parity and any(p == 0 for _, _, p in Rs): raise RuntimeError(f'{o3.format_Rs(Rs)} parity has to be specified everywhere or nowhere') kw_Rs[i] = Rs if has_parity is None: raise RuntimeError(f'please specify the argument `has_parity`') group = O3() if has_parity else SO3() # here we check that each index has one and only one representation for _s, p in formulas: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): if i in kw_Rs and j in kw_Rs and kw_Rs[i] != kw_Rs[j]: raise RuntimeError(f'Rs of {i} (Rs={o3.format_Rs(kw_Rs[i])}) and {j} (Rs={o3.format_Rs(kw_Rs[j])}) should be the same') if i in kw_Rs: kw_Rs[j] = kw_Rs[i] if j in kw_Rs: kw_Rs[i] = kw_Rs[j] for i in f0: if i not in kw_Rs: raise RuntimeError(f'index {i} has not Rs associated to it') ide = group.identity() dims = {i: len(kw_Rs[i](*ide)) if callable(kw_Rs[i]) else o3.dim(kw_Rs[i]) for i in f0} # dimension of each index full_base = list(itertools.product(*(range(dims[i]) for i in f0))) # (0, 0, 0), (0, 0, 1), (0, 0, 2), ... (3, 3, 3) # len(full_base) degrees of freedom in an unconstrained tensor # but there is constraints given by the group `formulas` # For instance if `ij=-ji`, then 00=-00, 01=-01 and so on base = set() for x in full_base: # T[x] is a coefficient of the tensor T and is related to other coefficient T[y] # if x and y are related by a formula xs = {(s, tuple(x[i] for i in p)) for s, p in formulas} # s * T[x] are all equal for all (s, x) in xs # if T[x] = -T[x] it is then equal to 0 and we lose this degree of freedom if not (-1, x) in xs: # the sign is arbitrary, put both possibilities base.add(frozenset({ frozenset(xs), frozenset({(-s, x) for s, x in xs}) })) # len(base) is the number of degrees of freedom in the tensor. # Now we want to decompose these degrees of freedom into irreps base = sorted([sorted([sorted(xs) for xs in x]) for x in base]) # requested for python 3.7 but not for 3.8 (probably a bug in 3.7) # First we compute the change of basis (projection) between full_base and base d_sym = len(base) d = len(full_base) Q = torch.zeros(d_sym, d) for i, x in enumerate(base): x = max(x, key=lambda xs: sum(s for s, x in xs)) for s, e in x: j = full_base.index(e) Q[i, j] = s / len(x)**0.5 assert torch.allclose(Q @ Q.T, torch.eye(d_sym)) if d_sym == 0: return [], torch.zeros(d_sym, d).to(dtype=dtype) # We project the representation on the basis `base` def representation(g): def re(r): if callable(r): return r(*g) return o3.rep(r, *g) m = kron(*(re(kw_Rs[i]) for i in f0)) return Q @ m @ Q.T # And check that after this projection it is still a representation assert is_representation(group, representation, eps) # The rest of the code simply extract the irreps present in this representation Rs_out = [] A = Q.clone() for r in group.irrep_indices(): if group.irrep(r)(ide).shape[0] > d_sym - o3.dim(Rs_out): break mul, B, representation = reduce(group, representation, group.irrep(r), eps) A = direct_sum(torch.eye(d_sym - B.shape[0]), B) @ A A = _round_sqrt(A, eps) if has_parity: Rs_out += [(mul,) + r] else: Rs_out += [(mul, r)] if o3.dim(Rs_out) == d_sym: break if o3.dim(Rs_out) != d_sym: raise RuntimeError(f'unable to decompose into irreducible representations') return o3.simplify(Rs_out), A.to(dtype=dtype)