Example #1
0
def cov_estimate(*, optimization_path: Sequence[jnp.ndarray],
                 optimization_path_grads: Sequence[jnp.ndarray], history: int):
    """Estimate covariance from an optimization path."""
    dim = optimization_path[0].shape[0]
    position_diffs = jnp.empty((dim, 0))
    gradient_diffs = jnp.empty((dim, 0))
    approximations: List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] = []
    diagonal_estimate = jnp.ones(dim)
    for j in range(len(optimization_path) - 1):
        _, thin_factors, scaling_outer_product = bfgs_inverse_hessian(
            updates_of_position_differences=position_diffs,
            updates_of_gradient_differences=gradient_diffs,
        )
        position_diff = optimization_path[j + 1] - optimization_path[j]
        gradient_diff = optimization_path_grads[j] - optimization_path_grads[
            j + 1]
        b = position_diff @ gradient_diff
        gradient_diff_norm = gradient_diff**2
        new_diagonal_estimate = diagonal_estimate
        if b < 1e-12 * jnp.sum(gradient_diff_norm):
            position_diffs = jnp.column_stack(
                (position_diffs[:, -history + 1:], position_diff))
            gradient_diffs = jnp.column_stack(
                (gradient_diffs[:, -history + 1:], gradient_diff))
            a = gradient_diff @ (diagonal_estimate * gradient_diff)
            c = position_diff @ (position_diff / diagonal_estimate)
            new_diagonal_estimate = 1.0 / (a / (b * diagonal_estimate) +
                                           gradient_diff_norm / b -
                                           (a * position_diff**2) /
                                           (b * c * diagonal_estimate**2))
        approximations.append(
            (diagonal_estimate, thin_factors, scaling_outer_product))
        diagonal_estimate = new_diagonal_estimate
    return approximations
Example #2
0
 def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
     """Create an empty COO instance. Public method is sparse.empty()."""
     shape = tuple(shape)
     if len(shape) != 2:
         raise ValueError(f"COO must have ndim=2; got shape={shape}")
     data = jnp.empty(0, dtype)
     row = col = jnp.empty(0, index_dtype)
     return cls((data, row, col), shape=shape)
Example #3
0
File: csr.py Project: jbampton/jax
 def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
     """Create an empty CSR instance. Public method is sparse.empty()."""
     shape = tuple(shape)
     if len(shape) != 2:
         raise ValueError(f"CSR must have ndim=2; got shape={shape}")
     data = jnp.empty(0, dtype)
     indices = jnp.empty(0, index_dtype)
     indptr = jnp.zeros(shape[0] + 1, index_dtype)
     return cls((data, indices, indptr), shape=shape)
