def sdcorr_params_to_sds_and_corr_jax(sdcorr_params):
    dim = number_of_triangular_elements_to_dimension_jax(len(sdcorr_params))
    sds = jnp.array(sdcorr_params[:dim])
    corr = jnp.eye(dim)
    corr = index_update(corr, index[jnp.tril_indices(dim, k=-1)], sdcorr_params[dim:])
    corr += jnp.tril(corr, k=-1).T
    return sds, corr
Example #2
0
def dot_interact(concat_features, keep_diags=True):
    """Performs feature interaction operation between dense or sparse features.

  Input tensors represent dense or sparse features.
  Pre-condition: The tensors have been stacked along dimension 1.

  Args:
    concat_features: Array of features with shape [B, n_features, feature_dim].
    keep_diags: Whether to keep the diagonal terms of x @ x.T.

  Returns:
    activations: Array representing interacted features.
  """
    batch_size = concat_features.shape[0]

    # Interact features, select upper or lower-triangular portion, and re-shape.
    xactions = jnp.matmul(concat_features,
                          jnp.transpose(concat_features, [0, 2, 1]))
    feature_dim = xactions.shape[-1]

    if keep_diags:
        indices = jnp.array(jnp.triu_indices(feature_dim))
    else:
        indices = jnp.array(jnp.tril_indices(feature_dim))
    num_elems = indices.shape[1]
    indices = jnp.tile(indices, [1, batch_size])
    indices0 = jnp.reshape(
        jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
        [1, -1])
    indices = tuple(jnp.concatenate((indices0, indices), 0))
    activations = xactions[indices]
    activations = jnp.reshape(activations, [batch_size, -1])
    return activations
Example #3
0
def logf(r, sig, p):
    first = -jnp.sum(r**2) / (2 * sig**2)
    second = jnp.sum(
        jnp.log(
            jnp.sinh(jnp.abs(r[:, None] - r[None, :]))[jnp.tril_indices(
                p, k=-1)]))
    return first + second
Example #4
0
def lo_tri_from_elements(elements, n):

    L = np.zeros((n, n))
    indices = np.tril_indices(n)
    L = index_update(L, indices, elements)

    return L
Example #5
0
def lo_tri_from_elements(elements, n):

    L = jnp.zeros((n, n))
    indices = jnp.tril_indices(L.shape[0])
    L = index_update(L, indices, elements)

    return L
 def _kernel_matrix_without_gradients(kernel_fn, theta, X, Y):
     kernel_fn = partial(kernel_fn, theta)
     if Y is None or (Y is X):
         if config_value('KERNEL_MATRIX_USE_LOOP'):
             n = len(X)
             with loops.Scope() as s:
                 # s.scattered_values = np.empty((n, n))
                 s.index1, s.index2 = np.tril_indices(n, k=0)
                 s.output = np.empty(len(s.index1))
                 for i in s.range(s.index1.shape[0]):
                     i1, i2 = s.index1[i], s.index2[i]
                     s.output = ops.index_update(s.output, i,
                                                 kernel_fn(X[i1], X[i2]))
             first_update = ops.index_update(np.empty((n, n)),
                                             (s.index1, s.index2), s.output)
             second_update = ops.index_update(first_update,
                                              (s.index2, s.index1),
                                              s.output)
             return second_update
         else:
             n = len(X)
             values_scattered = np.empty((n, n))
             index1, index2 = np.tril_indices(n, k=-1)
             inst1, inst2 = X[index1], X[index2]
             values = vmap(kernel_fn)(inst1, inst2)
             values_scattered = ops.index_update(values_scattered,
                                                 (index1, index2), values)
             values_scattered = ops.index_update(values_scattered,
                                                 (index2, index1), values)
             values_scattered = ops.index_update(
                 values_scattered, np.diag_indices(n),
                 vmap(lambda x: kernel_fn(x, x))(X))
             return values_scattered
     else:
         if config_value('KERNEL_MATRIX_USE_LOOP'):
             with loops.Scope() as s:
                 s.output = np.empty((X.shape[0], Y.shape[0]))
                 for i in s.range(X.shape[0]):
                     x = X[i]
                     s.output = ops.index_update(
                         s.output, i,
                         vmap(lambda y: kernel_fn(x, y))(Y))
             return s.output
         else:
             return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X)
