Example #1
0
def main(args):
    _, fetch_train = load_dataset(UCBADMIT, split="train", shuffle=False)
    dept, male, applications, admit = fetch_train()
    rng_key, rng_key_predict = random.split(random.PRNGKey(1))
    zs = run_inference(dept, male, applications, admit, rng_key, args)
    pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male,
                                      applications)["probs"]
    header = "=" * 30 + "glmm - TRAIN" + "=" * 30
    print_results(header, pred_probs, dept, male, admit / applications)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
    ax.errorbar(
        range(1, 13),
        jnp.mean(pred_probs, 0),
        jnp.std(pred_probs, 0),
        fmt="o",
        c="k",
        mfc="none",
        ms=7,
        elinewidth=1,
        label=r"mean $\pm$ std",
    )
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+")
    ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+")
    ax.set(
        xlabel="cases",
        ylabel="admit rate",
        title="Posterior Predictive Check with 90% CI",
    )
    ax.legend()

    plt.savefig("ucbadmit_plot.pdf")
Example #2
0
def summarize_posterior(preds, ci=96):
    ci_lower = (100 - ci) / 2
    ci_upper = (100 + ci) / 2
    preds_mean = preds.mean(0)
    preds_lower = jnp.percentile(preds, ci_lower, axis=0)
    preds_upper = jnp.percentile(preds, ci_upper, axis=0)
    return preds_mean, preds_lower, preds_upper
Example #3
0
def summary(samples: Dict[str, jnp.DeviceArray], poisson: bool) -> Dict[str, Dict[str, jnp.DeviceArray]]:
    """Generate a nice summary given the samples from `get_samples` of the Poisson model

    Args:
        samples: samples as acquired from `predictive`.get_samples
        poisson: exponentiate the values if Poisson was used

    Returns:
        dict: summary over the samples
    """
    site_stats = {}
    for k, v in samples.items():
        if poisson:  # Poisson model, thus we exponentiate!
            v = jnp.exp(v)
        else:
            v = v.astype(jnp.float32)  # for percentile to work
        v = ops.index_update(v, v < 0., 0.)  # avoid -1.
        site_stats[k] = {
            "mean": jnp.mean(v, 0),
            "std": jnp.std(v, 0),
            "5%": jnp.percentile(v, 5., axis=0),
            "25%": jnp.percentile(v, 25., axis=0),
            "75%": jnp.percentile(v, 75., axis=0),
            "95%": jnp.percentile(v, 95., axis=0),
        }
    return site_stats
Example #4
0
def plot_iter_daily_shade(soln_inc,n,ymax=1,scale=1,int=0,Tint=1,loCI=5,upCI=95,plotThis=False,plotName="test"):

  """
  plots the output (cumulative prevalence) from a multiple simulation, with or without an intervention. Shows mean and 95% CI
  soln_inc: 3D array of values for each iteration for each variable at each timepoint
  tvec: 1D vector of timepoints
  n: total population size
  ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N)
  scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values)
  int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0
  Tint: Optional, timepoint (days) at which intervention was started
  loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval
  plotThis: True or False, whether a plot will be saved as pdf 
  plotName: string, name of the plot to be saved
  """

  tvec=np.arange(1,np.shape(soln_inc)[1]+1)

  soln_avg=np.average(soln_inc,axis=0)
  soln_loCI=np.percentile(soln_inc,loCI,axis=0)
  soln_upCI=np.percentile(soln_inc,upCI,axis=0)

  # linear scale
  # add averages
  plt.figure(figsize=(2*6.4, 4.0))
  plt.subplot(121)
  plt.plot(tvec,soln_avg*scale)
  plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left")
  # add ranges
  plt.gca().set_prop_cycle(None)
  for i in range(0,7):
    plt.fill_between(tvec,soln_loCI[:,i]*scale,soln_upCI[:,i]*scale,alpha=0.3)
  if int==1:
      plt.plot([Tint,Tint],[0,ymax*scale],'k--')
  plt.ylim([0,ymax*scale])
  plt.xlabel("Time (days)")
  plt.ylabel("Daily incidence")

  # log scale
  # add averages
  plt.subplot(122)
  plt.plot(tvec,soln_avg*scale)
  plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left")
  # add ranges
  plt.gca().set_prop_cycle(None)
  for i in range(0,7):
    plt.fill_between(tvec,soln_loCI[:,i]*scale,soln_upCI[:,i]*scale,alpha=0.3)
  if int==1:
    plt.plot([Tint,Tint],[scale/n,ymax*scale],'k--')
  plt.ylim([scale/n,ymax*scale])
  plt.xlabel("Time (days)")
  plt.ylabel("Daily incidence")
  plt.semilogy()
  plt.tight_layout()
  if plotThis==True:
  	plt.savefig(plotName+'.pdf',bbox_inches='tight')
  plt.show()
Example #5
0
def test_quick_select():

    arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26, 7]) # even number of points
    assert quick_sort_median(arr) == jnp.percentile(arr, 50, interpolation='higher')
    arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26]) # odd number of points
    assert quick_sort_median(arr) == jnp.percentile(arr, 50, interpolation='higher')
    arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26])  # odd number of points
    assert quick_sort_median(arr) == jnp.median(arr)

    arr = jnp.asarray([10, 4, 5, 8, 11, 6, 26, 7])  # even number of points
    try:
        assert quick_sort_median(arr) == jnp.median(arr)
    except AssertionError:
        print("Not equivalent to median when array has an even size.")
