예제 #1
0
def spherical_harmonics_xyz(Rs, xyz):
    """
    spherical harmonics

    :param Rs: list of L's
    :param xyz: tensor of shape [..., 3]
    :return: tensor of shape [..., m]
    """
    Rs = rs.simplify(Rs)

    if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(Rs) <= 10:  # pragma: no cover
        try:
            return spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    *size, _ = xyz.shape
    xyz = xyz.reshape(-1, 3)
    xyz = xyz / torch.norm(xyz, 2, dim=1, keepdim=True)

    # if z > x, rotate x-axis with z-axis
    s = xyz[:, 2].abs() > xyz[:, 0].abs()
    xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])

    alpha = torch.atan2(xyz[:, 1], xyz[:, 0])
    z = xyz[:, 2]
    y = (xyz[:, 0].pow(2) + xyz[:, 1].pow(2)).sqrt()

    sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y)

    # rotate back
    sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device)
    return sh.reshape(*size, sh.shape[1])
예제 #2
0
def spherical_harmonics_xyz_cuda(Rs, xyz):  # pragma: no cover
    """
    cuda version of spherical_harmonics_xyz
    """
    from e3nn import cuda_rsh  # pylint: disable=no-name-in-module, import-outside-toplevel

    Rs = rs.simplify(Rs)

    *size, _ = xyz.size()
    xyz = xyz.reshape(-1, 3)
    xyz = xyz / torch.norm(xyz, 2, -1, keepdim=True)

    lmax = rs.lmax(Rs)
    out = xyz.new_empty(((lmax + 1)**2, xyz.size(0)))  # [ filters, batch_size]
    cuda_rsh.real_spherical_harmonics(out, xyz)

    # (-1)^L same as (pi-theta) -> (-1)^(L+m) and 'quantum' norm (-1)^m combined  # h - halved
    norm_coef = [elem for lh in range((lmax + 1) // 2) for elem in [1.] * (4 * lh + 1) + [-1.] * (4 * lh + 3)]
    if lmax % 2 == 0:
        norm_coef.extend([1.] * (2 * lmax + 1))
    norm_coef = out.new_tensor(norm_coef).unsqueeze(1)
    out.mul_(norm_coef)

    if not rs.are_equal(Rs, list(range(lmax + 1))):
        out = torch.cat([out[l**2: (l + 1)**2] for mul, l, _ in Rs for _ in range(mul)])

    return out.T.reshape(*size, out.shape[0])
예제 #3
0
def spherical_harmonics_xyz(Rs, xyz, eps=0.):
    """
    spherical harmonics

    :param Rs: list of L's
    :param xyz: tensor of shape [..., 3]
    :param eps: epsilon for denominator of atan2
    :return: tensor of shape [..., m]

    The eps parameter is only to be used when backpropogating to coordinates xyz.
    To determine a stable eps value, we recommend benchmarking against numerical
    gradients before setting this parameter. Use the smallest epsilon that prevents NaNs.
    For some cases, we have used 1e-10. Your case may require a different value.
    Use this option with care.
    """

    if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(
            Rs) <= 10:  # pragma: no cover
        try:
            return spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    norm = torch.norm(xyz, 2, dim=-1, keepdim=True)
    xyz = xyz / (norm + eps)

    alpha = torch.atan2(xyz[..., 1], xyz[..., 0] + eps)  # [...]
    z = xyz[..., 2]  # [...]
    y = (xyz[..., 0].pow(2) + xyz[..., 1].pow(2) + eps).sqrt()  # [...]

    return spherical_harmonics_alpha_z_y(Rs, alpha, z, y)
예제 #4
0
def spherical_harmonics_alpha_z_y(Rs, alpha, z, y):
    """
    cpu version of spherical_harmonics_alpha_beta
    """
    Rs = rs.simplify(Rs)
    sha = spherical_harmonics_alpha(rs.lmax(Rs), alpha.flatten())  # [z, m]
    shz = spherical_harmonics_z(Rs, z.flatten(), y.flatten())  # [z, l * m]
    out = mul_m_lm(Rs, sha, shz)
    return out.reshape(alpha.shape + (shz.shape[1],))
예제 #5
0
def spherical_harmonics_xyz(Rs, xyz, normalization='none'):
    """
    spherical harmonics

    :param Rs: list of L's
    :param xyz: tensor of shape [..., 3]
    :return: tensor of shape [..., m]
    """
    Rs = rs.simplify(Rs)
    *size, _ = xyz.shape
    xyz = xyz.reshape(-1, 3)
    d = torch.norm(xyz, 2, dim=1)
    xyz = xyz[d > 0]
    xyz = xyz / d[d > 0, None]

    sh = None

    if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(
            Rs) <= 10:  # pragma: no cover
        try:
            sh = _spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    if sh is None:
        # if z > x, rotate x-axis with z-axis
        s = xyz[:, 2].abs() > xyz[:, 0].abs()
        xyz[s] = xyz[s] @ xyz.new_tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]])

        alpha = torch.atan2(xyz[:, 1], xyz[:, 0])
        z = xyz[:, 2]
        y = xyz[:, :2].norm(dim=1)

        sh = spherical_harmonics_alpha_z_y(Rs, alpha, z, y)

        # rotate back
        sh[s] = sh[s] @ _rep_zx(tuple(Rs), xyz.dtype, xyz.device)

    if len(d) > len(sh):
        out = sh.new_zeros(len(d), sh.shape[1])
        out[d == 0] = math.sqrt(1 / (4 * math.pi)) * torch.cat([
            sh.new_ones(1) if l == 0 else sh.new_zeros(2 * l + 1)
            for mul, l, p in Rs for _ in range(mul)
        ])
        out[d > 0] = sh
        sh = out

    if normalization == 'component':
        sh.mul_(math.sqrt(4 * math.pi))
    if normalization == 'norm':
        sh.mul_(
            torch.cat([
                math.sqrt(4 * math.pi / (2 * l + 1)) * sh.new_ones(2 * l + 1)
                for mul, l, p in Rs for _ in range(mul)
            ]))
    return sh.reshape(*size, sh.shape[1])