Example #7
0
    def estimate_lower_bound_grad(variational_params, grad_mu, grad_lower):
        grad_mu = jax.tree_map(lambda x: x / num_samples, grad_mu)

        _, std = variational_params
        diagonal = jax.tree_map(lambda L: jnp.diag(jnp.diag(L)), std)
        grad_lower = jax.tree_multimap(
            lambda dL, D, n: dL / num_samples + D[jnp.tril_indices(n)],
            grad_lower, diagonal, nfeatures)
        return grad_mu, grad_lower
 def _kernel_matrix_with_gradients(kernel_fn, theta, X, Y):
     kernel_fn = value_and_grad(kernel_fn)
     kernel_fn = partial(kernel_fn, theta)
     if Y is None or (Y is X):
         if config_value('KERNEL_MATRIX_USE_LOOP'):
             n = len(X)
             with loops.Scope() as s:
                 s.scattered_values = np.empty((n, n))
                 s.scattered_grads = np.empty((n, n, len(theta)))
                 index1, index2 = np.tril_indices(n, k=0)
                 for i in s.range(index1.shape[0]):
                     i1, i2 = index1[i], index2[i]
                     value, grads = kernel_fn(X[i1], X[i2])
                     indexes = (np.stack([i1, i2]), np.stack([i2, i1]))
                     s.scattered_values = ops.index_update(
                         s.scattered_values, indexes, value)
                     s.scattered_grads = ops.index_update(
                         s.scattered_grads, indexes, grads)
             return s.scattered_values, s.scattered_grads
         else:
             n = len(X)
             values_scattered = np.empty((n, n))
             grads_scattered = np.empty((n, n, len(theta)))
             index1, index2 = np.tril_indices(n, k=-1)
             inst1, inst2 = X[index1], X[index2]
             values, grads = vmap(kernel_fn)(inst1, inst2)
             # Scatter computed values into matrix
             values_scattered = ops.index_update(values_scattered,
                                                 (index1, index2), values)
             values_scattered = ops.index_update(values_scattered,
                                                 (index2, index1), values)
             grads_scattered = ops.index_update(grads_scattered,
                                                (index1, index2), grads)
             grads_scattered = ops.index_update(grads_scattered,
                                                (index2, index1), grads)
             diag_values, diag_grads = vmap(lambda x: kernel_fn(x, x))(X)
             diag_indices = np.diag_indices(n)
             values_scattered = ops.index_update(values_scattered,
                                                 diag_indices, diag_values)
             grads_scattered = ops.index_update(grads_scattered,
                                                diag_indices, diag_grads)
             return values_scattered, grads_scattered
     else:
         return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X)
Example #9
0
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n))
Example #10
0
 def update(grads_and_lb, key):
     grad_mu, grad_lower, lower_bound = grads_and_lb
     params, epsilon = sample(key, variational_params)
     grad_logjoint = take_grad(params, data)
     grad_mu = jax.tree_multimap(lambda x, y: x + y.flatten(), grad_mu,
                                 grad_logjoint)
     tmp = jax.tree_multimap(jnp.outer, grad_logjoint, epsilon)
     grad_lower = jax.tree_multimap(
         lambda x, y: x + y[jnp.tril_indices(len(y))], grad_lower, tmp)
     lower_bound = lower_bound + logjoint(params, data)
     return (grad_mu, grad_lower, lower_bound), None
Example #11
0
    def grad(self, eta, phi):
        """ Returns nabla mu and nabla omega """
        zeta = self.inv_S(eta, phi)
        theta = self.inv_T(zeta)

        # compute gradients
        grad_joint = self.grad_joint(theta)
        grad_inv_t = self.jac_T(zeta)
        grad_trans = self.grad_det_J(zeta)

        grad_mu = grad_inv_t @ grad_joint + grad_trans
        # print(grad_μ, η, grad_μ * η, grad_μ * η.T, self.inv_L(ϕ).T)
        grad_L = (grad_mu * eta + self.inv_L(phi).T)[jnp.tril_indices(self.latent_dim)]

        return jnp.append(grad_mu, grad_L)
