예제 #1
0
def explicit_device_get_scope() -> Iterator[None]:
  """Indicates that the current context is an explicit device_get() call."""
  state = transfer_guard_lib.thread_local_state()
  prev = state.explicit_device_get
  state.explicit_device_get = True
  try:
    yield
  finally:
    state.explicit_device_get = prev
예제 #2
0
    assert False, f'Invalid transfer guard level {val}'

transfer_guard_host_to_device = config.define_enum_state(
    name='jax_transfer_guard_host_to_device',
    enum_values=[
        'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
    ],
    # The default is applied by transfer_guard_lib. Use None here to avoid
    # accidentally overriding --jax_transfer_guard.
    default=None,
    help=('Select the transfer guard level for host-to-device transfers. '
          'Default is "allow".'),
    update_global_hook=lambda val: _update_transfer_guard(
        transfer_guard_lib.global_state(), 'host_to_device', val),
    update_thread_local_hook=lambda val: _update_transfer_guard(
        transfer_guard_lib.thread_local_state(), 'host_to_device', val))

transfer_guard_device_to_device = config.define_enum_state(
    name='jax_transfer_guard_device_to_device',
    enum_values=[
        'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
    ],
    # The default is applied by transfer_guard_lib. Use None here to avoid
    # accidentally overriding --jax_transfer_guard.
    default=None,
    help=('Select the transfer guard level for device-to-device transfers. '
          'Default is "allow".'),
    update_global_hook=lambda val: _update_transfer_guard(
        transfer_guard_lib.global_state(), 'device_to_device', val),
    update_thread_local_hook=lambda val: _update_transfer_guard(
        transfer_guard_lib.thread_local_state(), 'device_to_device', val))