예제 #6
0
def spherical_harmonics_xyz(Rs, xyz):
    """
    spherical harmonics

    :param Rs: list of L's
    :param xyz: tensor of shape [..., 3]
    :return: tensor of shape [..., m]
    """

    if xyz.device.type == 'cuda' and not xyz.requires_grad and rs.lmax(
            Rs) <= 10:  # pragma: no cover
        try:
            return spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    xyz = xyz / torch.norm(xyz, 2, dim=-1, keepdim=True)

    alpha = torch.atan2(xyz[..., 1], xyz[..., 0])  # [...]
    z = xyz[..., 2]  # [...]
    y = (xyz[..., 0].pow(2) + xyz[..., 1].pow(2)).sqrt()  # [...]

    return spherical_harmonics_alpha_z_y(Rs, alpha, z, y)
예제 #7
0
def spherical_harmonics_alpha_beta(Rs, alpha, beta):
    """
    spherical harmonics

    :param Rs: list of L's
    :param alpha: float or tensor of shape [...]
    :param beta: float or tensor of shape [...]
    :return: tensor of shape [..., m]
    """
    if alpha.device.type == 'cuda' and beta.device.type == 'cuda' and not alpha.requires_grad and not beta.requires_grad and rs.lmax(Rs) <= 10:  # pragma: no cover
        xyz = torch.stack([beta.sin() * alpha.cos(), beta.sin() * alpha.sin(), beta.cos()], dim=-1)
        try:
            return spherical_harmonics_xyz_cuda(Rs, xyz)
        except ImportError:
            pass

    return spherical_harmonics_alpha_z_y(Rs, alpha, beta.cos(), beta.sin().abs())