def test_numpy_backend_delegation(self): # Assert that we are getting JAX's numpy backend. backend = math.backend() numpy = math.numpy self.assertEqual(jnp, backend['np']) # Assert that `numpy` calls the appropriate gin configured functions and # properties. self.assertTrue(numpy.isinf(numpy.inf)) self.assertEqual(jnp.isinf, numpy.isinf) self.assertEqual(jnp.inf, numpy.inf) # Assert that we will now get the pure numpy backend. self.override_gin("backend.name = 'numpy'") backend = math.backend() numpy = math.numpy self.assertEqual(onp, backend['np']) # Assert that `numpy` calls the appropriate gin configured functions and # properties. self.assertTrue(numpy.isinf(numpy.inf)) self.assertEqual(onp.isinf, numpy.isinf) self.assertEqual(onp.inf, numpy.inf)
def test_backend_imports_correctly(self): backend = math.backend() self.assertEqual(jnp, backend['np']) self.assertNotEqual(onp, backend['np']) self.override_gin("backend.name = 'numpy'") backend = math.backend() self.assertNotEqual(jnp, backend['np']) self.assertEqual(onp, backend['np'])