示例#1
0
def spherical_harmonics_xyz_cuda(Rs, xyz):  # pragma: no cover
    """
    cuda version of spherical_harmonics_xyz
    """
    from e3nn import cuda_rsh  # pylint: disable=no-name-in-module, import-outside-toplevel

    Rs = rs.simplify(Rs)

    *size, _ = xyz.size()
    xyz = xyz.reshape(-1, 3)
    xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True)

    lmax = rs.lmax(Rs)
    out = xyz.new_empty(((lmax + 1)**2, xyz.size(0)))  # [ filters, batch_size]
    cuda_rsh.real_spherical_harmonics(out, xyz)

    # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined  # h - halved
    norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)]
    if lmax % 2 == 0:
        norm_coef.extend([1.] * (2 * lmax + 1))
    norm_coef = out.new_tensor(norm_coef).unsqueeze(1)
    out.mul_(norm_coef)

    if not rs.are_equal(Rs, list(range(lmax + 1))):
        out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)])

    return out.T.reshape(*size, out.shape[0])
示例#2
0
def test_mul_and_dot():
    lmax = 4
    signal1 = torch.zeros((lmax + 1)**2)
    signal2 = signal1.clone()
    signal1[0] = 1.
    signal2[3] = 1.
    sph1 = SphericalTensor(signal1)
    sph2 = SphericalTensor(signal2)

    new_sph = sph1 * sph2
    assert rs.are_equal(new_sph.Rs, [(rs.mul_dim(sph1.Rs), 0, 0)])

    sph1.dot(sph2)