示例#1
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        beta = hyper_params.beta
        weight_decay = hyper_params.weight_decay
        learning_rate = hyper_params.learning_rate
        eps = hyper_params.eps

        # weight decay
        if not hyper_params.use_adamWStyle_weightDecay:
            grad += weight_decay * param

        # gradient accumulation
        # power(3/2) added such that learning_rate is the actual step size rather than lr / cbrt(lr)
        weighted_lr = jnp.power(learning_rate, 3./2.) * jnp.sqrt(step + 1)
        grad_sum = state.grad_sum + weighted_lr * grad
        grad_sum_sq = state.grad_sum_sq + weighted_lr * lax.square(grad)

        # parameter update
        new_param = state.initial_param - grad_sum / (jnp.cbrt(grad_sum_sq) + eps)
        new_param = beta*param + (1. - beta)*new_param # momentum

        # AdamW-style weight decay
        if hyper_params.use_adamWStyle_weightDecay:
            new_param -= (1. - beta) * learning_rate * weight_decay * param

        new_state = _MadgradParamState(state.initial_param, grad_sum, grad_sum_sq)
        return new_param, new_state
示例#2
0
def random_points_in_sphere(radius,
                            num,
                            rng):
  """Returns a random point sampled uniformly inside a sphere."""
  coords_rng, scale_rng = jax.random.split(rng)

  coords = random_points_on_sphere(radius, num, coords_rng)
  scale = jax.random.uniform(scale_rng, shape=(num,))
  scale = jnp.cbrt(scale)
  coords *= scale[:, jnp.newaxis]
  return coords
示例#3
0
    def body_fun(state):
        u, l, iter_idx, _, _ = state

        u_prev = u

        # Computes parameters.
        l2 = l**2
        dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2)
        sqd = jnp.sqrt(1.0 + dd)
        a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) /
                            (l2 * sqd)) / 2)
        a = jnp.real(a)
        b = (a - 1.0)**2 / 4.0
        c = a + b - 1.0

        # Updates l.
        l = l * (a + b * l2) / (1.0 + c * l2)

        # Uses QR or Cholesky decomposition.
        def true_fn(u):
            return _use_qr(u, m, n, params=(a, b, c))

        def false_fn(u):
            return _use_cholesky(u, m, n, params=(a, b, c))

        u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))

        if is_hermitian:
            u = (u + u.T.conj()) / 2.0

        # Checks convergence.
        iterating_l = jnp.abs(1.0 - l) > tol_l
        iterating_u = jnp.linalg.norm(u - u_prev) > tol_norm
        is_unconverged = jnp.logical_or(iterating_l, iterating_u)

        is_not_max_iteration = iter_idx < max_iterations

        return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration
示例#4
0
def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
    """QR-based dynamically weighted Halley iteration for polar decomposition."""

    # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
    # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
    # the smallest singular value of x.
    if eps is None:
        eps = float(jnp.finfo(x.dtype).eps)
    alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
             jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype)
    l = eps

    u = x / alpha

    # Iteration tolerances.
    tol_l = 10.0 * eps / 2.0
    tol_norm = jnp.cbrt(tol_l)

    def cond_fun(state):
        _, _, _, is_unconverged, is_not_max_iteration = state
        return jnp.logical_and(is_unconverged, is_not_max_iteration)

    def body_fun(state):
        u, l, iter_idx, _, _ = state

        u_prev = u

        # Computes parameters.
        l2 = l**2
        dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2)
        sqd = jnp.sqrt(1.0 + dd)
        a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) /
                            (l2 * sqd)) / 2)
        a = jnp.real(a)
        b = (a - 1.0)**2 / 4.0
        c = a + b - 1.0

        # Updates l.
        l = l * (a + b * l2) / (1.0 + c * l2)

        # Uses QR or Cholesky decomposition.
        def true_fn(u):
            return _use_qr(u, m, n, params=(a, b, c))

        def false_fn(u):
            return _use_cholesky(u, m, n, params=(a, b, c))

        u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))

        if is_hermitian:
            u = (u + u.T.conj()) / 2.0

        # Checks convergence.
        iterating_l = jnp.abs(1.0 - l) > tol_l
        iterating_u = jnp.linalg.norm(u - u_prev) > tol_norm
        is_unconverged = jnp.logical_or(iterating_l, iterating_u)

        is_not_max_iteration = iter_idx < max_iterations

        return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration

    iter_idx = 1
    is_unconverged = True
    is_not_max_iteration = True
    u, _, num_iters, is_unconverged, _ = jax.lax.while_loop(
        cond_fun=cond_fun,
        body_fun=body_fun,
        init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration))

    # Applies Newton-Schulz refinement for better accuracy.
    u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)

    h = u.T.conj() @ x
    h = (h + h.T.conj()) / 2.0

    # Converged within the maximum number of iterations.
    is_converged = jnp.logical_not(is_unconverged)

    return u, h, num_iters - 1, is_converged
