def spherical_harmonics_xyz(Rs, xyz): """ spherical harmonics :param Rs: list of L's :param xyz: tensor of shape [..., 3] :return: tensor of shape [..., m] """ Rs = rs.simplify(Rs) if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(Rs) <= 10: # pragma: no cover try: return spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass *size, _ = xyz.shape xyz = xyz.reshape(-1, 3) xyz = xyz / torch.norm(xyz, 2, dim=1, keepdim=True) # if z > x, rotate x-axis with z-axis s = xyz[:, 2].abs() > xyz[:, 0].abs() xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]) alpha = torch.atan2(xyz[:, 1], xyz[:, 0]) z = xyz[:, 2] y = (xyz[:, 0].pow(2) + xyz[:, 1].pow(2)).sqrt() sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y) # rotate back sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device) return sh.reshape(*size, sh.shape[1])
def spherical_harmonics_xyz_cuda(Rs, xyz): # pragma: no cover """ cuda version of spherical_harmonics_xyz """ from e3nn import cuda_rsh # pylint: disable=no-name-in-module, import-outside-toplevel Rs = rs.simplify(Rs) *size, _ = xyz.size() xyz = xyz.reshape(-1, 3) xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True) lmax = rs.lmax(Rs) out = xyz.new_empty(((lmax + 1)**2, xyz.size(0))) # [ filters, batch_size] cuda_rsh.real_spherical_harmonics(out, xyz) # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined # h - halved norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)] if lmax % 2 == 0: norm_coef.extend([1.] * (2 * lmax + 1)) norm_coef = out.new_tensor(norm_coef).unsqueeze(1) out.mul_(norm_coef) if not rs.are_equal(Rs, list(range(lmax + 1))): out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)]) return out.T.reshape(*size, out.shape[0])
def spherical_harmonics_xyz(Rs, xyz, eps=0.): """ spherical harmonics :param Rs: list of L's :param xyz: tensor of shape [..., 3] :param eps: epsilon for denominator of atan2 :return: tensor of shape [..., m] The eps parameter is only to be used when backpropogating to coordinates xyz. To determine a stable eps value, we recommend benchmarking against numerical gradients before setting this parameter. Use the smallest epsilon that prevents NaNs. For some cases, we have used 1e-10. Your case may require a different value. Use this option with care. """ if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax( Rs) <= 10: # pragma: no cover try: return spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass norm = torch.norm(xyz, 2, dim=-1, keepdim=True) xyz = xyz / (norm + eps) alpha = torch.atan2(xyz[..., 1], xyz[..., 0] + eps) # [...] z = xyz[..., 2] # [...] y = (xyz[..., 0].pow(2) + xyz[..., 1].pow(2) + eps).sqrt() # [...] return spherical_harmonics_alpha_z_y(Rs, alpha, z, y)
def spherical_harmonics_alpha_z_y(Rs, alpha, z, y): """ cpu version of spherical_harmonics_alpha_beta """ Rs = rs.simplify(Rs) sha = spherical_harmonics_alpha(rs.lmax(Rs), alpha.flatten()) # [z, m] shz = spherical_harmonics_z(Rs, z.flatten(), y.flatten()) # [z, l * m] out = mul_m_lm(Rs, sha, shz) return out.reshape(alpha.shape + (shz.shape[1],))
def spherical_harmonics_xyz(Rs, xyz, normalization='none'): """ spherical harmonics :param Rs: list of L's :param xyz: tensor of shape [..., 3] :return: tensor of shape [..., m] """ Rs = rs.simplify(Rs) *size, _ = xyz.shape xyz = xyz.reshape(-1, 3) d = torch.norm(xyz, 2, dim=1) xyz = xyz[d > 0] xyz = xyz / d[d > 0, None] sh = None if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax( Rs) <= 10: # pragma: no cover try: sh = _spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass if sh is None: # if z > x, rotate x-axis with z-axis s = xyz[:, 2].abs() > xyz[:, 0].abs() xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]) alpha = torch.atan2(xyz[:, 1], xyz[:, 0]) z = xyz[:, 2] y = xyz[:, :2].norm(dim=1) sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y) # rotate back sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device) if len(d) > len(sh): out = sh.new_zeros(len(d), sh.shape[1]) out[d == 0] = math.sqrt(1 / (4 * math.pi)) * torch.cat([ sh.new_ones(1) if l == 0 else sh.new_zeros(2 * l + 1) for mul, l, p in Rs for _ in range(mul) ]) out[d > 0] = sh sh = out if normalization == 'component': sh.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': sh.mul_( torch.cat([ math.sqrt(4 * math.pi / (2 * l + 1)) * sh.new_ones(2 * l + 1) for mul, l, p in Rs for _ in range(mul) ])) return sh.reshape(*size, sh.shape[1])
def spherical_harmonics_xyz(Rs, xyz): """ spherical harmonics :param Rs: list of L's :param xyz: tensor of shape [..., 3] :return: tensor of shape [..., m] """ if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax( Rs) <= 10: # pragma: no cover try: return spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass xyz = xyz / torch.norm(xyz, 2, dim=-1, keepdim=True) alpha = torch.atan2(xyz[..., 1], xyz[..., 0]) # [...] z = xyz[..., 2] # [...] y = (xyz[..., 0].pow(2) + xyz[..., 1].pow(2)).sqrt() # [...] return spherical_harmonics_alpha_z_y(Rs, alpha, z, y)
def spherical_harmonics_alpha_beta(Rs, alpha, beta): """ spherical harmonics :param Rs: list of L's :param alpha: float or tensor of shape [...] :param beta: float or tensor of shape [...] :return: tensor of shape [..., m] """ if alpha.device.type == 'cuda' and beta.device.type == 'cuda' and not alpha.requires_grad and not beta.requires_grad and rs.lmax(Rs) <= 10: # pragma: no cover xyz = torch.stack([beta.sin() * alpha.cos(), beta.sin() * alpha.sin(), beta.cos()], dim=-1) try: return spherical_harmonics_xyz_cuda(Rs, xyz) except ImportError: pass return spherical_harmonics_alpha_z_y(Rs, alpha, beta.cos(), beta.sin().abs())