def _resample(self, rng): """(Re)Samples posterior/marginals given past test results. Args: rng: random key Produces and examines first the marginal produced by LBP. If that marginal is not valid because LBP did not converge, or that posterior samples are needed in the next iteration of the simulator by a group selector, we compute both marginals and posterior using the more expensive sampler. """ # reset marginals self.state.marginals = [] # compute marginal using a cheap LBP sampler. lbp_sampler = self._samplers[0] lbp_sampler.produce_sample(rng, self.state) # if marginal is valid (i.e. LBP has converged), append it to state. if not np.any(np.isnan(lbp_sampler.marginal)): self.state.marginals.append(lbp_sampler.marginal) self.state.update_particles(lbp_sampler) # if marginal has not converged, or expensive sampler is needed, use it. if (np.any(np.isnan(lbp_sampler.marginal)) or (self._policy.next_selector.NEEDS_POSTERIOR and self.state.extra_tests_needed > 0 and (self.state.curr_cycle == 0 or self.state.curr_cycle < self._max_test_cycles - 1))): sampler = self._samplers[1] sampler.produce_sample(rng, self.state) self.state.marginals.append(sampler.marginal) self.state.update_particles(sampler)
def test_gradient_sinkhorn_euclidean(self, lse_mode, momentum_strategy): """Test gradient w.r.t. locations x of reg-ot-cost.""" d = 3 n = 10 m = 15 keys = jax.random.split(self.rng, 4) x = jax.random.normal(keys[0], (n, d)) / 10 y = jax.random.normal(keys[1], (m, d)) / 10 a = jax.random.uniform(keys[2], (n, )) b = jax.random.uniform(keys[3], (m, )) # Adding zero weights to test proper handling a = jax.ops.index_update(a, 0, 0) b = jax.ops.index_update(b, 3, 0) a = a / jnp.sum(a) b = b / jnp.sum(b) def loss_fn(x, y): geom = pointcloud.PointCloud(x, y, epsilon=0.01) f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn( geom, a, b, momentum_strategy=momentum_strategy, lse_mode=lse_mode) return regularized_transport_cost, (geom, f, g) delta = jax.random.normal(keys[0], (n, d)) delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-5 # perturbation magnitude # first calculation of gradient loss_and_grad = jax.value_and_grad(loss_fn, has_aux=True) (loss_value, aux), grad_loss = loss_and_grad(x, y) custom_grad = jnp.sum(delta * grad_loss) self.assertIsNot(loss_value, jnp.nan) self.assertEqual(grad_loss.shape, x.shape) self.assertFalse(jnp.any(jnp.isnan(grad_loss))) # second calculation of gradient tm = aux[0].transport_from_potentials(aux[1], aux[2]) tmp = 2 * tm[:, :, None] * (x[:, None, :] - y[None, :, :]) grad_x = jnp.sum(tmp, 1) other_grad = jnp.sum(delta * grad_x) # third calculation of gradient loss_delta_plus, _ = loss_fn(x + eps * delta, y) loss_delta_minus, _ = loss_fn(x - eps * delta, y) finite_diff_grad = (loss_delta_plus - loss_delta_minus) / (2 * eps) self.assertAllClose(custom_grad, other_grad, rtol=1e-02, atol=1e-02) self.assertAllClose(custom_grad, finite_diff_grad, rtol=1e-02, atol=1e-02) self.assertAllClose(other_grad, finite_diff_grad, rtol=1e-02, atol=1e-02) self.assertIsNot(jnp.any(jnp.isnan(custom_grad)), True)
def training_step(H, data_input, target, optimizer, ema, rng): def loss_fun(params): stats = VAE(H).apply({'params': params}, data_input, target, rng) return -stats['elbo'] * np.log(2), stats gradval, stats = lax.pmean( grad(loss_fun, has_aux=True)(optimizer.target), 'batch') gradval, grad_norm = clip_grad_norm(gradval, H.grad_clip) ll_nans = jnp.any(jnp.isnan(stats['log_likelihood'])) kl_nans = jnp.any(jnp.isnan(stats['kl'])) stats.update(log_likelihood_nans=ll_nans, kl_nans=kl_nans) learning_rate = H.lr * linear_warmup(H.warmup_iters)(optimizer.state.step) # only update if no rank has a nan and if the grad norm is below a specific # threshold def skip_update(_): # Only increment the step return optimizer.replace(state=optimizer.state.replace( step=optimizer.state.step + 1)), ema def update(_): optimizer_ = optimizer.apply_gradient(gradval, learning_rate=learning_rate) e_decay = H.ema_rate ema_ = tree_multimap(lambda e, p: e * e_decay + (1 - e_decay) * p, ema, optimizer.target) return optimizer_, ema_ skip = (ll_nans | kl_nans | ((H.skip_threshold != -1) & ~(grad_norm < H.skip_threshold))) optimizer, ema = lax.cond(skip, skip_update, update, None) stats.update(skipped_updates=skip, grad_norm=grad_norm) return optimizer, ema, stats
def sample_kernel(sa_state, model_args=(), model_kwargs=None): pe_fn = potential_fn if potential_fn_gen: pe_fn = potential_fn_gen(*model_args, **model_kwargs) zs, pes, loc, scale = sa_state.adapt_state # we recompute loc/scale after each iteration to avoid precision loss # XXX: consider to expose a setting to do this job periodically # to save some computations loc = jnp.mean(zs, 0) if scale.ndim == 2: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky) else: scale = jnp.std(zs, 0) rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(sa_state.rng_key, 4) _, unravel_fn = ravel_pytree(sa_state.z) z = loc + _sample_proposal(scale, rng_key_z) pe = pe_fn(unravel_fn(z)) pe = jnp.where(jnp.isnan(pe), jnp.inf, pe) diverging = (pe - sa_state.potential_energy) > max_delta_energy # NB: all terms having the pattern *s will have shape N x ... # and all terms having the pattern *s_ will have shape (N + 1) x ... locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z) zs_ = jnp.concatenate([zs, z[None, :]]) pes_ = jnp.concatenate([pes, pe[None]]) locs_ = jnp.concatenate([locs, loc[None, :]]) scales_ = jnp.concatenate([scales, scale[None, ...]]) if scale.ndim == 2: # dense_mass log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_ else: log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_ # mask invalid values (nan, +inf) by -inf log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf) # get rejecting index j = random.categorical(rng_key_reject, log_weights_) zs = _numpy_delete(zs_, j) pes = _numpy_delete(pes_, j) loc = locs_[j] scale = scales_[j] adapt_state = SAAdaptState(zs, pes, loc, scale) # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`. accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_)) itr = sa_state.i + 1 n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n # XXX: we make a modification of SA sampler in [1] # in [1], each MCMC state contains N points `zs` # here we do resampling to pick randomly a point from those N points k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def rmsprop(gradval, params, eta=0.01, gamma=0.9, eps=1e-8, R=500, per=100, disp=None): params = {k: np.array(v) for k, v in params.items()} grms = {k: np.zeros_like(v) for k, v in params.items()} n = len(params) # iterate to max for j in range(R + 1): val, grad = gradval(params) vnan = np.isnan(val) gnan = tree_map(lambda g: np.isnan(g).any(), grad) if vnan or np.any(tree_leaves(gnan)): print('Encountered nans!') disp(j, val, params) return params for k in params: grms[k] += (1 - gamma) * (grad[k]**2 - grms[k]) params[k] += eta * grad[k] / np.sqrt(grms[k] + eps) if disp is not None and j % per == 0: disp(j, val, params) return params
def update_fn(updates, opt_state, params=None): del params opt_state = ZeroNansState( jax.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates)) updates = jax.tree_map( lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates) return updates, opt_state
def test_sigmoid(x, version): """Check for null gradient issues.""" result = sigmoid(x, version=version) assert not np.isnan(result) dsigmoid = grad(sigmoid) dresult = dsigmoid(x, version=version) assert not np.isnan(dresult)
def test_regularized_unbalanced_bures(self): """Tests Regularized Unbalanced Bures.""" x = jnp.concatenate((jnp.array([0.9]), self.x[0, :])) y = jnp.concatenate((jnp.array([1.1]), self.y[0, :])) rub = costs.UnbalancedBures(self.dim, 1, 0.8) self.assertIsNot(jnp.any(jnp.isnan(rub(x, y))), True) self.assertIsNot(jnp.any(jnp.isnan(rub(y, x))), True) self.assertAllClose(rub(x, y), rub(y, x), rtol=1e-3, atol=1e-3)
def test_sample_momentum_resolution(): N = 100 mom = Momentum.from_ndarray(rjax.uniform(rng, (N, 3))) smom, cov = sample_momentum_resolution(rng, mom) assert cov.shape == (N, 3, 3) assert smom.as_array.shape == (N, 3) assert not np.isnan(cov).any() assert not np.isnan(smom.as_array).any()
def test_sample_position_resolution(): N = 100 pos = Position.from_ndarray(rjax.uniform(rng, (N, 3))) spos, cov = sample_position_resolution(rng, pos) assert cov.shape == (N, 3, 3) assert spos.as_array.shape == (N, 3) assert not np.isnan(cov).any() assert not np.isnan(spos.as_array).any()
def test_sample_helix_resolution(): N = 100 hel = Helix.from_ndarray(rjax.uniform(rng, (N, 5))) shel, cov = sample_helix_resolution(rng, hel) assert cov.shape == (N, 5, 5) assert shel.as_array.shape == (N, 5) assert not np.isnan(cov).any() assert not np.isnan(shel.as_array).any()
def test_sample_cluster_resolution(): N = 100 clu = Cluster.from_ndarray(rjax.uniform(rng, (N, 3))) sclu, cov = sample_cluster_resolution(rng, clu) assert cov.shape == (N, 3, 3) assert sclu.as_array.shape == (N, 3) assert not np.isnan(cov).any() assert not np.isnan(sclu.as_array).any()
def test_sin(self): """In [-1e10, 1e10] safe_sin and safe_cos are accurate.""" for fn in ['sin', 'cos']: y_true, y = safe_trig_harness(fn, 10) self.assertLess(np.max(np.abs(y - y_true)), 1e-4) self.assertFalse(jnp.any(jnp.isnan(y))) # Beyond that range it's less accurate but we just don't want it to be NaN. for fn in ['sin', 'cos']: y_true, y = safe_trig_harness(fn, 60) self.assertFalse(jnp.any(jnp.isnan(y)))
def adam(gradval, params, eta=0.01, beta1=0.9, beta2=0.9, eps=1e-7, c=0.01, R=500, per=100, disp=None, log=False): params = {k: np.array(v) for k, v in params.items()} gavg = {k: np.zeros_like(v) for k, v in params.items()} grms = {k: np.zeros_like(v) for k, v in params.items()} n = len(params) if log: hist = {k: [] for k in params} # iterate to max for j in range(R + 1): val, grad = gradval(params) # test for nans vnan = np.isnan(val) or np.isinf(val) gnan = np.array( [np.isnan(g).any() or np.isinf(g).any() for g in grad.values()]) if vnan or gnan.any(): print('Encountered nan/inf!') disp(j, val, params) summary_naninf(grad) return params # clip gradients gradc = {k: clip2(grad[k], c) for k in params} # early bias correction etat = eta * (np.sqrt(1 - beta2**(j + 1))) / (1 - beta1**(j + 1)) for k in params: gavg[k] += (1 - beta1) * (gradc[k] - gavg[k]) grms[k] += (1 - beta2) * (gradc[k]**2 - grms[k]) params[k] += etat * gavg[k] / (np.sqrt(grms[k]) + eps) if log: hist[k].append(params[k]) if disp is not None and j % per == 0: disp(j, val, params) if log: hist = {k: np.stack(v) for k, v in hist.items()} return hist return params
def sample(self, key: jnp.ndarray) -> jnp.ndarray: """Sample from the distribution. Args: key: JAX random key. """ sample = jax.random.beta(key, self.alpha, self.beta) is_nan = jnp.logical_or(jnp.isnan(self.alpha), jnp.isnan(self.beta)) return jnp.where(is_nan, jnp.full(self.alpha.shape, jnp.nan), sample)
def test_scce_out_of_bounds(): ypred = jnp.zeros([4, 10]) ytrue0 = jnp.array([0, 0, -1, 0]) ytrue1 = jnp.array([0, 0, 10, 0]) scce = elegy.losses.SparseCategoricalCrossentropy() assert jnp.isnan(scce(ytrue0, ypred)).any() assert jnp.isnan(scce(ytrue1, ypred)).any() scce = elegy.losses.SparseCategoricalCrossentropy(check_bounds=False) assert not jnp.isnan(scce(ytrue0, ypred)).any() assert not jnp.isnan(scce(ytrue1, ypred)).any()
def sample(self, key: jnp.ndarray) -> jnp.ndarray: """Sample from the distribution. Args: key: JAX random key. """ std_norm = jax.random.normal(key, shape=self.mu.shape, dtype=self.mu.dtype) is_nan = jnp.logical_or(jnp.isnan(self.mu), jnp.isnan(self.sigma)) return jnp.where(is_nan, jnp.full(self.mu.shape, jnp.nan), std_norm * self.sigma + self.mu)
def test_behler_parrinello_network(self, N_types, dtype): key = random.PRNGKey(1) R = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 0]], dtype) species = np.array([1, 1, N_types]) if N_types > 1 else None box_size = f32(1.5) displacement, _ = space.periodic(box_size) nn_init, nn_apply = energy.behler_parrinello(displacement, species) params = nn_init(key, R) nn_force_fn = grad(nn_apply, argnums=1) nn_force = jit(nn_force_fn)(params, R) nn_energy = jit(nn_apply)(params, R) self.assertAllClose(np.any(np.isnan(nn_energy)), False) self.assertAllClose(np.any(np.isnan(nn_force)), False) self.assertAllClose(nn_force.shape, [3, 3])
def visualize_normals(depth, acc, scaling=None): """Visualize fake normals of `depth` (optionally scaled to be isotropic).""" if scaling is None: mask = ~jnp.isnan(depth) x, y = jnp.meshgrid(jnp.arange(depth.shape[1]), jnp.arange(depth.shape[0]), indexing='xy') xy_var = (jnp.var(x[mask]) + jnp.var(y[mask])) / 2 z_var = jnp.var(depth[mask]) scaling = jnp.sqrt(xy_var / z_var) scaled_depth = scaling * depth normals = depth_to_normals(scaled_depth) return matte( jnp.isnan(normals) + jnp.nan_to_num((normals + 1) / 2, 0), acc)
def test_prior(self,key, num_samples, log_likelihood=None, **kwargs): keys = random.split(key, num_samples) for key in keys: U = random.uniform(key, shape=(self.U_ndims,)) Y = self(U, **kwargs) for k,v in Y.items(): if jnp.any(jnp.isnan(v)): raise ValueError('nan in prior transform',Y) if log_likelihood is not None: loglik = log_likelihood(**Y) if jnp.isnan(loglik): raise ValueError("Log likelihood is nan", Y) if loglik == 0.: print("Log likelihood is zero", loglik, Y) print(loglik)
def wrapper(self, x, *args): f, grad = self.fun(x, *args) if np.isnan(f) or np.any(np.isnan(grad)): print("nan detected") print(x) print(f) print(grad) if self.fundebug is not None: self.fundebug(x) assert (0) self.f = f self.grad = grad return f, grad
def _contains_query(vals, query): if isinstance(query, tuple): return map(partial(_contains_query, vals), query) if np.isnan(query): if np.any(np.isnan(vals)): raise FoundValue('NaN') elif np.isinf(query): if np.any(np.isinf(vals)): raise FoundValue('Found Inf') elif np.isscalar(query): if np.any(vals == query): raise FoundValue(str(query)) else: raise ValueError('Malformed Query: {}'.format(query))
def test_gradient_sinkhorn_geometry(self, lse_mode): """Test gradient w.r.t. cost matrix.""" for n, m in ((11, 13), (15, 9)): keys = jax.random.split(self.rng, 2) cost_matrix = jnp.abs(jax.random.normal(keys[0], (n, m))) delta = jax.random.normal(keys[1], (n, m)) delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-3 # perturbation magnitude def loss_fn(cm): a = jnp.ones(cm.shape[0]) / cm.shape[0] b = jnp.ones(cm.shape[1]) / cm.shape[1] geom = geometry.Geometry(cm, epsilon=0.5) out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) return out.reg_ot_cost, (geom, out.f, out.g) # first calculation of gradient loss_and_grad = jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) (loss_value, aux), grad_loss = loss_and_grad(cost_matrix) custom_grad = jnp.sum(delta * grad_loss) self.assertIsNot(loss_value, jnp.nan) self.assertEqual(grad_loss.shape, cost_matrix.shape) self.assertFalse(jnp.any(jnp.isnan(grad_loss))) # second calculation of gradient transport_matrix = aux[0].transport_from_potentials(aux[1], aux[2]) grad_x = transport_matrix other_grad = jnp.sum(delta * grad_x) # third calculation of gradient loss_delta_plus, _ = loss_fn(cost_matrix + eps * delta) loss_delta_minus, _ = loss_fn(cost_matrix - eps * delta) finite_diff_grad = (loss_delta_plus - loss_delta_minus) / (2 * eps) self.assertAllClose(custom_grad, other_grad, rtol=1e-02, atol=1e-02) self.assertAllClose(custom_grad, finite_diff_grad, rtol=1e-02, atol=1e-02) self.assertAllClose(other_grad, finite_diff_grad, rtol=1e-02, atol=1e-02) self.assertIsNot(jnp.any(jnp.isnan(custom_grad)), True)
def kernel(rng_key: jax.random.PRNGKey, state: RWMState) -> Tuple[RWMState, RWMInfo]: """Moves the chain by one step using the Random Walk Metropolis algorithm. Parameters ---------- rng_key: The pseudo-random number generator key used to generate random numbers. state: The current state of the chain: position, log-probability and gradient of the log-probability. Returns ------- The next state of the chain and additional information about the current step. """ key_move, key_accept = jax.random.split(rng_key) position, log_prob = state move_proposal = proposal_generator(key_move) new_position = position + move_proposal new_log_prob = logpdf(new_position) new_state = RWMState(new_position, new_log_prob) delta = new_log_prob - log_prob delta = np.where(np.isnan(delta), -np.inf, delta) p_accept = np.clip(np.exp(delta), a_max=1) do_accept = jax.random.bernoulli(key_accept, p_accept) accept_state = (new_state, RWMInfo(new_state, p_accept, True)) reject_state = (state, RWMInfo(new_state, p_accept, False)) return np.where(do_accept, accept_state, reject_state)
def update(self, transition_batch, return_td_error=False): r""" Update the model parameters (weights) of the underlying function approximator. Parameters ---------- transition_batch : TransitionBatch A batch of transitions. return_td_error : bool, optional Whether to return the TD-errors. Returns ------- metrics : dict of scalar ndarrays The structure of the metrics dict is ``{name: score}``. td_error : ndarray, optional The non-aggregated TD-errors, :code:`shape == (batch_size,)`. This is only returned if we set :code:`return_td_error=True`. """ grads, function_state, metrics, td_error = self.grads_and_metrics( transition_batch) if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)): raise RuntimeError(f"found nan's in grads: {grads}") self.update_from_grads(grads, function_state) return (metrics, td_error) if return_td_error else metrics
def nan_error_check(prim, error, enabled_errors, *in_vals, **params): out = prim.bind(*in_vals, **params) if ErrorCategory.NAN not in enabled_errors: return out, error no_nans = jnp.logical_not(jnp.any(jnp.isnan(out))) msg = f"nan generated by primitive {prim.name} at {summary()}" return out, assert_func(error, no_nans, msg, None)
def train_step(self, summary: Summary, data: dict, progress: np.ndarray): kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy()) for k, v in kv.items(): if jn.isnan(v): raise ValueError('NaN, try reducing learning rate', k) summary.scalar(k, float(v))
def _hmc_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key): if potential_fn_gen: nonlocal vv_update pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn) num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond(transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity) return vv_state, energy, num_steps, accept_prob, diverging
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy): step_size = np.where(going_right, step_size, -step_size) z_new, r_new, potential_energy_new, z_new_grad = vv_update( step_size, inverse_mass_matrix, (z, r, energy_current, z_grad), ) energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new) delta_energy = energy_new - energy_current # Handles the NaN case. delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) tree_weight = -delta_energy diverging = delta_energy > max_delta_energy accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad, z_new, potential_energy_new, z_new_grad, energy_new, depth=0, weight=tree_weight, r_sum=r_new, turning=False, diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1)
def update_state_arm(state: StateArm, samples: List, grads: List, error_fn: Callable, get_fb_grads=None) -> StateArm: last_sample = samples[-1] if samples is not None else state.last_sample if get_fb_grads: samples, grads = get_fb_grads(samples) samples = flatten_param_list(samples) grads = flatten_param_list(grads) # concatenate if state.samples is not None: all_samples = jnp.concatenate([state.samples, samples]) else: all_samples = samples if state.grads is not None: all_grads = jnp.concatenate([state.grads, grads]) else: all_grads = grads _metric = error_fn(samples, grads) metric = _metric if not jnp.isnan(_metric) else jnp.inf state = StateArm(hyperparameters=state.hyperparameters, run_timed_sampler=state.run_timed_sampler, last_sample=last_sample, samples=all_samples, grads=all_grads, metric=metric ) return state