def main(args): _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False) dept, male, applications, admit = fetch_train() rng_key, rng_key_predict = random.split(random.PRNGKey(1)) zs = run_inference(dept, male, applications, admit, rng_key, args) pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)["probs"] header = "=" * 30 + "glmm - TRAIN" + "=" * 30 print_results(header, pred_probs, dept, male, admit / applications) # make plots fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate") ax.errorbar( range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0), fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std", ) ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+") ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+") ax.set( xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI", ) ax.legend() plt.savefig("ucbadmit_plot.pdf")
def summarize_posterior(preds, ci=96): ci_lower = (100 - ci) / 2 ci_upper = (100 + ci) / 2 preds_mean = preds.mean(0) preds_lower = jnp.percentile(preds, ci_lower, axis=0) preds_upper = jnp.percentile(preds, ci_upper, axis=0) return preds_mean, preds_lower, preds_upper
def summary(samples: Dict[str, jnp.DeviceArray], poisson: bool) -> Dict[str, Dict[str, jnp.DeviceArray]]: """Generate a nice summary given the samples from `get_samples` of the Poisson model Args: samples: samples as acquired from `predictive`.get_samples poisson: exponentiate the values if Poisson was used Returns: dict: summary over the samples """ site_stats = {} for k, v in samples.items(): if poisson: # Poisson model, thus we exponentiate! v = jnp.exp(v) else: v = v.astype(jnp.float32) # for percentile to work v = ops.index_update(v, v < 0., 0.) # avoid -1. site_stats[k] = { "mean": jnp.mean(v, 0), "std": jnp.std(v, 0), "5%": jnp.percentile(v, 5., axis=0), "25%": jnp.percentile(v, 25., axis=0), "75%": jnp.percentile(v, 75., axis=0), "95%": jnp.percentile(v, 95., axis=0), } return site_stats
def plot_iter_daily_shade(soln_inc,n,ymax=1,scale=1,int=0,Tint=1,loCI=5,upCI=95,plotThis=False,plotName="test"): """ plots the output (cumulative prevalence) from a multiple simulation, with or without an intervention. Shows mean and 95% CI soln_inc: 3D array of values for each iteration for each variable at each timepoint tvec: 1D vector of timepoints n: total population size ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N) scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values) int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0 Tint: Optional, timepoint (days) at which intervention was started loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval plotThis: True or False, whether a plot will be saved as pdf plotName: string, name of the plot to be saved """ tvec=np.arange(1,np.shape(soln_inc)[1]+1) soln_avg=np.average(soln_inc,axis=0) soln_loCI=np.percentile(soln_inc,loCI,axis=0) soln_upCI=np.percentile(soln_inc,upCI,axis=0) # linear scale # add averages plt.figure(figsize=(2*6.4, 4.0)) plt.subplot(121) plt.plot(tvec,soln_avg*scale) plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left") # add ranges plt.gca().set_prop_cycle(None) for i in range(0,7): plt.fill_between(tvec,soln_loCI[:,i]*scale,soln_upCI[:,i]*scale,alpha=0.3) if int==1: plt.plot([Tint,Tint],[0,ymax*scale],'k--') plt.ylim([0,ymax*scale]) plt.xlabel("Time (days)") plt.ylabel("Daily incidence") # log scale # add averages plt.subplot(122) plt.plot(tvec,soln_avg*scale) plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left") # add ranges plt.gca().set_prop_cycle(None) for i in range(0,7): plt.fill_between(tvec,soln_loCI[:,i]*scale,soln_upCI[:,i]*scale,alpha=0.3) if int==1: plt.plot([Tint,Tint],[scale/n,ymax*scale],'k--') plt.ylim([scale/n,ymax*scale]) plt.xlabel("Time (days)") plt.ylabel("Daily incidence") plt.semilogy() plt.tight_layout() if plotThis==True: plt.savefig(plotName+'.pdf',bbox_inches='tight') plt.show()
def test_quick_select(): arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26, 7]) # even number of points assert quick_sort_median(arr) == jnp.percentile(arr, 50, interpolation='higher') arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26]) # odd number of points assert quick_sort_median(arr) == jnp.percentile(arr, 50, interpolation='higher') arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26]) # odd number of points assert quick_sort_median(arr) == jnp.median(arr) arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26, 7]) # even number of points try: assert quick_sort_median(arr) == jnp.median(arr) except AssertionError: print("Not equivalent to median when array has an even size.")
def main(args): _, fetch = load_dataset(LYNXHARE, shuffle=False) year, data = fetch() # data is in hare -> lynx order # use dense_mass for better mixing rate mcmc = MCMC(NUTS(model, dense_mass=True), args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data)) mcmc.print_summary() # predict populations y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"] pop_pred = jnp.exp(y_pred) mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0) plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67) plt.plot(year, data[:, 1], "bx", label="true lynx") plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67) plt.plot(year, mu[:, 1], "b--", label="pred lynx") plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2) plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3) plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)") plt.title("Posterior predictive (80% CI) with predator-prey pattern.") plt.legend() plt.savefig("ode_plot.pdf") plt.tight_layout()
def _find_binning_thresholds(data, max_bins=256, subsample=200000, random_state=None): if 2 > max_bins or max_bins > 256: raise ValueError(f'max_bins={max_bins} should be no smaller than 2 ' f'and no larger than 256.') if random_state is None: random_state = int(time.time()) rng = random.PRNGKey(random_state) if subsample is not None and data.shape[0] > subsample: subset = random.shuffle(rng, np.arange(data.shape[0]))[:subsample] data = data[subset] dtype = data.dtype if dtype.kind != 'f': dtype = np.float32 percentiles = np.linspace(0, 100, num=max_bins + 1)[1:-1] binning_thresholds = [] for f_idx in range(data.shape[1]): col_data = np.array(data[:, f_idx], dtype=dtype, order='C') distinct_values = onp.unique(col_data) if len(distinct_values) <= max_bins: midpoints = (distinct_values[:-1] + distinct_values[1:]) midpoints *= 0.5 else: midpoints = np.percentile(col_data, percentiles, interpolation='midpoint').astype(dtype) binning_thresholds.append(midpoints) return tuple(binning_thresholds)
def plot_samples(self, samples, plot_fields=['y'], start='2020-03-04', T=None, ax=None, legend=True, forecast=False, n_samples=0, intervals=[50, 80, 95]): ''' Plotting method for SIR-type models. ''' ax = plt.axes(ax) T_data = self.horizon(samples, forecast=forecast) T = T_data if T is None else min(T, T_data) fields = {f: 0.0 + self.get(samples, f, forecast=forecast)[:,:T] for f in plot_fields} names = {f: self.names[f] for f in plot_fields} medians = {names[f]: np.median(v, axis=0) for f, v in fields.items()} t = pd.date_range(start=start, periods=T, freq='D') ax.set_prop_cycle(None) colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] # Plot medians df = pd.DataFrame(index=t, data=medians) df.plot(ax=ax, legend=legend) median_max = df.max().values # Plot samples if requested if n_samples > 0: for i, f in enumerate(fields): df = pd.DataFrame(index=t, data=fields[f][:n_samples,:].T) df.plot(ax=ax, legend=False, alpha=0.1) # Plot prediction intervals pi_max = 10 handles = [] for interval in intervals: low=(100.-interval)/2 high=100.-low pred_intervals = {names[f]: np.percentile(v, (low, high), axis=0) for f, v in fields.items()} for i, pi in enumerate(pred_intervals.values()): h = ax.fill_between(t, pi[0,:], pi[1,:], alpha=0.1, color=colors[i], label=interval) handles.append(h) pi_max = np.maximum(pi_max, np.nanmax(pi[1,:])) return median_max, pi_max
def _describe_fit(self, dist): description = super()._describe_fit(dist) def entry_distance_fn(x, density): return abs(1.0 - density / dist.pdf(x)) distances = vmap(entry_distance_fn)(self.xs, self.densities) description["max_distance"] = np.max(distances) description["90th_distance"] = np.percentile(distances, 90) description["mean_distance"] = np.mean(distances) return description
def estimate_clipping_thresholds(apply_fn, params, l2_norm_clip_percentile, graph, labels, subgraphs, estimation_indices, adjacency_normalization): """Estimates gradient clipping thresholds.""" dummy_state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=optax.identity()) grads = compute_updates_for_dp(dummy_state, graph, labels, subgraphs, estimation_indices, adjacency_normalization) grad_norms = jax.tree_map(jax.vmap(jnp.linalg.norm), grads) get_percentile = lambda norms: jnp.percentile(norms, l2_norm_clip_percentile) l2_norms_threshold = jax.tree_map(get_percentile, grad_norms) return l2_norms_threshold
def prep_utility_fn(params, x_b1, y_b1, utility_type, steps=1000, percentile=0, y_incumb=None): """Prepare the utility functions for the score.""" gp_util = gp.GPUtils() params = gp_util.fit_gp(x_b1, y_b1, params, steps=steps) if y_incumb is None: y_incumb = jnp.percentile(y_b1, percentile, axis=0) utility_measure = scores.UtilityMeasure(incumbent=y_incumb, params=params) utility_measure_fn = lambda key, x_batch, y_batch: getattr( utility_measure, utility_type)(y_batch) return utility_measure_fn, params
def plot_R0(mcmc_samples, start, ax=None): ax = plt.axes(ax) # Compute R0 over time gamma = mcmc_samples['gamma'][:, None] beta = mcmc_samples['beta'] t = pd.date_range(start=start, periods=beta.shape[1], freq='D') R0 = beta / gamma pi = np.percentile(R0, (10, 90), axis=0) df = pd.DataFrame(index=t, data={'R0': np.median(R0, axis=0)}) df.plot(style='-o', ax=ax) ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1) ax.axhline(1, linestyle='--')
def plot_samples(self, samples, plot_fields=['y'], start='2020-03-04', T=None, ax=None, legend=True, forecast=False): ''' Plotting method for SIR-type models. ''' ax = plt.axes(ax) T_data = self.horizon(samples, forecast=forecast) T = T_data if T is None else min(T, T_data) fields = { f: self.get(samples, f, forecast=forecast)[:, :T] for f in plot_fields } names = {f: self.names[f] for f in plot_fields} medians = {names[f]: np.median(v, axis=0) for f, v in fields.items()} pred_intervals = { names[f]: np.percentile(v, (10, 90), axis=0) for f, v in fields.items() } t = pd.date_range(start=start, periods=T, freq='D') ax.set_prop_cycle(None) # Plot medians df = pd.DataFrame(index=t, data=medians) df.plot(ax=ax, legend=legend) median_max = df.max().values # Plot prediction intervals pi_max = 10 for pi in pred_intervals.values(): ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1, label='CI') pi_max = np.maximum(pi_max, np.nanmax(pi[1, :])) return median_max, pi_max
def plot_growth_rate(mcmc_samples, start, model=SEIRModel, ax=None): ax = plt.axes(ax) # Compute growth rate over time beta = mcmc_samples['beta'] sigma = mcmc_samples['sigma'][:, None] gamma = mcmc_samples['gamma'][:, None] t = pd.date_range(start=start, periods=beta.shape[1], freq='D') growth_rate = SEIRModel.growth_rate((beta, sigma, gamma)) pi = np.percentile(growth_rate, (10, 90), axis=0) df = pd.DataFrame(index=t, data={'growth_rate': np.median(growth_rate, axis=0)}) df.plot(style='-o', ax=ax) ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1) ax.axhline(0, linestyle='--')
def plot_R0(mcmc_samples, start): fig = plt.figure(figsize=(5,3)) # Compute average R0 over time gamma = mcmc_samples['gamma'][:,None] beta = mcmc_samples['beta'] t = pd.date_range(start=start, periods=beta.shape[1], freq='D') R0 = beta/gamma pi = np.percentile(R0, (10, 90), axis=0) df = pd.DataFrame(index=t, data={'R0': np.median(R0, axis=0)}) df.plot(style='-o') plt.fill_between(t, pi[0,:], pi[1,:], alpha=0.1) plt.axhline(1, linestyle='--') #plt.tight_layout() return fig
def percentile(self, tensor_in, q, axis=None, interpolation="linear"): r""" Compute the :math:`q`-th percentile of the tensor along the specified axis. Example: >>> import pyhf >>> import jax.numpy as jnp >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]]) >>> pyhf.tensorlib.percentile(a, 50) DeviceArray(3.5, dtype=float64) >>> pyhf.tensorlib.percentile(a, 50, axis=1) DeviceArray([7., 2.], dtype=float64) Args: tensor_in (`tensor`): The tensor containing the data q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute axis (`number` or `tensor`): The dimensions along which to compute interpolation (:obj:`str`): The interpolation method to use when the desired percentile lies between two data points ``i < j``: - ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the fractional part of the index surrounded by ``i`` and ``j``. - ``'lower'``: ``i``. - ``'higher'``: ``j``. - ``'midpoint'``: ``(i + j) / 2``. - ``'nearest'``: ``i`` or ``j``, whichever is nearest. Returns: JAX ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis. """ return jnp.percentile(tensor_in, q, axis=axis, interpolation=interpolation)
def test_tomographic_weight_rel_err(): import pylab as plt from jax import jit for S in range(5, 30, 5): @jit def tomo_weight(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=S) @jit def tomo_weight_ref(gamma, x1, x2, p1, p2): return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150) rel_error = [] for i in range(400): keys = random.split(random.PRNGKey(i), 6) x1 = jnp.concatenate( [4. * random.uniform(keys[0], shape=(2, )), jnp.zeros((1, ))], axis=-1) p1 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True) x2 = jnp.concatenate( [4. * random.uniform(keys[2], shape=(2, )), jnp.zeros((1, ))], axis=-1) p2 = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1), jnp.ones((1, )) ], axis=-1) p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True) # x1 = random.normal(keys[0], shape_dict=(2,)) # p1 = random.normal(keys[1], shape_dict=(2,)) # x2 = random.normal(keys[2], shape_dict=(2,)) # p2 = random.normal(keys[3], shape_dict=(2,)) t1 = random.uniform(keys[4], shape=(10000, )) t2 = random.uniform(keys[5], shape=(10000, )) u1 = x1 + t1[:, None] * p1 u2 = x2 + t2[:, None] * p2 gamma = jnp.linalg.norm(u1 - u2, axis=1)**2 hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100) bins = jnp.linspace(bins.min(), bins.max(), 20) w = tomo_weight(bins, x1, x2, p1, p2) w_ref = tomo_weight_ref(bins, x1, x2, p1, p2) rel_error.append(jnp.max(jnp.abs(w - w_ref)) / jnp.max(w_ref)) rel_error = jnp.array(rel_error) plt.hist(rel_error, bins='auto') plt.title("{} : {:.2f}|{:.2f}|{:.2f}".format( S, *jnp.percentile(rel_error, [5, 50, 95]))) plt.show()
_, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False) dept, male, applications, admit = fetch_train() rng_key, rng_key_predict = random.split(random.PRNGKey(1)) kernel = NUTS(glmm) mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(rng_key, dept, male, applications, admit) zs = mcmc.get_samples() pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs'] fig, ax = plt.subplots(1, 1) ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate") ax.errorbar(range(1, 13), np.mean(pred_probs, 0), np.std(pred_probs, 0), fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std") ax.plot(range(1, 13), np.percentile(pred_probs, 5, 0), "k+") ax.plot(range(1, 13), np.percentile(pred_probs, 95, 0), "k+") ax.set(xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI") ax.legend() plt.savefig("ucbadmit_plot.pdf") plt.tight_layout()
df_data.index = year # use dense_mass for better mixing rate mcmc = MCMC( NUTS(model, dense_mass=True), num_warmup, num_samples, num_chains=num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data)) mcmc.print_summary() # predict populations y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"] pop_pred = jnp.exp(y_pred) mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0) plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67) plt.plot(year, data[:, 1], "bx", label="true lynx") plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67) plt.plot(year, mu[:, 1], "b--", label="pred lynx") plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2) plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3) plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)") plt.title("Posterior predictive (80% CI) with predator-prey pattern.") plt.legend()
def sample_answer(self, rng_key, distribution: dist.Distribution): num_samples = 100 samples = distribution.sample(key=rng_key, sample_shape=(num_samples, )) return jnp.percentile(samples, self.percentile())
def get_polynomial_form(): """ The polynomial form of log P(|n + t1*w1 - t2*w2|^2 < lambda') is assumed to be: c_i = Q_ij p_j log_cdf = c_i g_i = g_i Q_ij p_j = Tr(Q @ (p g)) log_cdf_k = g_ki Q_ij p_kj Returns: """ from jax.scipy.optimize import minimize from jax import jit, value_and_grad from jax.lax import scan import pylab as plt def generate_single_data(key): """ Generate a physical set of: n = x1-x2/|x1-x2| is a unit vector. w1 = p1 / |x1-x2| w2 = p2 / |x1-x2| lambda' = lambda / |x1-x2|^2 Args: key: Returns: """ keys = random.split(key, 6) n = random.normal(keys[0], shape=(3, )) n = n / jnp.linalg.norm(n) w1 = random.normal(keys[1], shape=(3, )) w1 = w1 / jnp.linalg.norm(w1) w1 = w1 * random.uniform(keys[2], minval=0., maxval=10.) w2 = random.normal(keys[3], shape=(3, )) w2 = w2 / jnp.linalg.norm(w2) w2 = w2 * random.uniform(keys[4], minval=0., maxval=10.) gamma_prime = jnp.linspace( 0., 10., 100) # random.uniform(keys[5],minval=0.,maxval=10.)**2 cdf_ref = cumulative_tomographic_weight_dimensionless_function( gamma_prime, n, w1, w2, S=150) # /h**2 return n, w1, w2, gamma_prime, cdf_ref data = jit(vmap(generate_single_data))(random.split( random.PRNGKey(12340985), 100)) # print(data[-1]) def loss(Q): def single_loss(single_datum): n, w1, w2, gamma_prime, cdf_ref = single_datum return ( vmap(lambda gamma_prime: cumulative_tomographic_weight_dimensionless_polynomial( Q, gamma_prime, n, w1, w2))(gamma_prime) - cdf_ref)**2 return jnp.mean(vmap(single_loss)(data)) K = 3 Q0 = 0.01 * random.normal(random.PRNGKey(0), shape=(K * 7, )) print(jit(loss)(Q0)) @jit def do_minimize(): results = minimize(loss, Q0, method='BFGS', options=dict(gtol=1e-8, line_search_maxiter=100)) print(results.message) return results.x.reshape( (K, 7) ), results.status, results.fun, results.nfev, results.nit, results.jac @jit def do_sgd(key): def body(state, X): (Q, ) = state (key, ) = X n, w1, w2, gamma_prime, cdf_ref = generate_single_data(key) def loss(Q): return jnp.mean((vmap( lambda gamma_prime: cumulative_tomographic_weight_dimensionless_polynomial( Q, gamma_prime, n, w1, w2))(gamma_prime) - cdf_ref)**2) # + 0.1*jnp.mean(Q**2) f, g = value_and_grad(loss)(Q) Q = Q - 0.00000001 * g return (Q, ), (f, ) (Q, ), (values, ) = scan(body, (Q0, ), (random.split(key, 1000), )) return Q.reshape((-1, 7)), values # results = do_minimize() Q, values = vmap(do_sgd)(random.split(random.PRNGKey(12456), 100)) print('Qmean', Q.mean(0)) print('Qstd', Q.std(0)) f = values.mean(0) fstd = values.std(0) plt.plot(jnp.percentile(values, 50, axis=0)) plt.plot(jnp.percentile(values, 85, axis=0), ls='dotted', c='black') plt.plot(jnp.percentile(values, 15, axis=0), ls='dotted', c='black') plt.show() # print(results) return Q.mean(0)
def get_peaks_iter_daily(soln_inc,int=0,Tint=0,loCI=5,upCI=95): """ calculates the peak daily incidence for a multiple runs, with or without an intervention soln_inc: 3D array of values for each iteration for each variable at each timepoint ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N) scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values) int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0 Tint: Optional, timepoint (days) at which intervention was started loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval """ if int==0: time_int=0 else: time_int=Tint # Peak incidence peaks=np.amax(soln_inc[:,:,2],axis=1) print('Peak daily I1: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln_inc[:,:,3],axis=1) print('Peak daily I2: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln_inc[:,:,4],axis=1) print('Peak daily I3: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln_inc[:,:,5],axis=1) print('Peak daily deaths: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) # Timing of peak incidence tpeak=np.argmax(soln_inc[:,:,2],axis=1)+1.0-time_int print('Time of peak I1: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0))) tpeak=np.argmax(soln_inc[:,:,3],axis=1)+1.0-time_int print('Time of peak I2: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0))) tpeak=np.argmax(soln_inc[:,:,4],axis=1)+1.0-time_int print('Time of peak I3: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0))) tpeak=np.argmax(soln_inc[:,:,5],axis=1)+1.0-time_int print('Time of peak deaths: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0))) return
def get_peaks_iter(soln,tvec,int=0,Tint=0,loCI=5,upCI=95): """ calculates the peak prevalence for a multiple runs, with or without an intervention soln: 3D array of values for each iteration for each variable at each timepoint tvec: 1D vector of timepoints ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N) scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values) int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0 Tint: Optional, timepoint (days) at which intervention was started loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval """ delta_t=tvec[1]-tvec[0] if int==0: time_int=0 else: time_int=Tint all_cases=soln[:,:,1]+soln[:,:,2]+soln[:,:,3]+soln[:,:,4] # Final values print('Final recovered: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(soln[:,-1,6]), 100*np.percentile(soln[:,-1,6],loCI), 100*np.percentile(soln[:,-1,6],upCI))) print('Final deaths: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(soln[:,-1,5]), 100*np.percentile(soln[:,-1,5],loCI), 100*np.percentile(soln[:,-1,5],upCI))) print('Remaining infections: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100*np.average(all_cases[:,-1]),100*np.percentile(all_cases[:,-1],loCI),100*np.percentile(all_cases[:,-1],upCI))) # Peak prevalence peaks=np.amax(soln[:,:,2],axis=1) print('Peak I1: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln[:,:,3],axis=1) print('Peak I2: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln[:,:,4],axis=1) print('Peak I3: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) # Timing of peaks tpeak=np.argmax(soln[:,:,2],axis=1)*delta_t-time_int print('Time of peak I1: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak), np.percentile(tpeak,loCI),np.percentile(tpeak,upCI))) tpeak=np.argmax(soln[:,:,3],axis=1)*delta_t-time_int print('Time of peak I2: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,loCI),np.percentile(tpeak,upCI))) tpeak=np.argmax(soln[:,:,4],axis=1)*delta_t-time_int print('Time of peak I3: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,loCI),np.percentile(tpeak,upCI))) # Time when all the infections go extinct time_all_extinct = np.array(get_extinction_time(all_cases,0))*delta_t-time_int print('Time of extinction of all infections post intervention: {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(time_all_extinct),np.percentile(time_all_extinct,loCI),np.percentile(time_all_extinct,upCI))) return
def fit_STC(self, prewhiten=False, n_repeats=10, percentile=100., random_seed=2046, verbose=5): """ Spike-triggered Covariance Analysis. Parameters ========== transform: None or Str * None - Original X is used * 'whiten' - pre-whiten X * 'spline' - pre-whiten and smooth X by spline n_repeats: int Number of repeats for STC significance test. percentile: float Valid range of STC significance test. """ def get_stc(X, y, w): n = len(X) ste = X[y != 0] proj = ste - ste * w * w.T stc = proj.T @ proj / (n - 1) eigvec, eigval, _ = np.linalg.svd(stc) return eigvec, eigval key = random.PRNGKey(random_seed) y = self.y if prewhiten: if self.compute_mle is False: self.XtX = self.X.T @ self.X self.w_mle = np.linalg.solve(self.XtX, self.XtY) X = np.linalg.solve(self.XtX, self.X.T).T w = uvec(self.w_mle) else: X = self.X w = uvec(self.w_sta) eigvec, eigval = get_stc(X, y, w) if n_repeats: print('STC significance test: ') eigval_null = [] for counter in range(n_repeats): if verbose: if counter % int(verbose) == 0: print(f' {counter+1:}/{n_repeats}') y_randomize = random.permutation(key, y) _, eigval_randomize = get_stc(X, y_randomize, w) eigval_null.append(eigval_randomize) else: if verbose: print(f'Done.') eigval_null = np.vstack(eigval_null) max_null, min_null = np.percentile(eigval_null, percentile), np.percentile( eigval_null, 100 - percentile) mask_sig_pos = eigval > max_null mask_sig_neg = eigval < min_null mask_sig = np.logical_or(mask_sig_pos, mask_sig_neg) self.w_stc = eigvec self.w_stc_pos = eigvec[:, mask_sig_pos] self.w_stc_neg = eigvec[:, mask_sig_neg] self.w_stc_eigval = eigval self.w_stc_eigval_mask = mask_sig self.w_stc_eigval_pos_mask = mask_sig_pos self.w_stc_eigval_neg_mask = mask_sig_neg self.w_stc_max_null = max_null self.w_stc_min_null = min_null else: self.w_stc = eigvec self.w_stc_eigval = eigval self.w_stc_eigval_mask = np.ones_like(eigval).astype(bool)
def plot_cornerplot(results, vars=None, save_name=None): rkey0 = random.PRNGKey(123496) vars = _get_vars(results, vars) ndims = _get_ndims(results, vars) figsize = min(20, max(4, int(2 * ndims))) fig, axs = plt.subplots(ndims, ndims, figsize=(figsize, figsize)) if ndims == 1: axs = [[axs]] nsamples = results.num_samples log_p = results.log_p[:results.num_samples] nbins = int(jnp.sqrt(results.ESS)) + 1 lims = {} dim = 0 for key in vars: # sorted(results.samples.keys()): n1 = tuple_prod(results.samples[key].shape[1:]) for i in range(n1): samples1 = results.samples[key][:results.num_samples, ...].reshape( (nsamples, -1))[:, i] if jnp.std(samples1) == 0.: dim += 1 continue weights = jnp.where(jnp.isfinite(samples1), jnp.exp(log_p), 0.) log_weights = jnp.where(jnp.isfinite(samples1), log_p, -jnp.inf) samples1 = jnp.where(jnp.isfinite(samples1), samples1, 0.) # kde1 = gaussian_kde(samples1, weights=weights, bw_method='silverman') # samples1_resampled = kde1.resample(size=int(results.ESS)) rkey0, rkey = random.split(rkey0, 2) samples1_resampled = resample(rkey, samples1, log_weights, S=int(results.ESS)) binsx = jnp.linspace(*jnp.percentile(samples1_resampled, [0, 100]), 2 * nbins) dim2 = 0 for key2 in vars: # sorted(results.samples.keys()): n2 = tuple_prod(results.samples[key2].shape[1:]) for i2 in range(n2): ax = axs[dim][dim2] if dim2 > dim: dim2 += 1 ax.set_xticks([]) ax.set_xticklabels([]) ax.set_yticks([]) ax.set_yticklabels([]) continue if n2 > 1: title2 = "{}[{}]".format(key2, i2) else: title2 = "{}".format(key2) if n1 > 1: title1 = "{}[{}]".format(key, i) else: title1 = "{}".format(key) ax.set_title('{} {}'.format(title1, title2)) if dim == dim2: # ax.plot(binsx, kde1(binsx)) ax.hist(samples1_resampled, bins='auto', fc='None', edgecolor='black', density=True) sample_mean = jnp.average(samples1, weights=weights) sample_std = jnp.sqrt( jnp.average((samples1 - sample_mean)**2, weights=weights)) ax.set_title( "{:.2f}:{:.2f}:{:.2f}\n{:.2f}+-{:.2f}".format( *jnp.percentile(samples1_resampled, [5, 50, 95]), sample_mean, sample_std)) ax.vlines(sample_mean, *ax.get_ylim(), linestyles='solid', colors='red') ax.vlines([ sample_mean - sample_std, sample_mean + sample_std ], *ax.get_ylim(), linestyles='dotted', colors='red') ax.set_xlim(binsx.min(), binsx.max()) lims[dim] = ax.get_xlim() else: samples2 = results.samples[key2][:results.num_samples, ...].reshape( (nsamples, -1))[:, i2] if jnp.std(samples2) == 0.: dim2 += 1 continue weights = jnp.where(jnp.isfinite(samples2), jnp.exp(log_p), 0.) log_weights = jnp.where(jnp.isfinite(samples2), log_p, -jnp.inf) samples2 = jnp.where(jnp.isfinite(samples2), samples2, 0.) # kde2 = gaussian_kde(jnp.stack([samples1, samples2], axis=0), # weights=weights, # bw_method='silverman') # samples2_resampled = kde2.resample(size=int(results.ESS)) rkey0, rkey = random.split(rkey0, 2) samples2_resampled = resample(rkey, jnp.stack( [samples1, samples2], axis=-1), log_weights, S=int(results.ESS)) # norm = plt.Normalize(log_weights.min(), log_weights.max()) # color = jnp.atleast_2d(plt.cm.jet(norm(log_weights))) ax.hist2d(samples2_resampled[:, 1], samples2_resampled[:, 0], bins=(nbins, nbins), density=True, cmap=plt.cm.bone_r) # ax.scatter(samples2_resampled[:, 1], samples2_resampled[:, 0], marker='+', c='black', alpha=0.5) # binsy = jnp.linspace(*jnp.percentile(samples2_resampled[:, 1], [0, 100]), 2 * nbins) # X, Y = jnp.meshgrid(binsx, binsy, indexing='ij') # ax.contour(kde2(jnp.stack([X.flatten(), Y.flatten()], axis=0)).reshape((2 * nbins, 2 * nbins)), # extent=(binsy.min(), binsy.max(), # binsx.min(), binsx.max()), # origin='lower') if dim == ndims - 1: ax.set_xlabel("{}".format(title2)) if dim2 == 0: ax.set_ylabel("{}".format(title1)) dim2 += 1 dim += 1 for dim in range(ndims): for dim2 in range(ndims): if dim == dim2: continue ax = axs[dim][dim2] if ndims > 1 else axs[0] if dim in lims.keys(): ax.set_ylim(lims[dim]) if dim2 in lims.keys(): ax.set_xlim(lims[dim2]) if save_name is not None: fig.savefig(save_name) plt.show()
# Train only gains, keeping all other weights fixed. net = make_net([2048] * 6) only_gains_final_params = train( net, init_params=net.init(random.PRNGKey(0), jnp.zeros((1, 28 * 28))), trainable_predicate=lambda k: kmatch("**/gain", k), log_prefix="only_gain") only_gains_final_params_flat = flatten_params(only_gains_final_params) print(" full model params:") print(tree_map(jnp.shape, only_gains_final_params_flat)) gain_params = { k: v for k, v in only_gains_final_params_flat.items() if kmatch("**/gain", k) } gain_params_flat, unravel = ravel_pytree(gain_params) cutoff = jnp.percentile(jnp.abs(gain_params_flat), config.remove_percentile) # A mask that identifies only those gains that have the largest absolute # value. Only keep the top `(100 - config.remove_percentile)%` gains. Note # that this isn't the traditional LTH approach. It's more accurately described # as a structural lottery ticket. gain_mask = binarize(unravel(jnp.abs(gain_params_flat) > cutoff)) print(tree_map(jnp.sum, gain_mask)) def _lotteryify(k, v): """Take a parameter (key, value pair) and return a new parameter value with only the units in the `gain_mask` retained. This requires removing rows and columns across weight matrices in adjacent layers.""" if kmatch("**/gain", k): return v[gain_mask[k]]
def fit_STC(self, prewhiten=False, n_repeats=10, percentile=100., random_seed=2046, verbose=5): """ Spike-triggered Covariance Analysis. Parameters ========== prewhiten: bool n_repeats: int Number of repeats for STC significance test. percentile: float Valid range of STC significance test. verbose: int random_seed: int """ def get_stc(_X, _y, _w): n = len(_X) ste = _X[_y != 0] proj = ste - ste * _w * _w.T stc = proj.T @ proj / (n - 1) _eigvec, _eigval, _ = jnp.linalg.svd(stc) return _eigvec, _eigval key = random.PRNGKey(random_seed) y = self.y if prewhiten: if self.compute_mle is False: self.XtX = self.X.T @ self.X self.w_mle = jnp.linalg.solve(self.XtX, self.XtY) X = jnp.linalg.solve(self.XtX, self.X.T).T w = uvec(self.w_mle) else: X = self.X w = uvec(self.w_sta) eigvec, eigval = get_stc(X, y, w) self.w_stc = dict() if n_repeats: print('STC significance test: ') eigval_null = [] for counter in range(n_repeats): if verbose: if counter % int(verbose) == 0: print(f' {counter + 1:}/{n_repeats}') y_randomize = random.permutation(key, y) _, eigval_randomize = get_stc(X, y_randomize, w) eigval_null.append(eigval_randomize) else: if verbose: print(f'Done.') eigval_null = jnp.vstack(eigval_null) max_null, min_null = jnp.percentile(eigval_null, percentile), jnp.percentile( eigval_null, 100 - percentile) mask_sig_pos = eigval > max_null mask_sig_neg = eigval < min_null mask_sig = jnp.logical_or(mask_sig_pos, mask_sig_neg) self.w_stc['eigvec'] = eigvec self.w_stc['pos'] = eigvec[:, mask_sig_pos] self.w_stc['neg'] = eigvec[:, mask_sig_neg] self.w_stc['eigval'] = eigval self.w_stc['eigval_mask'] = mask_sig self.w_stc['eigval_pos_mask'] = mask_sig_pos self.w_stc['eigval_neg_mask'] = mask_sig_neg self.w_stc['max_null'] = max_null self.w_stc['min_null'] = min_null else: self.w_stc['eigvec'] = eigvec self.w_stc['eigval'] = eigval self.w_stc['eigval_mask'] = jnp.ones_like(eigval).astype(bool)
def percentile(a, q, axis=None, interpolation='linear', keepdims=False): if isinstance(a, JaxArray): a = a.value if isinstance(q, JaxArray): q = q.value r = jnp.percentile(a=a, q=q, axis=axis, interpolation=interpolation, keepdims=keepdims) return r if axis is None else JaxArray(r)
def hist_bin_fd(x): iqr = jnp.subtract(*jnp.percentile(x, [75, 25])) return 2.0 * iqr * x.size**(-1.0 / 3.0)
print(f'posterior for {p}') print_summary(post[p], 0.95, False) # PPC # call predictive without specifying new data # so it uses original data post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4), )) post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values) mu = post_pred["mu"] # summarize samples across cases mu_mean = jnp.mean(mu, 0) mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0) ax = plt.subplot(ylim=(float(mu_PI.min()), float(mu_PI.max())), xlabel="Observed divorce", ylabel="Predicted divorce") plt.plot(d.D, mu_mean, "o") x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101) plt.plot(x, x, "--") for i in range(d.shape[0]): plt.plot([d.D[i]] * 2, mu_PI[:, i], "b") fig = plt.gcf() for i in range(d.shape[0]): if d.Loc[i] in ["ID", "UT", "RI", "ME"]: ax.annotate(d.Loc[i], (d.D[i], mu_mean[i]), xytext=(1, 0),