def spherical_harmonics_xyz_cuda(Rs, xyz): # pragma: no cover """ cuda version of spherical_harmonics_xyz """ from e3nn import cuda_rsh # pylint: disable=no-name-in-module, import-outside-toplevel Rs = rs.simplify(Rs) *size, _ = xyz.size() xyz = xyz.reshape(-1, 3) xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True) lmax = rs.lmax(Rs) out = xyz.new_empty(((lmax + 1)**2, xyz.size(0))) # [ filters, batch_size] cuda_rsh.real_spherical_harmonics(out, xyz) # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined # h - halved norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)] if lmax % 2 == 0: norm_coef.extend([1.] * (2 * lmax + 1)) norm_coef = out.new_tensor(norm_coef).unsqueeze(1) out.mul_(norm_coef) if not rs.are_equal(Rs, list(range(lmax + 1))): out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)]) return out.T.reshape(*size, out.shape[0])
def test_mul_and_dot(): lmax = 4 signal1 = torch.zeros((lmax + 1)**2) signal2 = signal1.clone() signal1[0] = 1. signal2[3] = 1. sph1 = SphericalTensor(signal1) sph2 = SphericalTensor(signal2) new_sph = sph1 * sph2 assert rs.are_equal(new_sph.Rs, [(rs.mul_dim(sph1.Rs), 0, 0)]) sph1.dot(sph2)