Example #6
0
def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(NUTS(model, dense_mass=True),
                args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data))
    mcmc.print_summary()

    # predict populations
    y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
    pop_pred = jnp.exp(y_pred)
    mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
    plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")
    plt.tight_layout()
Example #7
0
def _find_binning_thresholds(data,
                             max_bins=256,
                             subsample=200000,
                             random_state=None):
    if 2 > max_bins or max_bins > 256:
        raise ValueError(f'max_bins={max_bins} should be no smaller than 2 '
                         f'and no larger than 256.')
    if random_state is None:
        random_state = int(time.time())
    rng = random.PRNGKey(random_state)
    if subsample is not None and data.shape[0] > subsample:
        subset = random.shuffle(rng, np.arange(data.shape[0]))[:subsample]
        data = data[subset]
    dtype = data.dtype
    if dtype.kind != 'f':
        dtype = np.float32

    percentiles = np.linspace(0, 100, num=max_bins + 1)[1:-1]
    binning_thresholds = []
    for f_idx in range(data.shape[1]):
        col_data = np.array(data[:, f_idx], dtype=dtype, order='C')
        distinct_values = onp.unique(col_data)
        if len(distinct_values) <= max_bins:
            midpoints = (distinct_values[:-1] + distinct_values[1:])
            midpoints *= 0.5
        else:
            midpoints = np.percentile(col_data,
                                      percentiles,
                                      interpolation='midpoint').astype(dtype)
        binning_thresholds.append(midpoints)
    return tuple(binning_thresholds)
Example #8
0
    def plot_samples(self,
                     samples, 
                     plot_fields=['y'],
                     start='2020-03-04',
                     T=None,
                     ax=None,          
                     legend=True,
                     forecast=False,
                     n_samples=0,
                     intervals=[50, 80, 95]):
        '''
        Plotting method for SIR-type models. 
        '''

        
        ax = plt.axes(ax)

        T_data = self.horizon(samples, forecast=forecast)        
        T = T_data if T is None else min(T, T_data) 
        
        fields = {f: 0.0 + self.get(samples, f, forecast=forecast)[:,:T] for f in plot_fields}
        names = {f: self.names[f] for f in plot_fields}
                
        medians = {names[f]: np.median(v, axis=0) for f, v in fields.items()}

        t = pd.date_range(start=start, periods=T, freq='D')

        ax.set_prop_cycle(None)
        colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

        # Plot medians
        df = pd.DataFrame(index=t, data=medians)
        df.plot(ax=ax, legend=legend)
        median_max = df.max().values

        # Plot samples if requested
        if n_samples > 0:
            for i, f in enumerate(fields):
                df = pd.DataFrame(index=t, data=fields[f][:n_samples,:].T)
                df.plot(ax=ax, legend=False, alpha=0.1)
                
        # Plot prediction intervals
        pi_max = 10
        handles = []
        for interval in intervals:
            low=(100.-interval)/2
            high=100.-low
            pred_intervals = {names[f]: np.percentile(v, (low, high), axis=0) for f, v in fields.items()}
            for i, pi in enumerate(pred_intervals.values()):
                h = ax.fill_between(t, pi[0,:], pi[1,:], alpha=0.1, color=colors[i], label=interval)
                handles.append(h)
                pi_max = np.maximum(pi_max, np.nanmax(pi[1,:]))

        
        return median_max, pi_max
Example #9
0
    def _describe_fit(self, dist):
        description = super()._describe_fit(dist)

        def entry_distance_fn(x, density):
            return abs(1.0 - density / dist.pdf(x))

        distances = vmap(entry_distance_fn)(self.xs, self.densities)
        description["max_distance"] = np.max(distances)
        description["90th_distance"] = np.percentile(distances, 90)
        description["mean_distance"] = np.mean(distances)
        return description
Example #10
0
def estimate_clipping_thresholds(apply_fn, params, l2_norm_clip_percentile,
                                 graph, labels, subgraphs, estimation_indices,
                                 adjacency_normalization):
    """Estimates gradient clipping thresholds."""
    dummy_state = train_state.TrainState.create(apply_fn=apply_fn,
                                                params=params,
                                                tx=optax.identity())
    grads = compute_updates_for_dp(dummy_state, graph, labels, subgraphs,
                                   estimation_indices, adjacency_normalization)
    grad_norms = jax.tree_map(jax.vmap(jnp.linalg.norm), grads)
    get_percentile = lambda norms: jnp.percentile(norms,
                                                  l2_norm_clip_percentile)
    l2_norms_threshold = jax.tree_map(get_percentile, grad_norms)
    return l2_norms_threshold
def prep_utility_fn(params,
                    x_b1,
                    y_b1,
                    utility_type,
                    steps=1000,
                    percentile=0,
                    y_incumb=None):
  """Prepare the utility functions for the score."""
  gp_util = gp.GPUtils()
  params = gp_util.fit_gp(x_b1, y_b1, params, steps=steps)
  if y_incumb is None:
    y_incumb = jnp.percentile(y_b1, percentile, axis=0)
  utility_measure = scores.UtilityMeasure(incumbent=y_incumb, params=params)
  utility_measure_fn = lambda key, x_batch, y_batch: getattr(
      utility_measure, utility_type)(y_batch)
  return utility_measure_fn, params
