예제 #1
0
def ecdf(vals: np.ndarray, num_points: int = None) -> ArrayTup:
    """Evaluate the empirical distribution function on fixed number of points."""
    if num_points is None:
        num_points = len(vals)
    cdf = np.linspace(0, 1, num_points)
    t = np.quantile(vals, cdf)
    return t, cdf
예제 #2
0
파일: _array.py 프로젝트: coax-dev/coax
def get_magnitude_quantiles(pytree, key_prefix=''):
    r"""

    Given a :doc:`pytree <pytrees>`, return a dict that contains the quantiles of the magnitudes of
    each individual component.

    This is meant to be a high-level diagnostic. It first extracts the leaves of the pytree, then
    flattens each leaf and then it computes the element-wise magnitude. Then, it concatenates all
    magnitudes into one long flat array. The quantiles are computed on this array.

    Parameters
    ----------
    pytree : a pytree with ndarray leaves

        A typical example is a pytree of model params (weights) or gradients with respect to such
        model params.

    key_prefix : str, optional

        The prefix to add the output dict keys.

    Returns
    -------
    magnitude_quantiles : dict

        A dict with keys: ``['min', 'p25', 'p50', 'p75', 'max']``. The values of the dict are
        non-negative floats that represent the magnitude quantiles.

    """
    quantiles = jnp.quantile(jnp.abs(tree_ravel(pytree)), jnp.array([0, 0.25, 0.5, 0.75, 1]))
    quantile_names = (f'{key_prefix}{k}' for k in ('min', 'p25', 'p50', 'p75', 'max'))
    return dict(zip(quantile_names, quantiles))
예제 #3
0
 def clean_chain_ar(self, abc_scenario: ABCScenario, chain_state: cdict):
     threshold = jnp.quantile(chain_state.distance,
                              self.parameters.acceptance_rate)
     self.parameters.threshold = float(threshold)
     chain_state.log_weight = jnp.where(chain_state.distance < threshold,
                                        0., -jnp.inf)
     return chain_state
예제 #4
0
def get_quantiles(x, n_quantiles: int = 1_000):

    # create outputs (p=[0,1])
    references = np.linspace(0, 1, num=np.maximum(n_quantiles, x.shape[0]))

    # get quantiles
    quantiles = np.quantile(x, references, axis=0)

    return quantiles, references
예제 #5
0
파일: baseball.py 프로젝트: ucals/numpyro
def print_results(header, preds, player_names, at_bats, hits):
    columns = ['', 'At-bats', 'ActualHits', 'Pred(p25)', 'Pred(p50)', 'Pred(p75)']
    header_format = '{:>20} {:>10} {:>10} {:>10} {:>10} {:>10}'
    row_format = '{:>20} {:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f}'
    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)
    print('\n', header, '\n')
    print(header_format.format(*columns))
    for i, p in enumerate(player_names):
        print(row_format.format(p, at_bats[i], hits[i], *quantiles[:, i]), '\n')
예제 #6
0
def print_results(header, preds, dept, male, probs):
    columns = ['Dept', 'Male', 'ActualProb', 'Pred(p25)', 'Pred(p50)', 'Pred(p75)']
    header_format = '{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}'
    row_format = '{:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}'
    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)
    print('\n', header, '\n')
    print(header_format.format(*columns))
    for i in range(len(dept)):
        print(row_format.format(dept[i], male[i], probs[i], *quantiles[:, i]), '\n')
예제 #7
0
def print_results(header, preds, player_names, at_bats, hits):
    columns = ["", "At-bats", "ActualHits", "Pred(p25)", "Pred(p50)", "Pred(p75)"]
    header_format = "{:>20} {:>10} {:>10} {:>10} {:>10} {:>10}"
    row_format = "{:>20} {:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f}"
    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)
    print("\n", header, "\n")
    print(header_format.format(*columns))
    for i, p in enumerate(player_names):
        print(row_format.format(p, at_bats[i], hits[i], *quantiles[:, i]), "\n")
예제 #8
0
 def _print_row(values, row_name=''):
     quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
     row_name_fmt = '{:>8}'
     header_format = row_name_fmt + '{:>12}' * 5
     row_format = row_name_fmt + '{:>12.3f}' * 5
     columns = ['(p{})'.format(q * 100) for q in quantiles]
     q_values = jnp.quantile(values, quantiles, axis=0)
     print(header_format.format('', *columns))
     print(row_format.format(row_name, *q_values))
     print('\n')
