def initialize_edges(x, Rs_in, pos, edge_index_dict, lmax, self_edge=1., symmetric_edges=False): """Initialize edge features of DataEdgeNeighbors using node features and SphericalTensor. Args: x (torch.tensor shape [N, rs.dim(Rs_in)]): Node features. Rs_in (rs.TY_RS_STRICT): Representation list of input. pos (torch.tensor shape [N, 3]): Cartesian coordinates of nodes. edge_index (torch.LongTensor shape [2, num_edges]): Edges described by index of node target then node source. lmax (int > 0): Maximum L to use for SphericalTensor projection of radial distance vectors self_edge (float, optional): L=0 feature for self edges. Defaults to 1. symmetric_edges (bool, optional): Constrain edge features to be symmetric in node index. Defaults to False Returns: edge_x: Edge features. Rs_edge (rs.TY_RS_STRICT): Representation list of edge features. """ from e3nn.tensor import SphericalTensor edge_x = [] if symmetric_edges: Rs, Q = rs.reduce_tensor('ij=ji', i=Rs_in) else: Rs, Q = rs.reduce_tensor('ij', i=Rs_in, j=Rs_in) Q = Q.reshape(-1, rs.dim(Rs_in), rs.dim(Rs_in)) Rs_sph = [(1, l, (-1)**l) for l in range(lmax + 1)] tp_kernel = rs.TensorProduct(Rs, Rs_sph, o3.selection_rule) keys, values = list(zip(*edge_index_dict.items())) sorted_edges = sorted(zip(keys, values), key=lambda x: x[1]) for (target, source), _ in sorted_edges: Ia = x[target] Ib = x[source] vector = (pos[source] - pos[target]).reshape(-1, 3) if torch.allclose(vector, torch.zeros(vector.shape)): signal = torch.zeros(rs.dim(Rs_sph)) signal[0] = self_edge else: signal = SphericalTensor.from_geometry(vector, lmax=lmax).signal if symmetric_edges: signal += SphericalTensor.from_geometry(-vector, lmax=lmax).signal signal *= 0.5 output = torch.einsum('kij,i,j->k', Q, Ia, Ib) output = tp_kernel(output, signal) edge_x.append(output) edge_x = torch.stack(edge_x, dim=0) return edge_x, tp_kernel.Rs_out
def test_reduce_tensor_Levi_Civita_symbol(): Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 1)]) assert Rs == [(1, 0, 0)] r = o3.rand_angles() D = o3.irr_repr(1, *r) Q = Q.reshape(3, 3, 3) Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q) assert (Q1 - Q).abs().max() < 1e-10
def to_irrep_transformation(self): dim = self.tensor.dim() change = o3.kron(*[o3.xyz_to_irreducible_basis()] * dim) Rs = [(1, 1)] # vectors old_indices = self.formula.split("=")[0] Rs_out, Q = rs.reduce_tensor(self.formula, **{i: Rs for i in old_indices}) return Rs_out, torch.einsum('ab,bc->ac', Q, change.reshape(3**dim, 3**dim))
def test_reduce_tensor_antisymmetric_L2(): Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 2)]) assert Rs[0] == (1, 1, 0) q = Q[:3].reshape(3, 5, 5, 5) r = o3.rand_angles() D1 = o3.irr_repr(1, *r) D2 = o3.irr_repr(2, *r) Q1 = torch.einsum('il,jm,kn,zijk->zlmn', D2, D2, D2, q) Q2 = torch.einsum('yz,zijk->yijk', D1, q) assert (Q1 - Q2).abs().max() < 1e-10 assert (q + q.transpose(1, 2)).abs().max() < 1e-10 assert (q + q.transpose(1, 3)).abs().max() < 1e-10 assert (q + q.transpose(3, 2)).abs().max() < 1e-10
def test_reduce_tensor_elasticity_tensor(): Rs, _Q = rs.reduce_tensor('ijkl=jikl=klij', i=[(1, 1)]) assert rs.dim(Rs) == 21
def test_reduce_tensor_elasticity_tensor_parity(): Rs, _Q = rs.reduce_tensor('ijkl=jikl=klij', i=[(1, 1, -1)]) assert all(p == 1 for (_, _, p) in Rs) assert rs.dim(Rs) == 21