def cov_estimate(*, optimization_path: Sequence[jnp.ndarray], optimization_path_grads: Sequence[jnp.ndarray], history: int): """Estimate covariance from an optimization path.""" dim = optimization_path[0].shape[0] position_diffs = jnp.empty((dim, 0)) gradient_diffs = jnp.empty((dim, 0)) approximations: List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] = [] diagonal_estimate = jnp.ones(dim) for j in range(len(optimization_path) - 1): _, thin_factors, scaling_outer_product = bfgs_inverse_hessian( updates_of_position_differences=position_diffs, updates_of_gradient_differences=gradient_diffs, ) position_diff = optimization_path[j + 1] - optimization_path[j] gradient_diff = optimization_path_grads[j] - optimization_path_grads[ j + 1] b = position_diff @ gradient_diff gradient_diff_norm = gradient_diff**2 new_diagonal_estimate = diagonal_estimate if b < 1e-12 * jnp.sum(gradient_diff_norm): position_diffs = jnp.column_stack( (position_diffs[:, -history + 1:], position_diff)) gradient_diffs = jnp.column_stack( (gradient_diffs[:, -history + 1:], gradient_diff)) a = gradient_diff @ (diagonal_estimate * gradient_diff) c = position_diff @ (position_diff / diagonal_estimate) new_diagonal_estimate = 1.0 / (a / (b * diagonal_estimate) + gradient_diff_norm / b - (a * position_diff**2) / (b * c * diagonal_estimate**2)) approximations.append( (diagonal_estimate, thin_factors, scaling_outer_product)) diagonal_estimate = new_diagonal_estimate return approximations
def _empty(cls, shape, *, dtype=None, index_dtype='int32'): """Create an empty COO instance. Public method is sparse.empty().""" shape = tuple(shape) if len(shape) != 2: raise ValueError(f"COO must have ndim=2; got shape={shape}") data = jnp.empty(0, dtype) row = col = jnp.empty(0, index_dtype) return cls((data, row, col), shape=shape)
def _empty(cls, shape, *, dtype=None, index_dtype='int32'): """Create an empty CSR instance. Public method is sparse.empty().""" shape = tuple(shape) if len(shape) != 2: raise ValueError(f"CSR must have ndim=2; got shape={shape}") data = jnp.empty(0, dtype) indices = jnp.empty(0, index_dtype) indptr = jnp.zeros(shape[0] + 1, index_dtype) return cls((data, indices, indptr), shape=shape)
def lbfgs( *, log_target_density: Callable[[jnp.ndarray], jnp.ndarray], initial_value: jnp.ndarray, # theta_init inverse_hessian_history: int = 6, # J relative_tolerance: float = 1e-13, # tau_rel max_iters: int = 1000, # L wolfe_bounds: Tuple[float, float] = (1e-4, 0.9), positivity_threshold: float = 2.2e-16, ): """LBFGS implementation which returns the optimization path and gradients.""" dim = initial_value.shape[0] grad_log_density = jax.grad(log_target_density) optimization_path = [initial_value] current_lp = log_target_density(initial_value) grad_optimization_path = [grad_log_density(initial_value)] position_diffs = jnp.empty((dim, 0)) gradient_diffs = jnp.empty((dim, 0)) for _ in range(max_iters): diagonal_estimate, thin_factors, scaling_outer_product = bfgs_inverse_hessian( updates_of_position_differences=position_diffs, updates_of_gradient_differences=gradient_diffs, ) grad_lp = grad_optimization_path[-1] search_direction = diagonal_estimate * grad_lp + thin_factors @ ( scaling_outer_product @ (jnp.transpose(thin_factors) @ grad_lp)) step_size = 1.0 while step_size > 1e-8: proposed = optimization_path[-1] + step_size * search_direction proposed_lp = log_target_density(proposed) if proposed_lp >= current_lp + (wolfe_bounds[0] * grad_lp) @ ( step_size * search_direction): proposed_grad = grad_log_density(proposed) if (proposed_grad @ search_direction <= wolfe_bounds[1] * grad_lp @ search_direction): break step_size = 0.5 * step_size optimization_path.append(proposed) grad_optimization_path.append(proposed_grad) if (proposed_lp - current_lp) / jnp.abs(current_lp) < relative_tolerance: return optimization_path, grad_optimization_path current_lp = proposed_lp position_diff: jnp.ndarray = optimization_path[-1] - optimization_path[ -2] grad_diff = -grad_optimization_path[-1] + grad_optimization_path[-2] if position_diff @ grad_diff > positivity_threshold * jnp.sum(grad_diff **2): position_diffs = jnp.column_stack( (position_diffs[:, -inverse_hessian_history + 1:], position_diff)) gradient_diffs = jnp.column_stack( (gradient_diffs[:, -inverse_hessian_history + 1:], grad_diff)) return optimization_path, grad_optimization_path
def _reset(self): self._t = 0 self.trajectory = Trajectory( observations=jnp.empty(self.n_steps + 1, self.batch_size, *self.observation_spec.shape), actions=jnp.empty(self.n_steps, self.batch_size, 1), rewards=jnp.empty(self.n_steps, self.batch_size, 1), discounts=jnp.empty(self.n_steps, self.batch_size, 1), trace_decays=jnp.empty(self.n_steps, self.batch_size, 1), ) return
def new(cls, in_features: int, out_features: int, use_bias=True): weight = jnp.empty([out_features, in_features]) if use_bias: bias = jnp.empty([out_features]) else: bias = None return cls(in_features=in_features, out_features=out_features, use_bias=use_bias, weight=weight, bias=bias)
def reset(self): """Reset the state.""" self.curr_cycle = 0 self.past_groups = np.empty((0, self.num_patients), dtype=bool) self.past_test_results = np.empty((0, ), dtype=bool) self.groups_to_test = np.empty((0, self.num_patients), dtype=bool) # Those are specific to some methods. They are not always used or filled. self.particle_weights = None self.particles = None self.to_clear_positives = np.empty((0, ), dtype=bool) self.all_cleared = False # In case we store marginals computed in different ways. self.marginals = {}
def apply_fun(params, inputs, **kwargs): input_size = inputs.shape[1] outputs = jnp.empty((inputs.shape[0], 2 * input_size- 1, inputs.shape[2]), dtype=jnp.complex128) outputs = jax.ops.index_update(outputs, jax.ops.index[:, 0:input_size, :], inputs[:, :, :]) outputs = jax.ops.index_update(outputs, jax.ops.index[:, input_size:2 * inputs.shape[1] - 1, :], inputs[:, 0:input_size - 1, :]) return outputs
def apply_fun(params, inputs, **kwargs): num_channels = inputs.shape[2] input_size = inputs.shape[1] outputs = jnp.empty((inputs.shape[0], input_size*num_channels), dtype=jnp.complex128) for i in range(0, num_channels): outputs = jax.ops.index_update(outputs, jax.ops.index[:, i*input_size:(i+1)*input_size], inputs[:, :, i]) return outputs
def run(self, rng_key, num_steps, *args, return_last=True, progbar=True, **kwargs): def bodyfn(i, info): svgd_state, losses = info svgd_state, loss = self.update(svgd_state, *args, **kwargs) losses = ops.index_update(losses, i, loss) return svgd_state, losses svgd_state = self.init(rng_key, *args, **kwargs) losses = np.empty((num_steps, )) if not progbar: svgd_state, losses = fori_loop(0, num_steps, bodyfn, (svgd_state, losses)) else: with tqdm.trange(num_steps) as t: for i in t: svgd_state, losses = jax.jit(bodyfn)(i, (svgd_state, losses)) t.set_description('SVGD {:.5}'.format(losses[i]), refresh=False) t.update() loss_res = losses[-1] if return_last else losses return svgd_state, loss_res
def _ndim_coords_from_arrays(points, ndim): """ Convert a tuple of coordinate arrays to a (..., ndim)-shaped array. """ if isinstance(points, tuple) and len(points) == 1: # handle argument tuple points = points[0] if isinstance(points, tuple): p = jnp.broadcast_arrays(*points) n = len(p) for j in range(1, n): if p[j].shape != p[0].shape: raise ValueError( "coordinate arrays do not have the same shape") points = jnp.empty(p[0].shape + (len(points), ), dtype=float) for j, item in enumerate(p): points[..., j] = item else: points = jnp.asarray(points) if points.ndim == 1: if ndim is None: points = points.reshape(-1, 1) else: points = points.reshape(-1, ndim) return points
def get_groups(self, rng, state): """Produces random design matrix fixed number of 1s per line. Args: rng: np.ndarray<int>[2]: the random key. state: the current state.State of the system. Returns: A np.array<bool>[num_groups, patients]. """ if self.group_size is None: # if no size has been defined, we compute it adaptively # in the simple case where prior is uniform. if np.size(state.prior_infection_rate) == 1: group_size = np.ceil( (np.log(state.prior_sensitivity - .5) - np.log(state.prior_sensitivity + state.prior_specificity - 1)) / np.log(1 - state.prior_infection_rate)) group_size = np.minimum(group_size, state.max_group_size) # if prior is not uniform, pick max size. else: group_size = self.max_group_size else: group_size = self.group_size group_size = int(np.squeeze(group_size)) new_groups = np.empty((0, state.num_patients), dtype=bool) for _ in range(state.extra_tests_needed): rng, rng_shuffle = jax.random.split(rng, 2) vec = np.zeros((1, state.num_patients), dtype=bool) idx = jax.random.permutation(rng_shuffle, np.arange(state.num_patients)) vec = jax.ops.index_update(vec, [0, idx[0:group_size]], True) new_groups = np.concatenate((new_groups, vec), axis=0) return new_groups
def dtoq_reyes(data): qubit_tensor = np.empty([0, 2]) for i in range(len(data)): x1 = 1 / math.sqrt(data[i]**2 + 1.0) x2 = data[i] / math.sqrt(data[i]**2 + 1.0) qubit_tensor = np.vstack((qubit_tensor, np.array([x1, x2]))) return qubit_tensor
def evolution_pepo_imag_time(g: float, dt: float, bc: str, dtype: np.dtype, lx: Optional[int] = None, ly: Optional[int] = None) -> Operator: # PEPO for U(dt) ~ U_vert(dt/2) U_bond(dt) U_vert(dt/2) # # half bond operators: # # | | | | # U_bond = A -- A # | | | | # # expm(- H_bond dt) = expm(- (-XX) dt) = expm(dt XX) = cosh(dt) + sinh(dt) XX = A_0 A_0 + A_1 A_1 # with A_0 = (cosh(dt) ** 0.5) * 1 , A_1 = (sinh(dt) ** 0.5) * X # A & B legs: (p,p*,k) A = np.empty([2, 2, 2], dtype=dtype) A = index_update(A, index[:, :, 0], (np.cosh(dt)**0.5) * s0) A = index_update(A, index[:, :, 1], (np.sinh(dt)**0.5) * sx) # expm(- H_vert dt/2) = expm(- (-gZ) dt/2) = expm(g dt/2 Z) u_vert = np.asarray(expm(g * (dt / 2) * sz), dtype=dtype) return _build_evolution_pepo(u_vert, A, bc, lx, ly)
def __do_rank_regression(self): f = jnp.hstack((jnp.atleast_2d(self.failures).T, jnp.zeros((self.failures.shape[0], 1)))) f = f[f[:, 0].argsort()] f = jnp.hstack((f, jnp.reshape(jnp.arange(self.failures.shape[0]), (self.failures.shape[0], -1)))) # censored items will be having flag '1' c = jnp.hstack((jnp.atleast_2d(self.censored).T, jnp.ones((self.censored.shape[0], 1)))) c = jnp.hstack((c, jnp.reshape(jnp.empty(self.censored.shape[0]), (self.censored.shape[0], -1)))) d = jnp.concatenate((c, f), axis=0) d = d[d[:, 0].argsort()] df = pd.DataFrame(data=d, columns=['time', 'is_cens', 'fo']) self.N = len(df.index) df['new_increment'] = (self.N + 1 - df['fo']) / (self.N + 2 - df.index.values) m = 1.0 - df['new_increment'].min() df['new_increment'] = df['new_increment'] + m df = df.drop(df[df['is_cens'] == 1].index) df['new_order_num'] = df['new_increment'].cumsum() df['cdf'] = util.median_rank(self.N, df['new_order_num'], 0.5) self.est_data = df
def get_groups(self, rng, state): """A greedy forward-backward algorithm to pick groups with large utility.""" particle_weights, particles = mutual_information.collapse_particles( rng, state.particle_weights, state.particles) n_patients = particles.shape[1] iterations = [self.forward_iterations, self.backward_iterations] chosen_groups = np.empty((0, n_patients), dtype=bool) added_groups_counter = 0 while added_groups_counter < state.extra_tests_needed: # start forming a new group, and improve it greedily proposed_group = np.zeros((n_patients,), dtype=bool) obj_old = -1 while np.sum(proposed_group) < state.max_group_size: for steps, backtrack in zip(iterations, [False, True]): for _ in range(steps): # Extract candidate with largest utility proposed_group, obj_new = next_best_group(particle_weights, particles, chosen_groups, proposed_group, state.prior_sensitivity, state.prior_specificity, self.utility_fn, backtracking=backtrack) if obj_new > obj_old + 1e-6: cur_group = proposed_group obj_old = obj_new else: break # stop adding, form next group chosen_groups = np.concatenate((chosen_groups, cur_group[np.newaxis, :]), axis=0) added_groups_counter += 1 return chosen_groups
def get_parameters(self): """Get variational parameters. Returns: Array holding current values of all variational parameters. """ if self.realNets: # FOR REAL NETS paramOut = jnp.empty(self.numParameters, dtype=global_defs.tReal) start = 0 for netId in [0,1]: parameters, _ = tree_flatten( self.net[netId].params ) # Flatten parameters to give a single vector for p in parameters: numParams = p.size paramOut = jax.ops.index_update( paramOut, jax.ops.index[start:start+numParams], p.reshape(-1) ) start += numParams return paramOut else: # FOR COMPLEX NET paramOut = jnp.concatenate([p.ravel() for p in tree_flatten(self.net.params)[0]]) if self.holomorphic: paramOut = jnp.concatenate([paramOut.real, paramOut.imag]) return paramOut
def _triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): # pylint: disable=redefined-outer-name """Scipy solve does not broadcast, so we must do so explicitly.""" del name if JAX_MODE: # But JAX uses XLA, which can do a batched solve. matrix = matrix + np.zeros(rhs.shape[:-2] + (1, 1), dtype=matrix.dtype) rhs = rhs + np.zeros(matrix.shape[:-2] + (1, 1), dtype=rhs.dtype) return scipy_linalg.solve_triangular(matrix, rhs, lower=lower, trans='C' if adjoint else 'N') try: bcast = onp.broadcast(matrix[..., :1], rhs) except ValueError as e: raise ValueError( 'Error with inputs shaped `matrix`={}, rhs={}:\n{}'.format( matrix.shape, rhs.shape, str(e))) dim = matrix.shape[-1] matrix = onp.broadcast_to(matrix, bcast.shape[:-1] + (dim, )) rhs = onp.broadcast_to(rhs, bcast.shape) nbatch = int(np.prod(matrix.shape[:-2])) flat_mat = matrix.reshape(nbatch, dim, dim) flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1]) result = np.empty(flat_rhs.shape) if np.size(result): # ValueError: On entry to STRTRS parameter number 7 had an illegal value. for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)): result[i] = scipy_linalg.solve_triangular( mat, rh, lower=lower, trans='C' if adjoint else 'N') return result.reshape(*rhs.shape)
def _indices(key): if not sparse_shape: return jnp.empty((nse, n_sparse), dtype=int) flat_ind = random.choice(key, sparse_size, shape=(nse, ), replace=not unique_indices) return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
def _kernel_matrix_without_gradients(kernel_fn, theta, X, Y): kernel_fn = partial(kernel_fn, theta) if Y is None or (Y is X): if config_value('KERNEL_MATRIX_USE_LOOP'): n = len(X) with loops.Scope() as s: # s.scattered_values = np.empty((n, n)) s.index1, s.index2 = np.tril_indices(n, k=0) s.output = np.empty(len(s.index1)) for i in s.range(s.index1.shape[0]): i1, i2 = s.index1[i], s.index2[i] s.output = ops.index_update(s.output, i, kernel_fn(X[i1], X[i2])) first_update = ops.index_update(np.empty((n, n)), (s.index1, s.index2), s.output) second_update = ops.index_update(first_update, (s.index2, s.index1), s.output) return second_update else: n = len(X) values_scattered = np.empty((n, n)) index1, index2 = np.tril_indices(n, k=-1) inst1, inst2 = X[index1], X[index2] values = vmap(kernel_fn)(inst1, inst2) values_scattered = ops.index_update(values_scattered, (index1, index2), values) values_scattered = ops.index_update(values_scattered, (index2, index1), values) values_scattered = ops.index_update( values_scattered, np.diag_indices(n), vmap(lambda x: kernel_fn(x, x))(X)) return values_scattered else: if config_value('KERNEL_MATRIX_USE_LOOP'): with loops.Scope() as s: s.output = np.empty((X.shape[0], Y.shape[0])) for i in s.range(X.shape[0]): x = X[i] s.output = ops.index_update( s.output, i, vmap(lambda y: kernel_fn(x, y))(Y)) return s.output else: return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X)
def cost_func_jvp(bb, u): n = bb.size directmat = jnp.empty([0]) for i in range(n): seed = jnp.zeros(n) seed = jax.ops.index_update(seed, jax.ops.index[i], 1) primal, res = jax.jvp(cost_func, (bb, u), (seed, jnp.zeros(n + 1))) directmat = jnp.append(directmat, res) return directmat
def apply_fun(params, inputs, **kwargs): if(len(inputs.shape) ==1): second_shape = inputs.shape[0] first_shape = 1 outputs = jnp.empty((first_shape, second_shape), dtype=jnp.complex128) outputs = jax.ops.index_update(outputs, jax.ops.index[0, :], inputs[:]) else: outputs = inputs return outputs
def __init__( self, indices: Tuple[int, ...], ) -> None: assert len(indices) >= 2, indices indices_sorted = sorted(indices[:-1]) indices_sorted.append(indices[-1]) self.indices = jnp.array(indices_sorted, dtype=jnp.int32) self.ncvecs = jnp.empty((0, 3), dtype=jnp.int32)
def _kernel_matrix_with_gradients(kernel_fn, theta, X, Y): kernel_fn = value_and_grad(kernel_fn) kernel_fn = partial(kernel_fn, theta) if Y is None or (Y is X): if config_value('KERNEL_MATRIX_USE_LOOP'): n = len(X) with loops.Scope() as s: s.scattered_values = np.empty((n, n)) s.scattered_grads = np.empty((n, n, len(theta))) index1, index2 = np.tril_indices(n, k=0) for i in s.range(index1.shape[0]): i1, i2 = index1[i], index2[i] value, grads = kernel_fn(X[i1], X[i2]) indexes = (np.stack([i1, i2]), np.stack([i2, i1])) s.scattered_values = ops.index_update( s.scattered_values, indexes, value) s.scattered_grads = ops.index_update( s.scattered_grads, indexes, grads) return s.scattered_values, s.scattered_grads else: n = len(X) values_scattered = np.empty((n, n)) grads_scattered = np.empty((n, n, len(theta))) index1, index2 = np.tril_indices(n, k=-1) inst1, inst2 = X[index1], X[index2] values, grads = vmap(kernel_fn)(inst1, inst2) # Scatter computed values into matrix values_scattered = ops.index_update(values_scattered, (index1, index2), values) values_scattered = ops.index_update(values_scattered, (index2, index1), values) grads_scattered = ops.index_update(grads_scattered, (index1, index2), grads) grads_scattered = ops.index_update(grads_scattered, (index2, index1), grads) diag_values, diag_grads = vmap(lambda x: kernel_fn(x, x))(X) diag_indices = np.diag_indices(n) values_scattered = ops.index_update(values_scattered, diag_indices, diag_values) grads_scattered = ops.index_update(grads_scattered, diag_indices, diag_grads) return values_scattered, grads_scattered else: return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X)
def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den): conc = jnp.broadcast_to(conc, shape) eig = jnp.broadcast_to(eig, shape) b0 = jnp.broadcast_to(b0, shape) eigmin = jnp.broadcast_to(eigmin, shape) phi_den = jnp.broadcast_to(phi_den, shape) def update_fn(curr): i, done, phi, key = curr phi_key, key = random.split(key) accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) x /= jnp.linalg.norm( x, axis=1, keepdims=True ) # Angular Central Gaussian distribution lf = ( conc[:, :1] * (x[:, :1] - 1) + eigmin + log_I1( 0, jnp.sqrt(conc[:, 1:] ** 2 + (corr * x[:, 1:]) ** 2) ).squeeze(0) - phi_den ) assert lf.shape == shape lg_inv = ( 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x ** 2).sum(1, keepdims=True)) ) assert lg_inv.shape == lf.shape accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv) phi = jnp.where(accepted, x, phi) return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): return jnp.bitwise_and( curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done)), ) phi_state = while_loop( cond_fn, update_fn, PhiMarginalState( i=jnp.array(0), done=jnp.zeros(shape, dtype=bool), phi=jnp.empty(shape, dtype=float), key=rng_key, ), ) return PhiMarginalState( phi_state.i, phi_state.done, phi_state.phi, phi_state.key )
def __call__(self, rng, state): """Produces new groups and adds them to state's stack.""" p_weights, particles = state.particle_weights, state.particles marginal = onp.array(np.sum(p_weights[:, np.newaxis] * particles, axis=0)) marginal = onp.squeeze(marginal) not_cut_ids, = onp.where(np.logical_and( marginal < self.cut_off_high, marginal > self.cut_off_low)) marginal = marginal[not_cut_ids] sorted_ids = onp.argsort(marginal) sorted_marginal = onp.array(marginal[sorted_ids]) n_p = 0 n_r = marginal.size if n_r == 0: # no one left to test in between thresholds state.all_cleared = True return state all_new_groups = np.empty((0, state.num_patients), dtype=bool) while n_p < marginal.size: index_max = onp.amin((n_r, state.max_group_size)) group_sizes = onp.arange(1, index_max + 1) cum_prod_prob = onp.cumprod(1 - sorted_marginal[n_p:(n_p + index_max)]) # formula below is only valid for group_size > 1, # corrected below for a group of size 1. sensitivity = onp.array( utils.select_from_sizes(state.prior_sensitivity, group_sizes)) specificity = onp.array( utils.select_from_sizes(state.prior_specificity, group_sizes)) exp_div_size = ( 1 + group_sizes * (sensitivity + (1 - sensitivity - specificity) * cum_prod_prob) ) / group_sizes exp_div_size[0] = 1 # adjusted cost for one patient is one. opt_size_group = onp.argmin(exp_div_size) + 1 new_group = onp.zeros((1, state.num_patients)) new_group[0, not_cut_ids[sorted_ids[n_p:n_p + opt_size_group]]] = True all_new_groups = np.concatenate((all_new_groups, new_group), axis=0) n_p = n_p + opt_size_group n_r = n_r - opt_size_group # sample randomly extra_tests_needed groups in modified case, all in # regular ID. # Because ID is a Dorfman type approach, it might be followed # by exhaustive splitting, which requires to keep track of groups # that tested positives to retest them. all_new_groups = jax.random.permutation(rng, all_new_groups) if self.modified: # in the case where we use modified ID, we only subsample a few groups. # one needs to take care of requesting to keep track of positives. new_groups = all_new_groups[0:state.extra_tests_needed].astype(bool) state.add_groups_to_test(new_groups, results_need_clearing=True) else: # with regular ID we add all groups at once. state.add_groups_to_test(all_new_groups.astype(bool), results_need_clearing=True) return state
def bounds(self): """Return the log-transformed bounds on the theta. Returns: bounds : array, shape (n_dims, 2) The log-transformed bounds on the kernel's hyperparameters theta """ return np.empty((0, 2))
def get_coefficients(self): e = np.empty(0) return ( e, e, np.array([self.a]), np.array([self.b]), np.array([self.c]), np.array([self.d]), )
def runDiscretePSO_jax(user_options, algorithm_options): particles = algorithm_options['particles'] dimensions = algorithm_options['dimensions'] objective = algorithm_options['objective'] # For each particle, initialize position and velocity seed = random.PRNGKey(datetime.now().microsecond) particles_position = random.uniform(seed, (particles, dimensions), None, -1, 1) seed = random.PRNGKey(datetime.now().microsecond) particles_velocity = random.uniform(seed, (particles, dimensions), None, -1, 1) # Use of system microseconds as random seed to get different numbers each time particles_position = toDiscrete(activation(particles_velocity)) best_global = None # Best swarm cost best_global_position = npj.empty( (particles, dimensions)) # Best swarm position best_particle_position = particles_position best_particle_cost = objective( best_particle_position) # obj_fuction(best_particle_position) for k in range(0, algorithm_options['iterations']): # Don't replace with 'iterations' variable because it is called only once objective_values = objective( best_particle_position) # obj_fuction(particles_position) best_index = npj.argmin(objective_values) best_value = objective_values[best_index] # particles x dimensions best_particle_position = calculate_best_position( objective_values, best_particle_cost, particles_position, best_particle_position, particles, dimensions) if best_global is None or best_value < best_global: # Update best swarm cost and position best_global = best_value best_global_position = particles_position[best_index] seed = random.PRNGKey(datetime.now().microsecond) r1 = random.uniform(seed, (particles, dimensions), None, 0, 1) seed = random.PRNGKey(datetime.now().microsecond) r2 = random.uniform(seed, (particles, dimensions), None, 0, 1) particles_velocity = calculate_velocity( user_options['w'], particles_velocity, user_options['c1'], user_options['c2'], r1, r2, best_particle_position, particles_position, best_global_position) particles_position = toDiscrete( activation(particles_position + particles_velocity)) best_particle_position = particles_position return best_global, best_global_position
def backward_pass(x_trj, u_trj, regu, target): k_trj = np.empty_like(u_trj) K_trj = np.empty((TIME_STEPS-1, N_U, N_X)) expected_cost_redu = 0. V_x, V_xx = derivative_final(x_trj[-1], target) V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target = lax.fori_loop( 0, TIME_STEPS-1, backward_pass_looper, [V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target] ) return k_trj, K_trj, expected_cost_redu