Example #12
0
def tril_factory(diag, tril):
    """Parameterizes a n x n lower-triangular matrix by its diagonal and
    off-diagonal entries, both presented as vectors.

    Args:
        diag: Vector of the diagonal entries of the matrix.
        tril: Vector of the off-diagonal entries of the matrix.

    Returns:
        L: Lower triangular matrix.

    """
    n = diag.size
    B = jnp.zeros((n, n))
    B = ops.index_update(B, jnp.tril_indices(n, -1), tril)
    L = B + jnp.diag(diag)
    return L
def chol_params_to_lower_triangular_matrix_jax(params):
    dim = number_of_triangular_elements_to_dimension_jax(len(params))
    mat = index_update(jnp.zeros((dim, dim)), index[jnp.tril_indices(dim)], params)
    return mat
def sdcorr_to_internal(external_values):
    """Convert sdcorr to cov and do a cholesky reparametrization."""
    cov = sdcorr_params_to_matrix_jax(external_values)
    chol = jnp.linalg.cholesky(cov)
    return chol[jnp.tril_indices(len(cov))]
def cov_matrix_to_sdcorr_params_jax(cov):
    dim = len(cov)
    sds, corr = cov_to_sds_and_corr_jax(cov)
    correlations = corr[jnp.tril_indices(dim, k=-1)]
    return jnp.hstack([sds, correlations])
def covariance_from_internal(internal_values):
    """Undo a cholesky reparametrization."""
    chol = chol_params_to_lower_triangular_matrix_jax(internal_values)
    cov = chol @ chol.T
    return cov[jnp.tril_indices(len(chol))]
Example #17
0
def vb_gauss_chol(key,
                  loglikelihood_fn,
                  logprior_fn,
                  data,
                  optimizer,
                  mean,
                  lower_triangular=None,
                  num_samples=20,
                  window_size=10,
                  niters=500,
                  eps=0.1,
                  smooth=True):
    '''
    Arguments:
      num_samples : number of Monte Carlo samples,
      mean : prior mean of the distribution family

    '''

    nfeatures = jax.tree_map(lambda x: x.shape[0], mean)

    if lower_triangular is None:
        # initializes the lower triangular matrices
        lower_triangular = jax.tree_map(lambda n: eps * jnp.eye(n), nfeatures)

    # Initialize parameters of the model + optimizer.
    variational_params = (mean, lower_triangular)

    params = (mean,
              jax.tree_multimap(lambda L, n: L[jnp.tril_indices(n)][..., None],
                                lower_triangular, nfeatures))

    opt_state = optimizer.init(params)

    step_fn = make_vb_gauss_chol_fns(loglikelihood_fn, logprior_fn, nfeatures,
                                     num_samples)

    def iter_fn(all_params, key):
        variational_params, params, opt_state = all_params
        grads = init_grads(nfeatures)

        lower_bound = 0
        grads, lower_bound = step_fn(key, variational_params, grads, data)

        grads = jax.tree_map(
            lambda x: x[..., None] if len(x.shape) == 1 else x, grads)
        grads = clip(grads)

        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        mean, std = params
        variational_params = (mean,
                              jax.tree_multimap(lambda s, d: vechinv(s, d),
                                                std, nfeatures))
        cholesky = jax.tree_map(lambda L: jnp.log(jnp.linalg.det(L @ L.T)),
                                variational_params[1])

        lb = jax.tree_multimap(
            lambda chol, n: lower_bound / num_samples + 1 / 2 * chol + n / 2,
            cholesky, nfeatures)

        return (variational_params, params, opt_state), (variational_params,
                                                         lb)

    keys = jax.random.split(key, niters)
    _, (variational_params,
        lower_bounds) = jax.lax.scan(iter_fn,
                                     (variational_params, params, opt_state),
                                     keys)
    lower_bounds = jax.tree_leaves(lower_bounds)[0]

    if smooth:

        def simple_moving_average(cur_sum, i):
            diff = (lower_bounds[i] -
                    lower_bounds[i - window_size]) / window_size
            cur_sum += diff
            return cur_sum, cur_sum

        indices = jnp.arange(window_size, niters)
        cur_sum = jnp.sum(lower_bounds[:window_size]) / window_size
        _, lower_bounds = jax.lax.scan(simple_moving_average, cur_sum, indices)
        lower_bounds = jnp.append(jnp.array([cur_sum]), lower_bounds)

    i = jnp.argmax(lower_bounds) + window_size - 1 if smooth else jnp.argmax(
        lower_bounds)
    best_params = jax.tree_map(lambda x: x[i], variational_params)

    return best_params, lower_bounds
