Exemplo n.º 1
0
def simple_energy(conf, charge_params, exclusion_idxs, charge_scales, cutoff):
    """
    Numerically stable implementation of the pairwise term:
    
    eij = qi*qj/dij

    """

    box = None
    # charges = params[param_idxs]
    charges = charge_params
    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)

    assert box is None

    dij = distance(ri, rj, box)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(conf.shape[0])
    qij = np.where(keep_mask, qij, np.zeros_like(qij))
    dij = np.where(keep_mask, dij, np.zeros_like(dij))
    eij = np.where(keep_mask, qij / dij,
                   np.zeros_like(dij))  # zero out diagonals

    # print(dij)

    if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij = eij*sw
        eij = np.where(dij > cutoff, np.zeros_like(eij), eij)

    src_idxs = exclusion_idxs[:, 0]
    dst_idxs = exclusion_idxs[:, 1]
    ri = conf[src_idxs]
    rj = conf[dst_idxs]
    dij = distance(ri, rj, box)

    qi = charges[src_idxs]
    qj = charges[dst_idxs]
    qij = np.multiply(qi, qj)

    scale_ij = charge_scales
    eij_exc = scale_ij * qij / dij

    if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij_exc = eij_exc*sw
        eij_exc = np.where(dij > cutoff, np.zeros_like(eij_exc), eij_exc)
        eij_exc = np.where(src_idxs == dst_idxs, np.zeros_like(eij_exc),
                           eij_exc)

    return np.sum(eij / 2) - np.sum(eij_exc)
Exemplo n.º 2
0
def harmonic_bond(conf, params, box, bond_idxs, param_idxs):
    """
    Compute the harmonic bond energy given a collection of molecules.

    This implements a harmonic angle potential: V(t) = k*(t - t0)^2 or V(t) = k*(cos(t)-cos(t0))^2

    Parameters:
    -----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    bond_idxs: [num_bonds, 2] np.array
        each element (src, dst) is a unique bond in the conformation

    param_idxs: [num_bonds, 2] np.array
        each element (k_idx, r_idx) maps into params for bond constants and ideal lengths

    """
    ci = conf[bond_idxs[:, 0]]
    cj = conf[bond_idxs[:, 1]]
    dij = distance(ci, cj, box)
    kbs = params[param_idxs[:, 0]]
    r0s = params[param_idxs[:, 1]]
    energy = np.sum(kbs / 2 * np.power(dij - r0s, 2.0))
    return energy
Exemplo n.º 3
0
def pairwise_energy(conf, box, charges, cutoff):
    """
    Numerically stable implementation of the pairwise term:
    
    eij = qi*qj/dij

    """


    qi = np.expand_dims(charges, 0) # (1, N)
    qj = np.expand_dims(charges, 1) # (N, 1)
    qij = np.multiply(qi, qj)
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj, box)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(conf.shape[0])
    qij = np.where(keep_mask, qij, np.zeros_like(qij))
    dij = np.where(keep_mask, dij, np.zeros_like(dij))
    eij = np.where(keep_mask, qij/dij, np.zeros_like(dij)) # zero out diagonals

    if cutoff is not None:
        eij = np.where(dij > cutoff, np.zeros_like(eij), eij)

    return eij
Exemplo n.º 4
0
def find_protein_pocket_atoms(conf, nha, search_radius):
    """
    Find atoms in the protein that are close to the binding pocket. This simply grabs the
    protein atoms that are within search_radius nm of each ligand atom.

    Parameters
    ----------
    conf: np.array [N,3]
        conformation of the ligand

    nha: int
        number of host atoms

    search_radius: float
        how far we search into the binding pocket.

    """
    ri = np.expand_dims(conf, axis=0)
    rj = np.expand_dims(conf, axis=1)
    dij = jax_utils.distance(ri, rj)
    pocket_atoms = set()

    for l_idx, dists in enumerate(dij[nha:]):
        nns = np.argsort(dists[:nha])
        for p_idx in nns:
            if dists[p_idx] < search_radius:
                pocket_atoms.add(p_idx)

    return list(pocket_atoms)