예제 #9
0
 def _print_row(values, row_name=""):
     quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
     row_name_fmt = "{:>8}"
     header_format = row_name_fmt + "{:>12}" * 5
     row_format = row_name_fmt + "{:>12.3f}" * 5
     columns = ["(p{})".format(int(q * 100)) for q in quantiles]
     q_values = jnp.quantile(values, quantiles, axis=0)
     print(header_format.format("", *columns))
     print(row_format.format(row_name, *q_values))
     print("\n")
예제 #10
0
def print_results(header, preds, dept, male, probs):
    columns = [
        "Dept", "Male", "ActualProb", "Pred(p25)", "Pred(p50)", "Pred(p75)"
    ]
    header_format = "{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}"
    row_format = "{:>10.0f} {:>10.0f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}"
    quantiles = jnp.quantile(preds, jnp.array([0.25, 0.5, 0.75]), axis=0)
    print("\n", header, "\n")
    print(header_format.format(*columns))
    for i in range(len(dept)):
        print(row_format.format(dept[i], male[i], probs[i], *quantiles[:, i]),
              "\n")
예제 #11
0
파일: hsgp.py 프로젝트: pyro-ppl/numpyro
def plot_trend(data, samples, ax=None):
    y = data["births_relative"]
    x = data["date"]
    fsd = samples["intercept"][:, None] + samples["trend/f"]
    f = jnp.quantile(fsd * y.std() + y.mean(), 0.50, axis=0)

    if ax is None:
        ax = plt.gca()

    ax.plot(x, y, **DATA_STYLE)
    ax.plot(x, f, **MODEL_STYLE)
    return ax
예제 #12
0
 def test_softquantile(self, quantile, axis):
     x = np.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
                    [17.9, 14.2, 55.5, 9.8, 3.5]],
                   [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5],
                    [1.9, 4.2, 55.5, 9.8, 1.5]]])
     qs = ops.softquantile(x, quantile, axis=axis)
     s = list(x.shape)
     s.pop(axis)
     self.assertTupleEqual(qs.shape, tuple(s))
     self.assertAllClose(qs,
                         np.quantile(x, quantile, axis=axis),
                         True,
                         rtol=1e-2)
예제 #13
0
 def test_softquantile(self, quantile, axis):
     x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
                     [17.9, 14.2, 55.5, 9.8, 3.5]],
                    [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5],
                     [1.9, 4.2, 55.5, 9.8, 1.5]]])
     qs = ops.softquantile(x,
                           quantile,
                           axis=axis,
                           threshold=1e-3,
                           epsilon=1e-3)
     s = list(x.shape)
     s.pop(axis)
     self.assertTupleEqual(qs.shape, tuple(s))
     np.testing.assert_allclose(qs,
                                jnp.quantile(x, quantile, axis=axis),
                                rtol=1e-2)
예제 #14
0
def print_results(
    model_name: str,
    predictions: jnp.ndarray,
    at_bats: jnp.ndarray,
    hits: jnp.ndarray,
    player_names: np.ndarray,
    is_train: bool,
) -> None:

    header = model_name + (" - train" if is_train else " - test")
    quantiles = jnp.quantile(predictions, jnp.array([0.25, 0.5, 0.75]), axis=0)
    print("\n", header, "\n")
    for i, p in enumerate(player_names):
        print(
            f"{p}: {at_bats[i]}, {hits[i]}, {quantiles[0, i]:.2f}, {quantiles[1, i]:.2f}, "
            f"{quantiles[2, i]:.2f}"
        )
