Exemple #1
0
 def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]:
   """See `Distribution._sample_n_and_log_prob`."""
   samples = self._sample_n(key, n)
   log_prob = jnp.zeros_like(samples)
   return samples, log_prob
Exemple #2
0
def get_grads(bijector, x, *cond):
    (out_y, out_ldj), vjp_fun = jax.vjp(bijector, x, *cond)
    vjp_fun = jax.jit(vjp_fun)
    return (vjp_fun((jnp.ones_like(out_y), jnp.zeros_like(out_ldj))),
            vjp_fun((jnp.zeros_like(out_y), jnp.ones_like(out_ldj))))
Exemple #3
0
def gmres(A, b, x0=None, n=5, M=identity, record=False):
  if x0 is None:
    x0 = np.zeros_like(b)
  return _gmres(A, b, x0, n, M, record)
Exemple #4
0
 def init_param_state(self, param):
     return _AdamParamState(jnp.zeros_like(param), jnp.zeros_like(param))
Exemple #5
0
 def init(x):
     s = jnp.zeros_like(x)
     nu = jnp.zeros_like(x)
     x0 = x
     return x, s, nu, x0
Exemple #6
0
def _nan_to_inf(x):
    return jnp.where(jnp.isnan(x), jnp.inf + jnp.zeros_like(x), x)
Exemple #7
0
    def predict_fn(
        get: Get = None,
        k_test_train=None,
        nngp_test_test: np.ndarray = None
    ) -> Dict[str, Union[np.ndarray, Gaussian]]:
        """`test`-set posterior given respective covariance matrices.

    Args:
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are
        returned.
      k_test_train:
        test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. For
        example, if you request to compute posterior test [only] NTK covariance,
        `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`,
        returns predictions on the training set. Note that train-set outputs are
        always `N(y_train, 0)` and mostly returned for API consistency.
      nngp_test_test:
        A test-test NNGP array. Provide if you want to compute test-test
        posterior covariance. `nngp_test_tes=None`, means to not compute it. If
        `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you
        want to get non-regularized (`diag_reg=0`) train-train posterior
        covariance. Note that non-regularized train-set outputs will always be
        the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API
        consistency. For regularized train-set posterior outputs according to a
        positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally,
        `nngp_test_test=nngp_train_train`.

    Returns:
      Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP
      posterior on the  `test` set.
    """
        if get is None:
            get = ('nngp', 'ntk')

        out = {}

        for g in get:
            k_dd = _get_attr(k_train_train, g)
            k_td = None if k_test_train is None else _get_attr(k_test_train, g)

            if k_td is None:
                # Train set predictions.
                y = y_train.astype(k_dd.dtype)
            else:
                # Test set predictions.
                y = np.tensordot(k_td, k_inv_y(g), (odd, first))
                y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes)

            if nngp_test_test is not None:
                if k_td is None:
                    out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype))
                else:
                    if (g == 'ntk' and (not hasattr(k_train_train, 'nngp')
                                        or not hasattr(k_test_train, 'nngp'))):
                        raise ValueError(
                            'If `"ntk" in get`, and `nngp_test_test is not None`, '
                            'and `k_test_train is not None`, i.e. you request the '
                            'NTK posterior covariance on the test set, you need '
                            'both NTK and NNGP train-train and test-train matrices '
                            'contained in `k_test_train` and `k_train_train`. '
                            'Hence they must be `namedtuple`s with `nngp` and '
                            '`ntk` attributes.')

                    k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'),
                                               even)

                    if g == 'nngp':
                        cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first))
                        cov = nngp_test_test - utils.zip_axes(cov)
                        out[g] = Gaussian(y, cov)

                    elif g == 'ntk':
                        term_1 = solve(g)(k_td, even)
                        cov = np.tensordot(_get_attr(k_train_train, 'nngp'),
                                           term_1, (odd, first))
                        cov = np.tensordot(term_1, cov, (first, first))

                        term_2 = np.tensordot(k_td, k_td_nngp_inv_y,
                                              (odd, first))
                        term_2 += np.moveaxis(term_2, first, last)
                        cov = utils.zip_axes(cov - term_2) + nngp_test_test
                        out[g] = Gaussian(y, cov)

                    else:
                        raise ValueError(g)

            else:
                out[g] = y

        return out
Exemple #8
0
 def init_param_state(self, param: jnp.ndarray) -> _RMSPropParamState:
     """Initializes parameter state. See base class."""
     return _RMSPropParamState(jnp.ones_like(param), jnp.zeros_like(param))
Exemple #9
0
def lennard_jones(conf, lj_params, box, cutoff):
    """
    Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining
    rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j).

    Parameters
    ----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    param_idxs: shape [num_atoms, 2] np.array
        each tuple (sig, eps) is used as part of the combining rules

    scale_matrix: shape [num_atoms, num_atoms] np.array
        scale mask denoting how we should scale interaction e[i,j].
        The elements should be between [0, 1]. If e[i,j] is 1 then the interaction
        is fully included, 0 implies it is discarded.

    cutoff: float
        Whether or not we apply cutoffs to the system. Any interactions
        greater than cutoff is fully discarded.
    
    """
    # box = None
    # assert box is None

    sig = lj_params[:, 0]
    eps = lj_params[:, 1]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = (sig_i + sig_j) / 2
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = np.sqrt(eps_i * eps_j)

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    # gi = np.expand_dims(groups, axis=0)
    # gj = np.expand_dims(groups, axis=1)
    # gij = np.bitwise_and(gi, gj) > 0

    # print(gij)

    # print("BOX", box)
    dij = distance(ri, rj, box)
    # print("DIJ", dij)

    N = conf.shape[0]
    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij / dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij = 4 * eps_ij * (sig6 - 1.0) * sig6

    # if cutoff is not None:
    # sw = switch_fn(dij, cutoff)
    # eij = eij*sw

    eij = np.where(keep_mask, eij, np.zeros_like(eij))

    # print("eps_ij", eps_ij)
    # print("sig_ij", sig_ij)

    return np.sum(eij / 2)
Exemple #10
0
 def init(x0):
   m0 = np.zeros_like(x0)
   v0 = np.zeros_like(x0)
   return x0, m0, v0
Exemple #11
0
 def init(x0):
   vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
   return x0, np.zeros_like(x0), vs
Exemple #12
0
 def init(x0):
   v0 = np.zeros_like(x0)
   return x0, v0
