def test_map_mul_to_Rs(): with o3.torch_default_dtype(torch.float64): Rs = [(3, 0)] mapping_matrix = rs.map_mul_to_Rs(Rs) assert torch.allclose(mapping_matrix, torch.eye(3)) Rs = [(1, 0), (1, 1), (1, 2)] mapping_matrix = rs.map_mul_to_Rs(Rs) check_matrix = torch.zeros(1 + 3 + 5, 3) check_matrix[0, 0] = 1. check_matrix[1:4, 1] = 1. check_matrix[4:, 2] = 1. assert torch.allclose(mapping_matrix, check_matrix)
def from_geometry(cls, vectors, radial_model, lmax, sum_points=True): """ :param vectors: tensor of shape [..., xyz] :param radial_model: function of signature R+ -> R^mul :param lmax: maximal order of the signal """ size = vectors.shape[:-1] vectors = vectors.reshape(-1, 3) # [N, 3] radii = vectors.norm(2, -1) radial_functions = radial_model(radii) *_size, R = radial_functions.shape Rs = [(R, L) for L in range(lmax + 1)] mul_map = rs.map_mul_to_Rs(Rs) radial_functions = torch.einsum('nr,dr->nd', radial_functions.repeat(1, lmax + 1), mul_map) # [N, signal] Ys = projection(vectors / radii.unsqueeze(-1), lmax) # [N, l * m] irrep_map = rs.map_irrep_to_Rs(Rs) Ys = torch.einsum('nc,dc->nd', Ys, irrep_map) # [N, l * mul * m] signal = Ys * radial_functions # [N, l * mul * m] if sum_points: signal = signal.sum(0) else: signal = signal.reshape(*size, -1) new_cls = cls(signal, R, lmax) new_cls.radial_model = radial_model return new_cls
def plot_data_on_grid(box_length, radial, Rs, sh=o3.spherical_harmonics_xyz, n=30): L_to_index = {} set_of_L = set([L for mul, L in Rs]) start = 0 for L in set_of_L: L_to_index[L] = [start, start + 2 * L + 1] start += 2 * L + 1 r = np.mgrid[-1:1:n * 1j, -1:1:n * 1j, -1:1:n * 1j].reshape(3, -1) r = r.transpose(1, 0) r *= box_length / 2. r = torch.from_numpy(r) Ys = sh(set_of_L, r) R = radial(r.norm(2, -1)).detach() # [r_values, n_r_filters] assert R.shape[-1] == rs.mul_dim(Rs) R_helper = torch.zeros(R.shape[-1], rs.dim(Rs)) mul_start = 0 y_start = 0 Ys_indices = [] for mul, L in Rs: Ys_indices += list(range(L_to_index[L][0], L_to_index[L][1])) * mul R_helper = rs.map_mul_to_Rs(Rs) full_Ys = Ys[Ys_indices] # [values, rs.dim(Rs)]] full_Ys = full_Ys.reshape(full_Ys.shape[0], -1) all_f = torch.einsum('xn,dn,dx->xd', R, R_helper, full_Ys) return r, all_f
def from_geometry_with_radial(cls, vectors, radial_model, lmax, sum_points=True): vectors = vectors.reshape(-1, 3) # [N, 3] r = vectors.norm(2, -1) radial_functions = radial_model(r) _N, R = radial_functions.shape Rs = [(R, L) for L in range(lmax + 1)] mul_map = rs.map_mul_to_Rs(Rs) radial_functions = torch.einsum('nr,dr->nd', radial_functions.repeat(1, lmax + 1), mul_map) # [N, signal] Ys = projection(vectors / r.unsqueeze(-1), lmax) # [N, l * m] irrep_map = rs.map_irrep_to_Rs(Rs) Ys = torch.einsum('nc,dc->nd', Ys, irrep_map) # [N, l * mul * m] signal = Ys * radial_functions # [N, l * mul * m] if sum_points: signal = signal.sum(0) new_cls = cls(signal, R, lmax) new_cls.radial_model = radial_model return new_cls
def kernel_geometric(Rs_in, Rs_out, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): # Compute Clebsh-Gordan coefficients Rs_f, Q = rs.tensor_product(Rs_in, selection_rule, Rs_out, normalization) # [out, in, Y] # Sort filters representation Rs_f, perm = rs.sort(Rs_f) Rs_f = rs.simplify(Rs_f) Q = torch.einsum('ijk,lk->ijl', Q, perm) del perm # Normalize the spherical harmonics if normalization == 'component': diag = torch.ones(rs.irrep_dim(Rs_f)) if normalization == 'norm': diag = torch.cat( [torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in Rs_f]) norm_Y = math.sqrt(4 * math.pi) * torch.diag(diag) # [Y, Y] # Matrix to dispatch the spherical harmonics mat_Y = rs.map_irrep_to_Rs(Rs_f) # [Rs_f, Y] mat_Y = mat_Y @ norm_Y # Create the radial model: R+ -> R^n_path mat_R = rs.map_mul_to_Rs(Rs_f) # [Rs_f, R] mixing_matrix = torch.einsum('ijk,ky,kw->ijyw', Q, mat_Y, mat_R) # [out, in, Y, R] return Rs_f, mixing_matrix
def __mul__(self, other): # Dot product if Rs of both objects match lmax = max(self.lmax, other.lmax) new_self = self.change_lmax(lmax) new_other = other.change_lmax(lmax) mult = new_self.signal * new_other.signal mapping_matrix = rs.map_mul_to_Rs(new_self.Rs) scalars = torch.einsum('rm,...r->...m', mapping_matrix, mult) Rs = [(1, 0, p1 * p2) for (_, l1, p1), (_, l2, p2) in zip(new_self.Rs, new_other.Rs)] return IrrepTensor(scalars, Rs)
def __mul__(self, other): # Dot product if Rs of both objects match if self.mul != other.mul: raise ValueError("Multiplicities do not match.") lmax = max(self.lmax, other.lmax) new_self = self.change_lmax(lmax) new_other = other.change_lmax(lmax) mult = (new_self.signal * new_other.signal) mapping_matrix = rs.map_mul_to_Rs(new_self.Rs) scalars = torch.einsum('rm,r->m', mapping_matrix, mult) return SphericalTensor(scalars, mul=new_self.mul * (new_self.lmax + 1), lmax=0)
def from_geometry_with_radial(cls, vectors, radial_model, L_max, sum_points=True): r = vectors.norm(2, -1) radial_functions = radial_model(r) # [N, R] N, R = radial_functions.shape Rs = [(R, L) for L in range(L_max + 1)] Ys = projection(vectors, L_max, sum_points=False, radius=False) # [channels, N] mul_map = rs.map_mul_to_Rs(Rs) irrep_map = rs.map_irrep_to_Rs(Rs) signal = torch.einsum('nr,cn,dr,dc->nd', radial_functions.repeat(1, len(Rs)), Ys, mul_map, irrep_map) if sum_points: signal = signal.sum(0) new_cls = cls(signal, Rs) new_cls.radial_model = radial_model return new_cls