예제 #1
0
def test__xla_dist_model_spawn_one_proc():
    try:
        _XlaDistModel.spawn(
            _test_xla_spawn_fn, args=(1, "xla"), kwargs_dict={}, nproc_per_node=1,
        )
    except SystemExit:
        pass
예제 #2
0
def test__xla_dist_model_spawn_n_procs():
    n = int(os.environ["NUM_TPU_WORKERS"])
    try:
        _XlaDistModel.spawn(
            _test_xla_spawn_fn, args=(n, "xla"), kwargs_dict={}, nproc_per_node=n,
        )
    except SystemExit:
        pass