def test_autodiff(self, shape, dtype, k, is_max_k): vals = np.arange(prod(shape), dtype=dtype) vals = self.rng().permutation(vals).reshape(shape) if is_max_k: fn = lambda vs: lax.approx_max_k(vs, k=k)[0] else: fn = lambda vs: lax.approx_min_k(vs, k=k)[0] jtu.check_grads(fn, (vals, ), 2, ["fwd", "rev"], eps=1e-2)
def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall): rng = jtu.rand_default(self.rng()) qy = rng(qy_shape, dtype) db = rng(db_shape, dtype) scores = lax.dot(qy, db) _, gt_args = lax.top_k(scores, k) _, ann_args = lax.approx_max_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall)
def approx_max_k(qy, db): scores = qy @ db.transpose() return lax.approx_max_k(scores, k)