def top_k(x, k): """Select the top k slices from the last dimension.""" bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape) sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs) topk_vals = lax.slice_in_dim(sorted_vals, -k, sorted_vals.shape[-1], axis=-1) topk_idxs = lax.slice_in_dim(sorted_idxs, -k, sorted_idxs.shape[-1], axis=-1) return topk_vals, topk_idxs
def top_k_classes(x, k): bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape) sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs) topk_idxs = (lax.slice_in_dim(sorted_idxs, -k, sorted_idxs.shape[-1], axis=-1)) return topk_idxs
def _roots_with_zeros(p, num_leading_zeros): # Avoid lapack errors when p is all zero p = _where(len(p) == num_leading_zeros, 1.0, p) # Roll any leading zeros to the end & compute the roots roots = _roots_no_zeros(roll(p, -num_leading_zeros)) # Sort zero roots to the end. roots = lax.sort_key_val(roots == 0, roots)[1] # Set roots associated with num_leading_zeros to NaN return _where( arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable): rng = jtu.rand_default(self.rng()) # This test relies on the property that wherever keys are tied, values are # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): flat_keys = np.arange(prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values keys, values = args_maker() fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False): """ Helper function for intersect1d which is jit-able """ ar = concatenate((ar1, ar2)) if return_indices: iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0) aux, indices = lax.sort_key_val(ar, iota) else: aux = sort(ar) mask = aux[1:] == aux[:-1] if return_indices: return aux, mask, indices else: return aux, mask
def test_pmap(self, qy_shape, db_shape, dtype, k, recall): num_devices = jax.device_count() rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) db_size = db.shape[0] gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], []))) _, gt_args = lax.top_k(-gt_scores, k) # negate the score to get min-k db_per_device = db_size // num_devices sharded_db = db.reshape(num_devices, db_per_device, 128) db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device def parallel_topk(qy, db, db_offset): scores = lax.dot_general(qy, db, (([1], [1]), ([], []))) ann_vals, ann_args = lax.approx_min_k( scores, k=k, reduction_dimension=1, recall_target=recall, reduction_input_size_override=db_size, aggregate_to_topk=False) return (ann_vals, ann_args + db_offset) # shape = qy_size, num_devices, approx_dp ann_vals, ann_args = jax.pmap(parallel_topk, in_axes=(None, 0, 0), out_axes=(1, 1))(qy, sharded_db, db_offsets) # collapse num_devices and approx_dp ann_vals = lax.collapse(ann_vals, 1, 3) ann_args = lax.collapse(ann_args, 1, 3) ann_vals, ann_args = lax.sort_key_val(ann_vals, ann_args, dimension=1) ann_args = lax.slice_in_dim(ann_args, start_index=0, limit_index=k, axis=1) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall)
def top_k(x, k): """Select the top k slices from the last dimension.""" bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape) sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs) # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization return sorted_vals[..., -k:], sorted_idxs[..., -k:]
def _reduce_any_error(errs, codes): errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0) return errs_[-1], codes_[-1]