def test__xla_dist_model_spawn_one_proc(): try: _XlaDistModel.spawn( _test_xla_spawn_fn, args=(1, "xla"), kwargs_dict={}, nproc_per_node=1, ) except SystemExit: pass
def test__xla_dist_model_spawn_n_procs(): n = int(os.environ["NUM_TPU_WORKERS"]) try: _XlaDistModel.spawn( _test_xla_spawn_fn, args=(n, "xla"), kwargs_dict={}, nproc_per_node=n, ) except SystemExit: pass