Exemplo n.º 5
0
def lennard_jones_exclusion(conf,
                            lj_params,
                            box,
                            exclusion_idxs,
                            lj_scales,
                            cutoff,
                            groups=None):

    # box = None
    # assert box is None

    assert exclusion_idxs.shape[1] == 2
    # assert exclusion_idxs.shape[0] == conf.shape[0]
    assert exclusion_idxs.shape[0] == lj_scales.shape[0]

    src_idxs = exclusion_idxs[:, 0]
    dst_idxs = exclusion_idxs[:, 1]
    ri = conf[src_idxs]
    rj = conf[dst_idxs]

    dij = distance(ri, rj, box)

    sig_params = lj_params[:, 0]
    sig_i = sig_params[src_idxs]
    sig_j = sig_params[dst_idxs]
    sig_ij = (sig_i + sig_j) / 2

    eps_params = lj_params[:, 1]
    eps_i = eps_params[src_idxs]
    eps_j = eps_params[dst_idxs]
    eps_ij = np.sqrt(eps_i * eps_j)
    eps_ij = np.where(eps_ij != 0, eps_ij, 0)  # (ytz): avoids nans

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij / dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    scale_ij = lj_scales
    eij_exc = scale_ij * 4 * eps_ij * (sig6 - 1.0) * sig6

    if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij_exc = eij_exc*sw
        eij_exc = np.where(dij > cutoff, np.zeros_like(eij_exc), eij_exc)
        eij_exc = np.where(src_idxs == dst_idxs, np.zeros_like(eij_exc),
                           eij_exc)

    # the exclusion energy is not divided by two.
    return np.sum(eij_exc)
Exemplo n.º 6
0
    def test_pairwise_periodic_distance(self):
        conf = np.array(
            [[-3.7431, 0.0007, 3.3896], [-2.2513, 0.2400, 3.0656],
             [-1.8353, -0.3114, 1.6756], [-0.3457, -0.0141, 1.3424],
             [0.0913, -0.5912, -0.0304], [1.5230, -0.3150, -0.3496],
             [2.4800, -1.3537, -0.3510], [3.8242, -1.1091, -0.6908],
             [4.2541, 0.2018, -1.0168], [3.2918, 1.2461, -1.0393],
             [1.9507, 0.9887, -0.6878], [3.6069, 2.5311, -1.4305],
             [4.6870, 2.6952, -2.3911], [5.9460, 1.9841, -1.7744],
             [5.6603, 0.4771, -1.4483], [6.7153, 0.0454, -0.5274],
             [8.0153, 0.3238, -0.7754], [8.3940, 1.0806, -1.9842],
             [7.3027, 2.0609, -2.5505], [9.0311, -0.1319, 0.1662],
             [4.2434, 2.1598, -3.7921], [4.9088, 4.2364, -2.4878],
             [4.6917, -2.1552, -0.7266], [-3.9733, 0.4081, 4.3758],
             [-3.9690, -1.0674, 3.3947], [-4.3790, 0.4951, 2.6522],
             [-1.6465, -0.2405, 3.8389], [-2.0559, 1.3147, 3.1027],
             [-1.9990, -1.3929, 1.6608], [-2.4698, 0.1371, 0.9054],
             [-0.1921, 1.0695, 1.3452], [0.2880, -0.4452, 2.1239],
             [-0.0916, -1.6686, -0.0213], [-0.5348, -0.1699, -0.8201],
             [2.2004, -2.3077, -0.1063], [1.2776, 1.7607, -0.6991],
             [6.1198, 2.5014, -0.8189], [5.7881, -0.1059, -2.3685],
             [6.4691, -0.4538, 0.2987], [9.3048, 1.6561, -1.8023],
             [8.6369, 0.3417, -2.7516], [7.6808, 3.0848, -2.5117],
             [7.1355, 1.8275, -3.6048], [8.8403, 0.2961, 1.1526],
             [10.0386, 0.1617, -0.1353], [9.0076, -1.2205, 0.2406],
             [4.1653, 1.0737, -3.8000], [4.9494, 2.4548, -4.5696],
             [3.2631, 2.5647, -4.0515], [3.9915, 4.7339, -2.8073],
             [5.1949, 4.6493, -1.5175], [5.6935, 4.4750, -3.2076],
             [4.1622, -2.9559, -0.5467]],
            dtype=np.float64)

        N = conf.shape[0]

        box = np.array([[1.3, 0.5, 0.6], [0.6, 1.2, 0.45], [0.4, 0.3, 1.2]],
                       dtype=np.float64)

        ri = np.expand_dims(conf, 0)
        rj = np.expand_dims(conf, 1)

        dij = jax_utils.distance(ri, rj, box)

        for i in range(N):
            for j in range(N):
                expected = reference_periodic_distance(conf[i], conf[j], box)
                np.testing.assert_array_almost_equal(dij[i][j], expected)
