Esempio n. 1
0
    def test_numpy_backend_delegation(self):
        # Assert that we are getting JAX's numpy backend.
        backend = math.backend()
        numpy = math.numpy
        self.assertEqual(jnp, backend['np'])

        # Assert that `numpy` calls the appropriate gin configured functions and
        # properties.
        self.assertTrue(numpy.isinf(numpy.inf))
        self.assertEqual(jnp.isinf, numpy.isinf)
        self.assertEqual(jnp.inf, numpy.inf)

        # Assert that we will now get the pure numpy backend.

        self.override_gin("backend.name = 'numpy'")

        backend = math.backend()
        numpy = math.numpy
        self.assertEqual(onp, backend['np'])

        # Assert that `numpy` calls the appropriate gin configured functions and
        # properties.
        self.assertTrue(numpy.isinf(numpy.inf))
        self.assertEqual(onp.isinf, numpy.isinf)
        self.assertEqual(onp.inf, numpy.inf)
Esempio n. 2
0
    def test_backend_imports_correctly(self):
        backend = math.backend()
        self.assertEqual(jnp, backend['np'])
        self.assertNotEqual(onp, backend['np'])

        self.override_gin("backend.name = 'numpy'")

        backend = math.backend()
        self.assertNotEqual(jnp, backend['np'])
        self.assertEqual(onp, backend['np'])