def isoneutral_diffusion_pre(maskT, maskU, maskV, maskW, dxt, dxu, dyt, dyu,
                             dzt, dzw, cost, cosu, salt, temp, zt, K_iso, K_11,
                             K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by):
    """
    Isopycnal diffusion for tracer
    following functional formulation by Griffies et al
    Code adopted from MOM2.1
    """
    epsln = 1e-20
    iso_slopec = 1e-3
    iso_dslope = 1e-3
    K_iso_steep = 50.
    tau = 0

    dTdx = np.zeros_like(K_11)
    dSdx = np.zeros_like(K_11)
    dTdy = np.zeros_like(K_11)
    dSdy = np.zeros_like(K_11)
    dTdz = np.zeros_like(K_11)
    dSdz = np.zeros_like(K_11)
    """
    drho_dt and drho_ds at centers of T cells
    """
    drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau],
                              np.abs(zt))
    drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau],
                              np.abs(zt))
    """
    gradients at top face of T cells
    """
    dTdz = jax.ops.index_update(
        dTdz, jax.ops.index[:, :, :-1], maskW[:, :, :-1] * \
        (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) / \
        dzw[np.newaxis, np.newaxis, :-1]
    )
    dSdz = jax.ops.index_update(
        dSdz, jax.ops.index[:, :, :-1], maskW[:, :, :-1] * \
        (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) / \
        dzw[np.newaxis, np.newaxis, :-1]
    )
    """
    gradients at eastern face of T cells
    """
    dTdx = jax.ops.index_update(
        dTdx, jax.ops.index[:-1, :, :], maskU[:-1, :, :] * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) \
        / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis])
    )
    dSdx = jax.ops.index_update(
        dSdx, jax.ops.index[:-1, :, :],
        maskU[:-1, :, :] * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) /
        (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]))
    """
    gradients at northern face of T cells
    """
    dTdy = jax.ops.index_update(
        dTdy, jax.ops.index[:, :-1, :], maskV[:, :-1, :] * \
        (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) \
        / dyu[np.newaxis, :-1, np.newaxis]
    )
    dSdy = jax.ops.index_update(dSdy, jax.ops.index[:, :-1, :], maskV[:, :-1, :] * \
        (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) \
        / dyu[np.newaxis, :-1, np.newaxis]
    )

    def dm_taper(sx):
        """
        tapering function for isopycnal slopes
        """
        return 0.5 * (1. + np.tanh((-np.abs(sx) + iso_slopec) / iso_dslope))

    """
    Compute Ai_ez and K11 on center of east face of T cell.
    """
    diffloc = np.zeros_like(K_11)
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[1:-2, 2:-2, 1:],
        0.25 * (K_iso[1:-2, 2:-2, 1:] + K_iso[1:-2, 2:-2, :-1] +
                K_iso[2:-1, 2:-2, 1:] + K_iso[2:-1, 2:-2, :-1]))
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[1:-2, 2:-2, 0],
        0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]))

    sumz = np.zeros_like(K_11)[1:-2, 2:-2]
    for kr in range(2):
        ki = 0 if kr == 1 else 1
        for ip in range(2):
            drodxe = drdT[1 + ip:-2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] \
                + drdS[1 + ip:-2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:]
            drodze = drdT[1 + ip:-2 + ip, 2:-2, ki:] * dTdz[1 + ip:-2 + ip, 2:-2, :-1 + kr or None] \
                + drdS[1 + ip:-2 + ip, 2:-2, ki:] * \
                dSdz[1 + ip:-2 + ip, 2:-2, :-1 + kr or None]
            sxe = -drodxe / (np.minimum(0., drodze) - epsln)
            taper = dm_taper(sxe)
            sumz = jax.ops.index_update(
                sumz, jax.ops.index[:, :, ki:], sumz[..., ki:] +
                dzw[np.newaxis, np.newaxis, :-1 + kr or None] *
                maskU[1:-2, 2:-2, ki:] *
                np.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper))
            Ai_ez = jax.ops.index_update(
                Ai_ez, jax.ops.index[1:-2, 2:-2, ki:, ip, kr],
                taper * sxe * maskU[1:-2, 2:-2, ki:])

    K_11 = jax.ops.index_update(K_11, jax.ops.index[1:-2, 2:-2, :],
                                sumz / (4. * dzt[np.newaxis, np.newaxis, :]))
    """
    Compute Ai_nz and K_22 on center of north face of T cell.
    """
    diffloc = jax.ops.index_update(diffloc, jax.ops.index[...], 0)
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[2:-2, 1:-2, 1:],
        0.25 * (K_iso[2:-2, 1:-2, 1:] + K_iso[2:-2, 1:-2, :-1] +
                K_iso[2:-2, 2:-1, 1:] + K_iso[2:-2, 2:-1, :-1]))
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[2:-2, 1:-2, 0],
        0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]))

    sumz = np.zeros_like(K_11)[2:-2, 1:-2]
    for kr in range(2):
        ki = 0 if kr == 1 else 1
        for jp in range(2):
            drodyn = drdT[2:-2, 1 + jp:-2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] + \
                drdS[2:-2, 1 + jp:-2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:]
            drodzn = drdT[2:-2, 1 + jp:-2 + jp, ki:] * dTdz[2:-2, 1 + jp:-2 + jp, :-1 + kr or None] \
                + drdS[2:-2, 1 + jp:-2 + jp, ki:] * \
                dSdz[2:-2, 1 + jp:-2 + jp, :-1 + kr or None]
            syn = -drodyn / (np.minimum(0., drodzn) - epsln)
            taper = dm_taper(syn)
            sumz = jax.ops.index_update(
                sumz, jax.ops.index[:, :, ki:], sumz[..., ki:] +
                dzw[np.newaxis, np.newaxis, :-1 + kr or None] *
                maskV[2:-2, 1:-2, ki:] *
                np.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper))
            Ai_nz = jax.ops.index_update(
                Ai_nz, jax.ops.index[2:-2, 1:-2, ki:, jp, kr],
                taper * syn * maskV[2:-2, 1:-2, ki:])
    K_22 = jax.ops.index_update(K_22, jax.ops.index[2:-2, 1:-2, :],
                                sumz / (4. * dzt[np.newaxis, np.newaxis, :]))
    """
    compute Ai_bx, Ai_by and K33 on top face of T cell.
    """
    sumx = np.zeros_like(K_11)[2:-2, 2:-2, :-1]
    sumy = np.zeros_like(K_11)[2:-2, 2:-2, :-1]

    for kr in range(2):
        drodzb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdz[2:-2, 2:-2, :-1] \
            + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdz[2:-2, 2:-2, :-1]

        # eastward slopes at the top of T cells
        for ip in range(2):
            drodxb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdx[1 + ip:-3 + ip, 2:-2, kr:-1 + kr or None] \
                + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdx[1 + ip:-3 + ip, 2:-2, kr:-1 + kr or None]
            sxb = -drodxb / (np.minimum(0., drodzb) - epsln)
            taper = dm_taper(sxb)
            sumx += dxu[1 + ip:-3 + ip, np.newaxis, np.newaxis] * \
                K_iso[2:-2, 2:-2, :-1] * taper * \
                sxb**2 * maskW[2:-2, 2:-2, :-1]
            Ai_bx = jax.ops.index_update(
                Ai_bx, jax.ops.index[2:-2, 2:-2, :-1, ip, kr],
                taper * sxb * maskW[2:-2, 2:-2, :-1])

        # northward slopes at the top of T cells
        for jp in range(2):
            facty = cosu[1 + jp:-3 + jp] * dyu[1 + jp:-3 + jp]
            drodyb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdy[2:-2, 1 + jp:-3 + jp, kr:-1 + kr or None] \
                + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdy[2:-2, 1 + jp:-3 + jp, kr:-1 + kr or None]
            syb = -drodyb / (np.minimum(0., drodzb) - epsln)
            taper = dm_taper(syb)
            sumy += facty[np.newaxis, :, np.newaxis] * K_iso[2:-2, 2:-2, :-1] \
                * taper * syb**2 * maskW[2:-2, 2:-2, :-1]
            Ai_by = jax.ops.index_update(
                Ai_by, jax.ops.index[2:-2, 2:-2, :-1, jp, kr],
                taper * syb * maskW[2:-2, 2:-2, :-1])

    K_33 = jax.ops.index_update(
        K_33, jax.ops.index[2:-2, 2:-2, :-1],
        sumx / (4 * dxt[2:-2, np.newaxis, np.newaxis]) + \
        sumy / (4 * dyt[np.newaxis, 2:-2, np.newaxis]
                * cost[np.newaxis, 2:-2, np.newaxis])
    )
    K_33 = jax.ops.index_update(K_33, jax.ops.index[2:-2, 2:-2, -1], 0.)

    return K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by
