def test(Rs, act): x = torch.randn(55, sum(2 * l + 1 for _, l, _ in Rs)) ac = S2Activation(Rs, act, 1000) y1 = ac(x, dim=-1) @ SO3.rep(ac.Rs_out, 0, 0, 0, -1).T y2 = ac(x @ SO3.rep(Rs, 0, 0, 0, -1).T, dim=-1) self.assertLess((y1 - y2).abs().max(), 1e-10)
def precompute(self, R): a = torch.linspace(0, 2 * math.pi, 2 * self.n) b = torch.linspace(0, math.pi, self.n)[2:-2] a, b = torch.meshgrid(a, b) xyz = torch.stack(SO3.angles_to_xyz(a, b), dim=-1) @ R.t() a, b = SO3.xyz_to_angles(xyz) proj = SphericalHarmonicsProject(a, b, self.lmax) return xyz, proj
def __init__(self, n, lmax): super().__init__() self.n = n self.lmax = lmax R = SO3.rot(math.pi / 2, math.pi / 2, math.pi / 2) self.xyz1, self.proj1 = self.precompute(R) R = SO3.rot(0, 0, 0) self.xyz2, self.proj2 = self.precompute(R)
def check_basis_equivariance(basis, order_in, order_out, alpha, beta, gamma): from e3nn import SO3 from scipy.ndimage import affine_transform import numpy as np n = basis.size(0) dim_in = 2 * order_in + 1 dim_out = 2 * order_out + 1 size = basis.size(-1) assert basis.size() == (n, dim_out, dim_in, size, size, size), basis.size() basis = basis / basis.view(n, -1).norm(dim=1).view(-1, 1, 1, 1, 1, 1) x = basis.view(-1, size, size, size) y = torch.empty_like(x) invrot = SO3.rot(-gamma, -beta, -alpha).numpy() center = (np.array(x.size()[1:]) - 1) / 2 for k in range(y.size(0)): y[k] = torch.tensor(affine_transform(x[k].numpy(), matrix=invrot, offset=center - np.dot(invrot, center))) y = y.view(*basis.size()) y = torch.einsum( "ij,bjkxyz,kl->bilxyz", ( irr_repr(order_out, alpha.item(), beta.item(), gamma.item(), dtype=y.dtype), y, irr_repr(order_in, -gamma.item(), -beta.item(), -alpha.item(), dtype=y.dtype) ) ) return torch.tensor([(basis[i] * y[i]).sum() for i in range(n)])
def random_rotate_translate(positions, rotation=True, translation=1): while True: trans = torch.rand(3) * 2 - 1 if trans.norm() <= 1: break rot = SO3.rot(*torch.rand(3) * 6.2832).type(torch.float32) return [rot @ pos + translation * trans for pos in positions]
def __init__(self, Rs, normalization='norm'): super().__init__() Rs = SO3.normalizeRs(Rs) n = sum(mul for mul, _, _ in Rs) self.Rs_in = Rs self.Rs_out = [(n, 0, +1)] self.normalization = normalization
def __init__(self, Rs, acts): ''' Can be used only with scalar fields :param acts: list of tuple (multiplicity, activation) ''' super().__init__() Rs = SO3.normalizeRs(Rs) acts = copy.deepcopy(acts) n1 = sum(mul for mul, _, _ in Rs) n2 = sum(mul for mul, _ in acts if mul > 0) for i, (mul, act) in enumerate(acts): if mul == -1: acts[i] = (n1 - n2, act) assert n1 - n2 >= 0 assert n1 == sum(mul for mul, _ in acts) i = 0 while i < len(Rs): mul_r, l, p_r = Rs[i] mul_a, act = acts[i] if mul_r < mul_a: acts[i] = (mul_r, act) acts.insert(i + 1, (mul_a - mul_r, act)) if mul_a < mul_r: Rs[i] = (mul_a, l, p_r) Rs.insert(i + 1, (mul_r - mul_a, l, p_r)) i += 1 x = torch.linspace(0, 10, 256) Rs_out = [] for (mul, l, p_in), (mul_a, act) in zip(Rs, acts): assert mul == mul_a a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: p_act = 1 elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: p_act = -1 else: p_act = 0 p = p_act if p_in == -1 else p_in Rs_out.append((mul, 0, p)) if p_in != 0 and p == 0: raise ValueError("warning! the parity is violated") self.Rs_out = Rs_out self.acts = acts
def __init__(self, Rs_1, Rs_2): super().__init__() Rs_1 = SO3.normalizeRs(Rs_1) Rs_2 = SO3.normalizeRs(Rs_2) assert sum(mul for mul, _, _ in Rs_1) == sum(mul for mul, _, _ in Rs_2) i = 0 while i < len(Rs_1): mul_1, l_1, p_1 = Rs_1[i] mul_2, l_2, p_2 = Rs_2[i] if mul_1 < mul_2: Rs_2[i] = (mul_1, l_2, p_2) Rs_2.insert(i + 1, (mul_2 - mul_1, l_2, p_2)) if mul_2 < mul_1: Rs_1[i] = (mul_2, l_1, p_1) Rs_1.insert(i + 1, (mul_1 - mul_2, l_1, p_1)) i += 1 self.Rs_1 = Rs_1 self.Rs_2 = Rs_2 Rs_out = [] for (mul, l_1, p_1), (mul_2, l_2, p_2) in zip(Rs_1, Rs_2): assert mul == mul_2 for l in range(abs(l_1 - l_2), l_1 + l_2 + 1): Rs_out.append((mul, l, p_1 * p_2)) C = SO3.clebsch_gordan(l, l_1, l_2).type( torch.get_default_dtype()) * (2 * l + 1)**0.5 if l_1 == 0 or l_2 == 0: m = C.view(2 * l + 1, 2 * l + 1) if C.dtype == torch.float: assert (m - torch.eye(2 * l + 1, dtype=C.dtype) ).abs().max() < 1e-7, m.numpy().round(3) else: assert (m - torch.eye(2 * l + 1, dtype=C.dtype) ).abs().max() < 1e-10, m.numpy().round(3) else: self.register_buffer("cg_{}_{}_{}".format(l, l_1, l_2), C) self.Rs_out = Rs_out
def __init__(self, Rs): super().__init__() self.Rs_in = SO3.normalizeRs(Rs) xs = [] j = 0 # input offset for mul, l, p in self.Rs_in: d = mul * (2 * l + 1) xs.append((l, p, mul, j, d)) j += d mixing_matrix = torch.zeros(j, j) Rs_out = [] i = 0 # output offset for l, p, mul, j, d in sorted(xs): Rs_out.append((mul, l, p)) mixing_matrix[i:i+d, j:j+d] = torch.eye(d) i += d self.Rs_out = SO3.normalizeRs(Rs_out) self.register_buffer('mixing_matrix', mixing_matrix)
def __init__(self, Rs_in, Rs_out, RadialModel, get_l_filters=None, sh=SO3.spherical_harmonics_xyz, normalization='norm'): ''' :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param get_l_filters: function of signature (l_in, l_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) ''' super().__init__() self.Rs_in = SO3.normalizeRs(Rs_in) self.Rs_out = SO3.normalizeRs(Rs_out) def filters_with_parity(l_in, p_in, l_out, p_out): def filters(l_in, l_out): return list(range(abs(l_in - l_out), l_in + l_out + 1)) nonlocal get_l_filters fn = filters if get_l_filters is None else get_l_filters return [ l for l in fn(l_in, l_out) if p_out == 0 or p_in * (-1)**l == p_out ] self.get_l_filters = filters_with_parity self.check_input_output() self.sh = sh assert isinstance( normalization, str), "normalization should be passed as a string value" assert normalization in [ 'norm', 'component' ], "normalization needs to be 'norm' or 'component'" self.normalization = normalization def lm_normalization(l_out, l_in): # put 2l_in+1 to keep the norm of the m vector constant # put 2l_ou+1 to keep the variance of each m component constant # sum_m Y_m^2 = (2l+1)/(4pi) and norm(Q) = 1 implies that norm(QY) = sqrt(1/4pi) lm_norm = None if normalization == 'norm': lm_norm = math.sqrt(2 * l_in + 1) * math.sqrt(4 * math.pi) elif normalization == 'component': lm_norm = math.sqrt(2 * l_out + 1) * math.sqrt(4 * math.pi) return lm_norm norm_coef = torch.zeros((len(self.Rs_out), len(self.Rs_in), 2)) n_path = 0 set_of_l_filters = set() for i, (mul_out, l_out, p_out) in enumerate(self.Rs_out): # consider that we sum a bunch of [lambda_(m_out)] vectors # we need to count how many of them we sum in order to normalize the network num_summed_elements = 0 for mul_in, l_in, p_in in self.Rs_in: l_filters = self.get_l_filters(l_in, p_in, l_out, p_out) num_summed_elements += mul_in * len(l_filters) for j, (mul_in, l_in, p_in) in enumerate(self.Rs_in): # normalization assuming that each terms are of order 1 and uncorrelated norm_coef[i, j, 0] = lm_normalization( l_out, l_in) / math.sqrt(num_summed_elements) norm_coef[i, j, 1] = lm_normalization(l_out, l_in) / math.sqrt(mul_in) l_filters = self.get_l_filters(l_in, p_in, l_out, p_out) assert l_filters == sorted( set(l_filters) ), "get_l_filters must return a sorted list of unique values" # compute the number of degrees of freedom n_path += mul_out * mul_in * len(l_filters) # create the set of all spherical harmonics orders needed set_of_l_filters = set_of_l_filters.union(l_filters) # create the radial model: R+ -> R^n_path # it contains the learned parameters self.R = RadialModel(n_path) self.set_of_l_filters = sorted(set_of_l_filters) self.register_buffer('norm_coef', norm_coef)
def backward(ctx, grad_kernel): Y, R, norm_coef = ctx.saved_tensors grad_Y = grad_R = None if ctx.needs_input_grad[0]: grad_Y = grad_kernel.new_zeros( *ctx.Y_shape) # [l_filter * m_filter, batch] if ctx.needs_input_grad[1]: grad_R = grad_kernel.new_zeros( *ctx.R_shape ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out) if not l_filters: continue n = mul_out * mul_in * len(l_filters) if grad_Y is not None: sub_R = R[:, begin_R:begin_R + n].view( -1, mul_out, mul_in, len(l_filters)) # [batch, mul_out, mul_in, l_filter] if grad_R is not None: sub_grad_R = grad_R[:, begin_R:begin_R + n].view( -1, mul_out, mul_in, len(l_filters)) # [batch, mul_out, mul_in, l_filter] begin_R += n grad_K = grad_kernel[:, s_out, s_in].view(-1, mul_out, 2 * l_out + 1, mul_in, 2 * l_in + 1) sub_norm_coef = norm_coef[i, j] # [batch] for k, l_filter in enumerate(l_filters): tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters if l < l_filter) C = SO3.clebsch_gordan( l_out, l_in, l_filter, cached=True, like=grad_kernel) # [m_out, m_in, m] if grad_Y is not None: grad_Y[tmp:tmp + 2 * l_filter + 1] += torch.einsum( "zuivj,ijk,zuv,z->kz", grad_K, C, sub_R[..., k], sub_norm_coef) if grad_R is not None: sub_Y = Y[tmp:tmp + 2 * l_filter + 1] # [m, batch] sub_grad_R[..., k] = torch.einsum("zuivj,ijk,kz,z->zuv", grad_K, C, sub_Y, sub_norm_coef) del ctx return grad_Y, grad_R, None, None, None, None, None
def forward(ctx, Y, R, norm_coef, Rs_in, Rs_out, get_l_filters, set_of_l_filters): """ :param Y: tensor [l_filter * m_filter, batch] :param R: tensor [batch, l_out * l_in * mul_out * mul_in * l_filter] :param norm_coef: tensor [l_out, l_in, batch] :return: tensor [batch, l_out * mul_out * m_out, l_in * mul_in * m_in] """ ctx.Rs_in = Rs_in ctx.Rs_out = Rs_out ctx.get_l_filters = get_l_filters ctx.set_of_l_filters = set_of_l_filters # save necessary tensors for backward saved_Y = saved_R = None if Y.requires_grad: ctx.Y_shape = Y.shape saved_R = R if R.requires_grad: ctx.R_shape = R.shape saved_Y = Y ctx.save_for_backward(saved_Y, saved_R, norm_coef) batch = Y.shape[1] n_in = sum(mul * (2 * l + 1) for mul, l, _ in ctx.Rs_in) n_out = sum(mul * (2 * l + 1) for mul, l, _ in ctx.Rs_out) kernel = Y.new_zeros(batch, n_out, n_in) # note: for the normalization we assume that the variance of R[i] is one begin_R = 0 begin_out = 0 for i, (mul_out, l_out, p_out) in enumerate(ctx.Rs_out): s_out = slice(begin_out, begin_out + mul_out * (2 * l_out + 1)) begin_out += mul_out * (2 * l_out + 1) begin_in = 0 for j, (mul_in, l_in, p_in) in enumerate(ctx.Rs_in): s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1)) begin_in += mul_in * (2 * l_in + 1) l_filters = ctx.get_l_filters(l_in, p_in, l_out, p_out) if not l_filters: continue # extract the subset of the `R` that corresponds to the couple (l_out, l_in) n = mul_out * mul_in * len(l_filters) sub_R = R[:, begin_R:begin_R + n].contiguous().view( batch, mul_out, mul_in, -1) # [batch, mul_out, mul_in, l_filter] begin_R += n sub_norm_coef = norm_coef[i, j] # [batch] # note: I don't know if we can vectorize this for loop because [l_filter * m_filter] cannot be put into [l_filter, m_filter] K = 0 for k, l_filter in enumerate(l_filters): tmp = sum(2 * l + 1 for l in ctx.set_of_l_filters if l < l_filter) sub_Y = Y[tmp:tmp + 2 * l_filter + 1] # [m, batch] C = SO3.clebsch_gordan(l_out, l_in, l_filter, cached=True, like=kernel) # [m_out, m_in, m] # note: The multiplication with `sub_R` could also be done outside of the for loop K += torch.einsum( "ijk,kz,zuv,z->zuivj", (C, sub_Y, sub_R[..., k], sub_norm_coef )) # [batch, mul_out, m_out, mul_in, m_in] if K is not 0: kernel[:, s_out, s_in] = K.contiguous().view_as(kernel[:, s_out, s_in]) return kernel
def __repr__(self): return "{name} ({Rs_in} -> {Rs_out})".format( name=self.__class__.__name__, Rs_in=SO3.formatRs(self.Rs_in), Rs_out=SO3.formatRs(self.Rs_out), )
def __init__(self, alpha, beta, lmax): super().__init__() sh = torch.cat([SO3.spherical_harmonics(l, alpha, beta) for l in range(lmax + 1)]) self.register_buffer("sh", sh)