コード例 #1
0
ファイル: conftest.py プロジェクト: ConanoutlooklvTBS/mars
def ray_large_cluster(request):  # pragma: no cover
    param = getattr(request, "param", {})
    num_nodes = param.get('num_nodes', 3)
    num_cpus = param.get('num_cpus', 10)
    try:
        from ray.cluster_utils import Cluster
    except ModuleNotFoundError:
        from ray._private.cluster_utils import Cluster
    cluster = Cluster()
    remote_nodes = []
    for i in range(num_nodes):
        remote_nodes.append(cluster.add_node(num_cpus=num_cpus))
        if len(remote_nodes) == 1:
            ray.init(address=cluster.address)
    register_ray_serializers()
    try:
        yield
    finally:
        unregister_ray_serializers()
        Router.set_instance(None)
        RayServer.clear()
        ray.shutdown()
        cluster.shutdown()
        if 'COV_CORE_SOURCE' in os.environ:
            # Remove this when https://github.com/ray-project/ray/issues/16802 got fixed
            subprocess.check_call(["ray", "stop", "--force"])
コード例 #2
0
ファイル: test_ray.py プロジェクト: fyrestone/mars
def _sync_web_session_test(web_address):
    register_ray_serializers()
    new_session(web_address, backend='oscar', default=True)
    raw = np.random.RandomState(0).rand(10, 5)
    a = mt.tensor(raw, chunk_size=5).sum(axis=1)
    b = a.execute(show_progress=False)
    assert b is a
    return True
コード例 #3
0
ファイル: conftest.py プロジェクト: haijohn/mars
def ray_start_regular(request):
    param = getattr(request, "param", {})
    if not param.get('enable', True):
        yield
    else:
        register_ray_serializers()
        yield ray.init(num_cpus=10)
        ray.shutdown()
        unregister_ray_serializers()
        Router.set_instance(None)
コード例 #4
0
ファイル: conftest.py プロジェクト: ConanoutlooklvTBS/mars
def ray_start_regular(request):
    param = getattr(request, "param", {})
    if not param.get('enable', True):
        yield
    else:
        register_ray_serializers()
        try:
            yield ray.init(num_cpus=20)
        finally:
            ray.shutdown()
            unregister_ray_serializers()
            Router.set_instance(None)
            RayServer.clear()
            if 'COV_CORE_SOURCE' in os.environ:
                # Remove this when https://github.com/ray-project/ray/issues/16802 got fixed
                subprocess.check_call(["ray", "stop", "--force"])
コード例 #5
0
def actor_pool_context():
    from mars.serialization.ray import register_ray_serializers, unregister_ray_serializers
    register_ray_serializers()
    address = process_placement_to_address(pg_name, 0, process_index=0)
    # Hold actor_handle to avoid actor being freed.
    if hasattr(ray.util, "get_placement_group"):
        pg, bundle_index = ray.util.get_placement_group(pg_name), 0
    else:
        pg, bundle_index = None, -1
    actor_handle = ray.remote(RayMainPool).options(
        name=address,
        placement_group=pg,
        placement_group_bundle_index=bundle_index).remote()
    ray.get(actor_handle.start.remote(address, n_process))

    class ProxyPool:
        def __init__(self, ray_pool_actor_handle):
            self.ray_pool_actor_handle = ray_pool_actor_handle

        def __getattr__(self, item):
            if hasattr(RayMainPool, item) and inspect.isfunction(
                    getattr(RayMainPool, item)):

                def call(*args, **kwargs):
                    ray.get(
                        self.ray_pool_actor_handle.actor_pool.remote(
                            item, *args, **kwargs))

                return call

            return ray.get(self.ray_pool_actor_handle.actor_pool.remote(item))

    yield ProxyPool(actor_handle)
    for addr in [
            process_placement_to_address(pg_name, 0, process_index=i)
            for i in range(n_process)
    ]:
        try:
            ray.kill(ray.get_actor(addr))
        except:  # noqa: E722  # nosec  # pylint: disable=bare-except
            pass
    Router.set_instance(None)
    unregister_ray_serializers()
コード例 #6
0
ファイル: test_ray.py プロジェクト: fyrestone/mars
def _run_web_session(web_address):
    register_ray_serializers()
    import asyncio
    asyncio.new_event_loop().run_until_complete(
        test_local._run_web_session_test(web_address))
    return True