Exemple #14
0
#%% Data for BNN

idx = 0
X_selected = jnp.array(X_train_idx[idx * n_data:idx * n_data + n_data, :2])
y_selected = jnp.array(y_delta_train_idx[idx * n_data:idx * n_data +
                                         n_data, :]).squeeze()
X = jnp.array(X_train[:, :2])
y = jnp.array(y_delta_train)
X_test = jnp.array(X_test)
y_test = jnp.array(y_delta_test)

p_selected = X_selected[:, 0]
t_selected = X_selected[:, 1]  # / params['t_max']

F = jnp.zeros_like(y_selected)

X_train = jnp.array(X_trainc[:, :2])
p_selected = X_train[:, 0]
t_selected = X_train[:, 1]  # / params['t_max']
y_selected = jnp.array(y_delta_trainc.squeeze())
mask = jnp.array(mask_data, dtype=bool)
F = jnp.zeros_like(y_selected)

#%% Data for BNN

# X_train = torch.tensor(X_trainc[:,:2], dtype=torch.float32)
# y_train = torch.tensor(y_delta_trainc, dtype=torch.float32)
# X_test = torch.tensor(X_test, dtype=torch.float32)
# y_test = torch.tensor(y_delta_test, dtype=torch.float32)
# mask = torch.tensor(mask_data.reshape(-1,1), dtype=torch.bool)
def loss_fn(
    output_logits,
    targets,
    valid_mask,
    num_nodes,
    captured,
    negative_example_weight = 1,
    focal_loss_gamma = 0.0,
):
  """Compute loss and single-batch metrics for some outputs.

  Args:
    output_logits: Binary logits produced by the model.
    targets: Model targets.
    valid_mask: Mask determining which outputs are valid.
    num_nodes: How many nodes there are in each example.
    captured: Ignored
    negative_example_weight: Weight to assign to a negative example when
      computing the loss. Positive examples always get weight 1.
    focal_loss_gamma: Focusing parameter for the focal loss, as described in Lin
      et al. (2018). If zero, uses standard cross-entropy loss.

  Returns:
    Tuple (loss, metrics_dict).
  """
  del captured
  num_targets = jnp.count_nonzero(targets)
  # Compute cross entropy.
  unmasked_nll = model_util.binary_logit_cross_entropy(output_logits, targets)
  if focal_loss_gamma:
    # (1-p_correct)**gamma = (-(p-1))**gamma = (-expm1(log(p)))**gamma
    focus_term = jnp.power(-jnp.expm1(-unmasked_nll), focal_loss_gamma)
    unmasked_nll = unmasked_nll * focus_term
  # Mask the results so that they only count nodes that exist.
  masked_nll = unmasked_nll * valid_mask
  # Primary loss: Sum of nll over all nodes. We use sum because most of the
  # edges are easy negatives.
  positive_nll = jnp.sum(
      jnp.where(targets, masked_nll, jnp.zeros_like(masked_nll)))
  negative_nll = jnp.sum(
      jnp.where(targets, jnp.zeros_like(masked_nll), masked_nll))
  reweighted_nll = positive_nll + negative_example_weight * negative_nll
  binary_nll = jnp.sum(reweighted_nll)
  # Compute additional metrics to track learning progress.
  # Average NLL of target edges:
  avg_nll_per_target = positive_nll / num_targets
  # Average NLL of non-target edges:
  num_non_targets = num_nodes**2 - num_targets
  avg_nll_per_non_target = negative_nll / num_non_targets
  # Max error for any edge prediction:
  worst_nll = jnp.max(masked_nll)

  loss = binary_nll

  # Ratio of positive to negative targets. If this is equal to
  # negative_example_weight, the positive and negative examples will have the
  # same total weight.
  positive_per_negative = num_targets / num_non_targets
  # Precision and recall at 0.1 threshold
  thresholded_preds = output_logits > jax.scipy.special.logit(0.1)
  count_target_pred = jnp.count_nonzero(thresholded_preds & targets)
  count_pred = jnp.count_nonzero(thresholded_preds & valid_mask.astype(bool))
  precision = count_target_pred / count_pred
  recall = count_target_pred / num_targets
  return loss, {
      "avg_per_target":
          avg_nll_per_target,
      "avg_per_non_target":
          avg_nll_per_non_target,
      "worst":
          worst_nll,
      "positive_per_negative":
          positive_per_negative,
      "effective_p_model_given_target":
          jnp.exp(-avg_nll_per_target),
      "effective_p_model_given_nontarget":
          1 - jnp.exp(-avg_nll_per_non_target),
      "batch_clf_thresh_at_0.1/precision":
          precision,
      "batch_clf_thresh_at_0.1/recall":
          recall,
      "batch_clf_thresh_at_0.1/f1":
          2 * (precision * recall) / (precision + recall),
  }
Exemple #16
0
def nonbonded_v3(conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask,
                 scales, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs):

    N = conf.shape[0]

    conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                         cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        box_4d = np.eye(4) * 1000
        box_4d = index_update(box_4d, index[:3, :3], box)
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j
    sig_ij_raw = sig_ij

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj, box)

    N = conf.shape[0]
    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij))

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij))
    eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij))

    sig2 = sig_ij / dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, np.zeros_like(eij_lj))

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(conf.shape[0])
    qij = np.where(keep_mask, qij, np.zeros_like(qij))
    dij = np.where(keep_mask, dij, np.zeros_like(dij))

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) / dij,
                          np.zeros_like(dij))  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, np.zeros_like(eij_charge),
                              eij_charge)

    eij_total = (eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask)

    return np.sum(eij_total / 2)
Exemple #17
0
 def new_zeros(x):
     if npyro:
         return jnp.zeros_like(x)
     else:
         return x.new_zeros(x.shape)
