def apply_param_gradient(self, step, hyper_params, param, state, grad): beta = hyper_params.beta weight_decay = hyper_params.weight_decay learning_rate = hyper_params.learning_rate eps = hyper_params.eps # weight decay if not hyper_params.use_adamWStyle_weightDecay: grad += weight_decay * param # gradient accumulation # power(3/2) added such that learning_rate is the actual step size rather than lr / cbrt(lr) weighted_lr = jnp.power(learning_rate, 3./2.) * jnp.sqrt(step + 1) grad_sum = state.grad_sum + weighted_lr * grad grad_sum_sq = state.grad_sum_sq + weighted_lr * lax.square(grad) # parameter update new_param = state.initial_param - grad_sum / (jnp.cbrt(grad_sum_sq) + eps) new_param = beta*param + (1. - beta)*new_param # momentum # AdamW-style weight decay if hyper_params.use_adamWStyle_weightDecay: new_param -= (1. - beta) * learning_rate * weight_decay * param new_state = _MadgradParamState(state.initial_param, grad_sum, grad_sum_sq) return new_param, new_state
def random_points_in_sphere(radius, num, rng): """Returns a random point sampled uniformly inside a sphere.""" coords_rng, scale_rng = jax.random.split(rng) coords = random_points_on_sphere(radius, num, coords_rng) scale = jax.random.uniform(scale_rng, shape=(num,)) scale = jnp.cbrt(scale) coords *= scale[:, jnp.newaxis] return coords
def body_fun(state): u, l, iter_idx, _, _ = state u_prev = u # Computes parameters. l2 = l**2 dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2) sqd = jnp.sqrt(1.0 + dd) a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2) a = jnp.real(a) b = (a - 1.0)**2 / 4.0 c = a + b - 1.0 # Updates l. l = l * (a + b * l2) / (1.0 + c * l2) # Uses QR or Cholesky decomposition. def true_fn(u): return _use_qr(u, m, n, params=(a, b, c)) def false_fn(u): return _use_cholesky(u, m, n, params=(a, b, c)) u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u)) if is_hermitian: u = (u + u.T.conj()) / 2.0 # Checks convergence. iterating_l = jnp.abs(1.0 - l) > tol_l iterating_u = jnp.linalg.norm(u - u_prev) > tol_norm is_unconverged = jnp.logical_or(iterating_l, iterating_u) is_not_max_iteration = iter_idx < max_iterations return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration
def _qdwh(x, m, n, is_hermitian, max_iterations, eps): """QR-based dynamically weighted Halley iteration for polar decomposition.""" # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for # the smallest singular value of x. if eps is None: eps = float(jnp.finfo(x.dtype).eps) alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) * jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype) l = eps u = x / alpha # Iteration tolerances. tol_l = 10.0 * eps / 2.0 tol_norm = jnp.cbrt(tol_l) def cond_fun(state): _, _, _, is_unconverged, is_not_max_iteration = state return jnp.logical_and(is_unconverged, is_not_max_iteration) def body_fun(state): u, l, iter_idx, _, _ = state u_prev = u # Computes parameters. l2 = l**2 dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2) sqd = jnp.sqrt(1.0 + dd) a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2) a = jnp.real(a) b = (a - 1.0)**2 / 4.0 c = a + b - 1.0 # Updates l. l = l * (a + b * l2) / (1.0 + c * l2) # Uses QR or Cholesky decomposition. def true_fn(u): return _use_qr(u, m, n, params=(a, b, c)) def false_fn(u): return _use_cholesky(u, m, n, params=(a, b, c)) u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u)) if is_hermitian: u = (u + u.T.conj()) / 2.0 # Checks convergence. iterating_l = jnp.abs(1.0 - l) > tol_l iterating_u = jnp.linalg.norm(u - u_prev) > tol_norm is_unconverged = jnp.logical_or(iterating_l, iterating_u) is_not_max_iteration = iter_idx < max_iterations return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration iter_idx = 1 is_unconverged = True is_not_max_iteration = True u, _, num_iters, is_unconverged, _ = jax.lax.while_loop( cond_fun=cond_fun, body_fun=body_fun, init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration)) # Applies Newton-Schulz refinement for better accuracy. u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u) h = u.T.conj() @ x h = (h + h.T.conj()) / 2.0 # Converged within the maximum number of iterations. is_converged = jnp.logical_not(is_unconverged) return u, h, num_iters - 1, is_converged
def cbrt(x): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.cbrt(x))
def _qdwh(matrix, eps, maxiter): """ Computes the unitary factor in the polar decomposition of A using the QDWH method. QDWH implements a 3rd order Pade approximation to the matrix sign function, X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2. (1) The coefficients a, b, and c are chosen dynamically based on an evolving estimate of the matrix condition number. Specifically, a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx) where l is initially a lower bound on the smallest singular value of X0, and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2). For poorly conditioned matrices (c > 100) the iteration (1) is rewritten in QR form, X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H, [Q1] R = [sqrt(c) X] (2) [Q2] [I ]. For well-conditioned matrices it is instead formulated using cheaper Cholesky iterations, X' = (b / c) X + (a - b/c) (X W^-1) W^-H, W = chol(I + c X^H X). (3) The QR iterations rapidly improve the condition number, and typically only 1 or 2 are required. A maximum of 6 iterations total are required for backwards stability to double precision. Args: matrix: The m x n input matrix. eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) . maxiter: Iterations will terminate after this many steps even if the above is unsatisfied. Returns: matrix: The unitary factor (m x n). jq: The number of QR iterations (1). jc: The number of Cholesky iterations (2). errs: Convergence history. """ n_rows, n_cols = matrix.shape fat = n_cols > n_rows if fat: matrix = matrix.T matrix, q_factor, l0 = _initialize_qdwh(matrix) if eps is None: eps = jnp.finfo(matrix.dtype).eps tol_lk = 5 * eps # stop when lk differs from 1 by less tol_delta = jnp.cbrt(tol_lk) # stop when the iterates change by less coefs = _qdwh_coefs(l0) errs = jnp.zeros(maxiter, dtype=matrix.real.dtype) matrix, j_qr, coefs, errs = _qdwh_qr(matrix, coefs, errs, tol_lk, tol_delta, maxiter) matrix, j_chol, errs = _qdwh_cholesky(matrix, coefs, errs, tol_lk, tol_delta, j_qr, maxiter) matrix = _dot(q_factor, matrix) if fat: matrix = matrix.T return matrix, j_qr, j_chol, errs
def d3( config: "D3Configuration", charges_: List[float], *coordinates: float, ): """The code has a faithful implementation of D3-zero and D3-BJ. There are also some optional parts that will selectively ignore or scale certain interatomic terms. Without these lines our ‘scalefactor’ is set to 1, which is equivalent to standard D3 terms. """ # van der Waals attractive R^-6 attractive_r6_vdw = 0.0 # van der Waals attractive R^-8 attractive_r8_vdw = 0.0 # Axilrod-Teller-Muto 3-body repulsive repulsive_abc = 0.0 # not sure what this is... rs8 = 1.0 natom = len(charges_) # the charges array is used ONLY for indexing, so we subtract 1 from the one we get as input charges = [x - 1 for x in charges_] # In case something clever needs to be done wrt inter and intramolecular interactions if config.bond_index is not None: molAatoms = getMollist(config.bond_index, 0) mols = [] for j in range(natom): mols.append(0) for atom in molAatoms: if atom == j: mols[j] = 1 mxc = [0] for j in range(MAX_ELEMENTS): mxc.append(0) for k in range(natom): if charges[k] > -1: for l in range(MAX_CONNECTIVITY): if isinstance(C6AB[j][j][l][l], (list, tuple)): if C6AB[j][j][l][l][0] > 0: mxc[j] = mxc[j] + 1 break # Coordination number based on covalent radii cn = ncoord(charges, coordinates) icomp = [0] * 100000 cc6ab = [0] * 100000 r2ab = [0] * 100000 dmp = [0] * 100000 for j in range(natom): dist = 0.0 rr = 0.0 attractive_r6_term = 0.0 attractive_r8_term = 0.0 ## This could be used to 'switch off' dispersion between bonded or geminal atoms ## scaling = False for k in range(j + 1, natom): scalefactor = 1.0 if config.intermolecular == True: if mols[j] == mols[k]: scalefactor = 0 print( f" --- Ignoring interaction between atoms {j+1} and {k+1}" ) if scaling and config.bond_index is not None: if config.bond_index[j][k] == 1: scalefactor = 0 for l in range(natom): if (config.bond_index[j][l] != 0 and config.bond_index[k][l] != 0 and j != k and config.bond_index[j][k] == 0): scalefactor = 0 for m in range(natom): if (config.bond_index[j][l] != 0 and config.bond_index[l][m] != 0 and config.bond_index[k][m] != 0 and j != m and k != l and config.bond_index[j][m] == 0): scalefactor = 1 / 1.2 if k > j: # compute distance totdist = jnp.sqrt( (coordinates[3 * j] - coordinates[3 * k])**2 + (coordinates[3 * j + 1] - coordinates[3 * k + 1])**2 + (coordinates[3 * j + 2] - coordinates[3 * k + 2])**2) C6jk = getc6(C6AB, mxc, charges, cn, j, k) # C8 parameters depend on C6 recursively atomA = int(charges[j]) atomB = int(charges[k]) C8jk = 3.0 * C6jk * R2R4[atomA] * R2R4[atomB] # C10 parameters (unused) # C10jk = 49.0 / 40.0 * jnp.power(C8jk, 2) / C6jk # Evaluation of the attractive term dependent on R^-6 and R^-8 if config.damp.casefold() == "zero".casefold(): dist = totdist rr = RAB[atomA][atomB] / dist tmp1 = config.rs6 * rr damp6 = 1 / (1 + 6 * jnp.power(tmp1, ALPHA6)) tmp2 = rs8 * rr damp8 = 1 / (1 + 6 * jnp.power(tmp2, ALPHA8)) attractive_r6_term = (-config.s6 * C6jk * damp6 / jnp.power(dist, 6) * scalefactor) attractive_r8_term = (-config.s8 * C8jk * damp8 / jnp.power(dist, 8) * scalefactor) elif config.damp.casefold() == "bj".casefold(): dist = totdist rr = RAB[atomA][atomB] / dist rr = jnp.sqrt(C8jk / C6jk) tmp1 = config.a1 * rr + config.a2 damp6 = jnp.power(tmp1, 6) damp8 = jnp.power(tmp1, 8) attractive_r6_term = (-config.s6 * C6jk / (jnp.power(dist, 6) + damp6) * scalefactor) attractive_r8_term = (-config.s8 * C8jk / (jnp.power(dist, 8) + damp8) * scalefactor) else: raise RuntimeError( f"{config.damp} is an unknown damping scheme.") if config.pairwise and scalefactor != 0: print( f" --- Pairwise interaction between atoms {j+1} and {k+1}: Edisp = {attractive_r6_term+attractive_r8_term:.6f} kcal/mol", ) attractive_r6_vdw += attractive_r6_term attractive_r8_vdw += attractive_r8_term if config.threebody: jk = int(lin(k, j)) icomp[jk] = 1 cc6ab[jk] = jnp.sqrt(C6jk) r2ab[jk] = jnp.power(dist, 2) dmp[jk] = jnp.cbrt(1.0 / rr) if config.threebody: e63 = 0.0 for iat in range(natom): for jat in range(natom): ij = int(lin(jat, iat)) if icomp[ij] == 1: for kat in range(jat, natom): ik = int(lin(kat, iat)) jk = int(lin(kat, jat)) if (kat > jat and jat > iat and icomp[ik] != 0 and icomp[jk] != 0): rav = (4.0 / 3.0) / (dmp[ik] * dmp[jk] * dmp[ij]) tmp = 1.0 / (1.0 + 6.0 * rav**ALPHA6) c9 = cc6ab[ij] * cc6ab[ik] * cc6ab[jk] d2 = [ r2ab[ij], r2ab[jk], r2ab[ik], ] t1 = (d2[0] + d2[1] - d2[2]) / jnp.sqrt( d2[0] * d2[1]) t2 = (d2[0] + d2[2] - d2[1]) / jnp.sqrt( d2[0] * d2[2]) t3 = (d2[2] + d2[1] - d2[0]) / jnp.sqrt( d2[1] * d2[2]) ang = 0.375 * t1 * t2 * t3 + 1.0 e63 = e63 + tmp * c9 * ang / (d2[0] * d2[1] * d2[2])**1.50 repulsive_abc_term = config.s6 * e63 repulsive_abc += repulsive_abc_term return attractive_r6_vdw + attractive_r8_vdw + repulsive_abc