Exemple #1
0
    def apply_kernel(self,
                     scaling: jnp.ndarray,
                     eps: float = None,
                     axis: int = None):
        """Applies grid kernel on scaling vector.

    See notes in parent class for use.

    Reshapes scaling vector as a grid, applies kernels onto each slice, and
    then ravels backs the output as a vector.

    More implementation details in https://arxiv.org/pdf/1708.01955.pdf

    Args:
      scaling: jnp.ndarray, a vector of scaling (>0) values.
      eps: float, regularization strength
      axis: axis (0 or 1) along which summation should be carried out.

    Returns:
      a vector, the result of kernel applied onto scaling.
    """
        scaling = jnp.reshape(scaling, self.grid_size)
        indices = list(range(1, self.grid_dimension))
        for dimension, kernel in enumerate(self.kernel_matrices):
            ind = indices.copy()
            ind.insert(dimension, 0)
            scaling = jnp.tensordot(kernel, scaling,
                                    axes=([0], [dimension])).transpose(ind)
        return scaling.ravel()
Exemple #2
0
def _save_results(
    x: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
    num_train: int,
) -> None:

    root = pathlib.Path("./data/seasonal")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    x_pred = posterior_predictive["x"]

    x_pred_trn = x_pred[:, :num_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    t_train = np.arange(num_train)

    x_pred_tst = x_pred[:, num_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)
    num_test = x_pred_tst.shape[1]
    t_test = np.arange(num_train, num_train + num_test)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 6))
    plt.plot(x.ravel(), label="ground truth", color=colors[0])

    plt.plot(t_train,
             x_pred_trn.mean(0)[:, 0],
             label="prediction",
             color=colors[1])
    plt.fill_between(t_train,
                     x_hpdi_trn[0, :, 0, 0],
                     x_hpdi_trn[1, :, 0, 0],
                     alpha=0.3,
                     color=colors[1])

    plt.plot(t_test,
             x_pred_tst.mean(0)[:, 0],
             label="forecast",
             color=colors[2])
    plt.fill_between(t_test,
                     x_hpdi_tst[0, :, 0, 0],
                     x_hpdi_tst[1, :, 0, 0],
                     alpha=0.3,
                     color=colors[2])

    plt.ylim(x.min() - 0.5, x.max() + 0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(root / "prediction.png")
    plt.close()
def plot_xhats_distr(x: np.ndarray, xhats: np.ndarray) -> plt.Figure:
  fig = plt.figure()
  plt.hist(xhats.ravel(), bins=50, density=True, alpha=0.3, label="sampled x")
  plt.hist(x, bins=50, density=True, alpha=0.3, label="actual x (data)")
  plt.legend()
  plt.xlabel("x")
  plt.ylabel("density")
  plt.title("Distribution of pre-sampled and actual x")
  return fig
Exemple #4
0
def _save_results(
    x: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
) -> None:

    root = pathlib.Path("./data/kalman")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    len_train = x.shape[0]

    x_pred_trn = posterior_predictive["x"][:, :len_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    x_pred_tst = posterior_predictive["x"][:, len_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)

    len_test = x_pred_tst.shape[1]

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 6))
    plt.plot(x.ravel(), label="ground truth", color=colors[0])

    t_train = np.arange(len_train)
    plt.plot(t_train,
             x_pred_trn.mean(0).ravel(),
             label="prediction",
             color=colors[1])
    plt.fill_between(t_train,
                     x_hpdi_trn[0].ravel(),
                     x_hpdi_trn[1].ravel(),
                     alpha=0.3,
                     color=colors[1])

    t_test = np.arange(len_train, len_train + len_test)
    plt.plot(t_test,
             x_pred_tst.mean(0).ravel(),
             label="forecast",
             color=colors[2])
    plt.fill_between(t_test,
                     x_hpdi_tst[0].ravel(),
                     x_hpdi_tst[1].ravel(),
                     alpha=0.3,
                     color=colors[2])

    plt.legend()
    plt.tight_layout()
    plt.savefig(root / "kalman.png")
    plt.close()
