def lqr_continuous_time_infinite_horizon(A, B, Q, R, N): # Take the last dimension, in case we try to do some kind of broadcasting # thing in the future. x_dim = A.shape[-1] # pylint: disable=line-too-long # See https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_continuous-time_LQR. A1 = A - B @ jp.linalg.solve(R, N.T) Q1 = Q - N @ jp.linalg.solve(R, N.T) # See https://en.wikipedia.org/wiki/Algebraic_Riccati_equation#Solution. H = jp.block([[A1, -B @ jp.linalg.solve(R, B.T)], [-Q1, -A1]]) eigvals, eigvectors = jp.linalg.eig(H) # For large-ish systems (eg x_dim = 7), sometimes we find some values that # have an imaginary component. That's an unfortunate consequence of the # numerical instability in the eigendecomposition. Still, # assert (eigvals.imag == jp.zeros_like(eigvals, dtype=jp.float32)).all() # assert (eigvectors.imag == jp.zeros_like(eigvectors, dtype=jp.float32)).all() # Now it should be safe to take out only the real components. eigvals = eigvals.real eigvectors = eigvectors.real argsort = jp.argsort(eigvals) ix = argsort[:x_dim] U = eigvectors[:, ix] P = U[x_dim:, :] @ jp.linalg.inv(U[:x_dim, :]) K = jp.linalg.solve(R, (B.T @ P + N.T)) return K, P, eigvals[ix]
def compute_eigenvalue_decomposition(Ms, sort_by='magnitude', do_compute_lefts=True): """Compute the eigenvalues of the matrix M. No assumptions are made on M. Arguments: M: 3D np.array nmatrices x dim x dim matrix do_compute_lefts: Compute the left eigenvectors? Requires a pseudo-inverse call. Returns: list of dictionaries with eigenvalues components: sorted eigenvalues, sorted right eigenvectors, and sored left eigenvectors (as column vectors). """ if sort_by == 'magnitude': sort_fun = onp.abs elif sort_by == 'real': sort_fun = onp.real else: assert False, "Not implemented yet." decomps = [] L = None for M in Ms: evals, R = onp.linalg.eig(M) indices = np.flipud(np.argsort(sort_fun(evals))) if do_compute_lefts: L = onp.linalg.pinv(R).T # as columns L = L[:, indices] decomps.append({'evals': evals[indices], 'R': R[:, indices], 'L': L}) return decomps
def test_modelSelection(family, prior, method): p = 5 n = 700 key = random.PRNGKey(0) X = random.normal(key, (n, p)) key, subkey1, subkey2 = random.split(key, 3) mu = 1.7 * X[:, 1] - 1.6 * X[:, 2] truth = jnp.full((p, ), False) truth = truth.at[[1, 2]].set(True) if family == "logistic": y = (mu + random.normal(subkey1, (n, )) > 0).astype(jnp.int32) elif family == "poisson": y = random.poisson(subkey2, lam=jnp.exp(mu), shape=(n, )) fmt = f"{{:0{p}b}}" gammes = np.array([list(fmt.format(i)) for i in range(2**p)]) gammes = jnp.array(gammes == "0")[:-1, :] _, modprobs = modelSelection(X, y, gammes, family=family, prior=prior, method=method) order = jnp.argsort(modprobs)[::-1] assert np.all(np.isfinite(modprobs)) if family == "logistic": assert np.all(gammes[order[0], :] == truth), gammes[order[0], :]
def eigh(H, precision=lax.Precision.HIGHEST, symmetrize=True, termination_size=128): """ Computes the eigendecomposition of the symmetric/Hermitian matrix H. Args: H: The `n x n` Hermitian input. precision: The matmul precision. symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used. termination_size: Recursion ends once the blocks reach this linear size. Returns: vals: The `n` eigenvalues of `H`, sorted from lowest to highest. vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up to numerical error. """ nrows, ncols = H.shape if nrows != ncols: raise TypeError(f"Input H of shape {H.shape} must be square.") if ncols <= termination_size: return jnp.linalg.eigh(H) evals, evecs = _eigh_work(H, precision=precision) sort_idxs = jnp.argsort(evals) evals = evals[sort_idxs] evecs = evecs[:, sort_idxs] return evals, evecs
def CM_eigenvectors_EVsorted(Z, R, N=0, cutoff=10): ''' Matrix containing eigenvalues of unsorted Coulomb matrix, sorted by their eigenvalues. Cutoff possible at dedicated len. or for certain sizes of eigenvalues Parameters ---------- Z : 1 x n dimensional array contains nuclear charges R : 3 x n dimensional array contains nuclear positions N : float number of electrons in system here: meaningless, can remain empty Return ------ M : Matrix contains eigenvectors of sorted CM (vectors: tuple contains Eigenvectors of matrix (n dim.) If i out of bounds, return none and print error) ''' N = CM_full_unsorted_matrix(Z, R) ev, evec = jnp.linalg.eigh(N) order = jnp.argsort(ev)[:min(ev.size, cutoff)] sorted_evec = evec[order] return (sorted_evec)
def top_k_error_rate_metric(logits: jnp.ndarray, one_hot_labels: jnp.ndarray, k: int = 5, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: """Returns the top-K error rate between some predictions and some labels. Args: logits: Output of the model. one_hot_labels: One-hot encoded labels. Dimensions should match the logits. k: Number of class the model is allowed to predict for each example. mask: Mask to apply to the loss to ignore some samples (usually, the padding of the batch). Array of ones and zeros. Returns: The error rate (1 - accuracy), averaged over the first dimension (samples). """ if mask is None: mask = jnp.ones([logits.shape[0]]) mask = mask.reshape([logits.shape[0]]) true_labels = jnp.argmax(one_hot_labels, -1).reshape([-1, 1]) top_k_preds = jnp.argsort(logits, axis=-1)[:, -k:] hit = jax.vmap(jnp.isin)(true_labels, top_k_preds) error_rate = 1 - ((hit * mask).sum() / mask.sum()) # Set to zero if there is no non-masked samples. return jnp.nan_to_num(error_rate)
def compute(self, particles, particle_info, loss_fn): if self._random_weights is None: self._random_weights = jnp.array(npr.randn(*particles.shape)) self._random_biases = jnp.array( npr.rand(*particles.shape) * 2 * np.pi) factor = self.bandwidth_factor(particles.shape[0]) if self.bandwidth_subset is not None: particles = particles[npr.choice(particles.shape[0], self.bandwidth_subset)] diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims( particles, axis=1) # N x N x D if particles.ndim == 2: diffs = safe_norm(diffs, ord=2, axis=-1) # N x N x D -> N x N diffs = jnp.reshape(diffs, (diffs.shape[0] * diffs.shape[1], -1)) # N * N x 1 if diffs.ndim == 2: diff_norms = safe_norm(diffs, ord=2, axis=-1) else: diff_norms = diffs median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)] bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5 def feature(x, w, b): return jnp.sqrt(2) * jnp.cos((x @ w + b) / bandwidth) def kernel(x, y): ws = self._random_weights if self.random_indices is None else self._random_weights[ self.random_indices] bs = self._random_biases if self.random_indices is None else self._random_biases[ self.random_indices] return jnp.sum( jax.vmap(lambda w, b: feature(x, w, b) * feature(y, w, b))(ws, bs)) return kernel
def PCA_visual(p_final): ''' Visualize matrix (N* M) as N points ''' loss_p = lambda p: loss(p_final[0], p, omega0, sx) Hess = jacrev(jacrev(loss_p))(p_final[1]) w, v = jnp.linalg.eig(Hess) v_real = v.real arglist = jnp.argsort(w)[-3:] pca = PCA(n_components=2) # scatter in 2d pannel reduced = pca.fit_transform(v_real[:, arglist].transpose()) t = reduced.transpose() # t_main = t[:,arglist] fig = plt.figure() ax1 = fig.add_subplot(111) ax1.scatter(t[0], t[1], marker='v', label='main') # ax1.scatter(t[0], t[1],marker='o',label='full') # ax1.scatter(t_main[0],t_main[1],marker='v',label='main') plt.legend(loc='upper left') plt.show()
def sorted_topk_indicators(x, k, sort_by=SortBy.POSITION): """Finds the (sorted) positions of the topk values in x. Args: x: The input scores of dimension (d,). k: The number of top elements to find. sort_by: Strategy to order the extracted values. This is useful when this function is applied to many perturbed input and average. As topk's output does not have a fixed order, the indicator vectors could be swaped and the average of the indicators would not be spiky. Returns: Indicator vectors in a tensor of shape (k, d) """ n = x.shape[-1] values, ranks = jax.lax.top_k(x, k) if sort_by == SortBy.NONE: sorted_ranks = ranks if sort_by == SortBy.VALUES: sorted_ranks = ranks[jnp.argsort(values)] if sort_by == SortBy.POSITION: sorted_ranks = jnp.sort(ranks) one_hot_fn = jax.vmap(functools.partial(jax.nn.one_hot, num_classes=n)) indicators = one_hot_fn(sorted_ranks) return indicators
def init_simplex_sampler_state(live_points_U): N, D = live_points_U.shape inter_point_distance = squared_norm(live_points_U, live_points_U) inter_point_distance = jnp.where(inter_point_distance == 0., jnp.inf, inter_point_distance) knn_indices = jnp.argsort(inter_point_distance, axis=-1)[:, :D + 1] return SimplexSamplerState(knn_indices=knn_indices)
def topk_mask_internal(value): assert value.ndim == 1 indices = jnp.argsort(value) k = jnp.round(density_fraction * jnp.size(value)).astype(jnp.int32) mask = jnp.greater_equal(np.arange(value.size), value.size - k) mask = jnp.zeros_like(mask).at[indices].set(mask) return mask.astype(np.int32)
def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = lax.gather_p.bind(operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) if ErrorCategory.OOB not in enabled_errors: return out, error # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) num_batch_dims = len(start_indices.shape) - 1 upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) in_bounds = (start_indices >= 0) & (start_indices <= upper_bound) msg = f'out-of-bounds indexing at {summary()}: ' msg += 'index {payload} is out of bounds for ' msg += f'array of shape {operand.shape}.' start_indices, in_bounds = jnp.ravel(start_indices), jnp.ravel(in_bounds) # Report first index which is out-of-bounds (in row-major order). payload = start_indices[jnp.argsort(in_bounds, axis=0)[0]] return out, assert_func(error, jnp.all(in_bounds), msg, payload)
def split_top_k(split_queries): split_scores = jnp.einsum('qd,rvd->qrv', split_queries, table) # Find highest scoring vector for each row. top_id_by_row = jnp.argmax(split_scores, axis=-1) top_score_by_row = jnp.max(split_scores, axis=-1) # Take k highest scores among all rows. top_row_idx = jnp.argsort(top_score_by_row, axis=-1)[:, :-self.k_top - 1:-1] # Sub-select best indices for k best rows. ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx) # Gather highest scoring vectors for k best rows. split_topk_values = table[top_row_idx, ids_by_topk_row] # Convert row indices to indices into flattened table. top_table_id_by_row = top_id_by_row + jnp.arange( 0, table_size, scores_per_row) # Get best ids into flattened table. split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx) split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx) return split_topk_values, split_topk_scores, split_topk_ids
def nucleaus_filter(logits, top_p=0.9, top_k=None): sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending sorted_indices = jnp.argsort(logits)[:, ::-1] cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1) if top_k is not None: # Keep only top_k tokens indices_range = jnp.arange(len(sorted_indices[0])) indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0) sorted_indices_to_remove = jnp.where(indices_range > top_k, sorted_indices, 0) _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) logit_mask = 1e10 * indices_to_remove logits -= logit_mask # Remove tokens with cumulative probability above a threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like( sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1] _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) logit_mask = 1e10 * indices_to_remove logits -= logit_mask return logits
def sample_permutation(key, coupling): """Samples a permutation matrix from a doubly stochastic coupling matrix. CAREFUL: the couplings that come out of the sinkhorn solver are not doubly stochastic but 1/dim * doubly_stochastic. See **Convex Relaxations for Permutation Problems** paper for rough explanation of the algorithm. Best to use by drawing multiple samples and picking the permutation with lowest cost as sometimes permutations seem to be drawn with high cost. the sample_best_permutation method does this. Args: key: jnp.ndarray that functions as a PRNG key. coupling: jnp.ndarray of shape [N, N] which must have marginals such that coupling.sum(0) == 1. and coupling.sum(1) == 1. Note that in sinkhorn we usually output couplings with marginals that sum to 1/N. Returns: permutation matrix: jnp.ndarray of shape [N, N] of floating dtype. """ dim = coupling.shape[0] # random monotonic vector v without duplicates. v = jax.random.choice(key, 10 * dim, shape=(dim,), replace=False) v = jnp.sort(v) * 10. w = jnp.dot(coupling, v) # Sorting w will give the row indices of the permutation matrix. row_ind = jnp.argsort(w) col_ind = jnp.arange(0, dim) # Compute permutation matrix from row and column indices perm = idx2permutation(row_ind, col_ind) return perm
def batched_neighbor_inds(confs, inds_l, inds_r, cutoff, boxes): """Given candidate interacting pairs (inds_l, inds_r), inds_l.shape == n_interactions exclude most pairs whose distances are >= cutoff (neighbor_inds_l, neighbor_inds_r) neighbor_inds_l.shape == (len(confs), max_n_neighbors) where the total number of neighbors returned for each conf in confs is the same max_n_neighbors This padding causes some amount of wasted effort, but keeps things nice and fixed-dimensional for later XLA steps """ assert len(confs.shape) == 3 distances = vmap(distance_on_pairs)(confs[:, inds_l], confs[:, inds_r], boxes) assert distances.shape == (len(confs), len(inds_l)) neighbor_masks = distances < cutoff # how many total neighbors? n_neighbors = np.sum(neighbor_masks, 1) max_n_neighbors = max(n_neighbors) assert max_n_neighbors > 0 # sorting in order of [falses, ..., trues] keep_inds = np.argsort(neighbor_masks, axis=1)[:, -max_n_neighbors:] neighbor_inds_l = inds_l[keep_inds] neighbor_inds_r = inds_r[keep_inds] assert neighbor_inds_l.shape == (len(confs), max_n_neighbors) assert neighbor_inds_l.shape == neighbor_inds_r.shape return neighbor_inds_l, neighbor_inds_r
def sum_from_unique( cls, input: np.array, mean: bool = True) -> Tuple[np.array, np.array, "SparseReduce"]: un, cts = np.unique(input, return_counts=True) un_idx = [ np.argwhere(input == un[i]).flatten() for i in range(un.size) ] l_arr = np.array([i.size for i in un_idx]) argsort = np.argsort(l_arr) un_sorted = un[argsort] cts_sorted = cts[argsort] un_idx_sorted = [un_idx[i] for i in argsort] change = list( np.argwhere( l_arr[argsort][:-1] - l_arr[argsort][1:] != 0).flatten() + 1) change.insert(0, 0) change.append(len(l_arr)) change = np.array(change) el = [] for i in range(len(change) - 1): el.append( np.array([ un_idx_sorted[j] for j in range(change[i], change[i + 1]) ])) #assert False return un_sorted, cts_sorted, SparseReduce(el, mean)
def nystrom_inv(gram, n_comp, regul=0.) -> np.array: p = random.permutation(gram.shape[0]) ip = np.argsort(p) (vec_in, vec_out, λ) = nystrom_eigh(gram[p, :][:, p], n_comp, regul) vec = np.vstack([vec_in, vec_out]) rval = vec @ np.diag(1. / (λ + regul)) @ vec.T return rval[ip, :][:, ip]
def prior_sample(self, num_samps, t=None): """ Sample from the model prior f~N(0,K) multiple times using a nested loop. :param num_samps: the number of samples to draw [scalar] :param t: the input locations at which to sample (defaults to train+test set) [N_samp, 1] :return: f_sample: the prior samples [S, N_samp] """ self.update_model(softplus_list(self.prior.hyp)) if t is None: t = self.t_all else: x_ind = np.argsort(t[:, 0]) t = t[x_ind] dt = np.concatenate([np.array([0.0]), np.diff(t[:, 0])]) N = dt.shape[0] with loops.Scope() as s: s.f_sample = np.zeros([N, self.func_dim, num_samps]) s.m = np.linalg.cholesky(self.Pinf) @ random.normal(random.PRNGKey(99), shape=[self.state_dim, 1]) for i in s.range(num_samps): s.m = np.linalg.cholesky(self.Pinf) @ random.normal(random.PRNGKey(i), shape=[self.state_dim, 1]) for k in s.range(N): A = self.prior.state_transition(dt[k], self.prior.hyp) # transition and noise process matrices Q = self.Pinf - A @ self.Pinf @ A.T C = np.linalg.cholesky(Q + 1e-6 * np.eye(self.state_dim)) # <--- can be a bit unstable # we need to provide a different PRNG seed every time: s.m = A @ s.m + C @ random.normal(random.PRNGKey(i*k+k), shape=[self.state_dim, 1]) H = self.prior.measurement_model(t[k, 1:], softplus_list(self.prior.hyp)) f = (H @ s.m).T s.f_sample = index_add(s.f_sample, index[k, ..., i], np.squeeze(f)) return s.f_sample
def data_split(batch: Batch, num_workers: int, s: float): # s encodes the heterogeneity of the data split if num_workers == 1: return batch n_data = batch.x.shape[0] n_homo_data = int(n_data * s) assert 0 < n_homo_data < n_data data_homo, data_hetero = jnp.split(batch.x, [n_homo_data]) label_homo, label_hetero = jnp.split(batch.y, [n_homo_data]) data_homo_list = jnp.split(data_homo, num_workers) label_homo_list = jnp.split(label_homo, num_workers) index = jnp.argsort(label_hetero) label_hetero_sorted = label_hetero[index] data_hetero_sorted = data_hetero[index] data_hetero_list = jnp.split(data_hetero_sorted, num_workers) label_hetero_list = jnp.split(label_hetero_sorted, num_workers) data_list = [ jnp.concatenate([data_homo, data_hetero], axis=0) for data_homo, data_hetero in zip(data_homo_list, data_hetero_list) ] label_list = [ jnp.concatenate([label_homo, label_hetero], axis=0) for label_homo, label_hetero in zip(label_homo_list, label_hetero_list) ] return [Batch(x, y) for x, y in zip(data_list, label_list)]
def compute(self, particles, particle_info, loss_fn): diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims( particles, axis=1) # N x N (x D) if self._normed() and particles.ndim == 2: diffs = safe_norm(diffs, ord=2, axis=-1) # N x D -> N diffs = jnp.reshape( diffs, (diffs.shape[0] * diffs.shape[1], -1)) # N * N (x D) factor = self.bandwidth_factor(particles.shape[0]) if diffs.ndim == 2: diff_norms = safe_norm(diffs, ord=2, axis=-1) else: diff_norms = diffs median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)] bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5 if self._normed(): bandwidth = bandwidth[0] def kernel(x, y): diff = safe_norm( x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y kernel_res = jnp.exp(-diff**2 / bandwidth) if self._mode == 'matrix': if self.matrix_mode == 'norm_diag': return kernel_res * jnp.identity(x.shape[0]) else: return jnp.diag(kernel_res) else: return kernel_res return kernel
def evaluate_example(self, example: SingleExample, prediction: SinglePrediction) -> MeanStat: """Computes token top k accuracy for a single sequence example. Args: example: One example with target in range [0, num_classes) of shape [max_length]. prediction: Unnormalized prediction for ``example`` of shape [max_length, num_classes] Returns: MeanStat for token top k accuracy for a single sequence example or at each token position if ``per_position`` is ``True``. """ target = example[self.target_key] pred = prediction if self.pred_key is None else prediction[ self.pred_key] if self.logits_mask is not None: logits_mask = jnp.array(self.logits_mask) pred += logits_mask target_weight = get_target_weight(target, self.masked_target_values) top_k_pred = jnp.argsort(-pred, axis=1)[:, :self.k] correct = jnp.any(jnp.transpose(top_k_pred) == target, axis=0).astype(jnp.float32) if self.per_position: return MeanStat.new(correct * target_weight, target_weight) return MeanStat.new(jnp.sum(correct * target_weight), jnp.sum(target_weight))
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get raw next-position logits. logits, new_cache, new_tokens_to_logits_state = tokens_to_logits( cur_token, cache, internal_state=tokens_to_logits_state) logits = logits / temperature # Mask out the BOS token. if masked_tokens is not None: mask = common_utils.onehot(jnp.array(masked_tokens), num_classes=logits.shape[-1], on_value=LARGE_NEGATIVE) mask = jnp.sum(mask, axis=0)[None, :] # Combine multiple masks together logits = logits + mask # Apply the repetition penalty. if repetition_penalty != 1: logits = apply_repetition_penalty( sequences, logits, i, repetition_penalty=repetition_penalty, repetition_window=repetition_window, repetition_penalty_normalize=repetition_penalty_normalize) # Mask out everything but the top-k entries. if top_k is not None: # Compute top_k_index and top_k_threshold with shapes (batch_size, 1). top_k_index = jnp.argsort(logits, axis=-1)[:, ::-1][:, top_k - 1:top_k] top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1) logits = jnp.where(logits < top_k_threshold, jnp.full_like(logits, LARGE_NEGATIVE), logits) # Sample next token from logits. sample = multinomial(rng1, logits) next_token = sample.astype(jnp.int32) # Only use sampled tokens if we have past the out_of_prompt_marker. out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker) next_token = (next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token = next_token[:, None] next_token_or_endpad = jnp.where(ended, jnp.full_like(next_token, pad_token), next_token) ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2, new_tokens_to_logits_state)
def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): # pylint: disable=unused-argument """Numpy implementation of `tf.argsort`.""" if direction == 'ASCENDING': pass elif direction == 'DESCENDING': values = np.negative(values) else: raise ValueError('Unrecognized direction: {}.'.format(direction)) return np.argsort(values, axis, kind='stable' if stable else 'quicksort')
def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
def top_k_approx(scores, k=100): """Returns approximate topk highest scores for each row. The api is same as jax.lax.top_k, so this can be used as a drop in replacement as long as num dims of scores tensor is 2. For more dimensions, please use one or more vmap(s) to be able to use it. In essence, we perform jnp.max operation, which can be thought of as lossy top 1, on fixed length window of items. We can control the amound of approximation by changing the window length. Smaller it gets, the approximation gets better but at the cost of performance. Once we have the max for all the windows, we apply regular slow but exact jax.lax.top_k over reduced set of items. Args: scores: [num_rows, num_cols] shaped tensor. Will return top K over last dim. k: How many top scores to return for each row. Returns: Topk scores, topk ids. Both shaped [num_rows, k] """ num_queries = scores.shape[0] num_items = scores.shape[1] # Make this bigger to improve recall. Should be between [1, k]. num_windows_multiplier = 5 window_lengths = num_items // k // num_windows_multiplier + 1 padded_num_items = k * num_windows_multiplier * window_lengths print(f"scores shape: {scores.shape}") print(f"padded_num_items: {padded_num_items}") print(f"num_items: {num_items}") scores = jnp.pad(scores, ((0, 0), (0, padded_num_items - num_items)), mode="constant", constant_values=jnp.NINF) scores = jnp.reshape( scores, (num_queries, k * num_windows_multiplier, window_lengths)) approx_top_local_scores = jnp.max(scores, axis=2) sorted_approx_top_scores_across_local = jnp.flip(jnp.sort( approx_top_local_scores, axis=1), axis=1) approx_top_ids_across_local = jnp.flip(jnp.argsort(approx_top_local_scores, axis=1), axis=1)[:, :k] approx_top_local_ids = jnp.argmax(scores, axis=2) offsets = jnp.arange(0, padded_num_items, window_lengths) approx_top_ids_with_offsets = approx_top_local_ids + offsets approx_top_ids = slice_2d(approx_top_ids_with_offsets, approx_top_ids_across_local) topk_scores = sorted_approx_top_scores_across_local[:, :k] topk_ids = approx_top_ids return topk_scores, topk_ids
def _projector_subspace(P, H, rank, maxiter=2): """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`. The subspaces are computed using the naiive QR eigendecomposition algorithm, which converges very quickly due to the sharp separation between the relevant eigenvalues of the projector. Args: P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank` eigenpairs. H: The aforementioned Hermitian matrix, which is used to track convergence. rank: Rank of `P`. maxiter: Maximum number of iterations. Returns: Vm, Vp: Isometries into the eigenspaces described in the docstring. """ # Choose an initial guess: the `rank` largest-norm columns of P. column_norms = jnp.linalg.norm(P, axis=1) sort_idxs = jnp.argsort(column_norms) X = P[:, sort_idxs] X = X[:, :rank] H_norm = jnp.linalg.norm(H) thresh = 10 * jnp.finfo(X.dtype).eps * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): Q, _ = jnp.linalg.qr(X, mode="complete") V1 = Q[:, :rank] V2 = Q[:, rank:] # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST) error_matrix = jnp.dot(error_matrix, V1, precision=lax.Precision.HIGHEST) error = jnp.linalg.norm(error_matrix) / H_norm return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh return jnp.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) one = jnp.ones(1, dtype=jnp.int32) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) return V1, V2
def weighted_percentile(x, w, ps, assume_sorted=False): """Compute the weighted percentile(s) of a single vector.""" x = x.reshape([-1]) w = w.reshape([-1]) if not assume_sorted: sortidx = jnp.argsort(jax.lax.stop_gradient(x)) x, w = x[sortidx], w[sortidx] acc_w = jnp.cumsum(w) return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x)
def permuted_inverse_cdf_coupling(logits_1, logits_2, permutation_seed=1): """Constructs the matrix for an inverse CDF coupling under a permutation.""" dim, = logits_1.shape perm = jnp.argsort( jax.random.uniform(jax.random.PRNGKey(permutation_seed), shape=[dim])) invperm = jnp.argsort(perm) p1 = jnp.exp(logits_1)[perm] p2 = jnp.exp(logits_2)[perm] p1_bins = jnp.concatenate([jnp.array([0.]), jnp.cumsum(p1)]) p2_bins = jnp.concatenate([jnp.array([0.]), jnp.cumsum(p2)]) # Value in bin (i, j): overlap between bin ranges def get(i, j): left = jnp.maximum(p1_bins[i], p2_bins[j]) right = jnp.minimum(p1_bins[i + 1], p2_bins[j + 1]) return jnp.where(left < right, right - left, 0.0) return jax.vmap(lambda i: jax.vmap(lambda j: get(i, j))(invperm))(invperm)
def knn(self, instance): distances = [ self.euclidean_distance(instance, x) for x in self.locations ] # the above could be down in parallel k_neighbors = jnp.argsort(distances[:self.k]) # get the nearest k points vote = Counter(self.labels[k_neighbors]) return vote.most_common()[0][0]