def test_top_k_dist_warper(self): input_ids = None vocab_size = 10 batch_size = 2 # create ramp distribution ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size top_k_warp = FlaxTopKLogitsWarper(3) scores = top_k_warp(input_ids, ramp_logits, cur_len=None) # check that correct tokens are filtered self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True]) # check special case length = 5 top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy() scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None) # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
def _normalize(x, y): z = jnp.where(jnp.isinf(x), x, x + y) zz = jnp.where( lax.abs(x) > lax.abs(y), x - z + y, y - z + x, ) return z, jnp.where(jnp.isinf(z), 0, zz)
def adam(gradval, params, eta=0.01, beta1=0.9, beta2=0.9, eps=1e-7, c=0.01, R=500, per=100, disp=None, log=False): params = {k: np.array(v) for k, v in params.items()} gavg = {k: np.zeros_like(v) for k, v in params.items()} grms = {k: np.zeros_like(v) for k, v in params.items()} n = len(params) if log: hist = {k: [] for k in params} # iterate to max for j in range(R + 1): val, grad = gradval(params) # test for nans vnan = np.isnan(val) or np.isinf(val) gnan = np.array( [np.isnan(g).any() or np.isinf(g).any() for g in grad.values()]) if vnan or gnan.any(): print('Encountered nan/inf!') disp(j, val, params) summary_naninf(grad) return params # clip gradients gradc = {k: clip2(grad[k], c) for k in params} # early bias correction etat = eta * (np.sqrt(1 - beta2**(j + 1))) / (1 - beta1**(j + 1)) for k in params: gavg[k] += (1 - beta1) * (gradc[k] - gavg[k]) grms[k] += (1 - beta2) * (gradc[k]**2 - grms[k]) params[k] += etat * gavg[k] / (np.sqrt(grms[k]) + eps) if log: hist[k].append(params[k]) if disp is not None and j % per == 0: disp(j, val, params) if log: hist = {k: np.stack(v) for k, v in hist.items()} return hist return params
def _contains_query(vals, query): if isinstance(query, tuple): return map(partial(_contains_query, vals), query) if np.isnan(query): if np.any(np.isnan(vals)): raise FoundValue('NaN') elif np.isinf(query): if np.any(np.isinf(vals)): raise FoundValue('Found Inf') elif np.isscalar(query): if np.any(vals == query): raise FoundValue(str(query)) else: raise ValueError('Malformed Query: {}'.format(query))
def infer(model, state, unroll = 1, eta = 0.001, jit = True): init, update, params = adam(eta) opt = init(state.xs) window = collections.deque(maxlen=windowlen(5000, unroll)) def step(rng, opt, t): keys = random.split(rng, unroll) for i in range(0, unroll): xs = params(opt) L, dxs = jax.value_and_grad(energy, argnums=1)(model, xs) opt = update(t+i, dxs, opt) return L, opt if jit: step = jax.jit(step) try: while True: state.rng, key = random.split(state.rng) L, opt = step(key, opt, state.t) if np.isnan(L) or np.isinf(L): raise Exception("Invalid loss") state.xs = params(opt) state.t += unroll window.append(-L.item()) print("\r[ %e, %.2e iterations " % (statistics.mean(window), state.t), end = "") except KeyboardInterrupt: pass return state
def test_min_length_dist_processor(self): vocab_size = 20 batch_size = 4 eos_token_id = 0 min_dist_processor = FlaxMinLengthLogitsProcessor( min_length=10, eos_token_id=eos_token_id) # check that min length is applied at length 5 input_ids = ids_tensor((batch_size, 20), vocab_size=20) cur_len = 5 scores = self._get_uniform_logits(batch_size, vocab_size) scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len) self.assertListEqual( scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")]) # check that min length is not applied anymore at length 15 scores = self._get_uniform_logits(batch_size, vocab_size) cur_len = 15 scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len) self.assertFalse(jnp.isinf(scores_before_min_length).any())
def tree_count_infs_nans(tree, psum_axis_name=None): leaves = jax.tree_leaves(tree) num_infs = sum(jnp.sum(jnp.isinf(x)) for x in leaves) num_nans = sum(jnp.sum(jnp.isnan(x)) for x in leaves) if psum_axis_name: num_infs, num_nans = jax.lax.psum((num_infs, num_nans), axis_name=psum_axis_name) return num_infs, num_nans
def create_model(yT, yC, num_components): # Cosntants nC = yC.shape[0] nT = yT.shape[0] zC = jnp.isinf(yC).sum().item() zT = jnp.isinf(yT).sum().item() yT_finite = yT[jnp.isinf(yT) == False] yC_finite = yC_finite = yC[jnp.isinf(yC) == False] K = num_components p = numpyro.sample('p', dist.Beta(.5, .5)) gammaC = numpyro.sample('gammaC', dist.Beta(1, 1)) gammaT = numpyro.sample('gammaT', dist.Beta(1, 1)) etaC = numpyro.sample('etaC', dist.Dirichlet(jnp.ones(K) / K)) etaT = numpyro.sample('etaT', dist.Dirichlet(jnp.ones(K) / K)) with numpyro.plate('mixutre_components', K): nu = numpyro.sample('nu', dist.LogNormal(3.5, 0.5)) mu = numpyro.sample('mu', dist.Normal(0, 3)) sigma = numpyro.sample('sigma', dist.LogNormal(0, .5)) phi = numpyro.sample('phi', dist.Normal(0, 3)) gammaT_star = simulate_data.compute_gamma_T_star(gammaC, gammaT, p) etaT_star = simulate_data.compute_eta_T_star(etaC, etaT, p, gammaC, gammaT, gammaT_star) with numpyro.plate('y_C', nC - zC): numpyro.sample('finite_obs_C', Mix(nu[None, :], mu[None, :], sigma[None, :], phi[None, :], etaC[None, :]), obs=yC_finite[:, None]) with numpyro.plate('y_T', nT - zT): numpyro.sample('finite_obs_T', Mix(nu[None, :], mu[None, :], sigma[None, :], phi[None, :], etaT_star[None, :]), obs=yT_finite[:, None]) numpyro.sample('N_C', dist.Binomial(nC, gammaC), obs=zC) numpyro.sample('N_T', dist.Binomial(nT, gammaT_star), obs=zT)
def run_hmc_steps(theta, eps, Lmax, key, log_posterior, log_posterior_grad_theta, diagonal_mass_matrix): # Diagonal mass matrix: diagonal entries of M (a vector) inverse_diag_mass = 1. / diagonal_mass_matrix key, subkey = random.split(key) # Location-scale transform to get the right variance # TODO: Check! phi = random.normal( subkey, shape=(theta.shape[0], )) * np.sqrt(diagonal_mass_matrix) start_theta = theta start_phi = phi cur_grad = log_posterior_grad_theta(theta) key, subkey = random.split(key) L = np_classic.random.randint(1, Lmax) for cur_l in range(L): phi = phi + 0.5 * eps * cur_grad theta = theta + eps * inverse_diag_mass * phi cur_grad = log_posterior_grad_theta(theta) phi = phi + 0.5 * eps * cur_grad # Compute (log) acceptance probability proposed_log_post = log_posterior(theta) previous_log_post = log_posterior(start_theta) proposed_log_phi = np.sum( norm.logpdf(phi, scale=np.sqrt(diagonal_mass_matrix))) previous_log_phi = np.sum( norm.logpdf(start_phi, scale=np.sqrt(diagonal_mass_matrix))) print(f'Proposed log posterior is: {proposed_log_post}.' f'Previous was {previous_log_post}.') if (np.isinf(proposed_log_post) or np.isnan(proposed_log_post) or np.isneginf(proposed_log_post)): # Reject was_accepted = False new_theta = start_theta # FIXME: What number to put here? log_r = -10 return was_accepted, log_r, new_theta log_r = (proposed_log_post + proposed_log_phi - previous_log_post - previous_log_phi) was_accepted, new_theta = acceptance_step(log_r, theta, start_theta, key) return was_accepted, log_r, new_theta
def __call__(self, controller_state, obs): state, t = obs.predicted_pressure, obs.time errs, waveform = controller_state.errs, controller_state.waveform fwd_targets = controller_state.fwd_targets target = waveform.at(t) fwd_t = t + self.fwd_history_len * DEFAULT_DT if self.fwd_history_len > 0: fwd_target = jax.lax.cond(fwd_t >= self.horizon * DEFAULT_DT, lambda x: fwd_targets[-1], lambda x: waveform.at(fwd_t), None) if self.normalize: target_normalized = self.p_normalizer(target).squeeze() state_normalized = self.p_normalizer(state).squeeze() next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target_normalized - state_normalized) if self.fwd_history_len > 0: fwd_target_normalized = self.p_normalizer(fwd_target).squeeze() next_fwd_targets = jnp.roll(fwd_targets, shift=-1) next_fwd_targets = next_fwd_targets.at[-1].set( fwd_target_normalized) else: next_fwd_targets = jnp.array([]) else: next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target - state) if self.fwd_history_len > 0: next_fwd_targets = jnp.roll(fwd_targets, shift=-1) next_fwd_targets = next_fwd_targets.at[-1].set(fwd_target) else: next_fwd_targets = jnp.array([]) controller_state = controller_state.replace( errs=next_errs, fwd_targets=next_fwd_targets) decay = self.decay(waveform, t) def true_func(null_arg): trajectory = jnp.hstack([next_errs, next_fwd_targets]) u_in = self.model_apply({"params": self.params}, trajectory) return u_in.squeeze().astype(jnp.float64) # changed decay compare from None to float(inf) due to cond requirements u_in = jax.lax.cond(jnp.isinf(decay), true_func, lambda x: jnp.array(decay), None) u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), self.clip).squeeze() # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, u_in
def log_normal_with_outliers(x, mean, cov, sigma): """ Computes log-Normal density with outliers removed. Args: x: RV value mean: mean of Gaussian cov: covariance of underlying, minus the obs. covariance sigma: stddev's of obs. error, inf encodes an outlier. Returns: a normal density for all points not of inf stddev obs. error. """ C = cov / (sigma[:, None] * sigma[None, :]) + jnp.eye(cov.shape[0]) L = jnp.linalg.cholesky(C) Ls = sigma[:, None] * L log_det = jnp.sum(jnp.where(jnp.isinf(sigma), 0., jnp.log(jnp.diag(Ls)))) dx = (x - mean) dx = solve_triangular(L, dx / sigma, lower=True) maha = dx @ dx log_likelihood = -0.5 * jnp.sum(~jnp.isinf(sigma)) * jnp.log(2. * jnp.pi) \ - log_det \ - 0.5 * maha return log_likelihood
def __call__(self, controller_state, obs): state, t = obs.predicted_pressure, obs.time errs, waveform = controller_state.errs, controller_state.waveform target = waveform.at(t) if self.normalize: target_normalized = self.p_normalizer(target).squeeze() state_normalized = self.p_normalizer(state).squeeze() next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target_normalized - state_normalized) else: next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target - state) controller_state = controller_state.replace(errs=next_errs) decay = self.decay(waveform, t) def true_func(null_arg): trajectory = jnp.expand_dims(next_errs[-self.history_len:], axis=(0, 1)) input_val = jnp.reshape((trajectory @ self.featurizer), (1, self.history_len, 1)) u_in = self.model_apply({"params": self.params}, input_val) return u_in.squeeze().astype(jnp.float32) # changed decay compare from None to float(inf) due to cond requirements u_in = jax.lax.cond(jnp.isinf(decay), true_func, lambda x: jnp.array(decay), None) # Implementing "leaky" clamp to solve the zero gradient problem if self.use_leaky_clamp: u_in = jax.lax.cond(u_in < 0.0, lambda x: x * 0.01, lambda x: x, u_in) u_in = jax.lax.cond(u_in > self.clip, lambda x: self.clip + x * 0.01, lambda x: x, u_in) else: u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float32), self.clip).squeeze() # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, u_in
def annotated_with_grad(flat_theta, summary): flat_theta = jnp.array(flat_theta) obj, grad = with_grad(flat_theta) print(obj, jnp.linalg.norm(grad)) if jnp.isnan(obj) or jnp.isinf(obj) or jnp.any(jnp.isnan(grad)): import ipdb problem = reconstruct(flat_theta, summary, jnp.reshape) ipdb.set_trace() return np.array(obj).astype(np.float64), np.array(grad).astype( np.float64)
def loss(params, obs): diags = model(params, obs) # obs, _, _, info = env.step(diags) # norm_res = jax.vmap(lambda x: jnp.linalg.norm(x[1, :], jnp.inf))(obs) # return jnp.mean(norm_res) _, _, _, info = env.step(diags) # assert jnp.allclose(info['residual'] * info['niter'], # jax.vmap(lambda x, y: x * y)( # info['residual'], info['niter'])) norm_res = jnp.mean(info['residual'] * info['niter']) # norm_res = jnp.mean(jnp.array(list(map( # lambda i: info[0]['residual' + str(i)], # range(args.batch_size) # )))) if jnp.isnan(norm_res): raise ValueError('encountered NaNs') if jnp.isinf(norm_res): raise ValueError('encountered infs') return norm_res
def test_forced_bos_token_logits_processor(self): vocab_size = 20 batch_size = 4 bos_token_id = 0 logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) # check that all scores are -inf except the bos_token_id score input_ids = ids_tensor((batch_size, 1), vocab_size=20) cur_len = 1 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all()) self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero # check that bos_token_id is not forced if current length is greater than 1 cur_len = 3 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertFalse(jnp.isinf(scores).any())
def test_forced_eos_token_logits_processor(self): vocab_size = 20 batch_size = 4 eos_token_id = 0 max_length = 5 logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) # check that all scores are -inf except the eos_token_id when max_length is reached input_ids = ids_tensor((batch_size, 4), vocab_size=20) cur_len = 4 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all()) self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero # check that eos_token_id is not forced if max_length is not reached cur_len = 3 scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertFalse(jnp.isinf(scores).any())
def metric(points, centers, K): # N cluster_id = masked_cluster_id(points, centers, K) # N,N dist = jnp.linalg.norm(points[:, None, :] - points[None, :, :], axis=-1) # N,N in_group = cluster_id[:, None] == cluster_id[None, :] in_group_dist, w = jnp.average(dist, weights=in_group, axis=-1, returned=True) in_group_dist *= w / (w - 1.) # max_K, N, N out_group = (~in_group) & (jnp.arange(max_K)[:, None, None] == cluster_id[None, None, :]) # max_K, N out_group_dist = jnp.sum(dist * out_group, axis=-1) / jnp.sum(out_group, axis=-1) out_group_dist = jnp.where(jnp.isnan(out_group_dist), jnp.inf, out_group_dist) # N out_group_dist = jnp.min(out_group_dist, axis=0) out_group_dist = jnp.where(jnp.isinf(out_group_dist), jnp.max(in_group_dist), out_group_dist) sillohette = (out_group_dist - in_group_dist) / jnp.maximum(in_group_dist, out_group_dist) # condition for pos def cov sillohette = jnp.where(w < points.shape[1], -jnp.inf, sillohette) return jnp.mean(sillohette), cluster_id
def log_prob(self, value: Array) -> Array: """Calculates the log probability of an event. This implementation differs slightly from the one in TFP, as it returns `-jnp.inf` on non-integer values instead of returning the log prob of the floor of the input. In addition, this implementation also returns `-jnp.inf` on inputs that are outside the support of the distribution (as opposed to `nan`, like TFP does). On other integer values, both implementations are identical. Args: value: An event. Returns: The log probability log P(value). """ is_integer = jnp.where(value > jnp.floor(value), False, True) log_cdf = self.log_cdf(value) log_cdf_m1 = self.log_cdf(value - 1.) log_probs = math.log_expbig_minus_expsmall(log_cdf, log_cdf_m1) return jnp.where(jnp.isinf(log_cdf), -jnp.inf, jnp.where(is_integer, log_probs, -jnp.inf))
def isinf(self: TensorType) -> TensorType: return type(self)(np.isinf(self.raw))
def to_array(self, dtype=None): head, tail = self._tup if dtype is not None: head = head.astype(dtype) tail = tail.astype(dtype) return head + jnp.where(jnp.isinf(head), 0, tail)
imag = utils.copy_docstring(tf.math.imag, lambda input, name=None: np.imag(input)) # in_top_k = utils.copy_docstring( # tf.math.in_top_k, # lambda targets, predictions, k, name=None: np.in_top_k) # TODO(b/256095991): Add unit-test. invert_permutation = utils.copy_docstring(tf.math.invert_permutation, lambda x, name=None: np.argsort) is_finite = utils.copy_docstring(tf.math.is_finite, lambda x, name=None: np.isfinite(x)) is_inf = utils.copy_docstring(tf.math.is_inf, lambda x, name=None: np.isinf(x)) is_nan = utils.copy_docstring(tf.math.is_nan, lambda x, name=None: np.isnan(x)) is_non_decreasing = utils.copy_docstring( tf.math.is_non_decreasing, lambda x, name=None: np.all(x[1:] >= x[:-1])) is_strictly_increasing = utils.copy_docstring( tf.math.is_strictly_increasing, lambda x, name=None: np.all(x[1:] > x[:-1])) l2_normalize = utils.copy_docstring( tf.math.l2_normalize, lambda x, axis=None, epsilon=1e-12, name=None: ( # pylint: disable=g-long-lambda np.linalg.norm(x, ord=2, axis=axis, keepdims=True)))
def solve_and_smooth(gain_outliers, phase_obs, times, freqs): logger.info("Performing solve for tec and const from phases.") Nd, Na, Nf, Nt = phase_obs.shape logger.info("Number of nan: {}".format(jnp.sum(jnp.isnan(phase_obs)))) logger.info("Number of inf: {}".format(jnp.sum(jnp.isinf(phase_obs)))) # blocksize chosen to maximise Fisher information, which is 2 for tec+const, and 3 for tec+const+clock blocksize = 2 remainder = Nt % blocksize if remainder != 0: if remainder < Nt: raise ValueError( f"Block size {blocksize} too big for number of timesteps {Nt}." ) (gain_outliers, phase_obs) = tree_map( lambda x: jnp.concatenate( [x, jnp.repeat(x[..., -1:], remainder, axis=-1)], axis=-1), (gain_outliers, phase_obs)) Nt = Nt + remainder times = jnp.concatenate([ times, times[-1] + jnp.arange(1, 1 + remainder) * jnp.mean(jnp.diff(times)) ]) size_dict = dict(d=Nd, a=Na, f=Nf, b=blocksize) # [Nd*Na*(Nt//blocksize), blocksize, Nf] gain_outliers = axes_move(gain_outliers, ['d', 'a', 'f', 'tb'], ['dat', 'b', 'f'], size_dict=size_dict) phase_obs = axes_move(phase_obs, ['d', 'a', 'f', 'tb'], ['dat', 'b', 'f'], size_dict=size_dict) T = Nd * Na * (Nt // blocksize) # Nd * Na * (Nt // blocksize) keys = random.split(random.PRNGKey(int(1000 * default_timer())), T) # [Nd*Na*(Nt//blocksize), blocksize], [# Nd*Na*(Nt//blocksize), blocksize] tec_mean, tec_std, const_mean, const_std, uncert_mean = chunked_pmap( lambda *args: unconstrained_solve(freqs, *args), keys, phase_obs, gain_outliers) # Nd*Na*(Nt//blocksize), blocksize const_weights = 1. / const_std**2 def smooth(y, weights): y = axes_move(y, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict) weights = axes_move(weights, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict) y = chunked_pmap( lambda y, weights: poly_smooth(times, y, deg=5, weights=weights), y, weights) y = axes_move(y, ['da', 'tb'], ['dat', 'b'], size_dict=size_dict) return y logger.info("Smoothing and outlier rejection of const (a weak prior).") # Nd,Na,Nt/blocksize, blocksize const_real_mean = smooth(jnp.cos(const_mean), const_weights) # Nd*Na*(Nt//blocksize), blocksize const_imag_mean = smooth(jnp.sin(const_mean), const_weights) # Nd*Na*(Nt//blocksize), blocksize const_mean_smoothed = jnp.arctan2( const_imag_mean, const_real_mean) # Nd*Na*(Nt//blocksize), blocksize # empirically determined uncertainty point where sigma(tec - tec_true) > 6 mTECU which_reprocess = jnp.any(uncert_mean > 0., axis=1) # Nd*Na*(Nt//blocksize) replace_map = jnp.where(which_reprocess) logger.info("Performing refined tec-only solve, with fixed const.") keys = random.split(random.PRNGKey(int(1000 * default_timer())), jnp.sum(which_reprocess)) # [Nd*Na*(Nt//blocksize), blocksize] (tec_mean_constrained, tec_std_constrained, const_mean_constrained, const_std_constrained) = \ chunked_pmap(lambda *args: constrained_solve(freqs, *args), keys, phase_obs[which_reprocess], gain_outliers[which_reprocess], const_mean_smoothed[which_reprocess], const_std[which_reprocess] ) tec_mean = tec_mean.at[replace_map].set(tec_mean_constrained) tec_std = tec_std.at[replace_map].set(tec_std_constrained) const_std = const_std.at[replace_map].set(const_std_constrained) const_mean = const_mean.at[replace_map].set(const_mean_constrained) (tec_mean, tec_std, const_mean, const_std) = tree_map(lambda x: x.reshape((Nd, Na, Nt)), (tec_mean, tec_std, const_mean, const_std)) # Nd, Na, Nt logger.info("Performing outlier detection on tec values.") tec_est, tec_outliers = detect_tec_outliers(times, tec_mean, tec_std) tec_std = jnp.where(tec_outliers, jnp.inf, tec_std) # remove remainder at the end if remainder != 0: (tec_mean, tec_std, const_mean, const_std) = tree_map(lambda x: x[..., :Nt - remainder], (tec_mean, tec_std, const_mean, const_std)) # compute phase mean with outlier-suppressed tec. phase_mean = tec_mean[..., None, :] * ( TEC_CONV / freqs[:, None]) + const_mean[..., None, :] phase_uncert = jnp.sqrt((tec_std[..., None, :] * (TEC_CONV / freqs[:, None]))**2 + (const_std[..., None, :])**2) return phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std
def summary_naninf(d): for k, v in d.items(): if (na := np.isnan(v)).any(): print(f'{k} nan: {na.sum()}/{na.size}') if (nf := np.isinf(v)).any(): print(f'{k} inf: {nf.sum()}/{nf.size}')
def single_acceptance(args): """Draws a proposal, simulates and compresses, checks distance A new proposal is drawn from a truncated multivariate normal distribution whose mean is centred on the parameter to move and the covariance is set by the population. From this proposed parameter value a simulation is made and compressed and the distance from the target is calculated. If this distance is less than the current position then the proposal is accepted. Parameters ---------- args : tuple see loop variable in `single_iteration` Returns ------- bool: True if proposal not accepted and number of attempts to get an accepted proposal not yet reached Todo ---- Parallel sampling is currently commented out """ (rng, loc, scale, summ, dis, draws, accepted, acceptance_counter) = args rng, key = jax.random.split(rng) proposed, summaries = self.get_samples( key, None, dist=tmvn(loc, scale, self.prior.low, self.prior.high, max_counter=max_samples)) distances = np.squeeze( self.distance_measure(np.expand_dims(summaries, 0), target, F)) # if n_parallel_simulations is not None: # min_distance_index = np.argmin(distances) # min_distance = distances[min_distance_index] # closer = np.less(min_distance, ϵ) # loc = jax.lax.cond( # closer, # lambda _ : proposed[min_distance_index], # lambda _ : loc, # None) # summ = jax.lax.cond( # closer, # lambda _ : summaries[min_distance_index], # lambda _ : summ, # None) # dis = jax.lax.cond( # closer, # lambda _ : distances[min_distance_index], # lambda _ : dis, # None) # iteration_draws = n_parallel_simulations \ # - np.isinf(distances).sum() # draws += iteration_draws # accepted = closer.sum() # else: closer = np.less(distances, np.min(dis)) loc = jax.lax.cond(closer, lambda _: proposed, lambda _: loc, None) summ = jax.lax.cond(closer, lambda _: summaries, lambda _: summ, None) dis = jax.lax.cond(closer, lambda _: distances, lambda _: dis, None) iteration_draws = 1 - np.isinf(distances).sum() draws += iteration_draws accepted = closer.sum() return (rng, loc, scale, summ, dis, draws, accepted, acceptance_counter + 1)
def main(data_dir, working_dir, obs_num, ncpu, plot_results): os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}" logger.info("Performing data smoothing via tec+const+clock inference.") dds4_h5parm = os.path.join(data_dir, 'L{}_DDS4_full_merged.h5'.format(obs_num)) dds5_h5parm = os.path.join(working_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num)) linked_dds5_h5parm = os.path.join( data_dir, 'L{}_DDS5_full_merged.h5'.format(obs_num)) logger.info("Looking for {}".format(dds4_h5parm)) link_overwrite(dds5_h5parm, linked_dds5_h5parm) prepare_soltabs(dds4_h5parm, dds5_h5parm) gain_outliers, phase_obs, amp, times, freqs = get_data( solution_file=dds4_h5parm) phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std = \ solve_and_smooth(gain_outliers, phase_obs, times, freqs) # exit(0) logger.info("Storing smoothed phase, amplitudes, tec, const, and clock") with DataPack(dds5_h5parm, readonly=False) as h: h.current_solset = 'sol000' # h.select(pol=slice(0, 1, 1), ant=slice(50, 51), dir=slice(0, None, 1), time=slice(0, 100, 1)) h.select(pol=slice(0, 1, 1)) h.phase = np.asarray(phase_mean)[None, ...] h.weights_phase = np.asarray(phase_uncert)[None, ...] h.amplitude = np.asarray(amp)[None, ...] h.tec = np.asarray(tec_mean)[None, ...] h.tec_outliers = np.asarray(tec_outliers)[None, ...] h.weights_tec = np.asarray(tec_std)[None, ...] h.const = np.asarray(const_mean)[None, ...] axes = h.axes_phase patch_names, _ = h.get_directions(axes['dir']) antenna_labels, _ = h.get_antennas(axes['ant']) if plot_results: diagnostic_data_dir = os.path.join(working_dir, 'diagnostic') os.makedirs(diagnostic_data_dir, exist_ok=True) logger.info("Plotting results.") data_plot_dir = os.path.join(working_dir, 'data_plots') os.makedirs(data_plot_dir, exist_ok=True) Nd, Na, Nf, Nt = phase_mean.shape for ia in range(Na): for id in range(Nd): fig, axs = plt.subplots(3, 1, sharex=True) axs[0].plot(times, tec_mean[id, ia, :], c='black', label='tec') ylim = axs[0].get_ylim() axs[0].vlines(times[tec_outliers[id, ia, :]], *ylim, colors='red', label='outliers', alpha=0.5) axs[0].set_ylim(*ylim) axs[1].plot(times, const_mean[id, ia, :], c='black', label='const') axs[1].fill_between( times, const_mean[id, ia, :] - const_std[id, ia, :], const_mean[id, ia, :] + const_std[id, ia, :], color='black', alpha=0.2) ylim = axs[1].get_ylim() axs[1].vlines(times[tec_outliers[id, ia, :]], *ylim, colors='red', label='outliers', alpha=0.5) axs[1].set_ylim(*ylim) axs[2].plot(times, tec_std[id, ia, :], c='black', label='tec_std') ylim = axs[2].get_ylim() axs[2].vlines(times[tec_outliers[id, ia, :]], *ylim, colors='red', label='outliers', alpha=0.5) axs[2].set_ylim(*ylim) axs[0].legend() axs[1].legend() axs[2].legend() axs[0].set_ylabel("DTEC [mTECU]") axs[1].set_ylabel("const [rad]") axs[2].set_ylabel("DTEC uncert [mTECU]") axs[2].set_xlabel("time [s]") fig.savefig( os.path.join( data_plot_dir, 'solutions_ant{:02d}_dir{:02d}.png'.format(ia, id))) plt.close("all") fig, axs = plt.subplots(4, 1, sharex=True, sharey=True) # phase data with input outliers # phase posterior with tec outliers # dphase with no outliers # phase uncertainty axs[0].imshow(phase_obs[id, ia, :, :], vmin=-jnp.pi, vmax=jnp.pi, cmap='twilight', aspect='auto', origin='lower', interpolation='nearest') axs[0].imshow(jnp.where(gain_outliers[id, ia, :, :], 1., jnp.nan), vmin=0., vmax=1., cmap='bone', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[0], "twilight", vmin=-jnp.pi, vmax=jnp.pi) axs[1].imshow(phase_mean[id, ia, :, :], vmin=-jnp.pi, vmax=jnp.pi, cmap='twilight', aspect='auto', origin='lower', interpolation='nearest') axs[1].imshow(jnp.where(jnp.isinf(phase_uncert[id, ia, :, :]), 1., jnp.nan), vmin=0., vmax=1., cmap='bone', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[1], "twilight", vmin=-jnp.pi, vmax=jnp.pi) dphase = wrap(wrap(phase_mean) - phase_obs) vmin = -0.5 vmax = 0.5 axs[2].imshow(dphase[id, ia, :, :], vmin=vmin, vmax=vmax, cmap='PuOr', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[2], "PuOr", vmin=vmin, vmax=vmax) vmin = 0. vmax = 0.8 axs[3].imshow(phase_uncert[id, ia, :, :], vmin=vmin, vmax=vmax, cmap='PuOr', aspect='auto', origin='lower', interpolation='nearest') add_colorbar_to_axes(axs[3], "PuOr", vmin=vmin, vmax=vmax) axs[0].set_ylabel("freq [MHz]") axs[1].set_ylabel("freq [MHz]") axs[2].set_ylabel("freq [MHz]") axs[3].set_ylabel("freq [MHz]") axs[3].set_xlabel("time [s]") axs[0].set_title("phase data [rad]") axs[1].set_title("phase model [rad]") axs[2].set_title("phase diff. [rad]") axs[3].set_title("phase uncert [rad]") fig.savefig( os.path.join( data_plot_dir, 'data_comparison_ant{:02d}_dir{:02d}.png'.format( ia, id))) plt.close("all") # exit(0) d = os.path.join(working_dir, 'tec_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, vmin=-60, vmax=60., observable='tec', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr) # os.makedirs(d, exist_ok=True) # DatapackPlotter(dds5_h5parm).plot( # fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)], # vmin=-60, # vmax=60., observable='tec', phase_wrap=False, plot_crosses=False, # plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, # solset='sol000', cmap=plt.cm.PuOr) # make_animation(d, prefix='fig', fps=4) d = os.path.join(working_dir, 'const_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, vmin=-np.pi, vmax=np.pi, observable='const', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr) # os.makedirs(d, exist_ok=True) # DatapackPlotter(dds5_h5parm).plot( # fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)], # vmin=-np.pi, # vmax=np.pi, observable='const', phase_wrap=False, plot_crosses=False, # plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, # solset='sol000', cmap=plt.cm.PuOr) # make_animation(d, prefix='fig', fps=4) d = os.path.join(working_dir, 'clock_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, vmin=None, vmax=None, observable='clock', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr) # os.makedirs(d, exist_ok=True) # DatapackPlotter(dds5_h5parm).plot( # fignames=[os.path.join(d, "fig-{:04d}.png".format(j)) for j in range(Nt)], # vmin=None, # vmax=None, # observable='clock', phase_wrap=False, plot_crosses=False, # plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, # solset='sol000', cmap=plt.cm.PuOr) # make_animation(d, prefix='fig', fps=4) d = os.path.join(working_dir, 'amplitude_plots') animate_datapack(dds5_h5parm, d, num_processes=(ncpu * 2) // 3, log_scale=True, observable='amplitude', phase_wrap=False, plot_crosses=False, plot_facet_idx=True, labels_in_radec=True, per_timestep_scale=True, solset='sol000', cmap=plt.cm.PuOr)
def has_inf_or_nan(x): return jnp.isinf(x).any() or jnp.isnan(x).any()
def isinf(self, a): return jnp.isinf(a)
def isinf(x): if isinstance(x, JaxArray): return JaxArray(jnp.isinf(x.value)) else: return jnp.isinf(x)
def linesearch(cost_and_grad, x, d, f0, df0, g0, aold=None, dfold=None, fold=None, ls_pars=None): if ls_pars is None: ls_pars = LineSearchParameter() # Wolfe conditions def wolfe_one(ai, fi): return fi > f0 + ls_pars.ls_suff_decr * ai * df0 def wolfe_two(dfi): return jnp.abs(dfi) <= - ls_pars.ls_curvature * df0 state = _LineSearchState( done=False, failed=False, i=1, a_i1=0., phi_i1=f0, dphi_i1=df0, nfev=0, ngev=0, a_star=0, phi_star=f0, dphi_star=df0, g_star=g0, ) if ls_pars.ls_verbosity >= 1: print('\tStarting linesearch...') if (aold is not None) and (dfold is not None): alpha_0 = aold * dfold / df0 initial_step_length = alpha_0 initial_step_length = jnp.where(alpha_0 > ls_pars.ls_initial_step, ls_pars.ls_initial_step, alpha_0) elif (fold is not None) and (~jnp.isinf(fold)): candidate = 1.01 * 2 * jnp.abs((f0 - fold) / df0) candidate = jnp.where(candidate < 1e-8, 1e-8, candidate) initial_step_length = jnp.where(candidate > 1.2 * ls_pars.ls_initial_step, ls_pars.ls_initial_step, candidate) if ls_pars.ls_verbosity >= 3: print(f'\tcandidate: {candidate:.2e}, accepted: {initial_step_length:.2e}') else: initial_step_length = ls_pars.ls_initial_step while ((~state.done) & (state.i <= ls_pars.ls_maxiter) & (~state.failed)): # no amax in this version, we just double as in scipy. # unlike original algorithm we do our choice at the start of this loop ai = jnp.where( state.i == 1, initial_step_length, state.a_i1 * ls_pars.ls_optimism ) fi, gri, dfi = cost_and_grad(ai) state = state._replace( nfev=state.nfev + 1, ngev=state.ngev + 1 ) # if dfi > 0: # state._replace( # failed=True, # ) # break while jnp.isnan(fi): ai = ai / 10. fi, gri, dfi = cost_and_grad(ai) state = state._replace( nfev=state.nfev + 1, ngev=state.ngev + 1 ) if ls_pars.ls_verbosity >= 2: print("\titer: {}\n\t\talpha: {:.2e} " "f(alpha): {:.5e}".format(state.i, ai, fi)) if wolfe_one(ai, fi) or ((fi > state.phi_i1) and state.i > 1): if ls_pars.ls_verbosity >= 2: print('\t\tEntering zoom1...') zoom1 = _zoom(cost_and_grad, wolfe_one, wolfe_two, state.a_i1, state.phi_i1, state.dphi_i1, ai, fi, dfi, gri, ls_pars) state = state._replace( done=(zoom1.done or state.done), failed=(zoom1.failed or state.failed), a_star=zoom1.a_star, phi_star=zoom1.phi_star, dphi_star=zoom1.dphi_star, g_star=zoom1.g_star, nfev=state.nfev + zoom1.nfev, ngev=state.ngev + zoom1.ngev ) elif wolfe_two(dfi): if ls_pars.ls_verbosity >= 2: print('\t\tWolfe two condition met, stopping') state = state._replace( done=(True or state.done), a_star=ai, phi_star=fi, dphi_star=dfi, g_star=gri ) elif dfi >= 0: if ls_pars.ls_verbosity >= 2: print('\t\tEntering zoom2') zoom2 = _zoom(cost_and_grad, wolfe_one, wolfe_two, ai, fi, dfi, state.a_i1, state.phi_i1, state.dphi_i1, gri, ls_pars) state = state._replace( done=(zoom2.done or state.done), failed=(zoom2.failed or state.failed), a_star=zoom2.a_star, phi_star=zoom2.phi_star, dphi_star=zoom2.dphi_star, g_star=zoom2.g_star, nfev=state.nfev + zoom2.nfev, ngev=state.ngev + zoom2.ngev ) state = state._replace( i=state.i + 1, a_i1=ai, phi_i1=fi, dphi_i1=dfi) status = jnp.where( state.failed, jnp.array(2), # zoom failed jnp.where( state.i > ls_pars.ls_maxiter, jnp.array(1), # maxiter reached jnp.array(0), # passed (should be) ), ) result = _LineSearchResult( failed=state.failed, nit=state.i - 1, # because iterations started at 1 nfev=state.nfev, ngev=state.ngev, k=state.i, a_k=state.a_i1 if status==1 else state.a_star, f_k=state.phi_star, g_k=state.g_star, status=status, ) if ls_pars.ls_verbosity >= 1: print('\tLinesearch {}, alpha star = {:.2e}'.format( 'failed' if state.failed else 'done', result.a_k)) return result
def check(): if np.isnan(lls[-1]) or np.isinf(lls[-1]): raise Exception("LL check failed.")