Ejemplo n.º 1
0
def initialize_edges(x,
                     Rs_in,
                     pos,
                     edge_index_dict,
                     lmax,
                     self_edge=1.,
                     symmetric_edges=False):
    """Initialize edge features of DataEdgeNeighbors using node features and SphericalTensor.

    Args:
        x (torch.tensor shape [N, rs.dim(Rs_in)]): Node features.
        Rs_in (rs.TY_RS_STRICT): Representation list of input.
        pos (torch.tensor shape [N, 3]): Cartesian coordinates of nodes.
        edge_index (torch.LongTensor shape [2, num_edges]): Edges described by index of node target then node source.
        lmax (int > 0): Maximum L to use for SphericalTensor projection of radial distance vectors
        self_edge (float, optional): L=0 feature for self edges. Defaults to 1.
        symmetric_edges (bool, optional): Constrain edge features to be symmetric in node index. Defaults to False

    Returns:
        edge_x: Edge features.
        Rs_edge (rs.TY_RS_STRICT): Representation list of edge features.
    """
    from e3nn.tensor import SphericalTensor
    edge_x = []
    if symmetric_edges:
        Rs, Q = rs.reduce_tensor('ij=ji', i=Rs_in)
    else:
        Rs, Q = rs.reduce_tensor('ij', i=Rs_in, j=Rs_in)
    Q = Q.reshape(-1, rs.dim(Rs_in), rs.dim(Rs_in))
    Rs_sph = [(1, l, (-1)**l) for l in range(lmax + 1)]
    tp_kernel = rs.TensorProduct(Rs, Rs_sph, o3.selection_rule)
    keys, values = list(zip(*edge_index_dict.items()))
    sorted_edges = sorted(zip(keys, values), key=lambda x: x[1])
    for (target, source), _ in sorted_edges:
        Ia = x[target]
        Ib = x[source]
        vector = (pos[source] - pos[target]).reshape(-1, 3)
        if torch.allclose(vector, torch.zeros(vector.shape)):
            signal = torch.zeros(rs.dim(Rs_sph))
            signal[0] = self_edge
        else:
            signal = SphericalTensor.from_geometry(vector, lmax=lmax).signal
            if symmetric_edges:
                signal += SphericalTensor.from_geometry(-vector,
                                                        lmax=lmax).signal
                signal *= 0.5
        output = torch.einsum('kij,i,j->k', Q, Ia, Ib)
        output = tp_kernel(output, signal)
        edge_x.append(output)
    edge_x = torch.stack(edge_x, dim=0)
    return edge_x, tp_kernel.Rs_out
Ejemplo n.º 2
0
def test_reduce_tensor_Levi_Civita_symbol():
    Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 1)])
    assert Rs == [(1, 0, 0)]
    r = o3.rand_angles()
    D = o3.irr_repr(1, *r)
    Q = Q.reshape(3, 3, 3)
    Q1 = torch.einsum('li,mj,nk,ijk', D, D, D, Q)
    assert (Q1 - Q).abs().max() < 1e-10
Ejemplo n.º 3
0
 def to_irrep_transformation(self):
     dim = self.tensor.dim()
     change = o3.kron(*[o3.xyz_to_irreducible_basis()] * dim)
     Rs = [(1, 1)]  # vectors
     old_indices = self.formula.split("=")[0]
     Rs_out, Q = rs.reduce_tensor(self.formula,
                                  **{i: Rs
                                     for i in old_indices})
     return Rs_out, torch.einsum('ab,bc->ac', Q,
                                 change.reshape(3**dim, 3**dim))
Ejemplo n.º 4
0
def test_reduce_tensor_antisymmetric_L2():
    Rs, Q = rs.reduce_tensor('ijk=-ikj=-jik', i=[(1, 2)])
    assert Rs[0] == (1, 1, 0)
    q = Q[:3].reshape(3, 5, 5, 5)

    r = o3.rand_angles()
    D1 = o3.irr_repr(1, *r)
    D2 = o3.irr_repr(2, *r)
    Q1 = torch.einsum('il,jm,kn,zijk->zlmn', D2, D2, D2, q)
    Q2 = torch.einsum('yz,zijk->yijk', D1, q)

    assert (Q1 - Q2).abs().max() < 1e-10
    assert (q + q.transpose(1, 2)).abs().max() < 1e-10
    assert (q + q.transpose(1, 3)).abs().max() < 1e-10
    assert (q + q.transpose(3, 2)).abs().max() < 1e-10
Ejemplo n.º 5
0
def test_reduce_tensor_elasticity_tensor():
    Rs, _Q = rs.reduce_tensor('ijkl=jikl=klij', i=[(1, 1)])
    assert rs.dim(Rs) == 21
Ejemplo n.º 6
0
def test_reduce_tensor_elasticity_tensor_parity():
    Rs, _Q = rs.reduce_tensor('ijkl=jikl=klij', i=[(1, 1, -1)])
    assert all(p == 1 for (_, _, p) in Rs)
    assert rs.dim(Rs) == 21