Beispiel #1
0
def manual_seed(seed: int) -> None:
    """Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported).

    Args:
        seed: Random state seed

    .. versionchanged:: 0.4.3
        Added ``torch.cuda.manual_seed_all(seed)``.

    .. versionchanged:: 0.4.5
        Added ``torch_xla.core.xla_model.set_rng_state(seed)``.
    """
    random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    try:
        import torch_xla.core.xla_model as xm

        xm.set_rng_state(seed)
    except ImportError:
        pass

    try:
        import numpy as np

        np.random.seed(seed)
    except ImportError:
        pass
Beispiel #2
0
def synchronize_rng_state(rng_type: Optional[RNGType] = None,
                          generator: Optional[torch.Generator] = None):
    # Get the proper rng state
    if rng_type == RNGType.TORCH:
        rng_state = torch.get_rng_state()
    elif rng_type == RNGType.CUDA:
        rng_state = torch.cuda.get_rng_state()
    elif rng_type == RNGType.XLA:
        assert is_tpu_available(
        ), "Can't synchronize XLA seeds on an environment without TPUs."
        rng_state = torch.tensor(xm.get_rng_state())
    elif rng_type == RNGType.GENERATOR:
        assert generator is not None, "Need a generator to synchronize its seed."
        rng_state = generator.get_state()

    # Broadcast the rng state from device 0 to other devices
    state = AcceleratorState()
    if state.distributed_type == DistributedType.TPU:
        rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0])
    elif state.distributed_type == DistributedType.MULTI_GPU:
        rng_state = rng_state.to(state.device)
        torch.distributed.broadcast(rng_state, 0)
        rng_state = rng_state.cpu()

    # Set the broadcast rng state
    if rng_type == RNGType.TORCH:
        torch.set_rng_state(rng_state)
    elif rng_type == RNGType.CUDA:
        torch.cuda.set_rng_state(rng_state)
    elif rng_type == RNGType.XLA:
        xm.set_rng_state(rng_state.item())
    elif rng_type == RNGType.GENERATOR:
        generator.set_state(rng_state)
Beispiel #3
0
def _mp_fn(index):
    device = xm.xla_device()

    if xm.xla_device_hw(device) in ('TPU', 'GPU'):
        world_size = xm.xrt_world_size()
        torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
            use_full_mat_mul_precision=True)
        torch.manual_seed(11)
        xm.set_rng_state(11)

        N = 3
        M = 4
        KO = 2
        wsize = KO * world_size
        wg = torch.randn(N, wsize, device=device, requires_grad=True)
        w = torch.narrow(wg, 1, index * KO, KO)
        x = torch.randn(wsize, M, device=device)

        mm = wg @ x
        bmm = xf.distributed_mm(w, x, split=2)

        mm_cpu = mm.cpu()
        bmm_cpu = bmm.cpu()
        if not mm_cpu.allclose(bmm_cpu, rtol=1e-04, atol=1e-04):
            print('distributed_mm() produced wrong result', file=sys.stderr)
            print('[{}]\n{}\n{}'.format(index, mm_cpu, bmm_cpu),
                  file=sys.stderr)
            sys.exit(1)
    else:
        print('Default device {} is not a TPU or GPU device'.format(device),
              file=sys.stderr)
Beispiel #4
0
    def __init__(self, seed):
        assert isinstance(seed, int)
        self.rng_state = get_rng_state()

        torch.manual_seed(seed)
        if xm is not None:
            xm.set_rng_state(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
Beispiel #5
0
def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``.

    Args:
        seed (:obj:`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # ^^ safe to call this function even if cuda is not available
    if is_tpu_available():
        xm.set_rng_state(seed)
Beispiel #6
0
def set_global_seed(seed: int) -> None:
    """Sets random seed into Numpy and Random, PyTorch and TensorFlow.

    Args:
        seed: random seed
    """
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    try:
        import torch_xla.core.xla_model as xm
    except ImportError:
        pass
    else:
        xm.set_rng_state(seed)
Beispiel #7
0
def set_rng_state(state):
    torch.set_rng_state(state["torch_rng_state"])
    if xm is not None:
        xm.set_rng_state(state["xla_rng_state"])
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(state["cuda_rng_state"])
Beispiel #8
0
 def setUp(self):
     super().setUp()
     xm.set_rng_state(101)
Beispiel #9
0
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    xm.set_rng_state(seed, device=xm.xla_device())