Example #12
0
def plot_R0(mcmc_samples, start, ax=None):

    ax = plt.axes(ax)

    # Compute R0 over time
    gamma = mcmc_samples['gamma'][:, None]
    beta = mcmc_samples['beta']
    t = pd.date_range(start=start, periods=beta.shape[1], freq='D')
    R0 = beta / gamma

    pi = np.percentile(R0, (10, 90), axis=0)
    df = pd.DataFrame(index=t, data={'R0': np.median(R0, axis=0)})
    df.plot(style='-o', ax=ax)
    ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1)

    ax.axhline(1, linestyle='--')
Example #13
0
File: base.py Project: elray1/covid
    def plot_samples(self,
                     samples,
                     plot_fields=['y'],
                     start='2020-03-04',
                     T=None,
                     ax=None,
                     legend=True,
                     forecast=False):
        '''
        Plotting method for SIR-type models. 
        '''

        ax = plt.axes(ax)

        T_data = self.horizon(samples, forecast=forecast)
        T = T_data if T is None else min(T, T_data)

        fields = {
            f: self.get(samples, f, forecast=forecast)[:, :T]
            for f in plot_fields
        }
        names = {f: self.names[f] for f in plot_fields}

        medians = {names[f]: np.median(v, axis=0) for f, v in fields.items()}
        pred_intervals = {
            names[f]: np.percentile(v, (10, 90), axis=0)
            for f, v in fields.items()
        }

        t = pd.date_range(start=start, periods=T, freq='D')

        ax.set_prop_cycle(None)

        # Plot medians
        df = pd.DataFrame(index=t, data=medians)
        df.plot(ax=ax, legend=legend)
        median_max = df.max().values

        # Plot prediction intervals
        pi_max = 10
        for pi in pred_intervals.values():
            ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1, label='CI')
            pi_max = np.maximum(pi_max, np.nanmax(pi[1, :]))

        return median_max, pi_max
Example #14
0
def plot_growth_rate(mcmc_samples, start, model=SEIRModel, ax=None):

    ax = plt.axes(ax)

    # Compute growth rate over time
    beta = mcmc_samples['beta']
    sigma = mcmc_samples['sigma'][:, None]
    gamma = mcmc_samples['gamma'][:, None]
    t = pd.date_range(start=start, periods=beta.shape[1], freq='D')

    growth_rate = SEIRModel.growth_rate((beta, sigma, gamma))

    pi = np.percentile(growth_rate, (10, 90), axis=0)
    df = pd.DataFrame(index=t,
                      data={'growth_rate': np.median(growth_rate, axis=0)})
    df.plot(style='-o', ax=ax)
    ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1)

    ax.axhline(0, linestyle='--')
Example #15
0
File: util.py Project: elray1/covid
def plot_R0(mcmc_samples, start):

    fig = plt.figure(figsize=(5,3))
    
    # Compute average R0 over time
    gamma = mcmc_samples['gamma'][:,None]
    beta = mcmc_samples['beta']
    t = pd.date_range(start=start, periods=beta.shape[1], freq='D')
    R0 = beta/gamma

    pi = np.percentile(R0, (10, 90), axis=0)
    df = pd.DataFrame(index=t, data={'R0': np.median(R0, axis=0)})
    df.plot(style='-o')
    plt.fill_between(t, pi[0,:], pi[1,:], alpha=0.1)

    plt.axhline(1, linestyle='--')
    
    #plt.tight_layout()

    return fig
Example #16
0
    def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
        r"""
        Compute the :math:`q`-th percentile of the tensor along the specified axis.

        Example:

            >>> import pyhf
            >>> import jax.numpy as jnp
            >>> pyhf.set_backend("jax")
            >>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
            >>> pyhf.tensorlib.percentile(a, 50)
            DeviceArray(3.5, dtype=float64)
            >>> pyhf.tensorlib.percentile(a, 50, axis=1)
            DeviceArray([7., 2.], dtype=float64)

        Args:
            tensor_in (`tensor`): The tensor containing the data
            q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
            axis (`number` or `tensor`): The dimensions along which to compute
            interpolation (:obj:`str`): The interpolation method to use when the
             desired percentile lies between two data points ``i < j``:

                - ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
                  fractional part of the index surrounded by ``i`` and ``j``.

                - ``'lower'``: ``i``.

                - ``'higher'``: ``j``.

                - ``'midpoint'``: ``(i + j) / 2``.

                - ``'nearest'``: ``i`` or ``j``, whichever is nearest.

        Returns:
            JAX ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.

        """
        return jnp.percentile(tensor_in, q, axis=axis, interpolation=interpolation)