Example #4
0
def lbfgs(
    *,
    log_target_density: Callable[[jnp.ndarray], jnp.ndarray],
    initial_value: jnp.ndarray,  # theta_init
    inverse_hessian_history: int = 6,  # J
    relative_tolerance: float = 1e-13,  # tau_rel
    max_iters: int = 1000,  # L
    wolfe_bounds: Tuple[float, float] = (1e-4, 0.9),
    positivity_threshold: float = 2.2e-16,
):
    """LBFGS implementation which returns the optimization path and gradients."""
    dim = initial_value.shape[0]
    grad_log_density = jax.grad(log_target_density)
    optimization_path = [initial_value]
    current_lp = log_target_density(initial_value)
    grad_optimization_path = [grad_log_density(initial_value)]
    position_diffs = jnp.empty((dim, 0))
    gradient_diffs = jnp.empty((dim, 0))
    for _ in range(max_iters):
        diagonal_estimate, thin_factors, scaling_outer_product = bfgs_inverse_hessian(
            updates_of_position_differences=position_diffs,
            updates_of_gradient_differences=gradient_diffs,
        )
        grad_lp = grad_optimization_path[-1]
        search_direction = diagonal_estimate * grad_lp + thin_factors @ (
            scaling_outer_product @ (jnp.transpose(thin_factors) @ grad_lp))
        step_size = 1.0
        while step_size > 1e-8:
            proposed = optimization_path[-1] + step_size * search_direction
            proposed_lp = log_target_density(proposed)
            if proposed_lp >= current_lp + (wolfe_bounds[0] * grad_lp) @ (
                    step_size * search_direction):
                proposed_grad = grad_log_density(proposed)
                if (proposed_grad @ search_direction <=
                        wolfe_bounds[1] * grad_lp @ search_direction):
                    break
            step_size = 0.5 * step_size
        optimization_path.append(proposed)
        grad_optimization_path.append(proposed_grad)
        if (proposed_lp -
                current_lp) / jnp.abs(current_lp) < relative_tolerance:
            return optimization_path, grad_optimization_path
        current_lp = proposed_lp

        position_diff: jnp.ndarray = optimization_path[-1] - optimization_path[
            -2]
        grad_diff = -grad_optimization_path[-1] + grad_optimization_path[-2]
        if position_diff @ grad_diff > positivity_threshold * jnp.sum(grad_diff
                                                                      **2):
            position_diffs = jnp.column_stack(
                (position_diffs[:,
                                -inverse_hessian_history + 1:], position_diff))
            gradient_diffs = jnp.column_stack(
                (gradient_diffs[:, -inverse_hessian_history + 1:], grad_diff))
    return optimization_path, grad_optimization_path
Example #5
0
 def _reset(self):
     self._t = 0
     self.trajectory = Trajectory(
         observations=jnp.empty(self.n_steps + 1, self.batch_size,
                                *self.observation_spec.shape),
         actions=jnp.empty(self.n_steps, self.batch_size, 1),
         rewards=jnp.empty(self.n_steps, self.batch_size, 1),
         discounts=jnp.empty(self.n_steps, self.batch_size, 1),
         trace_decays=jnp.empty(self.n_steps, self.batch_size, 1),
     )
     return
Example #6
0
 def new(cls, in_features: int, out_features: int, use_bias=True):
     weight = jnp.empty([out_features, in_features])
     if use_bias:
         bias = jnp.empty([out_features])
     else:
         bias = None
     return cls(in_features=in_features,
                out_features=out_features,
                use_bias=use_bias,
                weight=weight,
                bias=bias)
Example #7
0
    def reset(self):
        """Reset the state."""
        self.curr_cycle = 0
        self.past_groups = np.empty((0, self.num_patients), dtype=bool)
        self.past_test_results = np.empty((0, ), dtype=bool)
        self.groups_to_test = np.empty((0, self.num_patients), dtype=bool)

        # Those are specific to some methods. They are not always used or filled.
        self.particle_weights = None
        self.particles = None
        self.to_clear_positives = np.empty((0, ), dtype=bool)
        self.all_cleared = False

        # In case we store marginals computed in different ways.
        self.marginals = {}
Example #8
0
 def apply_fun(params, inputs, **kwargs):
     input_size = inputs.shape[1]
     outputs = jnp.empty((inputs.shape[0], 2 * input_size- 1, inputs.shape[2]), dtype=jnp.complex128)
     outputs = jax.ops.index_update(outputs, jax.ops.index[:, 0:input_size, :], inputs[:, :, :])
     outputs = jax.ops.index_update(outputs, jax.ops.index[:, input_size:2 * inputs.shape[1] - 1, :],
                                    inputs[:, 0:input_size - 1, :])
     return outputs
Example #9
0
 def apply_fun(params, inputs, **kwargs):
     num_channels = inputs.shape[2]
     input_size = inputs.shape[1]
     outputs = jnp.empty((inputs.shape[0], input_size*num_channels), dtype=jnp.complex128)
     for i in range(0, num_channels):
         outputs = jax.ops.index_update(outputs, jax.ops.index[:, i*input_size:(i+1)*input_size], inputs[:, :, i])
     return outputs
