def simple_energy(conf, charge_params, exclusion_idxs, charge_scales, cutoff): """ Numerically stable implementation of the pairwise term: eij = qi*qj/dij """ box = None # charges = params[param_idxs] charges = charge_params qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) assert box is None dij = distance(ri, rj, box) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(conf.shape[0]) qij = np.where(keep_mask, qij, np.zeros_like(qij)) dij = np.where(keep_mask, dij, np.zeros_like(dij)) eij = np.where(keep_mask, qij / dij, np.zeros_like(dij)) # zero out diagonals # print(dij) if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij = eij*sw eij = np.where(dij > cutoff, np.zeros_like(eij), eij) src_idxs = exclusion_idxs[:, 0] dst_idxs = exclusion_idxs[:, 1] ri = conf[src_idxs] rj = conf[dst_idxs] dij = distance(ri, rj, box) qi = charges[src_idxs] qj = charges[dst_idxs] qij = np.multiply(qi, qj) scale_ij = charge_scales eij_exc = scale_ij * qij / dij if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij_exc = eij_exc*sw eij_exc = np.where(dij > cutoff, np.zeros_like(eij_exc), eij_exc) eij_exc = np.where(src_idxs == dst_idxs, np.zeros_like(eij_exc), eij_exc) return np.sum(eij / 2) - np.sum(eij_exc)
def harmonic_bond(conf, params, box, bond_idxs, param_idxs): """ Compute the harmonic bond energy given a collection of molecules. This implements a harmonic angle potential: V(t) = k*(t - t0)^2 or V(t) = k*(cos(t)-cos(t0))^2 Parameters: ----------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params,] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None bond_idxs: [num_bonds, 2] np.array each element (src, dst) is a unique bond in the conformation param_idxs: [num_bonds, 2] np.array each element (k_idx, r_idx) maps into params for bond constants and ideal lengths """ ci = conf[bond_idxs[:, 0]] cj = conf[bond_idxs[:, 1]] dij = distance(ci, cj, box) kbs = params[param_idxs[:, 0]] r0s = params[param_idxs[:, 1]] energy = np.sum(kbs / 2 * np.power(dij - r0s, 2.0)) return energy
def pairwise_energy(conf, box, charges, cutoff): """ Numerically stable implementation of the pairwise term: eij = qi*qj/dij """ qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) dij = distance(ri, rj, box) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(conf.shape[0]) qij = np.where(keep_mask, qij, np.zeros_like(qij)) dij = np.where(keep_mask, dij, np.zeros_like(dij)) eij = np.where(keep_mask, qij/dij, np.zeros_like(dij)) # zero out diagonals if cutoff is not None: eij = np.where(dij > cutoff, np.zeros_like(eij), eij) return eij
def find_protein_pocket_atoms(conf, nha, search_radius): """ Find atoms in the protein that are close to the binding pocket. This simply grabs the protein atoms that are within search_radius nm of each ligand atom. Parameters ---------- conf: np.array [N,3] conformation of the ligand nha: int number of host atoms search_radius: float how far we search into the binding pocket. """ ri = np.expand_dims(conf, axis=0) rj = np.expand_dims(conf, axis=1) dij = jax_utils.distance(ri, rj) pocket_atoms = set() for l_idx, dists in enumerate(dij[nha:]): nns = np.argsort(dists[:nha]) for p_idx in nns: if dists[p_idx] < search_radius: pocket_atoms.add(p_idx) return list(pocket_atoms)
def lennard_jones_exclusion(conf, lj_params, box, exclusion_idxs, lj_scales, cutoff, groups=None): # box = None # assert box is None assert exclusion_idxs.shape[1] == 2 # assert exclusion_idxs.shape[0] == conf.shape[0] assert exclusion_idxs.shape[0] == lj_scales.shape[0] src_idxs = exclusion_idxs[:, 0] dst_idxs = exclusion_idxs[:, 1] ri = conf[src_idxs] rj = conf[dst_idxs] dij = distance(ri, rj, box) sig_params = lj_params[:, 0] sig_i = sig_params[src_idxs] sig_j = sig_params[dst_idxs] sig_ij = (sig_i + sig_j) / 2 eps_params = lj_params[:, 1] eps_i = eps_params[src_idxs] eps_j = eps_params[dst_idxs] eps_ij = np.sqrt(eps_i * eps_j) eps_ij = np.where(eps_ij != 0, eps_ij, 0) # (ytz): avoids nans if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij / dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 scale_ij = lj_scales eij_exc = scale_ij * 4 * eps_ij * (sig6 - 1.0) * sig6 if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij_exc = eij_exc*sw eij_exc = np.where(dij > cutoff, np.zeros_like(eij_exc), eij_exc) eij_exc = np.where(src_idxs == dst_idxs, np.zeros_like(eij_exc), eij_exc) # the exclusion energy is not divided by two. return np.sum(eij_exc)
def test_pairwise_periodic_distance(self): conf = np.array( [[-3.7431, 0.0007, 3.3896], [-2.2513, 0.2400, 3.0656], [-1.8353, -0.3114, 1.6756], [-0.3457, -0.0141, 1.3424], [0.0913, -0.5912, -0.0304], [1.5230, -0.3150, -0.3496], [2.4800, -1.3537, -0.3510], [3.8242, -1.1091, -0.6908], [4.2541, 0.2018, -1.0168], [3.2918, 1.2461, -1.0393], [1.9507, 0.9887, -0.6878], [3.6069, 2.5311, -1.4305], [4.6870, 2.6952, -2.3911], [5.9460, 1.9841, -1.7744], [5.6603, 0.4771, -1.4483], [6.7153, 0.0454, -0.5274], [8.0153, 0.3238, -0.7754], [8.3940, 1.0806, -1.9842], [7.3027, 2.0609, -2.5505], [9.0311, -0.1319, 0.1662], [4.2434, 2.1598, -3.7921], [4.9088, 4.2364, -2.4878], [4.6917, -2.1552, -0.7266], [-3.9733, 0.4081, 4.3758], [-3.9690, -1.0674, 3.3947], [-4.3790, 0.4951, 2.6522], [-1.6465, -0.2405, 3.8389], [-2.0559, 1.3147, 3.1027], [-1.9990, -1.3929, 1.6608], [-2.4698, 0.1371, 0.9054], [-0.1921, 1.0695, 1.3452], [0.2880, -0.4452, 2.1239], [-0.0916, -1.6686, -0.0213], [-0.5348, -0.1699, -0.8201], [2.2004, -2.3077, -0.1063], [1.2776, 1.7607, -0.6991], [6.1198, 2.5014, -0.8189], [5.7881, -0.1059, -2.3685], [6.4691, -0.4538, 0.2987], [9.3048, 1.6561, -1.8023], [8.6369, 0.3417, -2.7516], [7.6808, 3.0848, -2.5117], [7.1355, 1.8275, -3.6048], [8.8403, 0.2961, 1.1526], [10.0386, 0.1617, -0.1353], [9.0076, -1.2205, 0.2406], [4.1653, 1.0737, -3.8000], [4.9494, 2.4548, -4.5696], [3.2631, 2.5647, -4.0515], [3.9915, 4.7339, -2.8073], [5.1949, 4.6493, -1.5175], [5.6935, 4.4750, -3.2076], [4.1622, -2.9559, -0.5467]], dtype=np.float64) N = conf.shape[0] box = np.array([[1.3, 0.5, 0.6], [0.6, 1.2, 0.45], [0.4, 0.3, 1.2]], dtype=np.float64) ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) dij = jax_utils.distance(ri, rj, box) for i in range(N): for j in range(N): expected = reference_periodic_distance(conf[i], conf[j], box) np.testing.assert_array_almost_equal(dij[i][j], expected)
def harmonic_bond(conf, params, box, lamb, bond_idxs): """ Compute the harmonic bond energy given a collection of molecules. This implements a harmonic bond potential: V(conf) = \sum_bond kbs[bond] * (distance[bond] - r0s[bond])^2 Parameters: ----------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params, 2] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None lamb: float bond_idxs: [num_bonds, 2] np.array each element (src, dst) is a unique bond in the conformation Notes: ------ * lamb argument is unused """ assert params.shape == bond_idxs.shape ci = conf[bond_idxs[:, 0]] cj = conf[bond_idxs[:, 1]] dij = distance(ci, cj, box) kbs = params[:, 0] r0s = params[:, 1] energy = np.sum(kbs / 2 * np.power(dij - r0s, 2.0)) return energy
def test_bonded_periodic_distance(self): conf = np.array( [ [0.0637, 0.0126, 0.2203], # C [1.0573, -0.2011, 1.2864], # H [2.3928, 1.2209, -0.2230], # H [-0.6891, 1.6983, 0.0780], # H [-0.6312, -1.6261, -0.2601], # H ], dtype=np.float64) box = np.array([[1.3, 0.5, 0.6], [0.6, 1.2, 0.45], [0.4, 0.3, 1.2]], dtype=np.float64) src_idxs = [0, 1, 3, 2] dst_idxs = [1, 2, 0, 1] ri = conf[src_idxs] rj = conf[dst_idxs] dsts = jax_utils.distance(ri, rj, box) for idx, (i, j) in enumerate(zip(src_idxs, dst_idxs)): dij = reference_periodic_distance(conf[i], conf[j], box) np.testing.assert_array_almost_equal(dij, dsts[idx])
def gbsa_obc( coords, # params, lamb, # box, charge_params, gb_params, # charge_idxs, # radii_idxs, # scale_idxs, alpha, beta, gamma, cutoff_radii, cutoff_force, lambda_plane_idxs, lambda_offset_idxs, dielectric_offset=0.009, surface_tension=28.3919551, solute_dielectric=1.0, solvent_dielectric=78.5, probe_radius=0.14): box = None assert cutoff_radii == cutoff_force coords_4d = convert_to_4d(coords, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff_radii) N = len(charge_params) radii = gb_params[:, 0] scales = gb_params[:, 1] ri = np.expand_dims(coords_4d, 0) rj = np.expand_dims(coords_4d, 1) dij = distance(ri, rj, box) eye = np.eye(N, dtype=dij.dtype) r = dij + eye # so I don't have divide-by-zero nonsense or1 = radii.reshape((N, 1)) - dielectric_offset or2 = radii.reshape((1, N)) - dielectric_offset sr2 = scales.reshape((1, N)) * or2 L = np.maximum(or1, abs(r - sr2)) U = r + sr2 I = 1 / L - 1 / U + 0.25 * (r - sr2**2 / r) * (1 / (U**2) - 1 / (L**2)) + 0.5 * np.log( L / U) / r # handle the interior case I = np.where(or1 < (sr2 - r), I + 2 * (1 / or1 - 1 / L), I) I = step(r + sr2 - or1) * 0.5 * I # note the extra 0.5 here I -= np.diag(np.diag(I)) # switch I only for now # inner = (np.pi*np.power(dij,8))/(2*cutoff_radii) # sw = np.power(np.cos(inner), 2) # I = I*sw I = np.where(dij > cutoff_radii, 0, I) I = np.sum(I, axis=1) # okay, next compute born radii offset_radius = radii - dielectric_offset psi = I * offset_radius psi_coefficient = alpha psi2_coefficient = beta psi3_coefficient = gamma psi_term = (psi_coefficient * psi) - (psi2_coefficient * psi**2) + (psi3_coefficient * psi**3) B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii) E = 0.0 # single particle # ACE E += np.sum(surface_tension * (radii + probe_radius)**2 * (radii / B)**6) # on-diagonal charges = charge_params E += np.sum(-0.5 * (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 / B) # particle pair f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B)))) charge_products = np.outer(charges, charges) ixns = -(1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f # sw = np.power(np.cos((np.pi*dij)/(2*cutoff_radii)), 2) # ixns = ixns*sw ixns = np.where(dij > cutoff_force, 0, ixns) E += np.sum(np.triu(ixns, k=1)) return E
def urt(x, box): distance_matrix = distance(x, box) i, j = np.triu_indices(len(distance_matrix), k=1) return distance_matrix[i, j]
def setup_core_restraints(k, alpha, count, conf, nha, core_atoms, backbone_atoms, stage): """ Setup core restraints Parameters ---------- k: float Force constant of each restraint count: int Number of host atoms we restrain each guest_mol to nha: int Number of host atoms core_atoms: list of int atoms we're restraining. This is indexed by the total number of atoms in the system. backbone_atoms: list of backbone atoms stage: 0,1,2 0 - attach restraint 1 - decouple 2 - detach restraint """ ri = np.expand_dims(conf, axis=0) rj = np.expand_dims(conf, axis=1) dij = jax_utils.distance(ri, rj) all_nbs = [] bond_idxs = [] bond_params = [] for l_idx, dists in enumerate(dij[nha:]): if l_idx in core_atoms: nns = np.argsort(dists[:nha]) # restrain to count nearby backbone atoms to enable # side-chain sampling counter = 0 for p_idx in nns: if counter == count: break if p_idx in backbone_atoms: a = alpha b = dists[p_idx] bond_params.append((k, b, a)) bond_idxs.append([l_idx + nha, p_idx]) counter += 1 # corner case where we haven't found sufficient candidates if counter != count: raise Exception("Failed to find", count, "neighbors") bond_idxs = np.array(bond_idxs, dtype=np.int32) bond_params = np.array(bond_params, dtype=np.float64) B = bond_idxs.shape[0] # w = lambda*lambda_flags # w = 0 implies that restraints are on # w = +inf/-inf implies that restraints are off if stage == 0: lambda_flags = np.ones(B, dtype=np.int32) elif stage == 1: # fully interacting lambda_flags = np.zeros(B, dtype=np.int32) elif stage == 2: lambda_flags = np.ones(B, dtype=np.int32) return ('Restraint', (bond_idxs, bond_params, lambda_flags))
def lennard_jones(conf, params, box, param_idxs, scale_matrix, cutoff=None): """ Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j). Parameters ---------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params,] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None param_idxs: shape [num_atoms, 2] np.array each tuple (sig, eps) is used as part of the combining rules scale_matrix: shape [num_atoms, num_atoms] np.array scale mask denoting how we should scale interaction e[i,j]. The elements should be between [0, 1]. If e[i,j] is 1 then the interaction is fully included, 0 implies it is discarded. cutoff: float Whether or not we apply cutoffs to the system. Any interactions greater than cutoff is fully discarded. """ sig = params[param_idxs[:, 0]] eps = params[param_idxs[:, 1]] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = (sig_i + sig_j)/2 sig_ij_raw = sig_ij eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = scale_matrix * np.sqrt(eps_i * eps_j) eps_ij_raw = eps_ij ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) dij = distance(ri, rj, box) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) keep_mask = scale_matrix > 0 # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij)) eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij/dij sig2 *= sig2 sig6 = sig2*sig2*sig2 energy = 4*eps_ij*(sig6-1.0)*sig6 energy = np.where(keep_mask, energy, np.zeros_like(energy)) # divide by two to deal with symmetry return np.sum(energy)/2
def lennard_jones(conf, lj_params, cutoff, groups=None): """ Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j). Parameters ---------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params,] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None param_idxs: shape [num_atoms, 2] np.array each tuple (sig, eps) is used as part of the combining rules scale_matrix: shape [num_atoms, num_atoms] np.array scale mask denoting how we should scale interaction e[i,j]. The elements should be between [0, 1]. If e[i,j] is 1 then the interaction is fully included, 0 implies it is discarded. cutoff: float Whether or not we apply cutoffs to the system. Any interactions greater than cutoff is fully discarded. """ box = None assert box is None sig = lj_params[:, 0] eps = lj_params[:, 1] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = (sig_i + sig_j)/2 sig_ij_raw = sig_ij eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = np.sqrt(eps_i * eps_j) eps_ij_raw = eps_ij ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) gi = np.expand_dims(groups, axis=0) gj = np.expand_dims(groups, axis=1) gij = np.bitwise_and(gi, gj) > 0 # print(gij) dij = distance(ri, rj, box, gij) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) N = conf.shape[0] keep_mask = np.ones((N,N)) - np.eye(N) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij)) eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij/dij sig2 *= sig2 sig6 = sig2*sig2*sig2 eij = 4*eps_ij*(sig6-1.0)*sig6 # if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij = eij*sw eij = np.where(keep_mask, eij, np.zeros_like(eij)) return np.sum(eij/2)
def nonbonded_v3( conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs, runtime_validate=True, ): """Lennard-Jones + Coulomb, with a few important twists: * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter * Coulomb terms are multiplied by erfc(beta * distance) Parameters ---------- conf : (N, 3) or (N, 4) np.array 3D or 4D coordinates if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w) where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb) params : (N, 3) np.array columns [charges, sigmas, epsilons], one row per particle box : Optional 3x3 np.array lamb : float charge_rescale_mask : (N, N) np.array the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j] lj_rescale_mask : (N, N) np.array the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j] beta : float the charge product q_ij will be multiplied by erfc(beta*d_ij) cutoff : Optional float a pair of particles (i,j) will be considered non-interacting if the distance d_ij between their 4D coordinates exceeds cutoff lambda_plane_idxs : Optional (N,) np.array lambda_offset_idxs : Optional (N,) np.array runtime_validate: bool check whether beta is compatible with cutoff (if True, this function will currently not play nice with Jax JIT) TODO: is there a way to conditionally print a runtime warning inside of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError? Returns ------- energy : float References ---------- * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750 * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large systems" https://aip.scitation.org/doi/abs/10.1063/1.470117 * Coulomb interactions are treated using the direct-space contribution from eq 2 """ if runtime_validate: assert (charge_rescale_mask == charge_rescale_mask.T).all() assert (lj_rescale_mask == lj_rescale_mask.T).all() N = conf.shape[0] if conf.shape[-1] == 3: conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) # make 4th dimension of box large enough so its roughly aperiodic if box is not None: if box.shape[-1] == 3: box_4d = np.eye(4) * 1000 box_4d = index_update(box_4d, index[:3, :3], box) else: box_4d = box else: box_4d = None box = box_4d charges = params[:, 0] sig = params[:, 1] eps = params[:, 2] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = sig_i + sig_j eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = eps_i * eps_j dij = distance(conf, box) keep_mask = np.ones((N, N)) - np.eye(N) keep_mask = np.where(eps_ij != 0, keep_mask, 0) if cutoff is not None: if runtime_validate: validate_coulomb_cutoff(cutoff, beta, threshold=1e-2) eps_ij = np.where(dij < cutoff, eps_ij, 0) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, 0) eps_ij = np.where(keep_mask, eps_ij, 0) inv_dij = 1 / dij inv_dij = np.where(np.eye(N), 0, inv_dij) sig2 = sig_ij * inv_dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6 eij_lj = np.where(keep_mask, eij_lj, 0) qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(N) qij = np.where(keep_mask, qij, 0) dij = np.where(keep_mask, dij, 0) # funny enough lim_{x->0} erfc(x)/x = 0 eij_charge = np.where(keep_mask, qij * erfc(beta * dij) * inv_dij, 0) # zero out diagonals if cutoff is not None: eij_charge = np.where(dij > cutoff, 0, eij_charge) eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask return np.sum(eij_total / 2)
def nonbonded_v3(conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask, scales, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs): N = conf.shape[0] conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) # make 4th dimension of box large enough so its roughly aperiodic if box is not None: box_4d = np.eye(4) * 1000 box_4d = index_update(box_4d, index[:3, :3], box) else: box_4d = None box = box_4d charges = params[:, 0] sig = params[:, 1] eps = params[:, 2] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = sig_i + sig_j sig_ij_raw = sig_ij eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = eps_i * eps_j ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) dij = distance(ri, rj, box) N = conf.shape[0] keep_mask = np.ones((N, N)) - np.eye(N) keep_mask = np.where(eps_ij != 0, keep_mask, 0) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij)) eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij / dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6 eij_lj = np.where(keep_mask, eij_lj, np.zeros_like(eij_lj)) qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(conf.shape[0]) qij = np.where(keep_mask, qij, np.zeros_like(qij)) dij = np.where(keep_mask, dij, np.zeros_like(dij)) # funny enough lim_{x->0} erfc(x)/x = 0 eij_charge = np.where(keep_mask, qij * erfc(beta * dij) / dij, np.zeros_like(dij)) # zero out diagonals if cutoff is not None: eij_charge = np.where(dij > cutoff, np.zeros_like(eij_charge), eij_charge) eij_total = (eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask) return np.sum(eij_total / 2)
def born_radii(conf, atomic_radii, scaled_radius_factor, dielectric_offset, alpha_obc, beta_obc, gamma_obc): """ Compute the adjusted born radii of each atom. This is the first part of the GBSA calculation. Parameters ---------- conf: np.array shape Nx3 matrix of geometric coordinates atomic_radii: np.array shape [N,] array of radius of each atom scaled_radius_factor: np.array shape [N,] array of adjusted shape factors for each atom. Returns ------- np.array shape [N,] np.array of atomic radiis """ num_atoms = conf.shape[0] r_i = np.expand_dims(conf, axis=0) r_j = np.expand_dims(conf, axis=1) d_ij = distance(r_i, r_j) oR = atomic_radii - dielectric_offset oRI = np.expand_dims(oR, axis=1) # rows oRJ = np.expand_dims(oR, axis=0) # columns sRJ = oRJ * scaled_radius_factor rSRJ = d_ij + sRJ # along the diagonal rSRJ < oRI, resulting in a mask whose # diagonals are strictly false. mask_final = np.less(oRI, rSRJ) d_ij_inv = 1 / d_ij # 1/d_ij has NaNs along diagonals so we need to zero it out keep_mask = 1 - np.eye(conf.shape[0]) d_ij_inv = np.where(keep_mask, d_ij_inv, np.zeros_like(d_ij_inv)) rfs = np.abs(d_ij - sRJ) l_ij = np.maximum(oRI, rfs) l_ij = 1 / l_ij u_ij = 1 / rSRJ l_ij2 = l_ij * l_ij u_ij2 = u_ij * u_ij ratio = np.log(u_ij / l_ij) term = l_ij - u_ij + 0.25 * d_ij * (u_ij2 - l_ij2) + ( 0.5 * d_ij_inv * ratio) + (0.25 * sRJ * sRJ * d_ij_inv) * (l_ij2 - u_ij2) term_masked = np.where(mask_final, term, np.zeros_like(term)) summ = np.sum(term_masked, axis=-1) summ *= 0.5 * oR sum2 = summ * summ sum3 = summ * sum2 tanhSum = np.tanh(alpha_obc * summ - beta_obc * sum2 + gamma_obc * sum3) return 1.0 / (1.0 / oR - tanhSum / atomic_radii)