def _sanity_check(): from ignite.distributed.utils import _model assert _model.get_world_size() == _model.get_nnodes() * _model.get_nproc_per_node() assert _model.get_local_rank() < _model.get_nproc_per_node() assert _model.get_rank() < _model.get_world_size() assert _model.get_node_rank() < _model.get_nnodes()
def _test_xla_spawn_fn(local_rank, world_size, device): from ignite.distributed.utils import _model assert isinstance(_model, _XlaDistModel), f"{type(_model)} vs _XlaDistModel" assert _model.get_local_rank() == local_rank assert _model.get_world_size() == world_size d = _model.device() assert isinstance(d, torch.device) and d.type == device assert _model.get_rank() == local_rank assert _model.get_nproc_per_node() == world_size assert _model.get_node_rank() == 0 assert _model.get_nnodes() == 1