def _sample_n_and_log_prob(self, key: PRNGKey, n: int) -> Tuple[Array, Array]: """See `Distribution._sample_n_and_log_prob`.""" samples = self._sample_n(key, n) log_prob = jnp.zeros_like(samples) return samples, log_prob
def get_grads(bijector, x, *cond): (out_y, out_ldj), vjp_fun = jax.vjp(bijector, x, *cond) vjp_fun = jax.jit(vjp_fun) return (vjp_fun((jnp.ones_like(out_y), jnp.zeros_like(out_ldj))), vjp_fun((jnp.zeros_like(out_y), jnp.ones_like(out_ldj))))
def gmres(A, b, x0=None, n=5, M=identity, record=False): if x0 is None: x0 = np.zeros_like(b) return _gmres(A, b, x0, n, M, record)
def init_param_state(self, param): return _AdamParamState(jnp.zeros_like(param), jnp.zeros_like(param))
def init(x): s = jnp.zeros_like(x) nu = jnp.zeros_like(x) x0 = x return x, s, nu, x0
def _nan_to_inf(x): return jnp.where(jnp.isnan(x), jnp.inf + jnp.zeros_like(x), x)
def predict_fn( get: Get = None, k_test_train=None, nngp_test_test: np.ndarray = None ) -> Dict[str, Union[np.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. nngp_test_test: A test-test NNGP array. Provide if you want to compute test-test posterior covariance. `nngp_test_tes=None`, means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `nngp_test_test=nngp_train_train`. Returns: Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k_dd = _get_attr(k_train_train, g) k_td = None if k_test_train is None else _get_attr(k_test_train, g) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = np.tensordot(k_td, k_inv_y(g), (odd, first)) y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if nngp_test_test is not None: if k_td is None: out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `nngp_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices ' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'), even) if g == 'nngp': cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) cov = nngp_test_test - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = np.tensordot(term_1, cov, (first, first)) term_2 = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) term_2 += np.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + nngp_test_test out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out
def init_param_state(self, param: jnp.ndarray) -> _RMSPropParamState: """Initializes parameter state. See base class.""" return _RMSPropParamState(jnp.ones_like(param), jnp.zeros_like(param))
def lennard_jones(conf, lj_params, box, cutoff): """ Implements a non-periodic LJ612 potential using the Lorentz−Berthelot combining rules, where sig_ij = (sig_i + sig_j)/2 and eps_ij = sqrt(eps_i * eps_j). Parameters ---------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params,] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None param_idxs: shape [num_atoms, 2] np.array each tuple (sig, eps) is used as part of the combining rules scale_matrix: shape [num_atoms, num_atoms] np.array scale mask denoting how we should scale interaction e[i,j]. The elements should be between [0, 1]. If e[i,j] is 1 then the interaction is fully included, 0 implies it is discarded. cutoff: float Whether or not we apply cutoffs to the system. Any interactions greater than cutoff is fully discarded. """ # box = None # assert box is None sig = lj_params[:, 0] eps = lj_params[:, 1] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = (sig_i + sig_j) / 2 sig_ij_raw = sig_ij eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = np.sqrt(eps_i * eps_j) ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) # gi = np.expand_dims(groups, axis=0) # gj = np.expand_dims(groups, axis=1) # gij = np.bitwise_and(gi, gj) > 0 # print(gij) # print("BOX", box) dij = distance(ri, rj, box) # print("DIJ", dij) N = conf.shape[0] keep_mask = np.ones((N, N)) - np.eye(N) keep_mask = np.where(eps_ij != 0, keep_mask, 0) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij)) eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij / dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 eij = 4 * eps_ij * (sig6 - 1.0) * sig6 # if cutoff is not None: # sw = switch_fn(dij, cutoff) # eij = eij*sw eij = np.where(keep_mask, eij, np.zeros_like(eij)) # print("eps_ij", eps_ij) # print("sig_ij", sig_ij) return np.sum(eij / 2)
def init(x0): m0 = np.zeros_like(x0) v0 = np.zeros_like(x0) return x0, m0, v0
def init(x0): vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape] return x0, np.zeros_like(x0), vs
def init(x0): v0 = np.zeros_like(x0) return x0, v0
def isoneutral_diffusion_pre(maskT, maskU, maskV, maskW, dxt, dxu, dyt, dyu, dzt, dzw, cost, cosu, salt, temp, zt, K_iso, K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by): """ Isopycnal diffusion for tracer following functional formulation by Griffies et al Code adopted from MOM2.1 """ epsln = 1e-20 iso_slopec = 1e-3 iso_dslope = 1e-3 K_iso_steep = 50. tau = 0 dTdx = np.zeros_like(K_11) dSdx = np.zeros_like(K_11) dTdy = np.zeros_like(K_11) dSdy = np.zeros_like(K_11) dTdz = np.zeros_like(K_11) dSdz = np.zeros_like(K_11) """ drho_dt and drho_ds at centers of T cells """ drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau], np.abs(zt)) drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau], np.abs(zt)) """ gradients at top face of T cells """ dTdz = jax.ops.index_update( dTdz, jax.ops.index[:, :, :-1], maskW[:, :, :-1] * \ (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) / \ dzw[np.newaxis, np.newaxis, :-1] ) dSdz = jax.ops.index_update( dSdz, jax.ops.index[:, :, :-1], maskW[:, :, :-1] * \ (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) / \ dzw[np.newaxis, np.newaxis, :-1] ) """ gradients at eastern face of T cells """ dTdx = jax.ops.index_update( dTdx, jax.ops.index[:-1, :, :], maskU[:-1, :, :] * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) \ / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]) ) dSdx = jax.ops.index_update( dSdx, jax.ops.index[:-1, :, :], maskU[:-1, :, :] * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis])) """ gradients at northern face of T cells """ dTdy = jax.ops.index_update( dTdy, jax.ops.index[:, :-1, :], maskV[:, :-1, :] * \ (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) \ / dyu[np.newaxis, :-1, np.newaxis] ) dSdy = jax.ops.index_update(dSdy, jax.ops.index[:, :-1, :], maskV[:, :-1, :] * \ (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) \ / dyu[np.newaxis, :-1, np.newaxis] ) def dm_taper(sx): """ tapering function for isopycnal slopes """ return 0.5 * (1. + np.tanh((-np.abs(sx) + iso_slopec) / iso_dslope)) """ Compute Ai_ez and K11 on center of east face of T cell. """ diffloc = np.zeros_like(K_11) diffloc = jax.ops.index_update( diffloc, jax.ops.index[1:-2, 2:-2, 1:], 0.25 * (K_iso[1:-2, 2:-2, 1:] + K_iso[1:-2, 2:-2, :-1] + K_iso[2:-1, 2:-2, 1:] + K_iso[2:-1, 2:-2, :-1])) diffloc = jax.ops.index_update( diffloc, jax.ops.index[1:-2, 2:-2, 0], 0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0])) sumz = np.zeros_like(K_11)[1:-2, 2:-2] for kr in range(2): ki = 0 if kr == 1 else 1 for ip in range(2): drodxe = drdT[1 + ip:-2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] \ + drdS[1 + ip:-2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:] drodze = drdT[1 + ip:-2 + ip, 2:-2, ki:] * dTdz[1 + ip:-2 + ip, 2:-2, :-1 + kr or None] \ + drdS[1 + ip:-2 + ip, 2:-2, ki:] * \ dSdz[1 + ip:-2 + ip, 2:-2, :-1 + kr or None] sxe = -drodxe / (np.minimum(0., drodze) - epsln) taper = dm_taper(sxe) sumz = jax.ops.index_update( sumz, jax.ops.index[:, :, ki:], sumz[..., ki:] + dzw[np.newaxis, np.newaxis, :-1 + kr or None] * maskU[1:-2, 2:-2, ki:] * np.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper)) Ai_ez = jax.ops.index_update( Ai_ez, jax.ops.index[1:-2, 2:-2, ki:, ip, kr], taper * sxe * maskU[1:-2, 2:-2, ki:]) K_11 = jax.ops.index_update(K_11, jax.ops.index[1:-2, 2:-2, :], sumz / (4. * dzt[np.newaxis, np.newaxis, :])) """ Compute Ai_nz and K_22 on center of north face of T cell. """ diffloc = jax.ops.index_update(diffloc, jax.ops.index[...], 0) diffloc = jax.ops.index_update( diffloc, jax.ops.index[2:-2, 1:-2, 1:], 0.25 * (K_iso[2:-2, 1:-2, 1:] + K_iso[2:-2, 1:-2, :-1] + K_iso[2:-2, 2:-1, 1:] + K_iso[2:-2, 2:-1, :-1])) diffloc = jax.ops.index_update( diffloc, jax.ops.index[2:-2, 1:-2, 0], 0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0])) sumz = np.zeros_like(K_11)[2:-2, 1:-2] for kr in range(2): ki = 0 if kr == 1 else 1 for jp in range(2): drodyn = drdT[2:-2, 1 + jp:-2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] + \ drdS[2:-2, 1 + jp:-2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:] drodzn = drdT[2:-2, 1 + jp:-2 + jp, ki:] * dTdz[2:-2, 1 + jp:-2 + jp, :-1 + kr or None] \ + drdS[2:-2, 1 + jp:-2 + jp, ki:] * \ dSdz[2:-2, 1 + jp:-2 + jp, :-1 + kr or None] syn = -drodyn / (np.minimum(0., drodzn) - epsln) taper = dm_taper(syn) sumz = jax.ops.index_update( sumz, jax.ops.index[:, :, ki:], sumz[..., ki:] + dzw[np.newaxis, np.newaxis, :-1 + kr or None] * maskV[2:-2, 1:-2, ki:] * np.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper)) Ai_nz = jax.ops.index_update( Ai_nz, jax.ops.index[2:-2, 1:-2, ki:, jp, kr], taper * syn * maskV[2:-2, 1:-2, ki:]) K_22 = jax.ops.index_update(K_22, jax.ops.index[2:-2, 1:-2, :], sumz / (4. * dzt[np.newaxis, np.newaxis, :])) """ compute Ai_bx, Ai_by and K33 on top face of T cell. """ sumx = np.zeros_like(K_11)[2:-2, 2:-2, :-1] sumy = np.zeros_like(K_11)[2:-2, 2:-2, :-1] for kr in range(2): drodzb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdz[2:-2, 2:-2, :-1] \ + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdz[2:-2, 2:-2, :-1] # eastward slopes at the top of T cells for ip in range(2): drodxb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdx[1 + ip:-3 + ip, 2:-2, kr:-1 + kr or None] \ + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdx[1 + ip:-3 + ip, 2:-2, kr:-1 + kr or None] sxb = -drodxb / (np.minimum(0., drodzb) - epsln) taper = dm_taper(sxb) sumx += dxu[1 + ip:-3 + ip, np.newaxis, np.newaxis] * \ K_iso[2:-2, 2:-2, :-1] * taper * \ sxb**2 * maskW[2:-2, 2:-2, :-1] Ai_bx = jax.ops.index_update( Ai_bx, jax.ops.index[2:-2, 2:-2, :-1, ip, kr], taper * sxb * maskW[2:-2, 2:-2, :-1]) # northward slopes at the top of T cells for jp in range(2): facty = cosu[1 + jp:-3 + jp] * dyu[1 + jp:-3 + jp] drodyb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdy[2:-2, 1 + jp:-3 + jp, kr:-1 + kr or None] \ + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdy[2:-2, 1 + jp:-3 + jp, kr:-1 + kr or None] syb = -drodyb / (np.minimum(0., drodzb) - epsln) taper = dm_taper(syb) sumy += facty[np.newaxis, :, np.newaxis] * K_iso[2:-2, 2:-2, :-1] \ * taper * syb**2 * maskW[2:-2, 2:-2, :-1] Ai_by = jax.ops.index_update( Ai_by, jax.ops.index[2:-2, 2:-2, :-1, jp, kr], taper * syb * maskW[2:-2, 2:-2, :-1]) K_33 = jax.ops.index_update( K_33, jax.ops.index[2:-2, 2:-2, :-1], sumx / (4 * dxt[2:-2, np.newaxis, np.newaxis]) + \ sumy / (4 * dyt[np.newaxis, 2:-2, np.newaxis] * cost[np.newaxis, 2:-2, np.newaxis]) ) K_33 = jax.ops.index_update(K_33, jax.ops.index[2:-2, 2:-2, -1], 0.) return K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by
#%% Data for BNN idx = 0 X_selected = jnp.array(X_train_idx[idx * n_data:idx * n_data + n_data, :2]) y_selected = jnp.array(y_delta_train_idx[idx * n_data:idx * n_data + n_data, :]).squeeze() X = jnp.array(X_train[:, :2]) y = jnp.array(y_delta_train) X_test = jnp.array(X_test) y_test = jnp.array(y_delta_test) p_selected = X_selected[:, 0] t_selected = X_selected[:, 1] # / params['t_max'] F = jnp.zeros_like(y_selected) X_train = jnp.array(X_trainc[:, :2]) p_selected = X_train[:, 0] t_selected = X_train[:, 1] # / params['t_max'] y_selected = jnp.array(y_delta_trainc.squeeze()) mask = jnp.array(mask_data, dtype=bool) F = jnp.zeros_like(y_selected) #%% Data for BNN # X_train = torch.tensor(X_trainc[:,:2], dtype=torch.float32) # y_train = torch.tensor(y_delta_trainc, dtype=torch.float32) # X_test = torch.tensor(X_test, dtype=torch.float32) # y_test = torch.tensor(y_delta_test, dtype=torch.float32) # mask = torch.tensor(mask_data.reshape(-1,1), dtype=torch.bool)
def loss_fn( output_logits, targets, valid_mask, num_nodes, captured, negative_example_weight = 1, focal_loss_gamma = 0.0, ): """Compute loss and single-batch metrics for some outputs. Args: output_logits: Binary logits produced by the model. targets: Model targets. valid_mask: Mask determining which outputs are valid. num_nodes: How many nodes there are in each example. captured: Ignored negative_example_weight: Weight to assign to a negative example when computing the loss. Positive examples always get weight 1. focal_loss_gamma: Focusing parameter for the focal loss, as described in Lin et al. (2018). If zero, uses standard cross-entropy loss. Returns: Tuple (loss, metrics_dict). """ del captured num_targets = jnp.count_nonzero(targets) # Compute cross entropy. unmasked_nll = model_util.binary_logit_cross_entropy(output_logits, targets) if focal_loss_gamma: # (1-p_correct)**gamma = (-(p-1))**gamma = (-expm1(log(p)))**gamma focus_term = jnp.power(-jnp.expm1(-unmasked_nll), focal_loss_gamma) unmasked_nll = unmasked_nll * focus_term # Mask the results so that they only count nodes that exist. masked_nll = unmasked_nll * valid_mask # Primary loss: Sum of nll over all nodes. We use sum because most of the # edges are easy negatives. positive_nll = jnp.sum( jnp.where(targets, masked_nll, jnp.zeros_like(masked_nll))) negative_nll = jnp.sum( jnp.where(targets, jnp.zeros_like(masked_nll), masked_nll)) reweighted_nll = positive_nll + negative_example_weight * negative_nll binary_nll = jnp.sum(reweighted_nll) # Compute additional metrics to track learning progress. # Average NLL of target edges: avg_nll_per_target = positive_nll / num_targets # Average NLL of non-target edges: num_non_targets = num_nodes**2 - num_targets avg_nll_per_non_target = negative_nll / num_non_targets # Max error for any edge prediction: worst_nll = jnp.max(masked_nll) loss = binary_nll # Ratio of positive to negative targets. If this is equal to # negative_example_weight, the positive and negative examples will have the # same total weight. positive_per_negative = num_targets / num_non_targets # Precision and recall at 0.1 threshold thresholded_preds = output_logits > jax.scipy.special.logit(0.1) count_target_pred = jnp.count_nonzero(thresholded_preds & targets) count_pred = jnp.count_nonzero(thresholded_preds & valid_mask.astype(bool)) precision = count_target_pred / count_pred recall = count_target_pred / num_targets return loss, { "avg_per_target": avg_nll_per_target, "avg_per_non_target": avg_nll_per_non_target, "worst": worst_nll, "positive_per_negative": positive_per_negative, "effective_p_model_given_target": jnp.exp(-avg_nll_per_target), "effective_p_model_given_nontarget": 1 - jnp.exp(-avg_nll_per_non_target), "batch_clf_thresh_at_0.1/precision": precision, "batch_clf_thresh_at_0.1/recall": recall, "batch_clf_thresh_at_0.1/f1": 2 * (precision * recall) / (precision + recall), }
def nonbonded_v3(conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask, scales, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs): N = conf.shape[0] conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) # make 4th dimension of box large enough so its roughly aperiodic if box is not None: box_4d = np.eye(4) * 1000 box_4d = index_update(box_4d, index[:3, :3], box) else: box_4d = None box = box_4d charges = params[:, 0] sig = params[:, 1] eps = params[:, 2] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = sig_i + sig_j sig_ij_raw = sig_ij eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = eps_i * eps_j ri = np.expand_dims(conf, 0) rj = np.expand_dims(conf, 1) dij = distance(ri, rj, box) N = conf.shape[0] keep_mask = np.ones((N, N)) - np.eye(N) keep_mask = np.where(eps_ij != 0, keep_mask, 0) if cutoff is not None: eps_ij = np.where(dij < cutoff, eps_ij, np.zeros_like(eps_ij)) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, np.zeros_like(sig_ij)) eps_ij = np.where(keep_mask, eps_ij, np.zeros_like(eps_ij)) sig2 = sig_ij / dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6 eij_lj = np.where(keep_mask, eij_lj, np.zeros_like(eij_lj)) qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(conf.shape[0]) qij = np.where(keep_mask, qij, np.zeros_like(qij)) dij = np.where(keep_mask, dij, np.zeros_like(dij)) # funny enough lim_{x->0} erfc(x)/x = 0 eij_charge = np.where(keep_mask, qij * erfc(beta * dij) / dij, np.zeros_like(dij)) # zero out diagonals if cutoff is not None: eij_charge = np.where(dij > cutoff, np.zeros_like(eij_charge), eij_charge) eij_total = (eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask) return np.sum(eij_total / 2)
def new_zeros(x): if npyro: return jnp.zeros_like(x) else: return x.new_zeros(x.shape)
def main(argv): global CFG CFG = FLAGS.config if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16. _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16)) # Use hardware RNG for bernoulli randoms in dropout mask creation. if CFG.hardware_rng: models.set_hardware_bernoulli() if 'module_import' in CFG and CFG.module_import: for module in CFG.module_import: importlib.import_module(module) if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs: t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs) num_partitions = CFG.num_partitions topology = train_lib.compute_multihost_topology(num_partitions) batch_size = CFG.batch_size eval_batch_size = CFG.eval_batch_size per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets if batch_size % topology.num_replicas: raise ValueError('Batch size must be divisible by the number of replicas.') steps_per_epoch = CFG.steps_per_epoch logging.info('steps per epoch: %d', steps_per_epoch) broadcast = functools.partial( train_lib.broadcast, num_replicas=topology.per_replica_set_num_replicas, num_partitions=topology.per_host_num_partitions, devices=topology.this_host_device_assignment) if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) tf.io.gfile.copy(FLAGS['config'].config_filename, os.path.join(FLAGS.model_dir, 'config.py'), overwrite=True) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync if CFG.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache( CFG, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id) vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name) encoder = vocab.tf_tokenizer eos_id = vocab.tokenizer.eos_id() def decode_tokens(toks, eos_id = eos_id, max_id = 32000): """Decode tokens back to unicode.""" del eos_id # TODO(levskaya): T5 doesn't seem to emit EOS tokens? double check this # is the best decoding function or just switch to using tf_decode. # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) valid_toks = toks.astype(np.int32) valid_toks[valid_toks >= max_id] = 3 return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') train_config, eval_config, predict_config = get_configs(CFG) rng = random.PRNGKey(CFG.random_seed) rng, init_rng = random.split(rng) # This is used for infeed conversion from feature dict <--> tuple train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] device_train_input_shape = tuple([ (batch_size // topology.num_replicas, CFG.max_input_length if 'inputs' in k else CFG.max_target_length) for k in train_keys ]) learning_rate_fn = train_lib.create_learning_rate_scheduler( factors=CFG.schedule, base_learning_rate=CFG.learning_rate, warmup_steps=CFG.warmup_steps) # First, we only abstractly initialize the optimizer and model parameters, # since the parameters may not even fit in device memory! # TODO(jekbradbury): make optimizer_defs compare by value so it can be created # in get_initial_params without causing pytree incompatibility optimizer_def = optim.Adafactor( CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset) initialize_params_fn = functools.partial( get_initial_params, config=CFG, transformer_config=eval_config, optimizer_def=optimizer_def) optimizer = jax.eval_shape(initialize_params_fn, init_rng) # tuple-like pytree leaves for global_arg_shapes optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape), optimizer) # Build parameter partition annotations for preserving partitions from train # to eval. if num_partitions > 1: optimizer_partitions = optimizer.restore_state( partitions.set_partitions(num_partitions, optimizer.state_dict())) per_host_optimizer_partitions = optimizer.restore_state( partitions.set_partitions(topology.per_host_num_partitions, optimizer.state_dict())) # Restore unreplicated optimizer + model state from last checkpoint. # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore existing_checkpoint_found = False if CFG.restore_checkpoints: existing_checkpoint_found = train_lib.checkpoint_exists(FLAGS.model_dir) optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Import a pretrained-T5 checkpoint only if we didn't import a local # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.) # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore if CFG.restore_t5_checkpoint and not existing_checkpoint_found: optimizer = checkpoint_importer.restore_from_t5_checkpoint( optimizer, CFG.restore_t5_checkpoint) if CFG.restore_t5_checkpoint or existing_checkpoint_found: if num_partitions > 1: # Until checkpoint/restore is sharded, the restored checkpoint is global # and we need to slice each sharded parameter into the chunk containing # only the partitions that are present on this host. def per_host_chunk(x, spec): if spec is None or spec is x: # unsharded or not a parameter return x if spec[0] == 1: dim_size = x.shape[1] elif spec[1] == 1: dim_size = x.shape[0] else: raise NotImplementedError() chunk_size = ( dim_size * topology.per_host_num_partitions // num_partitions) lower = topology.per_replica_set_host_id * chunk_size upper = (topology.per_replica_set_host_id + 1) * chunk_size if spec[0] == 1: return x[:, lower:upper] else: return x[lower:upper] optimizer = jax.tree_multimap(per_host_chunk, optimizer, optimizer_partitions) else: # If pretraining and no checkpoint imported, we jit the (sharded-) init # function to minimize fragmentation. We use the same pmap(sharded_jit) # setup as the training step/loop to initialize everything "in-place" and # avoid communication or OOM. if num_partitions > 1: initialize_params_fn = sharded_jit( initialize_params_fn, in_parts=None, local_in_parts=None, out_parts=optimizer_partitions, local_out_parts=per_host_optimizer_partitions, # devices=one_replica_device_assignment, ) initialize_params_fn = jax.pmap( initialize_params_fn, 'batch', in_axes=0, axis_size=topology.num_replicas, devices=topology.device_assignment) init_rng = broadcast(init_rng) optimizer = initialize_params_fn(init_rng) # We maintain the optimizer in unbroadcasted form (i.e. with no leading # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg # out_axes=None. optimizer = train_lib.unbroadcast(optimizer) else: optimizer = jax.jit(initialize_params_fn)(init_rng) # --------------------------------------------------------------------------- # Compile multidevice versions of train/eval/predict step and cache init fn. # --------------------------------------------------------------------------- # We can use either a single train-step for a host training loop: # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs) # --> new_optimizer, metrics, new_dropout_rng def p_train_step(optimizer, batch, prev_metrics, dropout_rng): return train_lib.train_step( optimizer, batch, prev_metrics, dropout_rng, config=train_config, learning_rate_fn=learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss, use_bfloat16=CFG.use_bfloat16) if num_partitions > 1: p_train_step = sharded_jit( p_train_step, in_parts=(optimizer_partitions, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None), out_parts=(optimizer_partitions, None, None), local_out_parts=(per_host_optimizer_partitions, None, None)) # TODO(levskaya): the in_axes spec below might be wrong, double-check. p_train_step = jax.pmap( p_train_step, axis_name='batch', in_axes=(None, 0, 0, 0), donate_argnums=(0,), global_arg_shapes=(optimizer_shapes, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # OR, we use an on-device loop that feeds the training step via infeed queue. def device_train_loop_cond( args ): """Stopping criterion for on-device loop.""" _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body( args ): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed( token, shape=tuple( [jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): # Create symbolic token for threading infeed data. token = lax.create_token(step) # Run on-device loop. optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit( device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None), local_out_parts=(per_host_optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap( device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None), donate_argnums=(0,), global_arg_shapes=(optimizer_shapes, None, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Reduction psum for metric data. def p_allreduce_metrics(x): return lax.psum(x, axis_name='batch') if num_partitions > 1: p_allreduce_metrics = sharded_jit( p_allreduce_metrics, in_parts=None, local_in_parts=None, out_parts=None, local_out_parts=None, num_partitions=num_partitions, local_num_partitions=topology.per_host_num_partitions) p_allreduce_metrics = jax.pmap( p_allreduce_metrics, axis_name='batch', global_arg_shapes=None, axis_size=topology.num_replicas, devices=topology.device_assignment) # Training evaluation computation. # eval_step(params, batch, config, label_smoothing=0.0) --> metrics def p_eval_step(params, batch): return train_lib.eval_step( params, batch, config=eval_config, label_smoothing=CFG.label_smoothing) if num_partitions > 1: p_eval_step = sharded_jit( p_eval_step, in_parts=(optimizer_partitions.target, None), local_in_parts=(per_host_optimizer_partitions.target, None), out_parts=None, local_out_parts=None) p_eval_step = jax.pmap( p_eval_step, axis_name='batch', in_axes=(None, 0), global_arg_shapes=(optimizer_shapes.target, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Fast autoregressive decoding loop. # For inference and model evaluation. # predict_step(inputs, params, # eos_id, max_decode_len, config, beam_size=4) --> beam_seqs def p_pred_step(inputs, params): return train_lib.predict_step(inputs, params, eos_id, CFG.max_eval_target_length, predict_config, CFG.beam_size) if num_partitions > 1: p_pred_step = sharded_jit( p_pred_step, in_parts=(None, optimizer_partitions.target), local_in_parts=(None, per_host_optimizer_partitions.target), out_parts=None, local_out_parts=None) p_pred_step = jax.pmap( p_pred_step, axis_name='batch', in_axes=(0, None), global_arg_shapes=(None, optimizer_shapes.target), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # --------------------------------------------------------------------------- # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. # There should be a unique dropout key for each replica represented on this # host, but the key should be the same for the same replica on other hosts. # Again, this is what the replica set abstraction is for. dropout_rngs = random.split( random.fold_in(rng, topology.replica_set_id), topology.per_replica_set_num_replicas) # restore step from last checkpoint host_step = int(optimizer.state.step) empty_metrics = broadcast({ 'loss': 0.0, 'accuracy': 0.0, 'learning_rate': 0.0, 'denominator': 0.0 }) if CFG.infeed: # TODO(jekbradbury): support something like this for the Python-loop case logging.info('Precompiling training loop and moving optimizer to device.') optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs, empty_metrics, jnp.array(0, dtype=jnp.int32), 1) optimizer = train_lib.unbroadcast(optimizer) metrics['loss'].block_until_ready() logging.info('Starting training loop.') local_devices = jax.local_devices() device_step = broadcast(host_step) first_epoch = host_step // steps_per_epoch # Main Loop over "epochs". train_iter = train_ds.as_numpy_iterator() for epoch in range(first_epoch, first_epoch + CFG.num_epochs): metrics = empty_metrics # NOTE: 'optimizer' is unbroadcast by construction at initialization or # when loading a checkpoint. It is maintained in 'unbroadcast' state to # enable the XLA cross-replica sharding optimization. The broadcasting is # handled automatically by the pmap'd functions that use it. # Gather all task evaluation metrics. logging.info('Evaluating tasks.') if epoch == first_epoch + 1: train_lib.sync_devices() for task in eval_cache.tasks: logging.info('Evaluating task %s', task.name) all_predicted, all_bs = [], [] for pred_batch in eval_cache.preprocessed_examples[task.name]: # Handle final odd-sized batch by padding instead of dropping it. input_batch, unpadded_batch_size = train_lib.pad_batch_to_size( pred_batch['inputs'], per_replica_set_eval_batch_size) all_bs.append(unpadded_batch_size) # Split batch dimensions for pmap. input_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), input_batch) # Run fast inference on batch. all_predicted.append(p_pred_step(input_batch, optimizer.target)) # Pad out the number of batches so each host has the same number. max_host_batch_number = np.max( eval_cache.preprocessed_batch_sizes[task.name]) batch_shortfall = max_host_batch_number - len(all_predicted) if batch_shortfall > 0: # TODO(levskaya): Fix for case of entirely empty all_predicted. # To make sure the cross-host barriers work, we run the program the same # number of times on all hosts. The results of this call is ignored, and # the predictions are populated with zeros instead. p_pred_step(input_batch, optimizer.target) # Dummy call. all_predicted.extend([jnp.zeros_like(all_predicted[0])] * batch_shortfall) all_bs.extend([0] * batch_shortfall) all_predicted = jnp.concatenate(all_predicted) all_bs = jnp.array(all_bs) # Collect all batches from across hosts and reverse sharding. all_predicted = train_lib.host_allgather( all_predicted, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0) seqlength = all_predicted.shape[-1] total_examples = np.sum( train_lib.host_allgather(all_bs, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0)) del all_bs assert total_examples == len(eval_cache.examples[task.name]), ( 'Total number of batches incorrect for task %s.' % task.name) # De-shard the collected predicted tokens and remove padding. all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape( -1, seqlength)[:total_examples] # We now run the post-processing and metric-fns on a single host. if jax.host_id() == 0: assert eval_summary_writer raw_predictions = [] for tokens in all_predicted: raw_predictions.append(decode_tokens(tokens)) # post-process predictions for metric fns predictions = [ task.postprocess_fn(p, example=ex) for p, ex in zip(raw_predictions, eval_cache.examples[task.name]) ] for metric_fn in task.metric_fns: scores = metric_fn(eval_cache.targets[task.name], predictions) for metric_name, metric_value in scores.items(): tag = f'eval/{task.name}/{metric_name}' eval_summary_writer.scalar(tag, metric_value, host_step) logging.info('EVAL %s at step %d: %.3f', tag, host_step, metric_value) eval_summary_writer.flush() # Save text samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): tgt_txt = tf.compat.as_text( eval_cache.examples[task.name][n]['targets_plaintext']) pred_txt = raw_predictions[n] exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n' f'target: {tgt_txt}\n\n' f'prediction: {pred_txt}\n\n') eval_summary_writer.text(f'{task.name} samples', exemplars, host_step) eval_summary_writer.flush() # Take an Xprof trace after the first loop has compiled everything. if epoch == first_epoch + 1: train_lib.sync_devices() # For on-device loop, we launch the computation before feeding data. logging.info('BEGIN Train loop.') if CFG.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step), epoch) optimizer = train_lib.unbroadcast(optimizer) # Epoch loop. while int(host_step // steps_per_epoch) == epoch: batch = next(train_iter) batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch) # Feed the on-device training loop. if CFG.infeed: for i, device in enumerate(local_devices): # When using infeed to provide data to the computation, we're on our # own for feeding the right values to the right devices. Each device # should get the minibatch corresponding to its replica, a slice of # the larger batch corresponding to the host's replica set. if device.platform == 'tpu': device_coords = (*device.coords, device.id % 2) else: device_coords = (device.host_id, i) per_replica_set_device_coords = tuple( dc % prsm for dc, prsm in zip(device_coords, topology.per_replica_set_mesh)) per_replica_set_replica_coords = tuple( prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords, topology.per_replica_mesh)) per_replica_set_replica_id = 0 for prsm, prm, prsrc in zip(topology.per_replica_set_mesh, topology.per_replica_mesh, per_replica_set_replica_coords): per_replica_set_replica_id = ( per_replica_set_replica_id * prsm // prm + prsrc) input_tuple = tuple( [batch[k][per_replica_set_replica_id] for k in train_keys]) # Safety check: infeed does not check shape or types but requires # them to agree with on-device spec, otherwise the queue and program # stalls. tuple_shapes = jax.tree_map(jnp.shape, input_tuple) tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple) assert tuple_shapes == device_train_input_shape, ( 'infeed shape error %s != %s' % (tuple_shapes, device_train_input_shape)) assert tuple(set(tuple_dtypes)) == (jnp.int32,), \ ('infeed dtype error %s not all of type %s' % ( tuple_dtypes, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) # Host training loop. else: optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, metrics, dropout_rngs) optimizer = train_lib.unbroadcast(optimizer) host_step += 1 logging.info('END Train loop.') # Maybe save a checkpoint on one host. if (CFG.save_checkpoints and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step) # Gather training metrics. metrics = p_allreduce_metrics(metrics) metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics) denominator = metrics.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics) # pylint: disable=cell-var-from-loop logging.info('train in step: %s, %s', host_step, summary) if jax.host_id() == 0: assert train_summary_writer for key, val in summary.items(): train_summary_writer.scalar(key, val, host_step) train_summary_writer.flush() # Gather training evaluation metrics. logging.info('Gathering training evaluation metrics.') eval_metrics = [] eval_iter = eval_ds.as_numpy_iterator() for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter): eval_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) # average metrics across devices eval_metrics = p_allreduce_metrics(eval_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # average metrics across steps eval_metrics = jax.tree_map(np.sum, eval_metrics) eval_denominator = eval_metrics.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics) logging.info('eval in step: %s, %s', host_step, eval_summary) if jax.host_id() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, host_step) eval_summary_writer.flush() # Wait until computations are done before exiting logging.info('Finished.') train_lib.sync_devices() # Shut down the infeed threadpool. if CFG.infeed: infeed_pool.shutdown()
def dist_sq(R): dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :] zero = jnp.zeros_like(dR) dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR)) return jnp.sum(dR ** 2, axis=2)
def dist_sq(R): dR = R[:, np.newaxis, :] - R[np.newaxis, :, :] zero = np.zeros_like(dR) dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR)) return np.sum(dR ** 2, axis=2)
if not no_nans: with open(os.path.join(str(output_dir), f"{epoch}_nan.pkl"), 'wb') as f: pickle.dump(get_params(opt_state), f) exit( f"nan encountered in epoch_info and params pickled: {epoch_info}" ) params = get_params(opt_state) train_acc, train_logits, train_labels, train_nll, train_kl, _ = evaluate( params, train_eval_loader, input_size, args.nsamples, rng_generator, args.kl_coef / train_size) train_loss = train_nll + args.kl_coef * train_kl if args.disable_test: val_acc, val_logits, val_labels, val_nll, val_kl = jnp.zeros( 1), jnp.zeros_like(train_logits), jnp.zeros(1), jnp.zeros(1) test_acc, test_logits, test_labels, test_nll, test_kl = jnp.zeros( 1), jnp.zeros_like(train_logits), jnp.zeros(1), jnp.zeros(1) else: val_acc, val_logits, val_labels, val_nll, val_kl, _ = evaluate( ema_params, val_loader, input_size, args.nsamples, rng_generator, args.kl_coef / train_size) test_acc, test_logits, test_labels, test_nll, test_kl, test_ws = evaluate( ema_params, test_loader, input_size, args.nsamples, rng_generator, args.kl_coef / train_size) val_loss, test_loss = val_nll + args.kl_coef * val_kl, test_nll + args.kl_coef * test_kl cal_train = utils.get_calibration( train_labels, jax.device_get(jnp.exp(train_logits))) cal_val = utils.get_calibration(val_labels, jax.device_get(jnp.exp(val_logits)))
def _test_transformation(self, func, param, msg=None): primal, tangent = jax.jvp(func, (param,), (np.ones_like(param),)) self.assertEqual(primal.shape, tangent.shape) if not FLAGS.execute_only: self.assertNotAllEqual(tangent, np.zeros_like(tangent), msg=msg)
def init_fn(fx_train=0.): fx_train = fl(fx_train) qx_train = np.zeros_like(fx_train) return np.concatenate((fx_train, qx_train), axis=0)
def _test_transformation(self, func, param, msg=None): out, f_vjp = jax.vjp(func, param) cotangent, = f_vjp(np.ones_like(out).astype(out.dtype)) self.assertEqual(param.shape, cotangent.shape) if not FLAGS.execute_only: self.assertNotAllEqual(cotangent, np.zeros_like(cotangent), msg=msg)
def convergence(args): epoch, lamb, lamb_idx = args suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False) ligands = [] for mol in suppl: ligands.append(mol) ligand_a = ligands[0] ligand_b = ligands[1] # print(ligand_a.GetNumAtoms()) # print(ligand_b.GetNumAtoms()) # ligand_a = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C")) # ligand_b = Chem.AddHs(Chem.MolFromSmiles("CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C")) # ligand_a = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC")) # ligand_b = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1CC")) # AllChem.EmbedMolecule(ligand_a, randomSeed=2020) # AllChem.EmbedMolecule(ligand_b, randomSeed=2020) coords_a = get_conf(ligand_a, idx=0) coords_b = get_conf(ligand_b, idx=0) # coords_b = np.matmul(coords_b, special_ortho_group.rvs(3)) coords_a = recenter(coords_a) coords_b = recenter(coords_b) coords = np.concatenate([coords_a, coords_b]) a_idxs = get_heavy_atom_idxs(ligand_a) b_idxs = get_heavy_atom_idxs(ligand_b) a_full_idxs = np.arange(0, ligand_a.GetNumAtoms()) b_full_idxs = np.arange(0, ligand_b.GetNumAtoms()) b_idxs += ligand_a.GetNumAtoms() b_full_idxs += ligand_a.GetNumAtoms() nrg_fns = [] forcefield = 'ff/params/smirnoff_1_1_0_ccc.py' ff_raw = open(forcefield, "r").read() ff_handlers = deserialize_handlers(ff_raw) combined_mol = Chem.CombineMols(ligand_a, ligand_b) for handler in ff_handlers: if isinstance(handler, handlers.HarmonicBondHandler): bond_idxs, (bond_params, _) = handler.parameterize(combined_mol) nrg_fns.append( functools.partial(bonded.harmonic_bond, params=bond_params, box=None, bond_idxs=bond_idxs ) ) elif isinstance(handler, handlers.HarmonicAngleHandler): angle_idxs, (angle_params, _) = handler.parameterize(combined_mol) nrg_fns.append( functools.partial(bonded.harmonic_angle, params=angle_params, box=None, angle_idxs=angle_idxs ) ) # elif isinstance(handler, handlers.ImproperTorsionHandler): # torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol) # print(torsion_idxs) # assert 0 # nrg_fns.append( # functools.partial(bonded.periodic_torsion, # params=torsion_params, # box=None, # lamb=None, # torsion_idxs=torsion_idxs # ) # ) # elif isinstance(handler, handlers.ProperTorsionHandler): # torsion_idxs, (torsion_params, _) = handler.parameterize(combined_mol) # # print(torsion_idxs) # nrg_fns.append( # functools.partial(bonded.periodic_torsion, # params=torsion_params, # box=None, # lamb=None, # torsion_idxs=torsion_idxs # ) # ) masses_a = onp.array([a.GetMass() for a in ligand_a.GetAtoms()]) * 10000 masses_b = onp.array([a.GetMass() for a in ligand_b.GetAtoms()]) combined_masses = np.concatenate([masses_a, masses_b]) # com_restraint_fn = functools.partial(bonded.centroid_restraint, # params=None, # box=None, # lamb=None, # # masses=combined_masses, # try making this ones-like # masses=np.ones_like(combined_masses), # group_a_idxs=a_idxs, # group_b_idxs=b_idxs, # kb=50.0, # b0=0.0) pmi_restraint_fn = functools.partial(pmi_restraints_new, params=None, box=None, lamb=None, # masses=np.ones_like(combined_masses), masses=combined_masses, # a_idxs=a_full_idxs, # b_idxs=b_full_idxs, a_idxs=a_idxs, b_idxs=b_idxs, angle_force=100.0, com_force=100.0 ) prefactor = 2.7 # unitless shape_lamb = (4*np.pi)/(3*prefactor) # unitless kappa = np.pi/(np.power(shape_lamb, 2/3)) # unitless sigma = 0.15 # 1 angstrom std, 95% coverage by 2 angstroms alpha = kappa/(sigma*sigma) alphas = np.zeros(combined_mol.GetNumAtoms())+alpha weights = np.zeros(combined_mol.GetNumAtoms())+prefactor shape_restraint_fn = functools.partial( shape.harmonic_overlap, box=None, lamb=None, params=None, a_idxs=a_idxs, b_idxs=b_idxs, alphas=alphas, weights=weights, k=150.0 ) # shape_restraint_4d_fn = functools.partial( # shape.harmonic_4d_overlap, # box=None, # params=None, # a_idxs=a_idxs, # b_idxs=b_idxs, # alphas=alphas, # weights=weights, # k=200.0 # ) def restraint_fn(conf, lamb): return pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf) # return (1-lamb)*pmi_restraint_fn(conf) + lamb*shape_restraint_fn(conf) nrg_fns.append(restraint_fn) def nrg_fn(conf, lamb): s = [] for u in nrg_fns: s.append(u(conf, lamb=lamb)) return np.sum(s) grad_fn = jax.grad(nrg_fn, argnums=(0,1)) grad_fn = jax.jit(grad_fn) du_dx_fn = jax.grad(nrg_fn, argnums=(0)) du_dx_fn = jax.jit(du_dx_fn) x_t = coords v_t = np.zeros_like(x_t) w = Chem.SDWriter('frames_heavy_'+str(epoch)+'_'+str(lamb_idx)+'.sdf') dt = 1.5e-3 ca, cb, cc = langevin_coefficients(300.0, dt, 1.0, combined_masses) cb = -1*onp.expand_dims(cb, axis=-1) cc = onp.expand_dims(cc, axis=-1) du_dls = [] # re-seed since forking onp.random.seed(int.from_bytes(os.urandom(4), byteorder='little')) # for step in range(100000): for step in range(100000): # if step % 1000 == 0: # u = nrg_fn(x_t, lamb) # print("step", step, "nrg", onp.asarray(u), "avg_du_dl", onp.mean(du_dls)) # mol = make_conformer(combined_mol, x_t[:ligand_a.GetNumAtoms()], x_t[ligand_a.GetNumAtoms():]) # w.write(mol) # w.flush() if step % 5 == 0 and step > 10000: du_dx, du_dl = grad_fn(x_t, lamb) du_dls.append(du_dl) else: du_dx = du_dx_fn(x_t, lamb) v_t = ca*v_t + cb*du_dx + cc*onp.random.normal(size=x_t.shape) x_t = x_t + v_t*dt return np.mean(onp.mean(du_dls))
def init_fn(params): mu = jax.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAMSGradState(mu=mu, nu=nu)
def fed_opt( client_objectives: List[StochasticObjective], client_update_fn: ClientUpdateFn, server_update_fn: ServerUpdateFn, sample_clients_fn: SampleClientsFn, prng_key: jnp.ndarray, init_state: jnp.ndarray, num_rounds: int, num_clients_per_round: int, ) -> Tuple[List[ServerState], List[RoundInfo]]: """Runs generalized federated averaging for the specified number of rounds. At each round, the algorithm does the following: 1. Samples a batch of clients using `sample_clients_fn`. 2. Runs `client_update_fn` on each sampled client objective that returns a `client_delta`. 3. Aggregates `client_deltas` using `server_update_fn`. Args: client_update_fn: A function for computing local client updates. server_update_fn: A function for computing server updates. sample_clients_fn: A function for sampling indices of the clients. client_objectives: A list of client objective functions. prng_key: A key for random number generation. init_state: The initial server state. num_rounds: The number of training rounds to run. num_clients_per_round: The number of clients used at each round. Returns: A list of tuples `(round: int, state: ServerState)` that represents the trajectory of the server state over the course of training. """ num_clients = len(client_objectives) server_state = ServerState(r=0, x=init_state, v=jnp.zeros_like(init_state)) trajectory = [server_state] info = [None] for _ in range(num_rounds): round_info = {} # Select clients. prng_key, subkey = random.split(prng_key) with Timer("select_clients_time") as t: client_ids = sample_clients_fn(subkey, num_clients, num_clients_per_round) client_objectives_round = [ client_objectives[i] for i in client_ids ] client_weights_round = jnp.asarray( [float(o.num_points) for o in client_objectives_round]) round_info[t.description] = t.elapsed # Compute client updates. client_deltas_round = [] # TODO: parallelize this loop. with Timer("client_updates_time") as t: for client_objective in client_objectives_round: prng_key, subkey = random.split(prng_key) client_delta = client_update_fn(client_objective, server_state.x, subkey) client_deltas_round.append(client_delta) round_info[t.description] = t.elapsed # Update server state. with Timer("server_update_time") as t: server_state = server_update_fn(client_deltas_round, client_weights_round, server_state) round_info[t.description] = t.elapsed trajectory.append(server_state) info.append(round_info) return trajectory, info
def inner_apply(edge_embeddings, node_embeddings): first_layer_dim = mlp_vtoe_dims[0] additional_layer_dims = mlp_vtoe_dims[1:] if allow_non_adjacent and edge_embeddings is not None: num_separate_mlps = 1 + edge_embeddings.shape[-1] elif allow_non_adjacent: num_separate_mlps = 1 elif edge_embeddings is not None: num_separate_mlps = edge_embeddings.shape[-1] else: raise ValueError( "Either allow_non_adjacent should be True, or " "edge_embeddings should be provided") node_embedding_dim = node_embeddings.shape[-1] # First layer: process each node embedding. weight_from_source = self.param( "l0_weight_from_source", shape=(num_separate_mlps, node_embedding_dim, first_layer_dim), initializer=initializers.xavier_normal()) weight_from_dest = self.param( "l0_weight_from_dest", shape=(num_separate_mlps, node_embedding_dim, first_layer_dim), initializer=initializers.xavier_normal()) bias = self.param("l0_bias", shape=(num_separate_mlps, first_layer_dim), initializer=initializers.zeros) from_source = jnp.einsum("sx,kxy->sky", node_embeddings, weight_from_source) from_dest = jnp.einsum("dx,kxy->dky", node_embeddings, weight_from_dest) activations = jax.nn.relu(from_source[:, None, :, :] + from_dest[None, :, :, :] + bias[None, None, :, :]) # Additional layers: MLP for each edge type. for i, layer_dim in enumerate(additional_layer_dims): weight = self.param(f"l{i+1}_weight", shape=(num_separate_mlps, activations.shape[-1], layer_dim), initializer=initializers.xavier_normal()) bias = self.param(f"l{i+1}_bias", shape=(num_separate_mlps, layer_dim), initializer=initializers.zeros) activations = jax.nn.relu( jnp.einsum("sdkx,kxy->sdky", activations, weight) + bias[None, None, :, :]) # Sum over edge types and possibly over source nodes. if edge_embeddings is None: result = activations.squeeze(axis=2) if mask is not None: result = jnp.where(mask[:, :, None], result, jnp.zeros_like(result)) if message_passing: result = jnp.sum(result, axis=0) else: if allow_non_adjacent: if mask is None: pairwise = jnp.ones(edge_embeddings.shape[:2] + (1, )) else: pairwise = mask mlp_weights = jnp.concatenate([ edge_embeddings, pairwise.astype("float")[:, :, None] ], -1) else: mlp_weights = edge_embeddings if message_passing: result = jnp.einsum("sdky,sdk->dy", activations, mlp_weights) else: result = jnp.einsum("sdky,sdk->sdy", activations, mlp_weights) return result
def init_param_state(self, param): return _MomentumParamState(jnp.zeros_like(param))
def zeros_like(x, dtype_str=None, dev=None): if dtype_str: dtype = _jnp.__dict__[dtype_str] else: dtype = x.dtype return _to_dev(_jnp.zeros_like(x, dtype=dtype), dev)