Exemplo n.º 1
0
 def test_nesting(self):
   with jax.transfer_guard("disallow"):
     with jax.transfer_guard("allow"):
       with self.assertAllows("host_to_device_jnp_ones"):
         jnp.ones(1)
     with self.assertDisallows("host_to_device_jnp_ones"):
       jnp.ones(1)
Exemplo n.º 2
0
 def test_simple(self):
   """Simple transfer guard tests."""
   with jax.transfer_guard("allow"):
     with self.assertAllows("host_to_device_jnp_ones"):
       jnp.ones(1)
   with jax.transfer_guard("log"):
     with self.assertLogs("host_to_device_jnp_ones"):
       jnp.ones(1)
   with jax.transfer_guard("disallow"):
     with self.assertDisallows("host_to_device_jnp_ones"):
       jnp.ones(1)