Exemple #5
0
    def apply_lse_kernel(self,
                         f: jnp.ndarray,
                         g: jnp.ndarray,
                         eps: float,
                         vec: Optional[jnp.ndarray] = None,
                         axis: int = 0):
        """Applies grid kernel in log space. See notes in parent class for use case.

    Reshapes vector inputs below as grids, applies kernels onto each slice, and
    then expands the outputs as vectors.

    More implementation details in https://arxiv.org/pdf/1708.01955.pdf

    Args:
      f: jnp.ndarray, a vector of potentials
      g: jnp.ndarray, a vector of potentials
      eps: float, regularization strength
      vec: jnp.ndarray, if needed, a vector onto which apply the kernel weighted
        by f and g.
      axis: axis (0 or 1) along which summation should be carried out.

    Returns:
      a vector, the result of kernel applied in lse space onto vec.
    """
        f, g = jnp.reshape(f, self.grid_size), jnp.reshape(g, self.grid_size)

        if vec is not None:
            vec = jnp.reshape(vec, self.grid_size)

        if axis == 0:
            f, g = g, f

        for dimension in range(self.grid_dimension):
            g, vec = self._apply_lse_kernel_one_dimension(
                dimension, f, g, eps, vec)
            g -= jnp.where(jnp.isfinite(f), f, 0)
        if vec is None:
            vec = jnp.array(1.0)
        return g.ravel(), vec.ravel()
Exemple #6
0
def externalize(x: jnp.ndarray,
                dim: int,
                was_jax: bool = False) -> Union[np.ndarray, jnp.ndarray]:
    """Takes internal representation of 2D (time, feature) and returns to user
    Args:
        x (Union[jnp.ndarray, Real]): some input data to be converted
        dim (int): The original number of dimensions (-1 for scalar, n for nD
        NumPy array)
        was_jax (bool): whether original array was jax.numpy.DeviceArray

    Warning:
        * Transformations may not always be from R^n to R^n, so we have to make
        some guesses
        * We don't check whether the NumPy array is well-formed or not (e.g.,
        single data type, jnp.product(x.shape) == jnp.len(x.ravel()))

    Notes:
        * We assume if the original input was less than 2, we should squeeze
        out as many dimensions as we can
        * We leave input as is if original dimension was 2
        * It's not clear when to use this, especially if users are not
        consistent with their inputs

    Returns:
        x (Union[np.ndarray, jnp.ndarray]): externalized version of input

    Raises:
        TypeError: x is not a NumPy array with two dimensions
        ValueError: dim is not -1, 0, 1, or 2; input x does not match dim
        or dim
    """
    if not isinstance(x, jnp.ndarray):
        raise TypeError("x has incorrect type {}".format(type(x)))

    if was_jax:
        x = np.asarray(x)
    else:
        x = jnp.asarray(x)

    if x.ndim != 2:
        raise TypeError("x does not have two dimensions")
    if dim < -1 or dim > 2:
        raise ValueError("Original dimension must be -1, 0, 1, or 2")

    # If the original dimension was 2, just return
    if dim == 2:
        return x

    # Need to be careful with x with one item
    if len(x.ravel()) == 1:

        # If the original was a scalar, return a scalar
        if dim == -1:
            return x.item()

        # If the original was 1D, return 1D
        if dim == 1:
            return x.ravel()

    # Otherwise, try to squeeze as many dimensions as we can
    if dim < 2:
        return x.squeeze()