Exemplo n.º 7
0
def harmonic_bond(conf, params, box, lamb, bond_idxs):
    """
    Compute the harmonic bond energy given a collection of molecules.

    This implements a harmonic bond potential:
        V(conf) = \sum_bond kbs[bond] * (distance[bond] - r0s[bond])^2

    Parameters:
    -----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params, 2] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    lamb: float

    bond_idxs: [num_bonds, 2] np.array
        each element (src, dst) is a unique bond in the conformation

    Notes:
    ------
    * lamb argument is unused

    """
    assert params.shape == bond_idxs.shape

    ci = conf[bond_idxs[:, 0]]
    cj = conf[bond_idxs[:, 1]]
    dij = distance(ci, cj, box)
    kbs = params[:, 0]
    r0s = params[:, 1]

    energy = np.sum(kbs / 2 * np.power(dij - r0s, 2.0))
    return energy
Exemplo n.º 8
0
    def test_bonded_periodic_distance(self):
        conf = np.array(
            [
                [0.0637, 0.0126, 0.2203],  # C
                [1.0573, -0.2011, 1.2864],  # H
                [2.3928, 1.2209, -0.2230],  # H
                [-0.6891, 1.6983, 0.0780],  # H
                [-0.6312, -1.6261, -0.2601],  # H
            ],
            dtype=np.float64)

        box = np.array([[1.3, 0.5, 0.6], [0.6, 1.2, 0.45], [0.4, 0.3, 1.2]],
                       dtype=np.float64)

        src_idxs = [0, 1, 3, 2]
        dst_idxs = [1, 2, 0, 1]

        ri = conf[src_idxs]
        rj = conf[dst_idxs]

        dsts = jax_utils.distance(ri, rj, box)
        for idx, (i, j) in enumerate(zip(src_idxs, dst_idxs)):
            dij = reference_periodic_distance(conf[i], conf[j], box)
            np.testing.assert_array_almost_equal(dij, dsts[idx])
Exemplo n.º 9
0
def gbsa_obc(
        coords,
        # params,
        lamb,
        # box,
        charge_params,
        gb_params,
        # charge_idxs,
        # radii_idxs,
        # scale_idxs,
        alpha,
        beta,
        gamma,
        cutoff_radii,
        cutoff_force,
        lambda_plane_idxs,
        lambda_offset_idxs,
        dielectric_offset=0.009,
        surface_tension=28.3919551,
        solute_dielectric=1.0,
        solvent_dielectric=78.5,
        probe_radius=0.14):

    box = None

    assert cutoff_radii == cutoff_force

    coords_4d = convert_to_4d(coords, lamb, lambda_plane_idxs,
                              lambda_offset_idxs, cutoff_radii)

    N = len(charge_params)

    radii = gb_params[:, 0]
    scales = gb_params[:, 1]

    ri = np.expand_dims(coords_4d, 0)
    rj = np.expand_dims(coords_4d, 1)

    dij = distance(ri, rj, box)

    eye = np.eye(N, dtype=dij.dtype)

    r = dij + eye  # so I don't have divide-by-zero nonsense
    or1 = radii.reshape((N, 1)) - dielectric_offset
    or2 = radii.reshape((1, N)) - dielectric_offset
    sr2 = scales.reshape((1, N)) * or2

    L = np.maximum(or1, abs(r - sr2))
    U = r + sr2

    I = 1 / L - 1 / U + 0.25 * (r - sr2**2 / r) * (1 / (U**2) - 1 /
                                                   (L**2)) + 0.5 * np.log(
                                                       L / U) / r
    # handle the interior case
    I = np.where(or1 < (sr2 - r), I + 2 * (1 / or1 - 1 / L), I)
    I = step(r + sr2 - or1) * 0.5 * I  # note the extra 0.5 here
    I -= np.diag(np.diag(I))

    # switch I only for now
    # inner = (np.pi*np.power(dij,8))/(2*cutoff_radii)
    # sw = np.power(np.cos(inner), 2)
    # I = I*sw

    I = np.where(dij > cutoff_radii, 0, I)
    I = np.sum(I, axis=1)

    # okay, next compute born radii
    offset_radius = radii - dielectric_offset

    psi = I * offset_radius

    psi_coefficient = alpha
    psi2_coefficient = beta
    psi3_coefficient = gamma

    psi_term = (psi_coefficient * psi) - (psi2_coefficient *
                                          psi**2) + (psi3_coefficient * psi**3)

    B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii)

    E = 0.0
    # single particle
    # ACE
    E += np.sum(surface_tension * (radii + probe_radius)**2 * (radii / B)**6)

    # on-diagonal
    charges = charge_params

    E += np.sum(-0.5 * (1 / solute_dielectric - 1 / solvent_dielectric) *
                charges**2 / B)

    # particle pair
    f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B))))
    charge_products = np.outer(charges, charges)

    ixns = -(1 / solute_dielectric -
             1 / solvent_dielectric) * charge_products / f

    # sw = np.power(np.cos((np.pi*dij)/(2*cutoff_radii)), 2)
    # ixns = ixns*sw
    ixns = np.where(dij > cutoff_force, 0, ixns)

    E += np.sum(np.triu(ixns, k=1))

    return E