Exemple #18
0
def main(argv):
  global CFG
  CFG = FLAGS.config

  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16.
  _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16))

  # Use hardware RNG for bernoulli randoms in dropout mask creation.
  if CFG.hardware_rng:
    models.set_hardware_bernoulli()

  if 'module_import' in CFG and CFG.module_import:
    for module in CFG.module_import:
      importlib.import_module(module)

  if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs:
    t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs)

  num_partitions = CFG.num_partitions
  topology = train_lib.compute_multihost_topology(num_partitions)
  batch_size = CFG.batch_size
  eval_batch_size = CFG.eval_batch_size
  per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets
  if batch_size % topology.num_replicas:
    raise ValueError('Batch size must be divisible by the number of replicas.')

  steps_per_epoch = CFG.steps_per_epoch
  logging.info('steps per epoch: %d', steps_per_epoch)

  broadcast = functools.partial(
      train_lib.broadcast,
      num_replicas=topology.per_replica_set_num_replicas,
      num_partitions=topology.per_host_num_partitions,
      devices=topology.this_host_device_assignment)

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    tf.io.gfile.copy(FLAGS['config'].config_filename,
                     os.path.join(FLAGS.model_dir, 'config.py'),
                     overwrite=True)
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, 'train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, 'eval'))
  else:
    train_summary_writer = None
    eval_summary_writer = None

  # Write summaries in background thread to avoid blocking on device sync
  if CFG.infeed:
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

  (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache(
      CFG, topology.num_replica_sets, topology.replica_set_id,
      topology.per_replica_set_host_id)

  vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name)
  encoder = vocab.tf_tokenizer
  eos_id = vocab.tokenizer.eos_id()

  def decode_tokens(toks,
                    eos_id = eos_id,
                    max_id = 32000):
    """Decode tokens back to unicode."""
    del eos_id
    # TODO(levskaya): T5 doesn't seem to emit EOS tokens?  double check this
    # is the best decoding function or just switch to using tf_decode.
    # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    valid_toks = toks.astype(np.int32)
    valid_toks[valid_toks >= max_id] = 3
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  logging.info('Initializing model, optimizer, and step functions.')

  train_config, eval_config, predict_config = get_configs(CFG)

  rng = random.PRNGKey(CFG.random_seed)
  rng, init_rng = random.split(rng)
  # This is used for infeed conversion from feature dict <--> tuple
  train_keys = [
      'inputs', 'targets', 'inputs_position', 'targets_position',
      'inputs_segmentation', 'targets_segmentation'
  ]
  device_train_input_shape = tuple([
      (batch_size // topology.num_replicas,
       CFG.max_input_length if 'inputs' in k else CFG.max_target_length)
      for k in train_keys
  ])

  learning_rate_fn = train_lib.create_learning_rate_scheduler(
      factors=CFG.schedule,
      base_learning_rate=CFG.learning_rate,
      warmup_steps=CFG.warmup_steps)

  # First, we only abstractly initialize the optimizer and model parameters,
  # since the parameters may not even fit in device memory!
  # TODO(jekbradbury): make optimizer_defs compare by value so it can be created
  # in get_initial_params without causing pytree incompatibility
  optimizer_def = optim.Adafactor(
      CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset)
  initialize_params_fn = functools.partial(
      get_initial_params,
      config=CFG,
      transformer_config=eval_config,
      optimizer_def=optimizer_def)
  optimizer = jax.eval_shape(initialize_params_fn, init_rng)
  # tuple-like pytree leaves for global_arg_shapes
  optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape),
                                  optimizer)

  # Build parameter partition annotations for preserving partitions from train
  # to eval.
  if num_partitions > 1:
    optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(num_partitions, optimizer.state_dict()))
    per_host_optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(topology.per_host_num_partitions,
                                  optimizer.state_dict()))

  # Restore unreplicated optimizer + model state from last checkpoint.
  # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore
  existing_checkpoint_found = False
  if CFG.restore_checkpoints:
    existing_checkpoint_found = train_lib.checkpoint_exists(FLAGS.model_dir)
    optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)

  # Import a pretrained-T5 checkpoint only if we didn't import a local
  # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.)
  # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore
  if CFG.restore_t5_checkpoint and not existing_checkpoint_found:
    optimizer = checkpoint_importer.restore_from_t5_checkpoint(
        optimizer, CFG.restore_t5_checkpoint)

  if CFG.restore_t5_checkpoint or existing_checkpoint_found:
    if num_partitions > 1:
      # Until checkpoint/restore is sharded, the restored checkpoint is global
      # and we need to slice each sharded parameter into the chunk containing
      # only the partitions that are present on this host.
      def per_host_chunk(x, spec):
        if spec is None or spec is x:  # unsharded or not a parameter
          return x
        if spec[0] == 1:
          dim_size = x.shape[1]
        elif spec[1] == 1:
          dim_size = x.shape[0]
        else:
          raise NotImplementedError()
        chunk_size = (
            dim_size * topology.per_host_num_partitions // num_partitions)
        lower = topology.per_replica_set_host_id * chunk_size
        upper = (topology.per_replica_set_host_id + 1) * chunk_size
        if spec[0] == 1:
          return x[:, lower:upper]
        else:
          return x[lower:upper]

      optimizer = jax.tree_multimap(per_host_chunk, optimizer,
                                    optimizer_partitions)
  else:
    # If pretraining and no checkpoint imported, we jit the (sharded-) init
    # function to minimize fragmentation. We use the same pmap(sharded_jit)
    # setup as the training step/loop to initialize everything "in-place" and
    # avoid communication or OOM.
    if num_partitions > 1:
      initialize_params_fn = sharded_jit(
          initialize_params_fn,
          in_parts=None,
          local_in_parts=None,
          out_parts=optimizer_partitions,
          local_out_parts=per_host_optimizer_partitions,
          # devices=one_replica_device_assignment,
      )
      initialize_params_fn = jax.pmap(
          initialize_params_fn,
          'batch',
          in_axes=0,
          axis_size=topology.num_replicas,
          devices=topology.device_assignment)
      init_rng = broadcast(init_rng)
      optimizer = initialize_params_fn(init_rng)
      # We maintain the optimizer in unbroadcasted form (i.e. with no leading
      # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg
      # out_axes=None.
      optimizer = train_lib.unbroadcast(optimizer)
    else:
      optimizer = jax.jit(initialize_params_fn)(init_rng)

  # ---------------------------------------------------------------------------
  # Compile multidevice versions of train/eval/predict step and cache init fn.
  # ---------------------------------------------------------------------------

  # We can use either a single train-step for a host training loop:

  # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs)
  #  --> new_optimizer, metrics, new_dropout_rng
  def p_train_step(optimizer, batch,
                   prev_metrics,
                   dropout_rng):
    return train_lib.train_step(
        optimizer,
        batch,
        prev_metrics,
        dropout_rng,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        num_microbatches=CFG.microbatches,
        label_smoothing=CFG.label_smoothing,
        z_loss=CFG.z_loss,
        use_bfloat16=CFG.use_bfloat16)

  if num_partitions > 1:
    p_train_step = sharded_jit(
        p_train_step,
        in_parts=(optimizer_partitions, None, None, None),
        local_in_parts=(per_host_optimizer_partitions, None, None, None),
        out_parts=(optimizer_partitions, None, None),
        local_out_parts=(per_host_optimizer_partitions, None, None))
  # TODO(levskaya): the in_axes spec below might be wrong, double-check.
  p_train_step = jax.pmap(
      p_train_step,
      axis_name='batch',
      in_axes=(None, 0, 0, 0),
      donate_argnums=(0,),
      global_arg_shapes=(optimizer_shapes, None, None, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # OR, we use an on-device loop that feeds the training step via infeed queue.
  def device_train_loop_cond(
      args
  ):
    """Stopping criterion for on-device loop."""
    _, _, _, _, step, epoch = args
    return step // steps_per_epoch == epoch

  def device_train_loop_body(
      args
  ):
    """On-device loop body."""
    optimizer, dropout_rngs, metrics, token, step, epoch = args
    # Ordering input data from infeed requires threading a symbolic token
    # through the computation.
    input_data, token = lax.infeed(
        token,
        shape=tuple(
            [jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape]))
    # Rebuild input dict from infeed data tuple.
    batch = {k: v for k, v in zip(train_keys, input_data)}
    # Run the train_step function and return the loop state.
    optimizer, metrics, dropout_rngs = train_lib.train_step(
        optimizer,
        batch,
        metrics,
        dropout_rngs,
        train_config,
        learning_rate_fn,
        num_microbatches=CFG.microbatches,
        label_smoothing=CFG.label_smoothing,
        z_loss=CFG.z_loss)
    step += 1
    return optimizer, dropout_rngs, metrics, token, step, epoch

  def device_train_loop(optimizer, dropout_rngs,
                        metrics, step,
                        epoch):
    # Create symbolic token for threading infeed data.
    token = lax.create_token(step)
    # Run on-device loop.
    optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
        device_train_loop_cond, device_train_loop_body,
        (optimizer, dropout_rngs, metrics, token, step, epoch))
    return optimizer, dropout_rngs, metrics, step

  if num_partitions > 1:
    device_train_loop = sharded_jit(
        device_train_loop,
        in_parts=(optimizer_partitions, None, None, None, None),
        local_in_parts=(per_host_optimizer_partitions, None, None, None, None),
        out_parts=(optimizer_partitions, None, None, None),
        local_out_parts=(per_host_optimizer_partitions, None, None, None))
  p_train_epoch = jax.pmap(
      device_train_loop,
      axis_name='batch',
      in_axes=(None, 0, 0, None, None),
      donate_argnums=(0,),
      global_arg_shapes=(optimizer_shapes, None, None, None, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # Reduction psum for metric data.

  def p_allreduce_metrics(x):
    return lax.psum(x, axis_name='batch')

  if num_partitions > 1:
    p_allreduce_metrics = sharded_jit(
        p_allreduce_metrics,
        in_parts=None,
        local_in_parts=None,
        out_parts=None,
        local_out_parts=None,
        num_partitions=num_partitions,
        local_num_partitions=topology.per_host_num_partitions)
  p_allreduce_metrics = jax.pmap(
      p_allreduce_metrics,
      axis_name='batch',
      global_arg_shapes=None,
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)

  # Training evaluation computation.

  # eval_step(params, batch, config, label_smoothing=0.0) --> metrics
  def p_eval_step(params, batch):
    return train_lib.eval_step(
        params, batch, config=eval_config, label_smoothing=CFG.label_smoothing)

  if num_partitions > 1:
    p_eval_step = sharded_jit(
        p_eval_step,
        in_parts=(optimizer_partitions.target, None),
        local_in_parts=(per_host_optimizer_partitions.target, None),
        out_parts=None,
        local_out_parts=None)
  p_eval_step = jax.pmap(
      p_eval_step,
      axis_name='batch',
      in_axes=(None, 0),
      global_arg_shapes=(optimizer_shapes.target, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # Fast autoregressive decoding loop.
  # For inference and model evaluation.

  # predict_step(inputs, params,
  #              eos_id, max_decode_len, config, beam_size=4) --> beam_seqs
  def p_pred_step(inputs, params):
    return train_lib.predict_step(inputs, params, eos_id,
                                  CFG.max_eval_target_length, predict_config,
                                  CFG.beam_size)

  if num_partitions > 1:
    p_pred_step = sharded_jit(
        p_pred_step,
        in_parts=(None, optimizer_partitions.target),
        local_in_parts=(None, per_host_optimizer_partitions.target),
        out_parts=None,
        local_out_parts=None)
  p_pred_step = jax.pmap(
      p_pred_step,
      axis_name='batch',
      in_axes=(0, None),
      global_arg_shapes=(None, optimizer_shapes.target),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # ---------------------------------------------------------------------------
  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  # There should be a unique dropout key for each replica represented on this
  # host, but the key should be the same for the same replica on other hosts.
  # Again, this is what the replica set abstraction is for.
  dropout_rngs = random.split(
      random.fold_in(rng, topology.replica_set_id),
      topology.per_replica_set_num_replicas)
  # restore step from last checkpoint
  host_step = int(optimizer.state.step)
  empty_metrics = broadcast({
      'loss': 0.0,
      'accuracy': 0.0,
      'learning_rate': 0.0,
      'denominator': 0.0
  })
  if CFG.infeed:
    # TODO(jekbradbury): support something like this for the Python-loop case
    logging.info('Precompiling training loop and moving optimizer to device.')
    optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs,
                                             empty_metrics,
                                             jnp.array(0, dtype=jnp.int32), 1)
    optimizer = train_lib.unbroadcast(optimizer)
    metrics['loss'].block_until_ready()

  logging.info('Starting training loop.')

  local_devices = jax.local_devices()
  device_step = broadcast(host_step)
  first_epoch = host_step // steps_per_epoch

  # Main Loop over "epochs".
  train_iter = train_ds.as_numpy_iterator()
  for epoch in range(first_epoch, first_epoch + CFG.num_epochs):
    metrics = empty_metrics

    # NOTE: 'optimizer' is unbroadcast by construction at initialization or
    # when loading a checkpoint. It is maintained in 'unbroadcast' state to
    # enable the XLA cross-replica sharding optimization.  The broadcasting is
    # handled automatically by the pmap'd functions that use it.

    # Gather all task evaluation metrics.
    logging.info('Evaluating tasks.')
    if epoch == first_epoch + 1:
      train_lib.sync_devices()
    for task in eval_cache.tasks:
      logging.info('Evaluating task %s', task.name)
      all_predicted, all_bs = [], []
      for pred_batch in eval_cache.preprocessed_examples[task.name]:
        # Handle final odd-sized batch by padding instead of dropping it.
        input_batch, unpadded_batch_size = train_lib.pad_batch_to_size(
            pred_batch['inputs'], per_replica_set_eval_batch_size)
        all_bs.append(unpadded_batch_size)
        # Split batch dimensions for pmap.
        input_batch = jax.tree_map(
            lambda x: x.reshape(
                (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
            input_batch)
        # Run fast inference on batch.
        all_predicted.append(p_pred_step(input_batch, optimizer.target))

      # Pad out the number of batches so each host has the same number.
      max_host_batch_number = np.max(
          eval_cache.preprocessed_batch_sizes[task.name])
      batch_shortfall = max_host_batch_number - len(all_predicted)
      if batch_shortfall > 0:
        # TODO(levskaya): Fix for case of entirely empty all_predicted.
        # To make sure the cross-host barriers work, we run the program the same
        # number of times on all hosts. The results of this call is ignored, and
        # the predictions are populated with zeros instead.
        p_pred_step(input_batch, optimizer.target)  # Dummy call.
        all_predicted.extend([jnp.zeros_like(all_predicted[0])] *
                             batch_shortfall)
        all_bs.extend([0] * batch_shortfall)
      all_predicted = jnp.concatenate(all_predicted)
      all_bs = jnp.array(all_bs)

      # Collect all batches from across hosts and reverse sharding.
      all_predicted = train_lib.host_allgather(
          all_predicted, topology.num_replica_sets, topology.replica_set_id,
          topology.per_replica_set_host_id == 0)
      seqlength = all_predicted.shape[-1]
      total_examples = np.sum(
          train_lib.host_allgather(all_bs, topology.num_replica_sets,
                                   topology.replica_set_id,
                                   topology.per_replica_set_host_id == 0))
      del all_bs
      assert total_examples == len(eval_cache.examples[task.name]), (
          'Total number of batches incorrect for task %s.' % task.name)
      # De-shard the collected predicted tokens and remove padding.
      all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape(
          -1, seqlength)[:total_examples]

      # We now run the post-processing and metric-fns on a single host.
      if jax.host_id() == 0:
        assert eval_summary_writer
        raw_predictions = []
        for tokens in all_predicted:
          raw_predictions.append(decode_tokens(tokens))

        # post-process predictions for metric fns
        predictions = [
            task.postprocess_fn(p, example=ex)
            for p, ex in zip(raw_predictions, eval_cache.examples[task.name])
        ]

        for metric_fn in task.metric_fns:
          scores = metric_fn(eval_cache.targets[task.name], predictions)
          for metric_name, metric_value in scores.items():
            tag = f'eval/{task.name}/{metric_name}'
            eval_summary_writer.scalar(tag, metric_value, host_step)
            logging.info('EVAL %s at step %d: %.3f', tag, host_step,
                         metric_value)
          eval_summary_writer.flush()

        # Save text samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
          tgt_txt = tf.compat.as_text(
              eval_cache.examples[task.name][n]['targets_plaintext'])
          pred_txt = raw_predictions[n]
          exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n'
                        f'target: {tgt_txt}\n\n'
                        f'prediction: {pred_txt}\n\n')
        eval_summary_writer.text(f'{task.name} samples', exemplars, host_step)
        eval_summary_writer.flush()

    # Take an Xprof trace after the first loop has compiled everything.
    if epoch == first_epoch + 1:
      train_lib.sync_devices()

    # For on-device loop, we launch the computation before feeding data.
    logging.info('BEGIN Train loop.')
    if CFG.infeed:
      optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
          optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step),
          epoch)
      optimizer = train_lib.unbroadcast(optimizer)

    # Epoch loop.
    while int(host_step // steps_per_epoch) == epoch:
      batch = next(train_iter)
      batch = jax.tree_map(
          lambda x: x.reshape(
              (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch)
      # Feed the on-device training loop.
      if CFG.infeed:
        for i, device in enumerate(local_devices):
          # When using infeed to provide data to the computation, we're on our
          # own for feeding the right values to the right devices. Each device
          # should get the minibatch corresponding to its replica, a slice of
          # the larger batch corresponding to the host's replica set.
          if device.platform == 'tpu':
            device_coords = (*device.coords, device.id % 2)
          else:
            device_coords = (device.host_id, i)
          per_replica_set_device_coords = tuple(
              dc % prsm
              for dc, prsm in zip(device_coords, topology.per_replica_set_mesh))
          per_replica_set_replica_coords = tuple(
              prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords,
                                                 topology.per_replica_mesh))
          per_replica_set_replica_id = 0
          for prsm, prm, prsrc in zip(topology.per_replica_set_mesh,
                                      topology.per_replica_mesh,
                                      per_replica_set_replica_coords):
            per_replica_set_replica_id = (
                per_replica_set_replica_id * prsm // prm + prsrc)
          input_tuple = tuple(
              [batch[k][per_replica_set_replica_id] for k in train_keys])
          # Safety check: infeed does not check shape or types but requires
          # them to agree with on-device spec, otherwise the queue and program
          # stalls.
          tuple_shapes = jax.tree_map(jnp.shape, input_tuple)
          tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple)
          assert tuple_shapes == device_train_input_shape, (
              'infeed shape error %s != %s' %
              (tuple_shapes, device_train_input_shape))
          assert tuple(set(tuple_dtypes)) == (jnp.int32,), \
              ('infeed dtype error %s not all of type %s' % (
                  tuple_dtypes, jnp.int32))
          infeed_pool.submit(
              functools.partial(device.transfer_to_infeed, input_tuple))
      # Host training loop.
      else:
        optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch,
                                                        metrics, dropout_rngs)
        optimizer = train_lib.unbroadcast(optimizer)
      host_step += 1
    logging.info('END Train loop.')

    # Maybe save a checkpoint on one host.
    if (CFG.save_checkpoints and
        epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and
        jax.host_id() == 0):
      checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step)

    # Gather training metrics.
    metrics = p_allreduce_metrics(metrics)
    metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics)
    denominator = metrics.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics)  # pylint: disable=cell-var-from-loop
    logging.info('train in step: %s, %s', host_step, summary)
    if jax.host_id() == 0:
      assert train_summary_writer
      for key, val in summary.items():
        train_summary_writer.scalar(key, val, host_step)
      train_summary_writer.flush()

    # Gather training evaluation metrics.
    logging.info('Gathering training evaluation metrics.')
    eval_metrics = []
    eval_iter = eval_ds.as_numpy_iterator()
    for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter):
      eval_batch = jax.tree_map(
          lambda x: x.reshape(
              (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
          eval_batch)
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)
    # average metrics across devices
    eval_metrics = p_allreduce_metrics(eval_metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    # average metrics across steps
    eval_metrics = jax.tree_map(np.sum, eval_metrics)
    eval_denominator = eval_metrics.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics)
    logging.info('eval in step: %s, %s', host_step, eval_summary)
    if jax.host_id() == 0:
      assert eval_summary_writer
      for key, val in eval_summary.items():
        eval_summary_writer.scalar(key, val, host_step)
      eval_summary_writer.flush()

  # Wait until computations are done before exiting
  logging.info('Finished.')
  train_lib.sync_devices()
  # Shut down the infeed threadpool.
  if CFG.infeed:
    infeed_pool.shutdown()
Exemple #19
0
 def dist_sq(R):
   dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
   zero = jnp.zeros_like(dR)
   dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
   return jnp.sum(dR ** 2, axis=2)
Exemple #20
0
 def dist_sq(R):
   dR = R[:, np.newaxis, :] - R[np.newaxis, :, :]
   zero = np.zeros_like(dR)
   dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR))
   return np.sum(dR ** 2, axis=2)
Exemple #21
0
        if not no_nans:
            with open(os.path.join(str(output_dir), f"{epoch}_nan.pkl"),
                      'wb') as f:
                pickle.dump(get_params(opt_state), f)
                exit(
                    f"nan encountered in epoch_info and params pickled: {epoch_info}"
                )

        params = get_params(opt_state)
        train_acc, train_logits, train_labels, train_nll, train_kl, _ = evaluate(
            params, train_eval_loader, input_size, args.nsamples,
            rng_generator, args.kl_coef / train_size)
        train_loss = train_nll + args.kl_coef * train_kl
        if args.disable_test:
            val_acc, val_logits, val_labels, val_nll, val_kl = jnp.zeros(
                1), jnp.zeros_like(train_logits), jnp.zeros(1), jnp.zeros(1)
            test_acc, test_logits, test_labels, test_nll, test_kl = jnp.zeros(
                1), jnp.zeros_like(train_logits), jnp.zeros(1), jnp.zeros(1)
        else:
            val_acc, val_logits, val_labels, val_nll, val_kl, _ = evaluate(
                ema_params, val_loader, input_size, args.nsamples,
                rng_generator, args.kl_coef / train_size)
            test_acc, test_logits, test_labels, test_nll, test_kl, test_ws = evaluate(
                ema_params, test_loader, input_size, args.nsamples,
                rng_generator, args.kl_coef / train_size)
        val_loss, test_loss = val_nll + args.kl_coef * val_kl, test_nll + args.kl_coef * test_kl

        cal_train = utils.get_calibration(
            train_labels, jax.device_get(jnp.exp(train_logits)))
        cal_val = utils.get_calibration(val_labels,
                                        jax.device_get(jnp.exp(val_logits)))
Exemple #22
0
 def _test_transformation(self, func, param, msg=None):
   primal, tangent = jax.jvp(func, (param,), (np.ones_like(param),))
   self.assertEqual(primal.shape, tangent.shape)
   if not FLAGS.execute_only:
     self.assertNotAllEqual(tangent, np.zeros_like(tangent), msg=msg)
Exemple #23
0
 def init_fn(fx_train=0.):
     fx_train = fl(fx_train)
     qx_train = np.zeros_like(fx_train)
     return np.concatenate((fx_train, qx_train), axis=0)
Exemple #24
0
 def _test_transformation(self, func, param, msg=None):
   out, f_vjp = jax.vjp(func, param)
   cotangent, = f_vjp(np.ones_like(out).astype(out.dtype))
   self.assertEqual(param.shape, cotangent.shape)
   if not FLAGS.execute_only:
     self.assertNotAllEqual(cotangent, np.zeros_like(cotangent), msg=msg)
Exemple #25
0
def convergence(args):
    epoch, lamb, lamb_idx = args

    suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)

    ligands = []
    for mol in suppl:
        ligands.append(mol)

    ligand_a = ligands[0]
    ligand_b = ligands[1]

    # print(ligand_a.GetNumAtoms())
    # print(ligand_b.GetNumAtoms())

    # ligand_a = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"))
    # ligand_b = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C"))
    # ligand_a = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC"))
    # ligand_b = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC"))
    # AllChem.EmbedMolecule(ligand_a, randomSeed=2020)
    # AllChem.EmbedMolecule(ligand_b, randomSeed=2020)

    coords_a = get_conf(ligand_a, idx=0)
    coords_b = get_conf(ligand_b, idx=0)
    # coords_b = np.matmul(coords_b, special_ortho_group.rvs(3))

    coords_a = recenter(coords_a)
    coords_b = recenter(coords_b)

    coords = np.concatenate([coords_a, coords_b])

    a_idxs = get_heavy_atom_idxs(ligand_a)
    b_idxs = get_heavy_atom_idxs(ligand_b)

    a_full_idxs = np.arange(0, ligand_a.GetNumAtoms())
    b_full_idxs = np.arange(0, ligand_b.GetNumAtoms())

    b_idxs += ligand_a.GetNumAtoms()
    b_full_idxs += ligand_a.GetNumAtoms()

    nrg_fns = []

    forcefield = 'ff/params/smirnoff_1_1_0_ccc.py'
    ff_raw = open(forcefield, "r").read()
    ff_handlers = deserialize_handlers(ff_raw)

    combined_mol = Chem.CombineMols(ligand_a, ligand_b)

    for handler in ff_handlers:
        if isinstance(handler, handlers.HarmonicBondHandler):
            bond_idxs, (bond_params, _) = handler.parameterize(combined_mol)
            nrg_fns.append(
                functools.partial(bonded.harmonic_bond,
                    params=bond_params,
                    box=None,
                    bond_idxs=bond_idxs
                )
            )
        elif isinstance(handler, handlers.HarmonicAngleHandler):
            angle_idxs, (angle_params, _) = handler.parameterize(combined_mol)
            nrg_fns.append(
                functools.partial(bonded.harmonic_angle,
                    params=angle_params,
                    box=None,
                    angle_idxs=angle_idxs
                )
            )
        # elif isinstance(handler, handlers.ImproperTorsionHandler):
        #     torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol)
        #     print(torsion_idxs)
        #     assert 0
        #     nrg_fns.append(
        #         functools.partial(bonded.periodic_torsion,
        #             params=torsion_params,
        #             box=None,
        #             lamb=None,
        #             torsion_idxs=torsion_idxs
        #         )
        #     )
        # elif isinstance(handler, handlers.ProperTorsionHandler):
        #     torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol)
        #     # print(torsion_idxs)
        #     nrg_fns.append(
        #         functools.partial(bonded.periodic_torsion,
        #             params=torsion_params,
        #             box=None,
        #             lamb=None,
        #             torsion_idxs=torsion_idxs
        #         )
        #     )

    masses_a = onp.array([a.GetMass() for a in ligand_a.GetAtoms()]) * 10000
    masses_b = onp.array([a.GetMass() for a in ligand_b.GetAtoms()])

    combined_masses = np.concatenate([masses_a, masses_b])

    # com_restraint_fn = functools.partial(bonded.centroid_restraint,
    #     params=None,
    #     box=None,
    #     lamb=None,
    #     # masses=combined_masses, # try making this ones-like
    #     masses=np.ones_like(combined_masses),
    #     group_a_idxs=a_idxs,
    #     group_b_idxs=b_idxs,
    #     kb=50.0,
    #     b0=0.0)

    pmi_restraint_fn = functools.partial(pmi_restraints_new,
        params=None,
        box=None,
        lamb=None,
        # masses=np.ones_like(combined_masses),
        masses=combined_masses,
        # a_idxs=a_full_idxs,
        # b_idxs=b_full_idxs,
        a_idxs=a_idxs,
        b_idxs=b_idxs,
        angle_force=100.0,
        com_force=100.0
    )

    prefactor = 2.7 # unitless
    shape_lamb = (4*np.pi)/(3*prefactor) # unitless
    kappa = np.pi/(np.power(shape_lamb, 2/3)) # unitless
    sigma = 0.15 # 1 angstrom std, 95% coverage by 2 angstroms
    alpha = kappa/(sigma*sigma)

    alphas = np.zeros(combined_mol.GetNumAtoms())+alpha
    weights = np.zeros(combined_mol.GetNumAtoms())+prefactor

    shape_restraint_fn = functools.partial(
        shape.harmonic_overlap,
        box=None,
        lamb=None,
        params=None,
        a_idxs=a_idxs,
        b_idxs=b_idxs,
        alphas=alphas,
        weights=weights,
        k=150.0
    )

    # shape_restraint_4d_fn = functools.partial(
    #     shape.harmonic_4d_overlap,
    #     box=None,
    #     params=None,
    #     a_idxs=a_idxs,
    #     b_idxs=b_idxs,
    #     alphas=alphas,
    #     weights=weights,
    #     k=200.0
    # )

    def restraint_fn(conf, lamb):

        return pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf)
        # return (1-lamb)*pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf)


    nrg_fns.append(restraint_fn)

    def nrg_fn(conf, lamb):
        s = []
        for u in nrg_fns:
            s.append(u(conf, lamb=lamb))
        return np.sum(s)
 
    grad_fn = jax.grad(nrg_fn, argnums=(0,1))
    grad_fn = jax.jit(grad_fn)

    du_dx_fn = jax.grad(nrg_fn, argnums=(0))
    du_dx_fn = jax.jit(du_dx_fn)

    x_t = coords
    v_t = np.zeros_like(x_t)

    w = Chem.SDWriter('frames_heavy_'+str(epoch)+'_'+str(lamb_idx)+'.sdf')

    dt = 1.5e-3
    ca, cb, cc = langevin_coefficients(300.0, dt, 1.0, combined_masses)
    cb = -1*onp.expand_dims(cb, axis=-1)
    cc = onp.expand_dims(cc, axis=-1)

    du_dls = []

    # re-seed since forking 
    onp.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))


    # for step in range(100000):
    for step in range(100000):

        # if step % 1000 == 0:
        #     u = nrg_fn(x_t, lamb)
        #     print("step", step, "nrg", onp.asarray(u), "avg_du_dl",  onp.mean(du_dls))
        #     mol = make_conformer(combined_mol, x_t[:ligand_a.GetNumAtoms()], x_t[ligand_a.GetNumAtoms():])
        #     w.write(mol)
        #     w.flush()

        if step % 5 == 0 and step > 10000:
            du_dx, du_dl = grad_fn(x_t, lamb)
            du_dls.append(du_dl)
        else:
            du_dx = du_dx_fn(x_t, lamb)

        v_t = ca*v_t + cb*du_dx + cc*onp.random.normal(size=x_t.shape)
        x_t = x_t + v_t*dt

    return np.mean(onp.mean(du_dls))
