def fit(x_train): # Setup the K-NN estimator: start = timer(use_torch=False) x_train = jax_tensor(x_train) elapsed = timer(use_torch=False) - start def f(x_test): x_test = jax_tensor(x_test) # Estimate the largest reasonable batch size av_mem = int(5e8) # 500 Mb Ntrain, D = x_train.shape Ntest = x_test.shape[0] Ntest_loop = min(max(1, av_mem // (4 * D * Ntrain)), Ntest) Nloop = (Ntest - 1) // Ntest_loop + 1 indices = np.zeros((Ntest, K), dtype=int) start = timer(use_torch=False) # Actual K-NN query: for k in range(Nloop): x_test_k = x_test[Ntest_loop * k : Ntest_loop * (k + 1), :] indices[Ntest_loop * k : Ntest_loop * (k + 1), :] = knn_jax_fun( x_train, x_test_k, K, metric ) elapsed = timer(use_torch=False) - start return indices, elapsed return f, elapsed
def f(x_test): x_test = jax_tensor(x_test) # Actual K-NN query: start = timer(use_torch=False) indices = knn_jax_fun(x_train, x_test, K, metric) indices = np.array(indices) elapsed = timer(use_torch=False) - start return indices, elapsed
def fit(x_train): # Setup the K-NN estimator: start = timer(use_torch=False) x_train = jax_tensor(x_train) elapsed = timer(use_torch=False) - start def f(x_test): x_test = jax_tensor(x_test) # Actual K-NN query: start = timer(use_torch=False) indices = knn_jax_fun(x_train, x_test, K, metric) indices = np.array(indices) elapsed = timer(use_torch=False) - start return indices, elapsed return f, elapsed