Example #17
0
def test_tomographic_weight_rel_err():
    import pylab as plt
    from jax import jit

    for S in range(5, 30, 5):

        @jit
        def tomo_weight(gamma, x1, x2, p1, p2):
            return tomographic_weight_function(gamma, x1, x2, p1, p2, S=S)

        @jit
        def tomo_weight_ref(gamma, x1, x2, p1, p2):
            return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

        rel_error = []
        for i in range(400):
            keys = random.split(random.PRNGKey(i), 6)
            x1 = jnp.concatenate(
                [4. * random.uniform(keys[0], shape=(2, )),
                 jnp.zeros((1, ))],
                axis=-1)
            p1 = jnp.concatenate([
                4. * jnp.pi / 180. *
                random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
                jnp.ones((1, ))
            ],
                                 axis=-1)
            p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

            x2 = jnp.concatenate(
                [4. * random.uniform(keys[2], shape=(2, )),
                 jnp.zeros((1, ))],
                axis=-1)
            p2 = jnp.concatenate([
                4. * jnp.pi / 180. *
                random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
                jnp.ones((1, ))
            ],
                                 axis=-1)
            p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

            # x1 = random.normal(keys[0], shape_dict=(2,))
            # p1 = random.normal(keys[1], shape_dict=(2,))
            # x2 = random.normal(keys[2], shape_dict=(2,))
            # p2 = random.normal(keys[3], shape_dict=(2,))

            t1 = random.uniform(keys[4], shape=(10000, ))
            t2 = random.uniform(keys[5], shape=(10000, ))
            u1 = x1 + t1[:, None] * p1
            u2 = x2 + t2[:, None] * p2
            gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
            hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
            bins = jnp.linspace(bins.min(), bins.max(), 20)
            w = tomo_weight(bins, x1, x2, p1, p2)
            w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
            rel_error.append(jnp.max(jnp.abs(w - w_ref)) / jnp.max(w_ref))
        rel_error = jnp.array(rel_error)
        plt.hist(rel_error, bins='auto')
        plt.title("{} : {:.2f}|{:.2f}|{:.2f}".format(
            S, *jnp.percentile(rel_error, [5, 50, 95])))
        plt.show()
Example #18
0
_, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False)
dept, male, applications, admit = fetch_train()
rng_key, rng_key_predict = random.split(random.PRNGKey(1))


kernel = NUTS(glmm)
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains,
            progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)

mcmc.run(rng_key, dept, male, applications, admit)

zs = mcmc.get_samples()

pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs']

