コード例 #1
0
ファイル: examples_test.py プロジェクト: xueeinstein/jax
 def testKernelRegressionTrainAndPredict(self):
     n, d = 100, 20
     truth = self.rng.normal(size=d)
     xs = self.rng.normal(size=(n, d))
     ys = jnp.dot(xs, truth)
     kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
     predict = kernel_lsq.train(kernel, xs, ys)
     np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)
コード例 #2
0
 def testKernelRegressionTrainAndPredict(self):
     n, d = 100, 20
     rng = onp.random.RandomState(0)
     truth = rng.randn(d)
     xs = rng.randn(n, d)
     ys = np.dot(xs, truth)
     kernel = lambda x, y: np.dot(x, y)
     predict = kernel_lsq.train(kernel, xs, ys)
     assert np.allclose(predict(xs), ys, atol=1e-3)
コード例 #3
0
ファイル: examples_test.py プロジェクト: zhangyixun3433/jax
 def testKernelRegressionTrainAndPredict(self):
     # TODO(frostig): reenable this test.
     self.skipTest("Test is broken")
     n, d = 100, 20
     rng = onp.random.RandomState(0)
     truth = rng.randn(d)
     xs = rng.randn(n, d)
     ys = np.dot(xs, truth)
     kernel = lambda x, y: np.dot(x, y)
     predict = kernel_lsq.train(kernel, xs, ys)
     assert np.allclose(predict(xs), ys, atol=1e-3)
コード例 #4
0
ファイル: examples_test.py プロジェクト: gnecula/jax
 def testKernelRegressionTrainAndPredict(self):
     n, d = 100, 20
     rng = np.random.RandomState(0)
     truth = rng.randn(d)
     xs = rng.randn(n, d)
     ys = jnp.dot(xs, truth)
     kernel = lambda x, y: jnp.dot(x, y)
     predict = kernel_lsq.train(kernel, xs, ys)
     self.assertAllClose(predict(xs),
                         ys,
                         atol=1e-3,
                         rtol=1e-3,
                         check_dtypes=False)