Exemplo n.º 10
0
 def urt(x, box):
     distance_matrix = distance(x, box)
     i, j = np.triu_indices(len(distance_matrix), k=1)
     return distance_matrix[i, j]
Exemplo n.º 11
0
def setup_core_restraints(k, alpha, count, conf, nha, core_atoms,
                          backbone_atoms, stage):
    """
    Setup core restraints

    Parameters
    ----------
    k: float
        Force constant of each restraint

    count: int
        Number of host atoms we restrain each guest_mol to

    nha: int
        Number of host atoms

    core_atoms: list of int
        atoms we're restraining. This is indexed by the total number of atoms in the system.

    backbone_atoms: list of backbone atoms

    stage: 0,1,2
        0 - attach restraint
        1 - decouple
        2 - detach restraint

    """
    ri = np.expand_dims(conf, axis=0)
    rj = np.expand_dims(conf, axis=1)
    dij = jax_utils.distance(ri, rj)
    all_nbs = []

    bond_idxs = []
    bond_params = []

    for l_idx, dists in enumerate(dij[nha:]):
        if l_idx in core_atoms:
            nns = np.argsort(dists[:nha])
            # restrain to count nearby backbone atoms to enable
            # side-chain sampling

            counter = 0

            for p_idx in nns:

                if counter == count:
                    break

                if p_idx in backbone_atoms:
                    a = alpha
                    b = dists[p_idx]
                    bond_params.append((k, b, a))
                    bond_idxs.append([l_idx + nha, p_idx])
                    counter += 1

            # corner case where we haven't found sufficient candidates
            if counter != count:
                raise Exception("Failed to find", count, "neighbors")

    bond_idxs = np.array(bond_idxs, dtype=np.int32)
    bond_params = np.array(bond_params, dtype=np.float64)

    B = bond_idxs.shape[0]

    # w = lambda*lambda_flags
    # w = 0 implies that restraints are on
    # w = +inf/-inf implies that restraints are off
    if stage == 0:
        lambda_flags = np.ones(B, dtype=np.int32)
    elif stage == 1:
        # fully interacting
        lambda_flags = np.zeros(B, dtype=np.int32)
    elif stage == 2:
        lambda_flags = np.ones(B, dtype=np.int32)

    return ('Restraint', (bond_idxs, bond_params, lambda_flags))