예제 #15
0
def create_spline_basis(x,
                        knot_list=None,
                        num_knots=None,
                        degree=3,
                        add_intercept=True):
    assert ((knot_list is None) and (num_knots is not None)) or (
        (knot_list is not None) and
        (num_knots is None)), "Define knot_list OR num_knot"
    if knot_list is None:
        knot_list = jnp.quantile(x, q=jnp.linspace(0, 1, num=num_knots))
    else:
        num_knots = len(knot_list)

    knots = jnp.pad(knot_list, (degree, degree), mode="edge")
    B0 = BSpline(knots, jnp.identity(num_knots + 2), k=degree)
    B = B0(x)
    Bdiff = B0.derivative()(x)
    if add_intercept:
        B = jnp.hstack([jnp.ones(B.shape[0]).reshape(-1, 1), B])
        Bdiff = jnp.hstack([jnp.zeros(B.shape[0]).reshape(-1, 1), Bdiff])
    return knot_list, B, Bdiff
예제 #16
0
def logit_transformer(logits,
                      temp=1.0,
                      confidence_quantile_threshold=1.0,
                      self_supervised_label_transformation='soft',
                      logit_indices=None):
    """Transforms logits into labels used as targets in a loss functions.

  Args:
    logits: jnp float array; Prediction of a model.
    temp: float; Softmax temp.
    confidence_quantile_threshold: float; Training examples are weighted based
      on this.
    self_supervised_label_transformation: str; Type of labels to produce (soft
      or sharp).
    logit_indices: list(int); Usable Indices for logits (list of indices to
      use).

  Returns:

  """
    # Compute confidence for each prediction:
    confidence = jnp.amax(logits, axis=-1) - jnp.amin(logits, axis=-1)

    # Compute confidence threshold:
    alpha = jnp.quantile(confidence, confidence_quantile_threshold)
    # Only train on confident outputs:
    weights = jnp.float32(confidence >= alpha)

    if self_supervised_label_transformation == 'sharp':
        if logit_indices:
            logits = logits[Ellipsis, logit_indices]
        new_labels = jnp.argmax(logits, axis=-1)
    elif self_supervised_label_transformation == 'soft':
        new_labels = nn.softmax(logits / (temp or 1.0), axis=-1)
    else:
        new_labels = logits

    return new_labels, weights
예제 #17
0
파일: hsgp.py 프로젝트: pyro-ppl/numpyro
def plot_weektrend(data, samples, ax=None):
    dates = data["date"]
    weekdays = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
    y = data["births_relative"]
    mean, sdev = y.mean(), y.std()
    intercept = samples["intercept"][:, None]
    f1 = samples["trend/f"]
    f2 = samples["year/f"]
    g3 = samples["week-trend/f"]
    baseline = ((intercept + f1 + f2) * y.std()).mean(0)

    if ax is None:
        ax = plt.gca()

    ax.plot(dates, y - baseline, **DATA_STYLE)
    for n, day in enumerate(weekdays):
        week_beta = samples["week/beta"][:, n][:, None]
        fsd = jnp.exp(g3) * week_beta
        f = jnp.quantile(fsd * sdev + mean, 0.50, axis=0)
        ax.plot(dates, f, **MODEL_STYLE)
        ax.text(dates.iloc[-1], f[-1], day)

    return ax
예제 #18
0

def get_problem(dim, n0, a, n):
    return BananaModel(dim=dim, a=a).get_problem(n=n, n0=n0)


if __name__ == "__main__":
    T = 1000 / 100000
    npr.seed(463728)

    banana = BananaModel(a=250)
    X = banana.generate_test_data()

    fig, ax = plt.subplots()
    banana.scatterplot_posterior(X, ax[0], 1)
    banana.plot_posterior(X, ax, 1)
    plt.show()

    banana = BananaModel(a=5)
    X = banana.generate_test_data()
    fig, ax = plt.subplots()
    banana.scatterplot_posterior(X, ax, T)
    banana.plot_posterior(X, ax, T, 10)
    plt.show()

    post = banana.generate_posterior_samples(10000, X, 1)
    print(np.quantile(post[:, 0], np.array([0.01, 0.99])))
    print(np.quantile(post[:, 1], np.array([0.01, 0.99])))
    for i in range(banana.dim - 2):
        print(np.quantile(post[:, i + 2], np.array((0.01, 0.99))))
예제 #19
0
def quantile(a, q, axis=None, interpolation='linear', keepdims=False):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(q, JaxArray): q = q.value
  r = jnp.quantile(a=a, q=q, axis=axis, interpolation=interpolation, keepdims=keepdims)
  return r if axis is None else JaxArray(r)
