Ejemplo n.º 1
0
def _ray_start_cluster(**kwargs):
    init_kwargs = get_default_fixture_ray_kwargs()
    num_nodes = 0
    do_init = False
    # num_nodes & do_init are not arguments for ray.init, so delete them.
    if "num_nodes" in kwargs:
        num_nodes = kwargs["num_nodes"]
        del kwargs["num_nodes"]
    if "do_init" in kwargs:
        do_init = kwargs["do_init"]
        del kwargs["do_init"]
    elif num_nodes > 0:
        do_init = True
    init_kwargs.update(kwargs)
    cluster = Cluster()
    remote_nodes = []
    for i in range(num_nodes):
        if i > 0 and "_system_config" in init_kwargs:
            del init_kwargs["_system_config"]
        remote_nodes.append(cluster.add_node(**init_kwargs))
        # We assume driver will connect to the head (first node),
        # so ray init will be invoked if do_init is true
        if len(remote_nodes) == 1 and do_init:
            if client_test_enabled():
                ray_client.ray.init(address=cluster.address)
            else:
                ray.init(address=cluster.address)
    yield cluster
    # The code after the yield will run as teardown code.
    if client_test_enabled():
        ray_client.ray.disconnect()
        ray_client._stop_test_server(1)
        ray_client.reset_api()
    ray.shutdown()
    cluster.shutdown()
Ejemplo n.º 2
0
def _ray_start(**kwargs):
    init_kwargs = get_default_fixture_ray_kwargs()
    init_kwargs.update(kwargs)
    # Start the Ray processes.
    if client_test_enabled():
        address_info = ray_client.ray.init(**init_kwargs)
    else:
        address_info = ray.init(**init_kwargs)

    yield address_info
    # The code after the yield will run as teardown code.
    if client_test_enabled():
        ray_client.ray.disconnect()
        ray_client._stop_test_server(1)
        ray_client.reset_api()
    ray.shutdown()
Ejemplo n.º 3
0
def shutdown_only():
    yield None
    # The code after the yield will run as teardown code.
    if client_test_enabled():
        ray_client.ray.disconnect()
        ray_client._stop_test_server(1)
        ray_client.reset_api()
    ray.shutdown()
Ejemplo n.º 4
0
import sys
import time

import numpy as np
import pytest

import ray.cluster_utils
from ray.test_utils import (client_test_enabled)

import ray

logger = logging.getLogger(__name__)


# https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
def test_ignore_http_proxy(shutdown_only):
    ray.init(num_cpus=1)
    os.environ["http_proxy"] = "http://example.com"
    os.environ["https_proxy"] = "http://example.com"

    @ray.remote
    def f():
        return 1

    assert ray.get(f.remote()) == 1


# https://github.com/ray-project/ray/issues/7263
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_grpc_message_size(shutdown_only):
Ejemplo n.º 5
0
import pytest

import ray.cluster_utils
from ray.test_utils import (
    client_test_enabled,
    dicts_equal,
    wait_for_pid_to_exit,
)

import ray

logger = logging.getLogger(__name__)


# https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
def test_ignore_http_proxy(shutdown_only):
    ray.init(num_cpus=1)
    os.environ["http_proxy"] = "http://example.com"
    os.environ["https_proxy"] = "http://example.com"

    @ray.remote
    def f():
        return 1

    assert ray.get(f.remote()) == 1


# https://github.com/ray-project/ray/issues/7263
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_grpc_message_size(shutdown_only):
Ejemplo n.º 6
0
import logging
import random
import sys
import threading
import time

import numpy as np
import pytest

import ray.cluster_utils
import ray.test_utils

from ray.test_utils import client_test_enabled
from ray.test_utils import RayTestTimeoutException

if client_test_enabled():
    from ray.util.client import ray
else:
    import ray

logger = logging.getLogger(__name__)


# issue https://github.com/ray-project/ray/issues/7105
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_internal_free(shutdown_only):
    ray.init(num_cpus=1)

    @ray.remote
    class Sampler:
        def sample(self):
Ejemplo n.º 7
0
import sys
import threading
import time