Example #10
0
    def run(self,
            rng_key,
            num_steps,
            *args,
            return_last=True,
            progbar=True,
            **kwargs):
        def bodyfn(i, info):
            svgd_state, losses = info
            svgd_state, loss = self.update(svgd_state, *args, **kwargs)
            losses = ops.index_update(losses, i, loss)
            return svgd_state, losses

        svgd_state = self.init(rng_key, *args, **kwargs)
        losses = np.empty((num_steps, ))
        if not progbar:
            svgd_state, losses = fori_loop(0, num_steps, bodyfn,
                                           (svgd_state, losses))
        else:
            with tqdm.trange(num_steps) as t:
                for i in t:
                    svgd_state, losses = jax.jit(bodyfn)(i,
                                                         (svgd_state, losses))
                    t.set_description('SVGD {:.5}'.format(losses[i]),
                                      refresh=False)
                    t.update()
        loss_res = losses[-1] if return_last else losses
        return svgd_state, loss_res
Example #11
0
def _ndim_coords_from_arrays(points, ndim):
    """
    Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.
    """
    if isinstance(points, tuple) and len(points) == 1:
        # handle argument tuple
        points = points[0]
    if isinstance(points, tuple):
        p = jnp.broadcast_arrays(*points)
        n = len(p)
        for j in range(1, n):
            if p[j].shape != p[0].shape:
                raise ValueError(
                    "coordinate arrays do not have the same shape")
        points = jnp.empty(p[0].shape + (len(points), ), dtype=float)
        for j, item in enumerate(p):
            points[..., j] = item
    else:
        points = jnp.asarray(points)
        if points.ndim == 1:
            if ndim is None:
                points = points.reshape(-1, 1)
            else:
                points = points.reshape(-1, ndim)
    return points
Example #12
0
  def get_groups(self, rng, state):
    """Produces random design matrix fixed number of 1s per line.

    Args:
     rng: np.ndarray<int>[2]: the random key.
     state: the current state.State of the system.

    Returns:
     A np.array<bool>[num_groups, patients].
    """
    if self.group_size is None:
      # if no size has been defined, we compute it adaptively
      # in the simple case where prior is uniform.
      if np.size(state.prior_infection_rate) == 1:
        group_size = np.ceil(
            (np.log(state.prior_sensitivity - .5) -
             np.log(state.prior_sensitivity + state.prior_specificity - 1)) /
            np.log(1 - state.prior_infection_rate))
        group_size = np.minimum(group_size, state.max_group_size)
      # if prior is not uniform, pick max size.
      else:
        group_size = self.max_group_size
    else:
      group_size = self.group_size
    group_size = int(np.squeeze(group_size))
    new_groups = np.empty((0, state.num_patients), dtype=bool)
    for _ in range(state.extra_tests_needed):
      rng, rng_shuffle = jax.random.split(rng, 2)
      vec = np.zeros((1, state.num_patients), dtype=bool)
      idx = jax.random.permutation(rng_shuffle, np.arange(state.num_patients))
      vec = jax.ops.index_update(vec, [0, idx[0:group_size]], True)
      new_groups = np.concatenate((new_groups, vec), axis=0)
    return new_groups
Example #13
0
def dtoq_reyes(data):
    qubit_tensor = np.empty([0, 2])
    for i in range(len(data)):
        x1 = 1 / math.sqrt(data[i]**2 + 1.0)
        x2 = data[i] / math.sqrt(data[i]**2 + 1.0)
        qubit_tensor = np.vstack((qubit_tensor, np.array([x1, x2])))
    return qubit_tensor
