Ejemplo n.º 1
0
def top_k(x, k):
  """Select the top k slices from the last dimension."""
  bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
  sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
  topk_vals = lax.slice_in_dim(sorted_vals, -k, sorted_vals.shape[-1], axis=-1)
  topk_idxs = lax.slice_in_dim(sorted_idxs, -k, sorted_idxs.shape[-1], axis=-1)
  return topk_vals, topk_idxs
Ejemplo n.º 2
0
 def top_k_classes(x, k):
     bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
     sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
     topk_idxs = (lax.slice_in_dim(sorted_idxs,
                                   -k,
                                   sorted_idxs.shape[-1],
                                   axis=-1))
     return topk_idxs
Ejemplo n.º 3
0
def _roots_with_zeros(p, num_leading_zeros):
    # Avoid lapack errors when p is all zero
    p = _where(len(p) == num_leading_zeros, 1.0, p)
    # Roll any leading zeros to the end & compute the roots
    roots = _roots_no_zeros(roll(p, -num_leading_zeros))
    # Sort zero roots to the end.
    roots = lax.sort_key_val(roots == 0, roots)[1]
    # Set roots associated with num_leading_zeros to NaN
    return _where(
        arange(roots.size) < roots.size - num_leading_zeros, roots,
        complex(np.nan, np.nan))
Ejemplo n.º 4
0
  def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable):
    rng = jtu.rand_default(self.rng())
    # This test relies on the property that wherever keys are tied, values are
    # too, since we don't guarantee the same ordering of values with equal keys.
    # To avoid that case, we generate unique keys (globally in the key array).
    def args_maker():
      flat_keys = np.arange(prod(shape), dtype=key_dtype)
      keys = self.rng().permutation(flat_keys).reshape(shape)
      values = rng(shape, val_dtype)
      return keys, values
    keys, values = args_maker()

    fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
    check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
Ejemplo n.º 5
0
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
    """
    Helper function for intersect1d which is jit-able
    """
    ar = concatenate((ar1, ar2))
    if return_indices:
        iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0)
        aux, indices = lax.sort_key_val(ar, iota)
    else:
        aux = sort(ar)

    mask = aux[1:] == aux[:-1]
    if return_indices:
        return aux, mask, indices
    else:
        return aux, mask
Ejemplo n.º 6
0
    def test_pmap(self, qy_shape, db_shape, dtype, k, recall):
        num_devices = jax.device_count()
        rng = jtu.rand_default(self.rng())
        qy = rng(qy_shape, dtype)
        db = rng(db_shape, dtype)
        db_size = db.shape[0]
        gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
        _, gt_args = lax.top_k(-gt_scores, k)  # negate the score to get min-k
        db_per_device = db_size // num_devices
        sharded_db = db.reshape(num_devices, db_per_device, 128)
        db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device

        def parallel_topk(qy, db, db_offset):
            scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
            ann_vals, ann_args = lax.approx_min_k(
                scores,
                k=k,
                reduction_dimension=1,
                recall_target=recall,
                reduction_input_size_override=db_size,
                aggregate_to_topk=False)
            return (ann_vals, ann_args + db_offset)

        # shape = qy_size, num_devices, approx_dp
        ann_vals, ann_args = jax.pmap(parallel_topk,
                                      in_axes=(None, 0, 0),
                                      out_axes=(1, 1))(qy, sharded_db,
                                                       db_offsets)
        # collapse num_devices and approx_dp
        ann_vals = lax.collapse(ann_vals, 1, 3)
        ann_args = lax.collapse(ann_args, 1, 3)
        ann_vals, ann_args = lax.sort_key_val(ann_vals, ann_args, dimension=1)
        ann_args = lax.slice_in_dim(ann_args,
                                    start_index=0,
                                    limit_index=k,
                                    axis=1)
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)
Ejemplo n.º 7
0
def top_k(x, k):
  """Select the top k slices from the last dimension."""
  bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
  sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
  # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization
  return sorted_vals[..., -k:], sorted_idxs[..., -k:]
Ejemplo n.º 8
0
def _reduce_any_error(errs, codes):
    errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
    return errs_[-1], codes_[-1]