def explicit_device_get_scope() -> Iterator[None]: """Indicates that the current context is an explicit device_get() call.""" state = transfer_guard_lib.thread_local_state() prev = state.explicit_device_get state.explicit_device_get = True try: yield finally: state.explicit_device_get = prev
assert False, f'Invalid transfer guard level {val}' transfer_guard_host_to_device = config.define_enum_state( name='jax_transfer_guard_host_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' ], # The default is applied by transfer_guard_lib. Use None here to avoid # accidentally overriding --jax_transfer_guard. default=None, help=('Select the transfer guard level for host-to-device transfers. ' 'Default is "allow".'), update_global_hook=lambda val: _update_transfer_guard( transfer_guard_lib.global_state(), 'host_to_device', val), update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'host_to_device', val)) transfer_guard_device_to_device = config.define_enum_state( name='jax_transfer_guard_device_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' ], # The default is applied by transfer_guard_lib. Use None here to avoid # accidentally overriding --jax_transfer_guard. default=None, help=('Select the transfer guard level for device-to-device transfers. ' 'Default is "allow".'), update_global_hook=lambda val: _update_transfer_guard( transfer_guard_lib.global_state(), 'device_to_device', val), update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'device_to_device', val))