예제 #1
0
  def test_mixed_nesting(self):
    with jax.transfer_guard_host_to_device("disallow"):
      with jax.transfer_guard("allow"):
        with self.assertAllows("host_to_device_jnp_ones"):
          jnp.ones(1)
      with self.assertDisallows("host_to_device_jnp_ones"):
        jnp.ones(1)

    with jax.transfer_guard("disallow"):
      with jax.transfer_guard_host_to_device("allow"):
        with self.assertAllows("host_to_device_jnp_ones"):
          jnp.ones(1)
      with self.assertDisallows("host_to_device_jnp_ones"):
        jnp.ones(1)
예제 #2
0
def _device_to_device_funcs():
  """Generates device-to-device transfer functions."""
  if len(jax.local_devices()) < 2:
    # device-to-device tests require at least 2 devices.
    return []

  with jax.transfer_guard_host_to_device("allow"):
    device_arrays = [jnp.ones(1) for _ in range(2)]
  return [
      # (function name, is an explicit transfer?, function)
      ("device_to_device_jax_device_put", True,
       lambda: jax.device_put(device_arrays[0], device=jax.local_devices()[1])),
      ("device_to_device_jax_jit", False,
       lambda: jax.jit(lambda x: x, device=jax.local_devices()[1])
       (device_arrays[1])),
  ]
예제 #3
0
def _device_to_host_funcs():
  """Generates device-to-host transfer functions."""
  if jax.default_backend() == "cpu":
    # device-to-host does not incur transfer on the CPU backend.
    return []

  with jax.transfer_guard_host_to_device("allow"):
    device_arrays = [jnp.ones(1) for _ in range(6)]
  return [
      # (function name, is an explicit transfer?, function)
      ("device_to_host_jax_device_get", True,
       lambda: jax.device_get(device_arrays[0])),
      ("device_to_host_np_asarray", False,
       lambda: np.asarray(device_arrays[1])),
      ("device_to_host_copy_to_host_async", False,
       lambda: device_arrays[2].copy_to_host_async()),
      ("device_to_host_np_add", False, lambda: np.add(device_arrays[3], 1)),
      ("device_to_host_str", False, lambda: str(device_arrays[4])),
      ("device_to_host_pickle_dumps", False,
       lambda: pickle.dumps(device_arrays[5])),
  ]