示例#5
0
def cbrt(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.cbrt(x))
示例#6
0
def _qdwh(matrix, eps, maxiter):
    """ Computes the unitary factor in the polar decomposition of A using
  the QDWH method. QDWH implements a 3rd order Pade approximation to the
  matrix sign function,

  X' = X * (aI + b X^H X)(I + c X^H X)^-1, X0 = A / ||A||_2.          (1)

  The coefficients a, b, and c are chosen dynamically based on an evolving
  estimate of the matrix condition number. Specifically,

  a = h(l), b = g(a), c = a + b - 1, h(x) = x g(x^2), g(x) = a + bx / (1 + cx)

  where l is initially a lower bound on the smallest singular value of X0,
  and subsequently evolves according to l' = l (a + bl^2) / (1 + c l^2).

  For poorly conditioned matrices
  (c > 100) the iteration (1) is rewritten in QR form,

  X' = (b / c) X + (1 / c)(a - b/c) Q1 Q2^H,   [Q1] R = [sqrt(c) X]   (2)
                                               [Q2]     [I        ].

  For well-conditioned matrices it is instead formulated using cheaper
  Cholesky iterations,

  X' = (b / c) X + (a - b/c) (X W^-1) W^-H,   W = chol(I + c X^H X).  (3)

  The QR iterations rapidly improve the condition number, and typically
  only 1 or 2 are required. A maximum of 6 iterations total are required
  for backwards stability to double precision.

  Args:
    matrix: The m x n input matrix.
    eps: The final result will satisfy |X_k - X_k-1| < |X_k| * (4*eps)**(1/3) .
    maxiter: Iterations will terminate after this many steps even if the
             above is unsatisfied.
  Returns:
    matrix: The unitary factor (m x n).
    jq: The number of QR iterations (1).
    jc: The number of Cholesky iterations (2).
    errs: Convergence history.
  """
    n_rows, n_cols = matrix.shape
    fat = n_cols > n_rows
    if fat:
        matrix = matrix.T
    matrix, q_factor, l0 = _initialize_qdwh(matrix)

    if eps is None:
        eps = jnp.finfo(matrix.dtype).eps
    tol_lk = 5 * eps  # stop when lk differs from 1 by less
    tol_delta = jnp.cbrt(tol_lk)  # stop when the iterates change by less
    coefs = _qdwh_coefs(l0)
    errs = jnp.zeros(maxiter, dtype=matrix.real.dtype)
    matrix, j_qr, coefs, errs = _qdwh_qr(matrix, coefs, errs, tol_lk,
                                         tol_delta, maxiter)
    matrix, j_chol, errs = _qdwh_cholesky(matrix, coefs, errs, tol_lk,
                                          tol_delta, j_qr, maxiter)
    matrix = _dot(q_factor, matrix)

    if fat:
        matrix = matrix.T
    return matrix, j_qr, j_chol, errs