Example #14
0
def evolution_pepo_imag_time(g: float,
                             dt: float,
                             bc: str,
                             dtype: np.dtype,
                             lx: Optional[int] = None,
                             ly: Optional[int] = None) -> Operator:
    # PEPO for U(dt) ~ U_vert(dt/2) U_bond(dt) U_vert(dt/2)
    #
    # half bond operators:
    #
    #      |    |           |    |
    #      U_bond     =     A -- A
    #      |    |           |    |
    #
    # expm(- H_bond dt) = expm(- (-XX) dt) = expm(dt XX) = cosh(dt) + sinh(dt) XX = A_0 A_0 + A_1 A_1
    # with A_0 = (cosh(dt) ** 0.5) * 1  ,  A_1 = (sinh(dt) ** 0.5) * X
    # A & B legs: (p,p*,k)

    A = np.empty([2, 2, 2], dtype=dtype)
    A = index_update(A, index[:, :, 0], (np.cosh(dt)**0.5) * s0)
    A = index_update(A, index[:, :, 1], (np.sinh(dt)**0.5) * sx)
    # expm(- H_vert dt/2) = expm(- (-gZ) dt/2) = expm(g dt/2 Z)
    u_vert = np.asarray(expm(g * (dt / 2) * sz), dtype=dtype)

    return _build_evolution_pepo(u_vert, A, bc, lx, ly)
Example #15
0
 def __do_rank_regression(self):
     f = jnp.hstack((jnp.atleast_2d(self.failures).T,
                     jnp.zeros((self.failures.shape[0], 1))))
     f = f[f[:, 0].argsort()]
     f = jnp.hstack((f,
                     jnp.reshape(jnp.arange(self.failures.shape[0]),
                                 (self.failures.shape[0], -1))))
     # censored items will be having flag '1'
     c = jnp.hstack((jnp.atleast_2d(self.censored).T,
                     jnp.ones((self.censored.shape[0], 1))))
     c = jnp.hstack((c,
                     jnp.reshape(jnp.empty(self.censored.shape[0]),
                                 (self.censored.shape[0], -1))))
     d = jnp.concatenate((c, f), axis=0)
     d = d[d[:, 0].argsort()]
     df = pd.DataFrame(data=d, columns=['time', 'is_cens', 'fo'])
     self.N = len(df.index)
     df['new_increment'] = (self.N + 1 - df['fo']) / (self.N + 2 -
                                                      df.index.values)
     m = 1.0 - df['new_increment'].min()
     df['new_increment'] = df['new_increment'] + m
     df = df.drop(df[df['is_cens'] == 1].index)
     df['new_order_num'] = df['new_increment'].cumsum()
     df['cdf'] = util.median_rank(self.N, df['new_order_num'], 0.5)
     self.est_data = df
  def get_groups(self, rng, state):
    """A greedy forward-backward algorithm to pick groups with large utility."""
    particle_weights, particles = mutual_information.collapse_particles(
        rng, state.particle_weights, state.particles)
    n_patients = particles.shape[1]
    iterations = [self.forward_iterations, self.backward_iterations]

    chosen_groups = np.empty((0, n_patients), dtype=bool)
    added_groups_counter = 0
    while added_groups_counter < state.extra_tests_needed:
      # start forming a new group, and improve it greedily
      proposed_group = np.zeros((n_patients,), dtype=bool)
      obj_old = -1
      while np.sum(proposed_group) < state.max_group_size:
        for steps, backtrack in zip(iterations, [False, True]):
          for _ in range(steps):
            # Extract candidate with largest utility
            proposed_group, obj_new = next_best_group(particle_weights,
                                                      particles,
                                                      chosen_groups,
                                                      proposed_group,
                                                      state.prior_sensitivity,
                                                      state.prior_specificity,
                                                      self.utility_fn,
                                                      backtracking=backtrack)
            if obj_new > obj_old + 1e-6:
              cur_group = proposed_group
              obj_old = obj_new
            else:
              break
      # stop adding, form next group
      chosen_groups = np.concatenate((chosen_groups, cur_group[np.newaxis, :]),
                                     axis=0)
      added_groups_counter += 1
    return chosen_groups