Example #18
0
def vechinv(v, d):
    X = jnp.zeros((d, d))
    X = ops.index_update(X, jnp.tril_indices(d, k=0), v.squeeze())
    return X
Example #19
0
 def L(self, phi: jnp.DeviceArray) -> jnp.DeviceArray:
     L = jnp.zeros((self.latent_dim, self.latent_dim))
     L = index_update(L, jnp.tril_indices(self.latent_dim), phi[self.latent_dim :])
     return L
Example #20
0
def flatten_scale(scale):
    dim = scale.shape[-1]
    log_diag = jnp.log(jnp.diag(scale))
    scale = scale.at[jnp.diag_indices(dim)].set(log_diag)
    return scale[jnp.tril_indices(dim)]
Example #21
0
def pair_vectors(vs):
  num_vs, _ = vs.shape
  expanded_v = jnp.tile(vs[:, jnp.newaxis, :], [1, num_vs, 1])
  matrix = jnp.concatenate(
      [expanded_v, jnp.transpose(expanded_v, axes=[1, 0, 2])], axis=2)
  return matrix[jnp.tril_indices(num_vs, k=-1)]
def covariance_to_internal(external_values):
    """Do a cholesky reparametrization."""
    cov = cov_params_to_matrix_jax(external_values)
    chol = jnp.linalg.cholesky(cov)
    return chol[jnp.tril_indices(len(cov))]
def cov_matrix_to_params_jax(cov):
    return cov[jnp.tril_indices(len(cov))]
Example #24
0
def vec_to_tril(vec):
    d = int((np.sqrt(1 + 8 * vec.shape[0]) - 1) / 2)
    idx_l = jnp.tril_indices(d)
    L = jnp.zeros((d, d), dtype=jnp.float64)
    return ops.index_update(L, idx_l, vec)
Example #25
0
def matrix_to_tril_vec(x, diagonal=0):
    idxs = np.tril_indices(x.shape[-1], diagonal)
    return x[..., idxs[0], idxs[1]]
Example #26
0
def tril_to_vec(L):
    d = L.shape[0]
    idx_l = jnp.tril_indices(d)
    return L[idx_l]
Example #27
0
def unflatten_scale(flat_scale, original_dim):
    out = jnp.zeros([original_dim, original_dim], dtype=flat_scale.dtype)
    out = out.at[jnp.tril_indices(original_dim)].set(flat_scale)
    exp_diag = jnp.exp(jnp.diag(out))
    return out.at[jnp.diag_indices(original_dim)].set(exp_diag)
Example #28
0
#     k += 1
#     old_f0 = f0
#     toc_it = time()
#     lls.append(f0)
#     grs.append((gr_mu_norm, gr_sig_norm))

# toc = time()
# spent_riem = toc - tic
# lls = jnp.array(lls)
########################################

########################################
## Cholesky gradient descent
tic = time()
init_chol = jnp.append(
    jnp.linalg.cholesky(startsig)[jnp.tril_indices(p)], startmu)
gra_chol = jit(grad(func_chol))

chol_fun = [func_chol(init_chol)]
chol_gra = [jnp.linalg.norm(gra_chol(init_chol))]


def store(X):
    chol_fun.append(func_chol(X))
    chol_gra.append(jnp.linalg.norm(gra_chol(X)))