fig, ax = plt.subplots(1, 1)

ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate")
ax.errorbar(range(1, 13), np.mean(pred_probs, 0), np.std(pred_probs, 0),
            fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std")
ax.plot(range(1, 13), np.percentile(pred_probs, 5, 0), "k+")
ax.plot(range(1, 13), np.percentile(pred_probs, 95, 0), "k+")
ax.set(xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI")
ax.legend()

plt.savefig("ucbadmit_plot.pdf")
plt.tight_layout()


df_data.index = year
# use dense_mass for better mixing rate
mcmc = MCMC(
    NUTS(model, dense_mass=True),
    num_warmup,
    num_samples,
    num_chains=num_chains,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data))
mcmc.print_summary()

# predict populations
y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
pop_pred = jnp.exp(y_pred)
mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
plt.plot(year,
         data[:, 0],
         "ko",
         mfc="none",
         ms=4,
         label="true hare",
         alpha=0.67)
plt.plot(year, data[:, 1], "bx", label="true lynx")
plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
plt.plot(year, mu[:, 1], "b--", label="pred lynx")
plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
plt.legend()
    def sample_answer(self, rng_key, distribution: dist.Distribution):
        num_samples = 100
        samples = distribution.sample(key=rng_key,
                                      sample_shape=(num_samples, ))

        return jnp.percentile(samples, self.percentile())
def get_polynomial_form():
    """
    The polynomial form of log P(|n + t1*w1 - t2*w2|^2 < lambda') is assumed to be:

    c_i = Q_ij p_j
    log_cdf = c_i g_i = g_i Q_ij p_j = Tr(Q @ (p g))
    log_cdf_k = g_ki Q_ij p_kj
    Returns:

    """
    from jax.scipy.optimize import minimize
    from jax import jit, value_and_grad
    from jax.lax import scan
    import pylab as plt

    def generate_single_data(key):
        """
        Generate a physical set of:

        n = x1-x2/|x1-x2| is a unit vector.
        w1 = p1 / |x1-x2|
        w2 = p2 / |x1-x2|
        lambda' = lambda / |x1-x2|^2

        Args:
            key:

        Returns:

        """
        keys = random.split(key, 6)
        n = random.normal(keys[0], shape=(3, ))
        n = n / jnp.linalg.norm(n)

        w1 = random.normal(keys[1], shape=(3, ))
        w1 = w1 / jnp.linalg.norm(w1)
        w1 = w1 * random.uniform(keys[2], minval=0., maxval=10.)

        w2 = random.normal(keys[3], shape=(3, ))
        w2 = w2 / jnp.linalg.norm(w2)
        w2 = w2 * random.uniform(keys[4], minval=0., maxval=10.)

        gamma_prime = jnp.linspace(
            0., 10., 100)  # random.uniform(keys[5],minval=0.,maxval=10.)**2

        cdf_ref = cumulative_tomographic_weight_dimensionless_function(
            gamma_prime, n, w1, w2, S=150)  # /h**2
        return n, w1, w2, gamma_prime, cdf_ref

    data = jit(vmap(generate_single_data))(random.split(
        random.PRNGKey(12340985), 100))

    # print(data[-1])
    def loss(Q):
        def single_loss(single_datum):
            n, w1, w2, gamma_prime, cdf_ref = single_datum
            return (
                vmap(lambda gamma_prime:
                     cumulative_tomographic_weight_dimensionless_polynomial(
                         Q, gamma_prime, n, w1, w2))(gamma_prime) - cdf_ref)**2

        return jnp.mean(vmap(single_loss)(data))

    K = 3
    Q0 = 0.01 * random.normal(random.PRNGKey(0), shape=(K * 7, ))
    print(jit(loss)(Q0))

    @jit
    def do_minimize():
        results = minimize(loss,
                           Q0,
                           method='BFGS',
                           options=dict(gtol=1e-8, line_search_maxiter=100))
        print(results.message)
        return results.x.reshape(
            (K, 7)
        ), results.status, results.fun, results.nfev, results.nit, results.jac

    @jit
    def do_sgd(key):
        def body(state, X):
            (Q, ) = state
            (key, ) = X
            n, w1, w2, gamma_prime, cdf_ref = generate_single_data(key)

            def loss(Q):
                return jnp.mean((vmap(
                    lambda gamma_prime:
                    cumulative_tomographic_weight_dimensionless_polynomial(
                        Q, gamma_prime, n, w1, w2))(gamma_prime) -
                                 cdf_ref)**2)  # + 0.1*jnp.mean(Q**2)

            f, g = value_and_grad(loss)(Q)
            Q = Q - 0.00000001 * g
            return (Q, ), (f, )

        (Q, ), (values, ) = scan(body, (Q0, ), (random.split(key, 1000), ))
        return Q.reshape((-1, 7)), values

    # results = do_minimize()
    Q, values = vmap(do_sgd)(random.split(random.PRNGKey(12456), 100))
    print('Qmean', Q.mean(0))
    print('Qstd', Q.std(0))

    f = values.mean(0)
    fstd = values.std(0)
    plt.plot(jnp.percentile(values, 50, axis=0))
    plt.plot(jnp.percentile(values, 85, axis=0), ls='dotted', c='black')
    plt.plot(jnp.percentile(values, 15, axis=0), ls='dotted', c='black')
    plt.show()

    # print(results)
    return Q.mean(0)
Example #22
0
def get_peaks_iter_daily(soln_inc,int=0,Tint=0,loCI=5,upCI=95):

  """
  calculates the peak daily incidence for a multiple runs, with or without an intervention
  soln_inc: 3D array of values for each iteration for each variable at each timepoint
  ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N)
  scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values)
  int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0
  Tint: Optional, timepoint (days) at which intervention was started
  loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval
  """

  if int==0:
    time_int=0
  else:
    time_int=Tint

  # Peak incidence
  peaks=np.amax(soln_inc[:,:,2],axis=1)
  print('Peak daily I1: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  peaks=np.amax(soln_inc[:,:,3],axis=1)
  print('Peak daily I2: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  peaks=np.amax(soln_inc[:,:,4],axis=1)
  print('Peak daily I3: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  peaks=np.amax(soln_inc[:,:,5],axis=1)
  print('Peak daily deaths: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))

  # Timing of peak incidence  
  tpeak=np.argmax(soln_inc[:,:,2],axis=1)+1.0-time_int
  print('Time of peak I1: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0)))
  tpeak=np.argmax(soln_inc[:,:,3],axis=1)+1.0-time_int
  print('Time of peak I2: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0)))
  tpeak=np.argmax(soln_inc[:,:,4],axis=1)+1.0-time_int
  print('Time of peak I3: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0)))
  tpeak=np.argmax(soln_inc[:,:,5],axis=1)+1.0-time_int
  print('Time of peak deaths: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak),np.percentile(tpeak,5.0),np.percentile(tpeak,95.0)))

  return
Example #23
0
def get_peaks_iter(soln,tvec,int=0,Tint=0,loCI=5,upCI=95):

  """
  calculates the peak prevalence for a multiple runs, with or without an intervention
  soln: 3D array of values for each iteration for each variable at each timepoint
  tvec: 1D vector of timepoints
  ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N)
  scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values)
  int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0
  Tint: Optional, timepoint (days) at which intervention was started
  loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval
  """

  delta_t=tvec[1]-tvec[0]

  if int==0:
    time_int=0
  else:
    time_int=Tint

  all_cases=soln[:,:,1]+soln[:,:,2]+soln[:,:,3]+soln[:,:,4]

  # Final values
  print('Final recovered: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(soln[:,-1,6]), 100*np.percentile(soln[:,-1,6],loCI), 100*np.percentile(soln[:,-1,6],upCI)))
  print('Final deaths: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(soln[:,-1,5]), 100*np.percentile(soln[:,-1,5],loCI), 100*np.percentile(soln[:,-1,5],upCI)))
  print('Remaining infections: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100*np.average(all_cases[:,-1]),100*np.percentile(all_cases[:,-1],loCI),100*np.percentile(all_cases[:,-1],upCI)))

  # Peak prevalence
  peaks=np.amax(soln[:,:,2],axis=1)
  print('Peak I1: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  peaks=np.amax(soln[:,:,3],axis=1)
  print('Peak I2: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  peaks=np.amax(soln[:,:,4],axis=1)
  print('Peak I3: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  
  # Timing of peaks
  tpeak=np.argmax(soln[:,:,2],axis=1)*delta_t-time_int
  print('Time of peak I1: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak), np.percentile(tpeak,loCI),np.percentile(tpeak,upCI)))
  tpeak=np.argmax(soln[:,:,3],axis=1)*delta_t-time_int
  print('Time of peak I2: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak),np.percentile(tpeak,loCI),np.percentile(tpeak,upCI)))
  tpeak=np.argmax(soln[:,:,4],axis=1)*delta_t-time_int
  print('Time of peak I3: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak),np.percentile(tpeak,loCI),np.percentile(tpeak,upCI)))
  
  # Time when all the infections go extinct
  time_all_extinct = np.array(get_extinction_time(all_cases,0))*delta_t-time_int

  print('Time of extinction of all infections post intervention: {:4.2f} days  [{:4.2f}, {:4.2f}]'.format(
      np.average(time_all_extinct),np.percentile(time_all_extinct,loCI),np.percentile(time_all_extinct,upCI)))
  
  return
Example #24
0
    def fit_STC(self,
                prewhiten=False,
                n_repeats=10,
                percentile=100.,
                random_seed=2046,
                verbose=5):
        """

        Spike-triggered Covariance Analysis.

        Parameters
        ==========

        transform: None or Str
            * None - Original X is used
            * 'whiten' - pre-whiten X
            * 'spline' - pre-whiten and smooth X by spline

        n_repeats: int
            Number of repeats for STC significance test.

        percentile: float
            Valid range of STC significance test.

        """
        def get_stc(X, y, w):

            n = len(X)
            ste = X[y != 0]
            proj = ste - ste * w * w.T
            stc = proj.T @ proj / (n - 1)

            eigvec, eigval, _ = np.linalg.svd(stc)

            return eigvec, eigval

        key = random.PRNGKey(random_seed)

        y = self.y

        if prewhiten:

            if self.compute_mle is False:
                self.XtX = self.X.T @ self.X
                self.w_mle = np.linalg.solve(self.XtX, self.XtY)

            X = np.linalg.solve(self.XtX, self.X.T).T
            w = uvec(self.w_mle)

        else:
            X = self.X
            w = uvec(self.w_sta)

        eigvec, eigval = get_stc(X, y, w)
        if n_repeats:
            print('STC significance test: ')
            eigval_null = []
            for counter in range(n_repeats):
                if verbose:
                    if counter % int(verbose) == 0:
                        print(f'  {counter+1:}/{n_repeats}')

                y_randomize = random.permutation(key, y)
                _, eigval_randomize = get_stc(X, y_randomize, w)
                eigval_null.append(eigval_randomize)
            else:
                if verbose:
                    print(f'Done.')
            eigval_null = np.vstack(eigval_null)
            max_null, min_null = np.percentile(eigval_null,
                                               percentile), np.percentile(
                                                   eigval_null,
                                                   100 - percentile)
            mask_sig_pos = eigval > max_null
            mask_sig_neg = eigval < min_null
            mask_sig = np.logical_or(mask_sig_pos, mask_sig_neg)

            self.w_stc = eigvec
            self.w_stc_pos = eigvec[:, mask_sig_pos]
            self.w_stc_neg = eigvec[:, mask_sig_neg]

            self.w_stc_eigval = eigval
            self.w_stc_eigval_mask = mask_sig
            self.w_stc_eigval_pos_mask = mask_sig_pos
            self.w_stc_eigval_neg_mask = mask_sig_neg

            self.w_stc_max_null = max_null
            self.w_stc_min_null = min_null

        else:
            self.w_stc = eigvec
            self.w_stc_eigval = eigval
            self.w_stc_eigval_mask = np.ones_like(eigval).astype(bool)
Example #25
0
def plot_cornerplot(results, vars=None, save_name=None):
    rkey0 = random.PRNGKey(123496)
    vars = _get_vars(results, vars)
    ndims = _get_ndims(results, vars)
    figsize = min(20, max(4, int(2 * ndims)))
    fig, axs = plt.subplots(ndims, ndims, figsize=(figsize, figsize))
    if ndims == 1:
        axs = [[axs]]
    nsamples = results.num_samples
    log_p = results.log_p[:results.num_samples]
    nbins = int(jnp.sqrt(results.ESS)) + 1
    lims = {}
    dim = 0
    for key in vars:  # sorted(results.samples.keys()):
        n1 = tuple_prod(results.samples[key].shape[1:])
        for i in range(n1):
            samples1 = results.samples[key][:results.num_samples, ...].reshape(
                (nsamples, -1))[:, i]
            if jnp.std(samples1) == 0.:
                dim += 1
                continue
            weights = jnp.where(jnp.isfinite(samples1), jnp.exp(log_p), 0.)
            log_weights = jnp.where(jnp.isfinite(samples1), log_p, -jnp.inf)
            samples1 = jnp.where(jnp.isfinite(samples1), samples1, 0.)
            # kde1 = gaussian_kde(samples1, weights=weights, bw_method='silverman')
            # samples1_resampled = kde1.resample(size=int(results.ESS))
            rkey0, rkey = random.split(rkey0, 2)
            samples1_resampled = resample(rkey,
                                          samples1,
                                          log_weights,
                                          S=int(results.ESS))
            binsx = jnp.linspace(*jnp.percentile(samples1_resampled, [0, 100]),
                                 2 * nbins)
            dim2 = 0
            for key2 in vars:  # sorted(results.samples.keys()):
                n2 = tuple_prod(results.samples[key2].shape[1:])
                for i2 in range(n2):
                    ax = axs[dim][dim2]
                    if dim2 > dim:
                        dim2 += 1
                        ax.set_xticks([])
                        ax.set_xticklabels([])
                        ax.set_yticks([])
                        ax.set_yticklabels([])
                        continue
                    if n2 > 1:
                        title2 = "{}[{}]".format(key2, i2)
                    else:
                        title2 = "{}".format(key2)
                    if n1 > 1:
                        title1 = "{}[{}]".format(key, i)
                    else:
                        title1 = "{}".format(key)
                    ax.set_title('{} {}'.format(title1, title2))
                    if dim == dim2:
                        # ax.plot(binsx, kde1(binsx))
                        ax.hist(samples1_resampled,
                                bins='auto',
                                fc='None',
                                edgecolor='black',
                                density=True)
                        sample_mean = jnp.average(samples1, weights=weights)
                        sample_std = jnp.sqrt(
                            jnp.average((samples1 - sample_mean)**2,
                                        weights=weights))
                        ax.set_title(
                            "{:.2f}:{:.2f}:{:.2f}\n{:.2f}+-{:.2f}".format(
                                *jnp.percentile(samples1_resampled,
                                                [5, 50, 95]), sample_mean,
                                sample_std))
                        ax.vlines(sample_mean,
                                  *ax.get_ylim(),
                                  linestyles='solid',
                                  colors='red')
                        ax.vlines([
                            sample_mean - sample_std, sample_mean + sample_std
                        ],
                                  *ax.get_ylim(),
                                  linestyles='dotted',
                                  colors='red')
                        ax.set_xlim(binsx.min(), binsx.max())
                        lims[dim] = ax.get_xlim()
                    else:
                        samples2 = results.samples[key2][:results.num_samples,
                                                         ...].reshape(
                                                             (nsamples,
                                                              -1))[:, i2]
                        if jnp.std(samples2) == 0.:
                            dim2 += 1
                            continue
                        weights = jnp.where(jnp.isfinite(samples2),
                                            jnp.exp(log_p), 0.)
                        log_weights = jnp.where(jnp.isfinite(samples2), log_p,
                                                -jnp.inf)
                        samples2 = jnp.where(jnp.isfinite(samples2), samples2,
                                             0.)
                        # kde2 = gaussian_kde(jnp.stack([samples1, samples2], axis=0),
                        #                     weights=weights,
                        #                     bw_method='silverman')
                        # samples2_resampled = kde2.resample(size=int(results.ESS))
                        rkey0, rkey = random.split(rkey0, 2)
                        samples2_resampled = resample(rkey,
                                                      jnp.stack(
                                                          [samples1, samples2],
                                                          axis=-1),
                                                      log_weights,
                                                      S=int(results.ESS))
                        # norm = plt.Normalize(log_weights.min(), log_weights.max())
                        # color = jnp.atleast_2d(plt.cm.jet(norm(log_weights)))
                        ax.hist2d(samples2_resampled[:, 1],
                                  samples2_resampled[:, 0],
                                  bins=(nbins, nbins),
                                  density=True,
                                  cmap=plt.cm.bone_r)
                        # ax.scatter(samples2_resampled[:, 1], samples2_resampled[:, 0], marker='+', c='black', alpha=0.5)
                        # binsy = jnp.linspace(*jnp.percentile(samples2_resampled[:, 1], [0, 100]), 2 * nbins)
                        # X, Y = jnp.meshgrid(binsx, binsy, indexing='ij')
                        # ax.contour(kde2(jnp.stack([X.flatten(), Y.flatten()], axis=0)).reshape((2 * nbins, 2 * nbins)),
                        #            extent=(binsy.min(), binsy.max(),
                        #                    binsx.min(), binsx.max()),
                        #            origin='lower')
                    if dim == ndims - 1:
                        ax.set_xlabel("{}".format(title2))
                    if dim2 == 0:
                        ax.set_ylabel("{}".format(title1))

                    dim2 += 1
            dim += 1
    for dim in range(ndims):
        for dim2 in range(ndims):
            if dim == dim2:
                continue
            ax = axs[dim][dim2] if ndims > 1 else axs[0]
            if dim in lims.keys():
                ax.set_ylim(lims[dim])
            if dim2 in lims.keys():
                ax.set_xlim(lims[dim2])
    if save_name is not None:
        fig.savefig(save_name)
    plt.show()
Example #26
0
# Train only gains, keeping all other weights fixed.
net = make_net([2048] * 6)
only_gains_final_params = train(
    net,
    init_params=net.init(random.PRNGKey(0), jnp.zeros((1, 28 * 28))),
    trainable_predicate=lambda k: kmatch("**/gain", k),
    log_prefix="only_gain")
only_gains_final_params_flat = flatten_params(only_gains_final_params)
print("  full model params:")
print(tree_map(jnp.shape, only_gains_final_params_flat))
gain_params = {
    k: v
    for k, v in only_gains_final_params_flat.items() if kmatch("**/gain", k)
}
gain_params_flat, unravel = ravel_pytree(gain_params)
cutoff = jnp.percentile(jnp.abs(gain_params_flat), config.remove_percentile)
# A mask that identifies only those gains that have the largest absolute
# value. Only keep the top `(100 - config.remove_percentile)%` gains. Note
# that this isn't the traditional LTH approach. It's more accurately described
# as a structural lottery ticket.
gain_mask = binarize(unravel(jnp.abs(gain_params_flat) > cutoff))
print(tree_map(jnp.sum, gain_mask))


def _lotteryify(k, v):
    """Take a parameter (key, value pair) and return a new parameter value with
  only the units in the `gain_mask` retained. This requires removing rows and
  columns across weight matrices in adjacent layers."""
    if kmatch("**/gain", k):
        return v[gain_mask[k]]
Example #27
0
    def fit_STC(self,
                prewhiten=False,
                n_repeats=10,
                percentile=100.,
                random_seed=2046,
                verbose=5):
        """

        Spike-triggered Covariance Analysis.

        Parameters
        ==========

        prewhiten: bool

        n_repeats: int
            Number of repeats for STC significance test.

        percentile: float
            Valid range of STC significance test.

        verbose: int
        random_seed: int
        """
        def get_stc(_X, _y, _w):

            n = len(_X)
            ste = _X[_y != 0]
            proj = ste - ste * _w * _w.T
            stc = proj.T @ proj / (n - 1)

            _eigvec, _eigval, _ = jnp.linalg.svd(stc)

            return _eigvec, _eigval

        key = random.PRNGKey(random_seed)

        y = self.y

        if prewhiten:

            if self.compute_mle is False:
                self.XtX = self.X.T @ self.X
                self.w_mle = jnp.linalg.solve(self.XtX, self.XtY)

            X = jnp.linalg.solve(self.XtX, self.X.T).T
            w = uvec(self.w_mle)

        else:
            X = self.X
            w = uvec(self.w_sta)

        eigvec, eigval = get_stc(X, y, w)

        self.w_stc = dict()
        if n_repeats:
            print('STC significance test: ')
            eigval_null = []
            for counter in range(n_repeats):
                if verbose:
                    if counter % int(verbose) == 0:
                        print(f'  {counter + 1:}/{n_repeats}')

                y_randomize = random.permutation(key, y)
                _, eigval_randomize = get_stc(X, y_randomize, w)
                eigval_null.append(eigval_randomize)
            else:
                if verbose:
                    print(f'Done.')
            eigval_null = jnp.vstack(eigval_null)
            max_null, min_null = jnp.percentile(eigval_null,
                                                percentile), jnp.percentile(
                                                    eigval_null,
                                                    100 - percentile)
            mask_sig_pos = eigval > max_null
            mask_sig_neg = eigval < min_null
            mask_sig = jnp.logical_or(mask_sig_pos, mask_sig_neg)

            self.w_stc['eigvec'] = eigvec
            self.w_stc['pos'] = eigvec[:, mask_sig_pos]
            self.w_stc['neg'] = eigvec[:, mask_sig_neg]

            self.w_stc['eigval'] = eigval
            self.w_stc['eigval_mask'] = mask_sig
            self.w_stc['eigval_pos_mask'] = mask_sig_pos
            self.w_stc['eigval_neg_mask'] = mask_sig_neg

            self.w_stc['max_null'] = max_null
            self.w_stc['min_null'] = min_null

        else:
            self.w_stc['eigvec'] = eigvec
            self.w_stc['eigval'] = eigval
            self.w_stc['eigval_mask'] = jnp.ones_like(eigval).astype(bool)
Example #28
0
def percentile(a, q, axis=None, interpolation='linear', keepdims=False):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(q, JaxArray): q = q.value
  r = jnp.percentile(a=a, q=q, axis=axis, interpolation=interpolation, keepdims=keepdims)
  return r if axis is None else JaxArray(r)
Example #29
0
def hist_bin_fd(x):
    iqr = jnp.subtract(*jnp.percentile(x, [75, 25]))
    return 2.0 * iqr * x.size**(-1.0 / 3.0)
    print(f'posterior for {p}')
    print_summary(post[p], 0.95, False)

# PPC

# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4), ))
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2),
                                         M=d.M.values,
                                         A=d.A.values)
mu = post_pred["mu"]

# summarize samples across cases
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)

ax = plt.subplot(ylim=(float(mu_PI.min()), float(mu_PI.max())),
                 xlabel="Observed divorce",
                 ylabel="Predicted divorce")
plt.plot(d.D, mu_mean, "o")
x = jnp.linspace(mu_PI.min(), mu_PI.max(), 101)
plt.plot(x, x, "--")
for i in range(d.shape[0]):
    plt.plot([d.D[i]] * 2, mu_PI[:, i], "b")
fig = plt.gcf()

for i in range(d.shape[0]):
    if d.Loc[i] in ["ID", "UT", "RI", "ME"]:
        ax.annotate(d.Loc[i], (d.D[i], mu_mean[i]),
                    xytext=(1, 0),