Example #17
0
    def get_parameters(self):
        """Get variational parameters.
        
        Returns:
            Array holding current values of all variational parameters.
        """

        if self.realNets: # FOR REAL NETS

            paramOut = jnp.empty(self.numParameters, dtype=global_defs.tReal)

            start = 0
            for netId in [0,1]:
                parameters, _ = tree_flatten( self.net[netId].params )
                # Flatten parameters to give a single vector
                for p in parameters:
                    numParams = p.size
                    paramOut = jax.ops.index_update( paramOut, jax.ops.index[start:start+numParams], p.reshape(-1) )
                    start += numParams

            return paramOut

        else:             # FOR COMPLEX NET

            paramOut = jnp.concatenate([p.ravel() for p in tree_flatten(self.net.params)[0]])

            if self.holomorphic:
                paramOut = jnp.concatenate([paramOut.real, paramOut.imag])            

            return paramOut
Example #18
0
def _triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):  # pylint: disable=redefined-outer-name
    """Scipy solve does not broadcast, so we must do so explicitly."""
    del name
    if JAX_MODE:  # But JAX uses XLA, which can do a batched solve.
        matrix = matrix + np.zeros(rhs.shape[:-2] + (1, 1), dtype=matrix.dtype)
        rhs = rhs + np.zeros(matrix.shape[:-2] + (1, 1), dtype=rhs.dtype)
        return scipy_linalg.solve_triangular(matrix,
                                             rhs,
                                             lower=lower,
                                             trans='C' if adjoint else 'N')
    try:
        bcast = onp.broadcast(matrix[..., :1], rhs)
    except ValueError as e:
        raise ValueError(
            'Error with inputs shaped `matrix`={}, rhs={}:\n{}'.format(
                matrix.shape, rhs.shape, str(e)))
    dim = matrix.shape[-1]
    matrix = onp.broadcast_to(matrix, bcast.shape[:-1] + (dim, ))
    rhs = onp.broadcast_to(rhs, bcast.shape)
    nbatch = int(np.prod(matrix.shape[:-2]))
    flat_mat = matrix.reshape(nbatch, dim, dim)
    flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
    result = np.empty(flat_rhs.shape)
    if np.size(result):
        # ValueError: On entry to STRTRS parameter number 7 had an illegal value.
        for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)):
            result[i] = scipy_linalg.solve_triangular(
                mat, rh, lower=lower, trans='C' if adjoint else 'N')
    return result.reshape(*rhs.shape)
Example #19
0
 def _indices(key):
     if not sparse_shape:
         return jnp.empty((nse, n_sparse), dtype=int)
     flat_ind = random.choice(key,
                              sparse_size,
                              shape=(nse, ),
                              replace=not unique_indices)
     return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
 def _kernel_matrix_without_gradients(kernel_fn, theta, X, Y):
     kernel_fn = partial(kernel_fn, theta)
     if Y is None or (Y is X):
         if config_value('KERNEL_MATRIX_USE_LOOP'):
             n = len(X)
             with loops.Scope() as s:
                 # s.scattered_values = np.empty((n, n))
                 s.index1, s.index2 = np.tril_indices(n, k=0)
                 s.output = np.empty(len(s.index1))
                 for i in s.range(s.index1.shape[0]):
                     i1, i2 = s.index1[i], s.index2[i]
                     s.output = ops.index_update(s.output, i,
                                                 kernel_fn(X[i1], X[i2]))
             first_update = ops.index_update(np.empty((n, n)),
                                             (s.index1, s.index2), s.output)
             second_update = ops.index_update(first_update,
                                              (s.index2, s.index1),
                                              s.output)
             return second_update
         else:
             n = len(X)
             values_scattered = np.empty((n, n))
             index1, index2 = np.tril_indices(n, k=-1)
             inst1, inst2 = X[index1], X[index2]
             values = vmap(kernel_fn)(inst1, inst2)
             values_scattered = ops.index_update(values_scattered,
                                                 (index1, index2), values)
             values_scattered = ops.index_update(values_scattered,
                                                 (index2, index1), values)
             values_scattered = ops.index_update(
                 values_scattered, np.diag_indices(n),
                 vmap(lambda x: kernel_fn(x, x))(X))
             return values_scattered
     else:
         if config_value('KERNEL_MATRIX_USE_LOOP'):
             with loops.Scope() as s:
                 s.output = np.empty((X.shape[0], Y.shape[0]))
                 for i in s.range(X.shape[0]):
                     x = X[i]
                     s.output = ops.index_update(
                         s.output, i,
                         vmap(lambda y: kernel_fn(x, y))(Y))
             return s.output
         else:
             return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X)
