コード例 #1
0
 def test_autodiff(self, shape, dtype, k, is_max_k):
     vals = np.arange(prod(shape), dtype=dtype)
     vals = self.rng().permutation(vals).reshape(shape)
     if is_max_k:
         fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
     else:
         fn = lambda vs: lax.approx_min_k(vs, k=k)[0]
     jtu.check_grads(fn, (vals, ), 2, ["fwd", "rev"], eps=1e-2)
コード例 #2
0
 def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall):
     rng = jtu.rand_default(self.rng())
     qy = rng(qy_shape, dtype)
     db = rng(db_shape, dtype)
     scores = lax.dot(qy, db)
     _, gt_args = lax.top_k(scores, k)
     _, ann_args = lax.approx_max_k(scores, k, recall_target=recall)
     self.assertEqual(k, len(ann_args[0]))
     ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
     self.assertGreater(ann_recall, recall)
コード例 #3
0
ファイル: ann_test.py プロジェクト: frederikwilde/jax
 def approx_max_k(qy, db):
     scores = qy @ db.transpose()
     return lax.approx_max_k(scores, k)