Exemple #26
0
 def init_fn(params):
     mu = jax.tree_map(  # First moment
         lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
     nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
     return ScaleByAMSGradState(mu=mu, nu=nu)
Exemple #27
0
def fed_opt(
    client_objectives: List[StochasticObjective],
    client_update_fn: ClientUpdateFn,
    server_update_fn: ServerUpdateFn,
    sample_clients_fn: SampleClientsFn,
    prng_key: jnp.ndarray,
    init_state: jnp.ndarray,
    num_rounds: int,
    num_clients_per_round: int,
) -> Tuple[List[ServerState], List[RoundInfo]]:
    """Runs generalized federated averaging for the specified number of rounds.

    At each round, the algorithm does the following:
        1.  Samples a batch of clients using `sample_clients_fn`.
        2.  Runs `client_update_fn` on each sampled client objective that
            returns a `client_delta`.
        3.  Aggregates `client_deltas` using `server_update_fn`.

    Args:
        client_update_fn: A function for computing local client updates.
        server_update_fn: A function for computing server updates.
        sample_clients_fn: A function for sampling indices of the clients.
        client_objectives: A list of client objective functions.
        prng_key: A key for random number generation.
        init_state: The initial server state.
        num_rounds: The number of training rounds to run.
        num_clients_per_round: The number of clients used at each round.

    Returns:
        A list of tuples `(round: int, state: ServerState)` that represents the
        trajectory of the server state over the course of training.
    """
    num_clients = len(client_objectives)
    server_state = ServerState(r=0, x=init_state, v=jnp.zeros_like(init_state))

    trajectory = [server_state]
    info = [None]
    for _ in range(num_rounds):
        round_info = {}

        # Select clients.
        prng_key, subkey = random.split(prng_key)
        with Timer("select_clients_time") as t:
            client_ids = sample_clients_fn(subkey, num_clients,
                                           num_clients_per_round)
            client_objectives_round = [
                client_objectives[i] for i in client_ids
            ]
            client_weights_round = jnp.asarray(
                [float(o.num_points) for o in client_objectives_round])
        round_info[t.description] = t.elapsed

        # Compute client updates.
        client_deltas_round = []
        # TODO: parallelize this loop.
        with Timer("client_updates_time") as t:
            for client_objective in client_objectives_round:
                prng_key, subkey = random.split(prng_key)
                client_delta = client_update_fn(client_objective,
                                                server_state.x, subkey)
                client_deltas_round.append(client_delta)
        round_info[t.description] = t.elapsed

        # Update server state.
        with Timer("server_update_time") as t:
            server_state = server_update_fn(client_deltas_round,
                                            client_weights_round, server_state)
        round_info[t.description] = t.elapsed
        trajectory.append(server_state)
        info.append(round_info)

    return trajectory, info
Exemple #28
0
        def inner_apply(edge_embeddings, node_embeddings):
            first_layer_dim = mlp_vtoe_dims[0]
            additional_layer_dims = mlp_vtoe_dims[1:]

            if allow_non_adjacent and edge_embeddings is not None:
                num_separate_mlps = 1 + edge_embeddings.shape[-1]
            elif allow_non_adjacent:
                num_separate_mlps = 1
            elif edge_embeddings is not None:
                num_separate_mlps = edge_embeddings.shape[-1]
            else:
                raise ValueError(
                    "Either allow_non_adjacent should be True, or "
                    "edge_embeddings should be provided")

            node_embedding_dim = node_embeddings.shape[-1]

            # First layer: process each node embedding.
            weight_from_source = self.param(
                "l0_weight_from_source",
                shape=(num_separate_mlps, node_embedding_dim, first_layer_dim),
                initializer=initializers.xavier_normal())
            weight_from_dest = self.param(
                "l0_weight_from_dest",
                shape=(num_separate_mlps, node_embedding_dim, first_layer_dim),
                initializer=initializers.xavier_normal())
            bias = self.param("l0_bias",
                              shape=(num_separate_mlps, first_layer_dim),
                              initializer=initializers.zeros)
            from_source = jnp.einsum("sx,kxy->sky", node_embeddings,
                                     weight_from_source)
            from_dest = jnp.einsum("dx,kxy->dky", node_embeddings,
                                   weight_from_dest)
            activations = jax.nn.relu(from_source[:, None, :, :] +
                                      from_dest[None, :, :, :] +
                                      bias[None, None, :, :])

            # Additional layers: MLP for each edge type.
            for i, layer_dim in enumerate(additional_layer_dims):
                weight = self.param(f"l{i+1}_weight",
                                    shape=(num_separate_mlps,
                                           activations.shape[-1], layer_dim),
                                    initializer=initializers.xavier_normal())
                bias = self.param(f"l{i+1}_bias",
                                  shape=(num_separate_mlps, layer_dim),
                                  initializer=initializers.zeros)
                activations = jax.nn.relu(
                    jnp.einsum("sdkx,kxy->sdky", activations, weight) +
                    bias[None, None, :, :])

            # Sum over edge types and possibly over source nodes.
            if edge_embeddings is None:
                result = activations.squeeze(axis=2)
                if mask is not None:
                    result = jnp.where(mask[:, :, None], result,
                                       jnp.zeros_like(result))
                if message_passing:
                    result = jnp.sum(result, axis=0)
            else:
                if allow_non_adjacent:
                    if mask is None:
                        pairwise = jnp.ones(edge_embeddings.shape[:2] + (1, ))
                    else:
                        pairwise = mask
                    mlp_weights = jnp.concatenate([
                        edge_embeddings,
                        pairwise.astype("float")[:, :, None]
                    ], -1)
                else:
                    mlp_weights = edge_embeddings

                if message_passing:
                    result = jnp.einsum("sdky,sdk->dy", activations,
                                        mlp_weights)
                else:
                    result = jnp.einsum("sdky,sdk->sdy", activations,
                                        mlp_weights)

            return result
Exemple #29
0
 def init_param_state(self, param):
     return _MomentumParamState(jnp.zeros_like(param))
Exemple #30
0
def zeros_like(x, dtype_str=None, dev=None):
    if dtype_str:
        dtype = _jnp.__dict__[dtype_str]
    else:
        dtype = x.dtype
    return _to_dev(_jnp.zeros_like(x, dtype=dtype), dev)