def cost_func_jvp(bb, u):
    n = bb.size
    directmat = jnp.empty([0])
    for i in range(n):
        seed = jnp.zeros(n)
        seed = jax.ops.index_update(seed, jax.ops.index[i], 1)
        primal, res = jax.jvp(cost_func, (bb, u), (seed, jnp.zeros(n + 1)))
        directmat = jnp.append(directmat, res)
    return directmat
Example #22
0
 def apply_fun(params, inputs, **kwargs):
     if(len(inputs.shape) ==1):
         second_shape = inputs.shape[0]
         first_shape = 1
         outputs = jnp.empty((first_shape, second_shape), dtype=jnp.complex128)
         outputs = jax.ops.index_update(outputs, jax.ops.index[0, :], inputs[:])
     else:
         outputs = inputs
     return outputs
Example #23
0
 def __init__(
     self,
     indices: Tuple[int, ...],
 ) -> None:
     assert len(indices) >= 2, indices
     indices_sorted = sorted(indices[:-1])
     indices_sorted.append(indices[-1])
     self.indices = jnp.array(indices_sorted, dtype=jnp.int32)
     self.ncvecs = jnp.empty((0, 3), dtype=jnp.int32)
 def _kernel_matrix_with_gradients(kernel_fn, theta, X, Y):
     kernel_fn = value_and_grad(kernel_fn)
     kernel_fn = partial(kernel_fn, theta)
     if Y is None or (Y is X):
         if config_value('KERNEL_MATRIX_USE_LOOP'):
             n = len(X)
             with loops.Scope() as s:
                 s.scattered_values = np.empty((n, n))
                 s.scattered_grads = np.empty((n, n, len(theta)))
                 index1, index2 = np.tril_indices(n, k=0)
                 for i in s.range(index1.shape[0]):
                     i1, i2 = index1[i], index2[i]
                     value, grads = kernel_fn(X[i1], X[i2])
                     indexes = (np.stack([i1, i2]), np.stack([i2, i1]))
                     s.scattered_values = ops.index_update(
                         s.scattered_values, indexes, value)
                     s.scattered_grads = ops.index_update(
                         s.scattered_grads, indexes, grads)
             return s.scattered_values, s.scattered_grads
         else:
             n = len(X)
             values_scattered = np.empty((n, n))
             grads_scattered = np.empty((n, n, len(theta)))
             index1, index2 = np.tril_indices(n, k=-1)
             inst1, inst2 = X[index1], X[index2]
             values, grads = vmap(kernel_fn)(inst1, inst2)
             # Scatter computed values into matrix
             values_scattered = ops.index_update(values_scattered,
                                                 (index1, index2), values)
             values_scattered = ops.index_update(values_scattered,
                                                 (index2, index1), values)
             grads_scattered = ops.index_update(grads_scattered,
                                                (index1, index2), grads)
             grads_scattered = ops.index_update(grads_scattered,
                                                (index2, index1), grads)
             diag_values, diag_grads = vmap(lambda x: kernel_fn(x, x))(X)
             diag_indices = np.diag_indices(n)
             values_scattered = ops.index_update(values_scattered,
                                                 diag_indices, diag_values)
             grads_scattered = ops.index_update(grads_scattered,
                                                diag_indices, diag_grads)
             return values_scattered, grads_scattered
     else:
         return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X)
