Ejemplo n.º 1
0
def test__xla_dist_model_create_from_backend():
    # without spawn
    model = _XlaDistModel.create_from_backend("xla-tpu")

    import torch_xla.core.xla_model as xm

    _assert_model(
        model,
        {
            "device": xm.xla_device(),
            "local_rank": 0,
            "rank": 0,
            "world_size": 1,
            "node_index": 0,
            "nnodes": 1,
            "nproc_per_node": 1,
        },
    )

    model.finalize()
Ejemplo n.º 2
0
def test__xla_model():
    available_backends = _XlaDistModel.available_backends
    assert "xla-tpu" in available_backends

    with pytest.raises(ValueError, match=r"Backend should be one of"):
        _XlaDistModel.create_from_backend("abc")