Exemplo n.º 12
0
def lennard_jones(conf, params, box, param_idxs, scale_matrix, cutoff=None):
    """
    Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining
    rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j).

    Parameters
    ----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    param_idxs: shape [num_atoms, 2] np.array
        each tuple (sig, eps) is used as part of the combining rules

    scale_matrix: shape [num_atoms, num_atoms] np.array
        scale mask denoting how we should scale interaction e[i,j].
        The elements should be between [0, 1]. If e[i,j] is 1 then the interaction
        is fully included, 0 implies it is discarded.

    cutoff: float
        Whether or not we apply cutoffs to the system. Any interactions
        greater than cutoff is fully discarded.
    
    """
    sig = params[param_idxs[:, 0]]
    eps = params[param_idxs[:, 1]]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = (sig_i + sig_j)/2
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)
    eps_ij = scale_matrix * np.sqrt(eps_i * eps_j)

    eps_ij_raw = eps_ij

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)

    dij = distance(ri, rj, box)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    keep_mask = scale_matrix > 0

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij/dij
    sig2 *= sig2
    sig6 = sig2*sig2*sig2

    energy = 4*eps_ij*(sig6-1.0)*sig6
    energy = np.where(keep_mask, energy, np.zeros_like(energy))

    # divide by two to deal with symmetry
    return np.sum(energy)/2
Exemplo n.º 13
0
def lennard_jones(conf, lj_params, cutoff, groups=None):
    """
    Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining
    rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j).

    Parameters
    ----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    param_idxs: shape [num_atoms, 2] np.array
        each tuple (sig, eps) is used as part of the combining rules

    scale_matrix: shape [num_atoms, num_atoms] np.array
        scale mask denoting how we should scale interaction e[i,j].
        The elements should be between [0, 1]. If e[i,j] is 1 then the interaction
        is fully included, 0 implies it is discarded.

    cutoff: float
        Whether or not we apply cutoffs to the system. Any interactions
        greater than cutoff is fully discarded.
    
    """
    box = None
    assert box is None

    sig = lj_params[:, 0]
    eps = lj_params[:, 1]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = (sig_i + sig_j)/2
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = np.sqrt(eps_i * eps_j)

    eps_ij_raw = eps_ij

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    gi = np.expand_dims(groups, axis=0)
    gj = np.expand_dims(groups, axis=1)
    gij = np.bitwise_and(gi, gj) > 0

    # print(gij)
    dij = distance(ri, rj, box, gij)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    N = conf.shape[0]
    keep_mask = np.ones((N,N)) - np.eye(N)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij/dij
    sig2 *= sig2
    sig6 = sig2*sig2*sig2

    eij = 4*eps_ij*(sig6-1.0)*sig6

    # if cutoff is not None:
        # sw = switch_fn(dij, cutoff)
        # eij = eij*sw

    eij = np.where(keep_mask, eij, np.zeros_like(eij))
    return np.sum(eij/2)
Exemplo n.º 14
0
def nonbonded_v3(
    conf,
    params,
    box,
    lamb,
    charge_rescale_mask,
    lj_rescale_mask,
    beta,
    cutoff,
    lambda_plane_idxs,
    lambda_offset_idxs,
    runtime_validate=True,
):
    """Lennard-Jones + Coulomb, with a few important twists:
    * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs
    * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter
    * Coulomb terms are multiplied by erfc(beta * distance)

    Parameters
    ----------
    conf : (N, 3) or (N, 4) np.array
        3D or 4D coordinates
        if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w)
            where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb)
    params : (N, 3) np.array
        columns [charges, sigmas, epsilons], one row per particle
    box : Optional 3x3 np.array
    lamb : float
    charge_rescale_mask : (N, N) np.array
        the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j]
    lj_rescale_mask : (N, N) np.array
        the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j]
    beta : float
        the charge product q_ij will be multiplied by erfc(beta*d_ij)
    cutoff : Optional float
        a pair of particles (i,j) will be considered non-interacting if the distance d_ij
        between their 4D coordinates exceeds cutoff
    lambda_plane_idxs : Optional (N,) np.array
    lambda_offset_idxs : Optional (N,) np.array
    runtime_validate: bool
        check whether beta is compatible with cutoff
        (if True, this function will currently not play nice with Jax JIT)
        TODO: is there a way to conditionally print a runtime warning inside
            of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError?

    Returns
    -------
    energy : float

    References
    ----------
    * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial
        dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750
    * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large
    systems" https://aip.scitation.org/doi/abs/10.1063/1.470117
        * Coulomb interactions are treated using the direct-space contribution from eq 2
    """
    if runtime_validate:
        assert (charge_rescale_mask == charge_rescale_mask.T).all()
        assert (lj_rescale_mask == lj_rescale_mask.T).all()

    N = conf.shape[0]

    if conf.shape[-1] == 3:
        conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                             cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        if box.shape[-1] == 3:
            box_4d = np.eye(4) * 1000
            box_4d = index_update(box_4d, index[:3, :3], box)
        else:
            box_4d = box
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    dij = distance(conf, box)

    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        if runtime_validate:
            validate_coulomb_cutoff(cutoff, beta, threshold=1e-2)
        eps_ij = np.where(dij < cutoff, eps_ij, 0)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, 0)
    eps_ij = np.where(keep_mask, eps_ij, 0)

    inv_dij = 1 / dij
    inv_dij = np.where(np.eye(N), 0, inv_dij)

    sig2 = sig_ij * inv_dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, 0)

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(N)
    qij = np.where(keep_mask, qij, 0)
    dij = np.where(keep_mask, dij, 0)

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) * inv_dij,
                          0)  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, 0, eij_charge)

    eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask

    return np.sum(eij_total / 2)
