def _test__xla_dist_model_create_from_context_in_child_proc(index): model = _XlaDistModel.create_from_context() assert model.backend() == "xla-tpu" import torch_xla.core.xla_model as xm _assert_model( model, { "device": xm.xla_device(), "local_rank": index, "rank": xm.get_ordinal(), "world_size": xm.xrt_world_size(), "node_index": 0, "nnodes": 1, "nproc_per_node": xm.xrt_world_size(), }, )
def main_fold(fold): import time import torch.nn as nn import torch.optim as optim import torch_xla.core.xla_model as xm from ignite.engine import Engine, Events device = xm.xla_device(fold) comp_model = _XlaDistModel.create_from_context() assert comp_model.device() == device model = nn.Linear(100, 10) model.to(device) # Move model before creating optimizer optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) def training_step(engine, _): data = torch.rand(4, 100, device=device) model.train() data = data.to(device) optimizer.zero_grad() output = model(data) loss = output.sum() loss.backward() xm.optimizer_step(optimizer, barrier=True) return loss.item() trainer = Engine(training_step) # THIS CAN BE A CAUSE OF CRASH if DEVICE is OTHER THAN device tensor = torch.tensor([fold + 1.0], dtype=torch.float).to(comp_model.device()) xm.all_reduce("max", [ tensor, ]) time.sleep(0.01 * fold) trainer.run([0] * 100, max_epochs=2)
def test__xla_dist_model_create_from_context(): # without spawn model = _XlaDistModel.create_from_context() assert model.backend() == "xla-tpu" import torch_xla.core.xla_model as xm _assert_model( model, { "device": xm.xla_device(), "local_rank": 0, "rank": 0, "world_size": 1, "node_index": 0, "nnodes": 1, "nproc_per_node": 1, }, )