示例#1
0
def _test__xla_dist_model_create_from_context_in_child_proc(index):
    model = _XlaDistModel.create_from_context()

    assert model.backend() == "xla-tpu"

    import torch_xla.core.xla_model as xm

    _assert_model(
        model,
        {
            "device": xm.xla_device(),
            "local_rank": index,
            "rank": xm.get_ordinal(),
            "world_size": xm.xrt_world_size(),
            "node_index": 0,
            "nnodes": 1,
            "nproc_per_node": xm.xrt_world_size(),
        },
    )
示例#2
0
def main_fold(fold):
    import time

    import torch.nn as nn
    import torch.optim as optim
    import torch_xla.core.xla_model as xm

    from ignite.engine import Engine, Events

    device = xm.xla_device(fold)

    comp_model = _XlaDistModel.create_from_context()
    assert comp_model.device() == device

    model = nn.Linear(100, 10)

    model.to(device)  # Move model before creating optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    def training_step(engine, _):
        data = torch.rand(4, 100, device=device)
        model.train()
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = output.sum()
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
        return loss.item()

    trainer = Engine(training_step)

    # THIS CAN BE A CAUSE OF CRASH if DEVICE is OTHER THAN device
    tensor = torch.tensor([fold + 1.0],
                          dtype=torch.float).to(comp_model.device())
    xm.all_reduce("max", [
        tensor,
    ])

    time.sleep(0.01 * fold)

    trainer.run([0] * 100, max_epochs=2)
示例#3
0
def test__xla_dist_model_create_from_context():
    # without spawn
    model = _XlaDistModel.create_from_context()

    assert model.backend() == "xla-tpu"

    import torch_xla.core.xla_model as xm

    _assert_model(
        model,
        {
            "device": xm.xla_device(),
            "local_rank": 0,
            "rank": 0,
            "world_size": 1,
            "node_index": 0,
            "nnodes": 1,
            "nproc_per_node": 1,
        },
    )