def run_optim(key: np.ndarray, lhs: np.ndarray, tmp: np.ndarray,
              xhats: np.ndarray, tmp_c: np.ndarray, xhats_c: np.ndarray,
              xstar: float, bound: Text, out_dir: Text, x: np.ndarray,
              y: np.ndarray) -> Tuple[int, float, float, int, float, float]:
    """Run optimization (either lower or upper) for a single xstar."""
    # Directory setup
    # ---------------------------------------------------------------------------
    out_dir = os.path.join(out_dir, f"{bound}-xstar_{xstar}")
    if FLAGS.store_data:
        logging.info(f"Current run output directory: {out_dir}...")
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

    # Init optim params
    # ---------------------------------------------------------------------------
    logging.info(
        f"Initialize parameters L, mu, log_sigma, lmbda, tau, slack...")
    key, subkey = random.split(key)
    params = init_params(subkey)

    for parname, param in zip(['L', 'mu', 'log_sigma'], params):
        logging.info(f"Parameter {parname}: {param.shape}")
        logging.info(f"  -> {parname}: {param}")

    tau = FLAGS.tau_init
    logging.info(f"Initial tau = {tau}")
    fin_tau = np.minimum(FLAGS.tau_factor**FLAGS.num_rounds * tau,
                         FLAGS.tau_max)
    logging.info(f"Final tau = {fin_tau}")

    # Set constraint approach and slacks
    # ---------------------------------------------------------------------------
    slack = FLAGS.slack * np.ones(FLAGS.num_z * 2)
    lmbda = np.zeros(FLAGS.num_z * 2)
    logging.info(f"Lambdas: {lmbda.shape}")

    logging.info(
        f"Fractional tolerance (slack) for constraints = {FLAGS.slack}")
    logging.info(f"Set relative slack variables...")
    slack *= np.abs(lhs.ravel())
    logging.info(f"Set minimum slack to {FLAGS.slack_abs}...")
    slack = np.maximum(FLAGS.slack_abs, slack)
    logging.info(f"Slack {slack.shape}")
    logging.info(f"Actual slack min: {np.min(slack)}, max: {np.max(slack)}")

    # Setup optimizer
    # ---------------------------------------------------------------------------
    logging.info(f"Vanilla SGD with init_lr={FLAGS.lr}...")
    logging.info(f"Set learning rate schedule")
    step_size = optim.inverse_time_decay(FLAGS.lr, FLAGS.decay_steps,
                                         FLAGS.decay_rate, FLAGS.staircase)
    init_fun, update_fun, get_params = optim.sgd(step_size)

    logging.info(
        f"Init state for JAX optimizer (including L, mu, log_sigma)...")
    state = init_fun(params)

    # Setup result dict
    # ---------------------------------------------------------------------------
    logging.info(f"Initialize dictionary for results...")
    results = {
        "mu": [],
        "sigma": [],
        "cholesky_factor": [],
        "tau": [],
        "lambda": [],
        "objective": [],
        "constraint_term": [],
        "rhs": []
    }
    if FLAGS.plot_intermediate:
        results["grad_norms"] = []
        results["lagrangian"] = []

    logging.info(f"Evaluate at xstar={xstar}...")

    logging.info(f"Evaluate {bound} bound...")
    sign = 1 if bound == "lower" else -1

    # ===========================================================================
    # OPTIMIZATION LOOP
    # ===========================================================================
    # One-time logging before first step
    # ---------------------------------------------------------------------------
    key, subkey = random.split(key)
    obj, rhs, psisum, constr = objective_rhs_psisum_constr(
        subkey, get_params(state), lmbda, tau, lhs, slack, xstar, tmp_c,
        xhats_c)
    results["objective"].append(obj)
    results["constraint_term"].append(psisum)
    results["rhs"].append(rhs)

    logging.info(f"Objective: scalar")
    logging.info(f"RHS: {rhs.shape}")
    logging.info(f"Sum over Psis: scalar")
    logging.info(f"Constraint: {constr.shape}")

    tril_idx = np.tril_indices(FLAGS.dim_theta + 1)
    count = 0
    logging.info(f"Start optimization loop...")
    for _ in tqdm(range(FLAGS.num_rounds)):
        # log current parameters
        # -------------------------------------------------------------------------
        results["lambda"].append(lmbda)
        results["tau"].append(tau)
        cur_L, cur_mu, cur_logsigma = get_params(state)
        cur_chol = make_cholesky_factor(cur_L)[tril_idx].ravel()[1:]
        results["mu"].append(cur_mu)
        results["sigma"].append(np.exp(cur_logsigma))
        results["cholesky_factor"].append(cur_chol)

        subkeys = random.split(key, num=FLAGS.opt_steps + 1)
        key = subkeys[0]
        # inner optimization for subproblem
        # -------------------------------------------------------------------------
        for j in range(FLAGS.opt_steps):
            v, grads = lagrangian_value_and_grad(subkeys[j + 1],
                                                 get_params(state), lmbda, tau,
                                                 lhs, slack, xstar, tmp, xhats,
                                                 sign)
            state = update_fun(count, grads, state)
            count += 1
            if FLAGS.plot_intermediate:
                results["lagrangian"].append(v)
                results["grad_norms"].append(
                    [np.linalg.norm(grad) for grad in grads])

        # post inner optimization logging
        # -------------------------------------------------------------------------
        key, subkey = random.split(key)
        obj, rhs, psisum, constr = objective_rhs_psisum_constr(
            subkey, get_params(state), lmbda, tau, lhs, slack, xstar, tmp_c,
            xhats_c)
        results["objective"].append(obj)
        results["constraint_term"].append(psisum)
        results["rhs"].append(rhs)

        # update lambda, tau
        # -------------------------------------------------------------------------
        lmbda = update_lambda(constr, lmbda, tau)
        tau = np.minimum(tau * FLAGS.tau_factor, FLAGS.tau_max)

    # Convert and store results
    # ---------------------------------------------------------------------------
    logging.info(f"Finished optimization loop...")

    logging.info(f"Convert all results to numpy arrays...")
    results = {k: np.array(v) for k, v in results.items()}

    logging.info(f"Add final parameters and lhs to results...")
    L, mu, log_sigma = get_params(state)
    results["final_L"] = L
    results["final_mu"] = mu
    results["final_log_sigma"] = log_sigma
    results["lhs"] = lhs

    if FLAGS.store_data:
        logging.info(f"Save result data to...")
        result_path = os.path.join(out_dir, "results.npz")
        onp.savez(result_path, **results)

    # Generate and store plots
    # ---------------------------------------------------------------------------
    if FLAGS.plot_intermediate:
        fig_dir = os.path.join(out_dir, "figures")
        logging.info(f"Generate and save all plots at {fig_dir}...")
        plotting.plot_all(results, x, y, response, fig_dir)

    # Compute last valid and last satisfied
    # ---------------------------------------------------------------------------
    maxabsdiff = np.array([np.max(np.abs(lhs - r)) for r in results["rhs"]])
    fin_i = np.sum(~np.isnan(results["objective"])) - 1
    logging.info(f"Final non-nan objective at {fin_i}.")
    fin_obj = results["objective"][fin_i]
    fin_maxabsdiff = maxabsdiff[fin_i]

    sat_i = [
        np.all((np.abs((lhs - r) / lhs) < FLAGS.slack)
               | (np.abs(lhs - r) < FLAGS.slack_abs)) for r in results["rhs"]
    ]
    sat_i = np.where(sat_i)[0]

    if len(sat_i) > 0:
        sat_i = sat_i[-1]
        logging.info(f"Final satisfied constraint at {sat_i}.")
        sat_obj = results["objective"][sat_i]
        sat_maxabsdiff = maxabsdiff[sat_i]
    else:
        sat_i = -1
        logging.info(f"Constraints were never satisfied.")
        sat_obj, sat_maxabsdiff = np.nan, np.nan

    logging.info("Finished run.")
    return fin_i, fin_obj, fin_maxabsdiff, sat_i, sat_obj, sat_maxabsdiff