def test__xla_dist_model_create_from_backend(): # without spawn model = _XlaDistModel.create_from_backend("xla-tpu") import torch_xla.core.xla_model as xm _assert_model( model, { "device": xm.xla_device(), "local_rank": 0, "rank": 0, "world_size": 1, "node_index": 0, "nnodes": 1, "nproc_per_node": 1, }, ) model.finalize()
def test__xla_model(): available_backends = _XlaDistModel.available_backends assert "xla-tpu" in available_backends with pytest.raises(ValueError, match=r"Backend should be one of"): _XlaDistModel.create_from_backend("abc")