def spherical_harmonics_xyz_cuda(ls, xyz): """ cuda version of spherical_harmonics_xyz """ from e3nn import cuda_rsh # pylint: disable=no-name-in-module, import-outside-toplevel *size, _ = xyz.size() xyz = xyz.reshape(-1, 3) lmax = max(ls) out = xyz.new_empty(((lmax + 1)**2, xyz.size(0))) # [ filters, batch_size] cuda_rsh.real_spherical_harmonics(out, xyz) # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined # h - halved norm_coef = [ elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3) ] if lmax % 2 == 0: norm_coef.extend([1.] * (2 * lmax + 1)) norm_coef = torch.tensor(norm_coef).to(out).unsqueeze(1) out.mul_(norm_coef) if ls != list(range(lmax + 1)): out = torch.cat([out[l**2:(l + 1)**2] for l in ls]) return out.T.reshape(*size, -1)
def spherical_harmonics_xyz_cuda(Rs, xyz): # pragma: no cover """ cuda version of spherical_harmonics_xyz """ from e3nn import cuda_rsh # pylint: disable=no-name-in-module, import-outside-toplevel Rs = rs.simplify(Rs) *size, _ = xyz.size() xyz = xyz.reshape(-1, 3) xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True) lmax = rs.lmax(Rs) out = xyz.new_empty(((lmax + 1)**2, xyz.size(0))) # [ filters, batch_size] cuda_rsh.real_spherical_harmonics(out, xyz) # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined # h - halved norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)] if lmax % 2 == 0: norm_coef.extend([1.] * (2 * lmax + 1)) norm_coef = out.new_tensor(norm_coef).unsqueeze(1) out.mul_(norm_coef) if not rs.are_equal(Rs, list(range(lmax + 1))): out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)]) return out.T.reshape(*size, out.shape[0])