res = minimize(func_chol,
               init_chol,
               method='newton-cg',
               jac=gra_chol,
Example #29
0
def func_chol(x):
    chol, mu = x[:-p], x[-p:]
    sig = index_update(jnp.zeros(shape=(p, p)), jnp.tril_indices(p), chol)
    sig = jnp.einsum('ij,kj', sig, sig)
    return func(mu, sig)
Example #30
0
def run_optim(key: np.ndarray, lhs: np.ndarray, tmp: np.ndarray,
              xhats: np.ndarray, tmp_c: np.ndarray, xhats_c: np.ndarray,
              xstar: float, bound: Text, out_dir: Text, x: np.ndarray,
              y: np.ndarray) -> Tuple[int, float, float, int, float, float]:
    """Run optimization (either lower or upper) for a single xstar."""
    # Directory setup
    # ---------------------------------------------------------------------------
    out_dir = os.path.join(out_dir, f"{bound}-xstar_{xstar}")
    if FLAGS.store_data:
        logging.info(f"Current run output directory: {out_dir}...")
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

    # Init optim params
    # ---------------------------------------------------------------------------
    logging.info(
        f"Initialize parameters L, mu, log_sigma, lmbda, tau, slack...")
    key, subkey = random.split(key)
    params = init_params(subkey)

    for parname, param in zip(['L', 'mu', 'log_sigma'], params):
        logging.info(f"Parameter {parname}: {param.shape}")
        logging.info(f"  -> {parname}: {param}")

    tau = FLAGS.tau_init
    logging.info(f"Initial tau = {tau}")
    fin_tau = np.minimum(FLAGS.tau_factor**FLAGS.num_rounds * tau,
                         FLAGS.tau_max)
    logging.info(f"Final tau = {fin_tau}")

    # Set constraint approach and slacks
    # ---------------------------------------------------------------------------
    slack = FLAGS.slack * np.ones(FLAGS.num_z * 2)
    lmbda = np.zeros(FLAGS.num_z * 2)
    logging.info(f"Lambdas: {lmbda.shape}")

    logging.info(
        f"Fractional tolerance (slack) for constraints = {FLAGS.slack}")
    logging.info(f"Set relative slack variables...")
    slack *= np.abs(lhs.ravel())
    logging.info(f"Set minimum slack to {FLAGS.slack_abs}...")
    slack = np.maximum(FLAGS.slack_abs, slack)
    logging.info(f"Slack {slack.shape}")
    logging.info(f"Actual slack min: {np.min(slack)}, max: {np.max(slack)}")

    # Setup optimizer
    # ---------------------------------------------------------------------------
    logging.info(f"Vanilla SGD with init_lr={FLAGS.lr}...")
    logging.info(f"Set learning rate schedule")
    step_size = optim.inverse_time_decay(FLAGS.lr, FLAGS.decay_steps,
                                         FLAGS.decay_rate, FLAGS.staircase)
    init_fun, update_fun, get_params = optim.sgd(step_size)

    logging.info(
        f"Init state for JAX optimizer (including L, mu, log_sigma)...")
    state = init_fun(params)

    # Setup result dict
    # ---------------------------------------------------------------------------
    logging.info(f"Initialize dictionary for results...")
    results = {
        "mu": [],
        "sigma": [],
        "cholesky_factor": [],
        "tau": [],
        "lambda": [],
        "objective": [],
        "constraint_term": [],
        "rhs": []
    }
    if FLAGS.plot_intermediate:
        results["grad_norms"] = []
        results["lagrangian"] = []

    logging.info(f"Evaluate at xstar={xstar}...")

    logging.info(f"Evaluate {bound} bound...")
    sign = 1 if bound == "lower" else -1

    # ===========================================================================
    # OPTIMIZATION LOOP
    # ===========================================================================
    # One-time logging before first step
    # ---------------------------------------------------------------------------
    key, subkey = random.split(key)
    obj, rhs, psisum, constr = objective_rhs_psisum_constr(
        subkey, get_params(state), lmbda, tau, lhs, slack, xstar, tmp_c,
        xhats_c)
    results["objective"].append(obj)
    results["constraint_term"].append(psisum)
    results["rhs"].append(rhs)

    logging.info(f"Objective: scalar")
    logging.info(f"RHS: {rhs.shape}")
    logging.info(f"Sum over Psis: scalar")
    logging.info(f"Constraint: {constr.shape}")

    tril_idx = np.tril_indices(FLAGS.dim_theta + 1)
    count = 0
    logging.info(f"Start optimization loop...")
    for _ in tqdm(range(FLAGS.num_rounds)):
        # log current parameters
        # -------------------------------------------------------------------------
        results["lambda"].append(lmbda)
        results["tau"].append(tau)
        cur_L, cur_mu, cur_logsigma = get_params(state)
        cur_chol = make_cholesky_factor(cur_L)[tril_idx].ravel()[1:]
        results["mu"].append(cur_mu)
        results["sigma"].append(np.exp(cur_logsigma))
        results["cholesky_factor"].append(cur_chol)

        subkeys = random.split(key, num=FLAGS.opt_steps + 1)
        key = subkeys[0]
        # inner optimization for subproblem
        # -------------------------------------------------------------------------
        for j in range(FLAGS.opt_steps):
            v, grads = lagrangian_value_and_grad(subkeys[j + 1],
                                                 get_params(state), lmbda, tau,
                                                 lhs, slack, xstar, tmp, xhats,
                                                 sign)
            state = update_fun(count, grads, state)
            count += 1
            if FLAGS.plot_intermediate:
                results["lagrangian"].append(v)
                results["grad_norms"].append(
                    [np.linalg.norm(grad) for grad in grads])

        # post inner optimization logging
        # -------------------------------------------------------------------------
        key, subkey = random.split(key)
        obj, rhs, psisum, constr = objective_rhs_psisum_constr(
            subkey, get_params(state), lmbda, tau, lhs, slack, xstar, tmp_c,
            xhats_c)
        results["objective"].append(obj)
        results["constraint_term"].append(psisum)
        results["rhs"].append(rhs)

        # update lambda, tau
        # -------------------------------------------------------------------------
        lmbda = update_lambda(constr, lmbda, tau)
        tau = np.minimum(tau * FLAGS.tau_factor, FLAGS.tau_max)

    # Convert and store results
    # ---------------------------------------------------------------------------
    logging.info(f"Finished optimization loop...")

    logging.info(f"Convert all results to numpy arrays...")
    results = {k: np.array(v) for k, v in results.items()}

    logging.info(f"Add final parameters and lhs to results...")
    L, mu, log_sigma = get_params(state)
    results["final_L"] = L
    results["final_mu"] = mu
    results["final_log_sigma"] = log_sigma
    results["lhs"] = lhs

    if FLAGS.store_data:
        logging.info(f"Save result data to...")
        result_path = os.path.join(out_dir, "results.npz")
        onp.savez(result_path, **results)

    # Generate and store plots
    # ---------------------------------------------------------------------------
    if FLAGS.plot_intermediate:
        fig_dir = os.path.join(out_dir, "figures")
        logging.info(f"Generate and save all plots at {fig_dir}...")
        plotting.plot_all(results, x, y, response, fig_dir)

    # Compute last valid and last satisfied
    # ---------------------------------------------------------------------------
    maxabsdiff = np.array([np.max(np.abs(lhs - r)) for r in results["rhs"]])
    fin_i = np.sum(~np.isnan(results["objective"])) - 1
    logging.info(f"Final non-nan objective at {fin_i}.")
    fin_obj = results["objective"][fin_i]
    fin_maxabsdiff = maxabsdiff[fin_i]

    sat_i = [
        np.all((np.abs((lhs - r) / lhs) < FLAGS.slack)
               | (np.abs(lhs - r) < FLAGS.slack_abs)) for r in results["rhs"]
    ]
    sat_i = np.where(sat_i)[0]

    if len(sat_i) > 0:
        sat_i = sat_i[-1]
        logging.info(f"Final satisfied constraint at {sat_i}.")
        sat_obj = results["objective"][sat_i]
        sat_maxabsdiff = maxabsdiff[sat_i]
    else:
        sat_i = -1
        logging.info(f"Constraints were never satisfied.")
        sat_obj, sat_maxabsdiff = np.nan, np.nan

    logging.info("Finished run.")
    return fin_i, fin_obj, fin_maxabsdiff, sat_i, sat_obj, sat_maxabsdiff