Example #25
0
    def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den):
        conc = jnp.broadcast_to(conc, shape)
        eig = jnp.broadcast_to(eig, shape)
        b0 = jnp.broadcast_to(b0, shape)
        eigmin = jnp.broadcast_to(eigmin, shape)
        phi_den = jnp.broadcast_to(phi_den, shape)

        def update_fn(curr):
            i, done, phi, key = curr
            phi_key, key = random.split(key)
            accept_key, acg_key, phi_key = random.split(phi_key, 3)

            x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape)
            x /= jnp.linalg.norm(
                x, axis=1, keepdims=True
            )  # Angular Central Gaussian distribution

            lf = (
                conc[:, :1] * (x[:, :1] - 1)
                + eigmin
                + log_I1(
                    0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2)
                ).squeeze(0)
                - phi_den
            )
            assert lf.shape == shape

            lg_inv = (
                1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True))
            )
            assert lg_inv.shape == lf.shape

            accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv)

            phi = jnp.where(accepted, x, phi)
            return PhiMarginalState(i + 1, done | accepted, phi, key)

        def cond_fn(curr):
            return jnp.bitwise_and(
                curr.i < SineBivariateVonMises.max_sample_iter,
                jnp.logical_not(jnp.all(curr.done)),
            )

        phi_state = while_loop(
            cond_fn,
            update_fn,
            PhiMarginalState(
                i=jnp.array(0),
                done=jnp.zeros(shape, dtype=bool),
                phi=jnp.empty(shape, dtype=float),
                key=rng_key,
            ),
        )
        return PhiMarginalState(
            phi_state.i, phi_state.done, phi_state.phi, phi_state.key
        )
  def __call__(self, rng, state):
    """Produces new groups and adds them to state's stack."""
    p_weights, particles = state.particle_weights, state.particles
    marginal = onp.array(np.sum(p_weights[:, np.newaxis] * particles, axis=0))
    marginal = onp.squeeze(marginal)
    not_cut_ids, = onp.where(np.logical_and(
        marginal < self.cut_off_high, marginal > self.cut_off_low))
    marginal = marginal[not_cut_ids]
    sorted_ids = onp.argsort(marginal)
    sorted_marginal = onp.array(marginal[sorted_ids])
    n_p = 0
    n_r = marginal.size
    if n_r == 0:  # no one left to test in between thresholds
      state.all_cleared = True
      return state

    all_new_groups = np.empty((0, state.num_patients), dtype=bool)
    while n_p < marginal.size:
      index_max = onp.amin((n_r, state.max_group_size))
      group_sizes = onp.arange(1, index_max + 1)
      cum_prod_prob = onp.cumprod(1 - sorted_marginal[n_p:(n_p + index_max)])
      # formula below is only valid for group_size > 1,
      # corrected below for a group of size 1.
      sensitivity = onp.array(
          utils.select_from_sizes(state.prior_sensitivity, group_sizes))
      specificity = onp.array(
          utils.select_from_sizes(state.prior_specificity, group_sizes))

      exp_div_size = (
          1 + group_sizes *
          (sensitivity + (1 - sensitivity - specificity) * cum_prod_prob)
          ) / group_sizes
      exp_div_size[0] = 1  # adjusted cost for one patient is one.
      opt_size_group = onp.argmin(exp_div_size) + 1
      new_group = onp.zeros((1, state.num_patients))
      new_group[0, not_cut_ids[sorted_ids[n_p:n_p + opt_size_group]]] = True
      all_new_groups = np.concatenate((all_new_groups, new_group), axis=0)
      n_p = n_p + opt_size_group
      n_r = n_r - opt_size_group
    # sample randomly extra_tests_needed groups in modified case, all in
    # regular ID.
    # Because ID is a Dorfman type approach, it might be followed
    # by exhaustive splitting, which requires to keep track of groups
    # that tested positives to retest them.
    all_new_groups = jax.random.permutation(rng, all_new_groups)
    if self.modified:
      # in the case where we use modified ID, we only subsample a few groups.
      # one needs to take care of requesting to keep track of positives.
      new_groups = all_new_groups[0:state.extra_tests_needed].astype(bool)
      state.add_groups_to_test(new_groups,
                               results_need_clearing=True)
    else:
      # with regular ID we add all groups at once.
      state.add_groups_to_test(all_new_groups.astype(bool),
                               results_need_clearing=True)
    return state
