def manual_seed(seed: int) -> None: """Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported). Args: seed: Random state seed .. versionchanged:: 0.4.3 Added ``torch.cuda.manual_seed_all(seed)``. .. versionchanged:: 0.4.5 Added ``torch_xla.core.xla_model.set_rng_state(seed)``. """ random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) try: import torch_xla.core.xla_model as xm xm.set_rng_state(seed) except ImportError: pass try: import numpy as np np.random.seed(seed) except ImportError: pass
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None): # Get the proper rng state if rng_type == RNGType.TORCH: rng_state = torch.get_rng_state() elif rng_type == RNGType.CUDA: rng_state = torch.cuda.get_rng_state() elif rng_type == RNGType.XLA: assert is_tpu_available( ), "Can't synchronize XLA seeds on an environment without TPUs." rng_state = torch.tensor(xm.get_rng_state()) elif rng_type == RNGType.GENERATOR: assert generator is not None, "Need a generator to synchronize its seed." rng_state = generator.get_state() # Broadcast the rng state from device 0 to other devices state = AcceleratorState() if state.distributed_type == DistributedType.TPU: rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0]) elif state.distributed_type == DistributedType.MULTI_GPU: rng_state = rng_state.to(state.device) torch.distributed.broadcast(rng_state, 0) rng_state = rng_state.cpu() # Set the broadcast rng state if rng_type == RNGType.TORCH: torch.set_rng_state(rng_state) elif rng_type == RNGType.CUDA: torch.cuda.set_rng_state(rng_state) elif rng_type == RNGType.XLA: xm.set_rng_state(rng_state.item()) elif rng_type == RNGType.GENERATOR: generator.set_state(rng_state)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) torch.manual_seed(11) xm.set_rng_state(11) N = 3 M = 4 KO = 2 wsize = KO * world_size wg = torch.randn(N, wsize, device=device, requires_grad=True) w = torch.narrow(wg, 1, index * KO, KO) x = torch.randn(wsize, M, device=device) mm = wg @ x bmm = xf.distributed_mm(w, x, split=2) mm_cpu = mm.cpu() bmm_cpu = bmm.cpu() if not mm_cpu.allclose(bmm_cpu, rtol=1e-04, atol=1e-04): print('distributed_mm() produced wrong result', file=sys.stderr) print('[{}]\n{}\n{}'.format(index, mm_cpu, bmm_cpu), file=sys.stderr) sys.exit(1) else: print('Default device {} is not a TPU or GPU device'.format(device), file=sys.stderr)
def __init__(self, seed): assert isinstance(seed, int) self.rng_state = get_rng_state() torch.manual_seed(seed) if xm is not None: xm.set_rng_state(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed)
def set_seed(seed: int): """ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``. Args: seed (:obj:`int`): The seed to set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available if is_tpu_available(): xm.set_rng_state(seed)
def set_global_seed(seed: int) -> None: """Sets random seed into Numpy and Random, PyTorch and TensorFlow. Args: seed: random seed """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) try: import torch_xla.core.xla_model as xm except ImportError: pass else: xm.set_rng_state(seed)
def set_rng_state(state): torch.set_rng_state(state["torch_rng_state"]) if xm is not None: xm.set_rng_state(state["xla_rng_state"]) if torch.cuda.is_available(): torch.cuda.set_rng_state(state["cuda_rng_state"])
def setUp(self): super().setUp() xm.set_rng_state(101)
def seed_everything(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) xm.set_rng_state(seed, device=xm.xla_device())