Exemplo n.º 15
0
def nonbonded_v3(conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask,
                 scales, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs):

    N = conf.shape[0]

    conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                         cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        box_4d = np.eye(4) * 1000
        box_4d = index_update(box_4d, index[:3, :3], box)
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj, box)

    N = conf.shape[0]
    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij / dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, np.zeros_like(eij_lj))

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(conf.shape[0])
    qij = np.where(keep_mask, qij, np.zeros_like(qij))
    dij = np.where(keep_mask, dij, np.zeros_like(dij))

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) / dij,
                          np.zeros_like(dij))  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, np.zeros_like(eij_charge),
                              eij_charge)

    eij_total = (eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask)

    return np.sum(eij_total / 2)
Exemplo n.º 16
0
def born_radii(conf, atomic_radii, scaled_radius_factor, dielectric_offset,
               alpha_obc, beta_obc, gamma_obc):
    """
    Compute the adjusted born radii of each atom. This is the first part of the GBSA calculation.

    Parameters
    ----------
    conf: np.array
        shape Nx3 matrix of geometric coordinates

    atomic_radii: np.array
        shape [N,] array of radius of each atom

    scaled_radius_factor: np.array
        shape [N,] array of adjusted shape factors for each atom.

    Returns
    -------
    np.array
        shape [N,] np.array of atomic radiis

    """
    num_atoms = conf.shape[0]

    r_i = np.expand_dims(conf, axis=0)
    r_j = np.expand_dims(conf, axis=1)
    d_ij = distance(r_i, r_j)

    oR = atomic_radii - dielectric_offset
    oRI = np.expand_dims(oR, axis=1)  # rows
    oRJ = np.expand_dims(oR, axis=0)  # columns
    sRJ = oRJ * scaled_radius_factor
    rSRJ = d_ij + sRJ

    # along the diagonal rSRJ < oRI, resulting in a mask whose
    # diagonals are strictly false.
    mask_final = np.less(oRI, rSRJ)

    d_ij_inv = 1 / d_ij
    # 1/d_ij has NaNs along diagonals so we need to zero it out
    keep_mask = 1 - np.eye(conf.shape[0])
    d_ij_inv = np.where(keep_mask, d_ij_inv, np.zeros_like(d_ij_inv))

    rfs = np.abs(d_ij - sRJ)
    l_ij = np.maximum(oRI, rfs)
    l_ij = 1 / l_ij
    u_ij = 1 / rSRJ
    l_ij2 = l_ij * l_ij
    u_ij2 = u_ij * u_ij
    ratio = np.log(u_ij / l_ij)
    term = l_ij - u_ij + 0.25 * d_ij * (u_ij2 - l_ij2) + (
        0.5 * d_ij_inv *
        ratio) + (0.25 * sRJ * sRJ * d_ij_inv) * (l_ij2 - u_ij2)
    term_masked = np.where(mask_final, term, np.zeros_like(term))

    summ = np.sum(term_masked, axis=-1)

    summ *= 0.5 * oR
    sum2 = summ * summ
    sum3 = summ * sum2
    tanhSum = np.tanh(alpha_obc * summ - beta_obc * sum2 + gamma_obc * sum3)

    return 1.0 / (1.0 / oR - tanhSum / atomic_radii)