示例#7
0
def d3(
    config: "D3Configuration",
    charges_: List[float],
    *coordinates: float,
):
    """The code has a faithful implementation of D3-zero and D3-BJ. There are
    also some optional parts that will selectively ignore or scale certain
    interatomic terms. Without these lines our ‘scalefactor’ is set to 1, which
    is equivalent to standard D3 terms.
    """

    # van der Waals attractive R^-6
    attractive_r6_vdw = 0.0
    # van der Waals attractive R^-8
    attractive_r8_vdw = 0.0
    # Axilrod-Teller-Muto 3-body repulsive
    repulsive_abc = 0.0

    # not sure what this is...
    rs8 = 1.0

    natom = len(charges_)
    # the charges array is used ONLY for indexing, so we subtract 1 from the one we get as input
    charges = [x - 1 for x in charges_]

    # In case something clever needs to be done wrt inter and intramolecular interactions
    if config.bond_index is not None:
        molAatoms = getMollist(config.bond_index, 0)
        mols = []
        for j in range(natom):
            mols.append(0)
            for atom in molAatoms:
                if atom == j:
                    mols[j] = 1

    mxc = [0]
    for j in range(MAX_ELEMENTS):
        mxc.append(0)
        for k in range(natom):
            if charges[k] > -1:
                for l in range(MAX_CONNECTIVITY):
                    if isinstance(C6AB[j][j][l][l], (list, tuple)):
                        if C6AB[j][j][l][l][0] > 0:
                            mxc[j] = mxc[j] + 1
                break

    # Coordination number based on covalent radii
    cn = ncoord(charges, coordinates)

    icomp = [0] * 100000
    cc6ab = [0] * 100000
    r2ab = [0] * 100000
    dmp = [0] * 100000

    for j in range(natom):
        dist = 0.0
        rr = 0.0
        attractive_r6_term = 0.0
        attractive_r8_term = 0.0
        ## This could be used to 'switch off' dispersion between bonded or geminal atoms ##
        scaling = False
        for k in range(j + 1, natom):
            scalefactor = 1.0

            if config.intermolecular == True:
                if mols[j] == mols[k]:
                    scalefactor = 0
                    print(
                        f"   --- Ignoring interaction between atoms {j+1} and {k+1}"
                    )

            if scaling and config.bond_index is not None:
                if config.bond_index[j][k] == 1:
                    scalefactor = 0
                for l in range(natom):
                    if (config.bond_index[j][l] != 0
                            and config.bond_index[k][l] != 0 and j != k
                            and config.bond_index[j][k] == 0):
                        scalefactor = 0
                    for m in range(natom):
                        if (config.bond_index[j][l] != 0
                                and config.bond_index[l][m] != 0
                                and config.bond_index[k][m] != 0 and j != m
                                and k != l and config.bond_index[j][m] == 0):
                            scalefactor = 1 / 1.2

            if k > j:
                # compute distance
                totdist = jnp.sqrt(
                    (coordinates[3 * j] - coordinates[3 * k])**2 +
                    (coordinates[3 * j + 1] - coordinates[3 * k + 1])**2 +
                    (coordinates[3 * j + 2] - coordinates[3 * k + 2])**2)

                C6jk = getc6(C6AB, mxc, charges, cn, j, k)

                # C8 parameters depend on C6 recursively
                atomA = int(charges[j])
                atomB = int(charges[k])

                C8jk = 3.0 * C6jk * R2R4[atomA] * R2R4[atomB]

                # C10 parameters (unused)
                # C10jk = 49.0 / 40.0 * jnp.power(C8jk, 2) / C6jk

                # Evaluation of the attractive term dependent on R^-6 and R^-8
                if config.damp.casefold() == "zero".casefold():
                    dist = totdist
                    rr = RAB[atomA][atomB] / dist
                    tmp1 = config.rs6 * rr
                    damp6 = 1 / (1 + 6 * jnp.power(tmp1, ALPHA6))
                    tmp2 = rs8 * rr
                    damp8 = 1 / (1 + 6 * jnp.power(tmp2, ALPHA8))

                    attractive_r6_term = (-config.s6 * C6jk * damp6 /
                                          jnp.power(dist, 6) * scalefactor)
                    attractive_r8_term = (-config.s8 * C8jk * damp8 /
                                          jnp.power(dist, 8) * scalefactor)
                elif config.damp.casefold() == "bj".casefold():
                    dist = totdist
                    rr = RAB[atomA][atomB] / dist
                    rr = jnp.sqrt(C8jk / C6jk)
                    tmp1 = config.a1 * rr + config.a2
                    damp6 = jnp.power(tmp1, 6)
                    damp8 = jnp.power(tmp1, 8)

                    attractive_r6_term = (-config.s6 * C6jk /
                                          (jnp.power(dist, 6) + damp6) *
                                          scalefactor)
                    attractive_r8_term = (-config.s8 * C8jk /
                                          (jnp.power(dist, 8) + damp8) *
                                          scalefactor)
                else:
                    raise RuntimeError(
                        f"{config.damp} is an unknown damping scheme.")

                if config.pairwise and scalefactor != 0:
                    print(
                        f"   --- Pairwise interaction between atoms {j+1} and {k+1}: Edisp = {attractive_r6_term+attractive_r8_term:.6f} kcal/mol",
                    )

                attractive_r6_vdw += attractive_r6_term
                attractive_r8_vdw += attractive_r8_term

                if config.threebody:
                    jk = int(lin(k, j))
                    icomp[jk] = 1
                    cc6ab[jk] = jnp.sqrt(C6jk)
                    r2ab[jk] = jnp.power(dist, 2)
                    dmp[jk] = jnp.cbrt(1.0 / rr)

    if config.threebody:
        e63 = 0.0
        for iat in range(natom):
            for jat in range(natom):
                ij = int(lin(jat, iat))
                if icomp[ij] == 1:
                    for kat in range(jat, natom):
                        ik = int(lin(kat, iat))
                        jk = int(lin(kat, jat))

                        if (kat > jat and jat > iat and icomp[ik] != 0
                                and icomp[jk] != 0):
                            rav = (4.0 / 3.0) / (dmp[ik] * dmp[jk] * dmp[ij])
                            tmp = 1.0 / (1.0 + 6.0 * rav**ALPHA6)

                            c9 = cc6ab[ij] * cc6ab[ik] * cc6ab[jk]
                            d2 = [
                                r2ab[ij],
                                r2ab[jk],
                                r2ab[ik],
                            ]
                            t1 = (d2[0] + d2[1] - d2[2]) / jnp.sqrt(
                                d2[0] * d2[1])
                            t2 = (d2[0] + d2[2] - d2[1]) / jnp.sqrt(
                                d2[0] * d2[2])
                            t3 = (d2[2] + d2[1] - d2[0]) / jnp.sqrt(
                                d2[1] * d2[2])
                            ang = 0.375 * t1 * t2 * t3 + 1.0
                            e63 = e63 + tmp * c9 * ang / (d2[0] * d2[1] *
                                                          d2[2])**1.50

        repulsive_abc_term = config.s6 * e63
        repulsive_abc += repulsive_abc_term

    return attractive_r6_vdw + attractive_r8_vdw + repulsive_abc