Example #27
0
    def bounds(self):
        """Return the log-transformed bounds on the theta.

        Returns:
            bounds : array, shape (n_dims, 2)
                The log-transformed bounds on the kernel's hyperparameters
                theta

        """
        return np.empty((0, 2))
Example #28
0
 def get_coefficients(self):
     e = np.empty(0)
     return (
         e,
         e,
         np.array([self.a]),
         np.array([self.b]),
         np.array([self.c]),
         np.array([self.d]),
     )
def runDiscretePSO_jax(user_options, algorithm_options):
    particles = algorithm_options['particles']
    dimensions = algorithm_options['dimensions']
    objective = algorithm_options['objective']
    # For each particle, initialize position and velocity
    seed = random.PRNGKey(datetime.now().microsecond)
    particles_position = random.uniform(seed, (particles, dimensions), None,
                                        -1, 1)
    seed = random.PRNGKey(datetime.now().microsecond)
    particles_velocity = random.uniform(seed, (particles, dimensions), None,
                                        -1, 1)
    # Use of system microseconds as random seed to get different numbers each time

    particles_position = toDiscrete(activation(particles_velocity))

    best_global = None  # Best swarm cost
    best_global_position = npj.empty(
        (particles, dimensions))  # Best swarm position
    best_particle_position = particles_position
    best_particle_cost = objective(
        best_particle_position)  # obj_fuction(best_particle_position)

    for k in range(0, algorithm_options['iterations']):
        # Don't replace with 'iterations' variable because it is called only once
        objective_values = objective(
            best_particle_position)  # obj_fuction(particles_position)
        best_index = npj.argmin(objective_values)
        best_value = objective_values[best_index]

        # particles x dimensions
        best_particle_position = calculate_best_position(
            objective_values, best_particle_cost, particles_position,
            best_particle_position, particles, dimensions)

        if best_global is None or best_value < best_global:
            # Update best swarm cost and position
            best_global = best_value
            best_global_position = particles_position[best_index]

        seed = random.PRNGKey(datetime.now().microsecond)
        r1 = random.uniform(seed, (particles, dimensions), None, 0, 1)
        seed = random.PRNGKey(datetime.now().microsecond)
        r2 = random.uniform(seed, (particles, dimensions), None, 0, 1)

        particles_velocity = calculate_velocity(
            user_options['w'], particles_velocity, user_options['c1'],
            user_options['c2'], r1, r2, best_particle_position,
            particles_position, best_global_position)

        particles_position = toDiscrete(
            activation(particles_position + particles_velocity))

        best_particle_position = particles_position

    return best_global, best_global_position
Example #30
0
def backward_pass(x_trj, u_trj, regu, target):
    k_trj = np.empty_like(u_trj)
    K_trj = np.empty((TIME_STEPS-1, N_U, N_X))
    expected_cost_redu = 0.
    V_x, V_xx = derivative_final(x_trj[-1], target)
     
    V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target = lax.fori_loop(
        0, TIME_STEPS-1, backward_pass_looper, [V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target]
    )
        
    return k_trj, K_trj, expected_cost_redu