import numpy as np
import pytest

from unittest.mock import MagicMock, patch

import ray.cluster_utils
from ray.test_utils import client_test_enabled
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.exceptions import GetTimeoutError
from ray.exceptions import RayTaskError

if client_test_enabled():
    from ray.util.client import ray
else:
    import ray

logger = logging.getLogger(__name__)


@pytest.mark.parametrize("shutdown_only", [{
    "local_mode": True
}, {
    "local_mode": False
}],
                         indirect=True)
def test_variable_number_of_args(shutdown_only):
    ray.init(num_cpus=1)
Ejemplo n.º 8
0
import sys
import threading
import time

import numpy as np
import pytest

from unittest.mock import MagicMock, patch

import ray.cluster_utils
from ray.test_utils import client_test_enabled
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.exceptions import GetTimeoutError
from ray.exceptions import RayTaskError

if client_test_enabled():
    from ray.util.client import ray
else:
    import ray

logger = logging.getLogger(__name__)


@pytest.mark.parametrize("shutdown_only", [{
    "local_mode": True
}, {
    "local_mode": False
}],
                         indirect=True)
def test_variable_number_of_args(shutdown_only):
    ray.init(num_cpus=1)
Ejemplo n.º 9
0
import logging
import random
import sys
import threading
import time

import numpy as np
import pytest

import ray.cluster_utils
import ray.test_utils

from ray.test_utils import client_test_enabled
from ray.test_utils import RayTestTimeoutException

if client_test_enabled():
    from ray.util.client import ray
else:
    import ray

logger = logging.getLogger(__name__)


# issue https://github.com/ray-project/ray/issues/7105
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_internal_free(shutdown_only):
    ray.init(num_cpus=1)

    @ray.remote
    class Sampler:
        def sample(self):
Ejemplo n.º 10
0
        @ray.remote
        class Actor:
            def __init__(self):
                # This should use the last version of f.
                self.x = ray.get(f.remote())

            def get_val(self):
                return self.x

        actor = Actor.remote()
        return ray.get(actor.get_val.remote())

    assert ray.get(g.remote()) == num_remote_functions - 1


@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_method_metadata_cache(ray_start_regular):
    class Actor(object):
        pass

    # The cache of ActorClassMethodMetadata.
    cache = ray.actor.ActorClassMethodMetadata._cache
    cache.clear()

    # Check cache hit during ActorHandle deserialization.
    A1 = ray.remote(Actor)
    a = A1.remote()
    assert len(cache) == 1
    cached_data_id = [id(x) for x in list(cache.items())[0]]
    for x in range(10):
        a = pickle.loads(pickle.dumps(a))
Ejemplo n.º 11
0
import logging
import random
import sys
import threading
import time

import numpy as np
import pytest

import ray.cluster_utils
import ray.test_utils

from ray.test_utils import client_test_enabled
from ray.test_utils import RayTestTimeoutException

if client_test_enabled():
    from ray.util.client import ray
else:
    import ray

logger = logging.getLogger(__name__)


# issue https://github.com/ray-project/ray/issues/7105
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_internal_free(shutdown_only):
    ray.init(num_cpus=1)

    @ray.remote
    class Sampler:
        def sample(self):
Ejemplo n.º 12
0
        @ray.remote
        class Actor:
            def __init__(self):
                # This should use the last version of f.
                self.x = ray.get(f.remote())

            def get_val(self):
                return self.x

        actor = Actor.remote()
        return ray.get(actor.get_val.remote())

    assert ray.get(g.remote()) == num_remote_functions - 1


@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_method_metadata_cache(ray_start_regular):
    class Actor(object):
        pass

    # The cache of ActorClassMethodMetadata.
    cache = ray.actor.ActorClassMethodMetadata._cache
    cache.clear()

    # Check cache hit during ActorHandle deserialization.
    A1 = ray.remote(Actor)
    a = A1.remote()
    assert len(cache) == 1
    cached_data_id = [id(x) for x in list(cache.items())[0]]
    for x in range(10):
        a = pickle.loads(pickle.dumps(a))