예제 #20
0
파일: handler.py 프로젝트: sagar87/numgp
 def quantiles(self, param, which="posterior", *args, **kwargs):
     return jnp.quantile(self._select(which)[param], jnp.array([0.05, 0.95]), axis=0)
예제 #21
0
    problem = util.Problem(
        model.log_likelihood_per_sample, model.log_prior, data, 1,
        model.true_mean, model.generate_true_posterior,
        lambda problem, ax: model.plot_posterior(ax, problem.data))
    return problem


if __name__ == "__main__":
    dim = 2
    model = GaussModel(dim)
    data = model.generate_data(100000)
    posterior = model.generate_true_posterior(1000, data)
    fig, ax = plt.subplots()
    model.plot_posterior(ax, data)
    plt.scatter(posterior[:, 0], posterior[:, 1])
    plt.show()
    for i in range(dim):
        print(np.quantile(posterior[:, i], np.array([0.01, 0.99])))

    # x = np.repeat(1, dim)
    # theta = np.repeat(2, dim)
    # print(log_likelihood_per_sample(x, theta, model.cov))
    # L = np.linalg.cholesky(model.cov)
    # print(log_likelihood_per_sample_fast(x, theta, L))

    # log_likelihood_per_sample(np.zeros(dim), np.zeros(dim), np.eye(dim))
    # log_likelihood_per_sample_fast(np.zeros(dim), np.zeros(dim), np.eye(dim))
    # from timeit import timeit
    # print(timeit(lambda: log_likelihood_per_sample(x, theta, model.cov), number=1))
    # print(timeit(lambda: log_likelihood_per_sample_fast(x, theta, L), number=1))
예제 #22
0
파일: generic.py 프로젝트: wesselb/lab
def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None):
    return jnp.quantile(a, q, axis=axis, interpolation="linear")
# Compute density on a grid.
lin = jnp.linspace(-jnp.pi, jnp.pi)
xx, yy = jnp.meshgrid(lin, lin)
theta = jnp.vstack((xx.ravel(), yy.ravel())).T
ptor = pm.torus.ang2euclid(theta)
prob = importance_density(rng_is, bij_params, deq_params, 10000, ptor)
aprob = torus_density(theta)

# Visualize learned distribution.
fig, axes = plt.subplots(1, 4, figsize=(16, 5))
axes[0].plot(trace)
axes[0].grid(linestyle=':')
axes[0].set_ylabel('Combined Loss')
axes[0].set_xlabel('Gradient Descent Iteration')
num_plot = 10000
axes[1].plot(xobs[:num_plot, 0], xobs[:num_plot, 1], '.', alpha=0.2, label='Rejection Sampling')
axes[1].plot(xang[:num_plot, 0], xang[:num_plot, 1], '.', alpha=0.2, label='Dequantization Sampling')
axes[1].grid(linestyle=':')
leg = axes[1].legend()
for lh in leg.legendHandles:
    lh._legmarker.set_alpha(1)

axes[2].contourf(xx, yy, jnp.clip(prob, 0., jnp.quantile(prob, 0.95)).reshape(xx.shape))
axes[2].set_title('Importance Sample Density Estimate')
axes[3].contourf(xx, yy, aprob.reshape(xx.shape))
axes[3].set_title('Analytic Density')
plt.suptitle('Mean MSE: {:.5f} - Covariance MSE: {:.5f} - KL$(q\Vert p)$ = {:.5f} - Rel. ESS: {:.2f}%'.format(mean_mse, cov_mse, klqp, ress))
plt.tight_layout()
ln = 'elbo' if args.elbo_loss else 'kl'
plt.savefig(os.path.join('images', '{}-{}-num-batch-{}-num-importance-{}-num-steps-{}-seed-{}.png'.format(ln, args.density, args.num_batch, args.num_importance, args.num_steps, args.seed)))
예제 #24
0
 def next_threshold_adaptive(self,
                             state: cdict,
                             extra: cdict) -> float:
     return jnp.quantile(state.distance, extra.parameters.ess_threshold_retain * state.ess[0] / state.ess.size)
예제 #25
0
 def f(x):
     return np.quantile(x, q, 0)
예제 #26
0
 def quantile(self, x, quantiles):
     return jnp.quantile(x, quantiles, axis=-1)