def get_rng_state(): state = {"torch_rng_state": torch.get_rng_state()} if xm is not None: state["xla_rng_state"] = xm.get_rng_state() if torch.cuda.is_available(): state["cuda_rng_state"] = torch.cuda.get_rng_state() return state
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None): # Get the proper rng state if rng_type == RNGType.TORCH: rng_state = torch.get_rng_state() elif rng_type == RNGType.CUDA: rng_state = torch.cuda.get_rng_state() elif rng_type == RNGType.XLA: assert is_tpu_available( ), "Can't synchronize XLA seeds on an environment without TPUs." rng_state = torch.tensor(xm.get_rng_state()) elif rng_type == RNGType.GENERATOR: assert generator is not None, "Need a generator to synchronize its seed." rng_state = generator.get_state() # Broadcast the rng state from device 0 to other devices state = AcceleratorState() if state.distributed_type == DistributedType.TPU: rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0]) elif state.distributed_type == DistributedType.MULTI_GPU: rng_state = rng_state.to(state.device) torch.distributed.broadcast(rng_state, 0) rng_state = rng_state.cpu() # Set the broadcast rng state if rng_type == RNGType.TORCH: torch.set_rng_state(rng_state) elif rng_type == RNGType.CUDA: torch.cuda.set_rng_state(rng_state) elif rng_type == RNGType.XLA: xm.set_rng_state(rng_state.item()) elif rng_type == RNGType.GENERATOR: generator.set_state(rng_state)