コード例 #1
0
ファイル: test_dashboard.py プロジェクト: zzmcdc/ray
def test_dashboard_module_decorator(enable_test_module):
    head_cls_list = dashboard_utils.get_all_modules(
        dashboard_utils.DashboardHeadModule)
    agent_cls_list = dashboard_utils.get_all_modules(
        dashboard_utils.DashboardAgentModule)

    assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
    assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)

    test_code = """
import os
import ray.new_dashboard.utils as dashboard_utils

os.environ.pop("RAY_DASHBOARD_MODULE_TEST")
head_cls_list = dashboard_utils.get_all_modules(
        dashboard_utils.DashboardHeadModule)
agent_cls_list = dashboard_utils.get_all_modules(
        dashboard_utils.DashboardAgentModule)
print(head_cls_list)
print(agent_cls_list)
assert all(cls.__name__ != "TestHead" for cls in head_cls_list)
assert all(cls.__name__ != "TestAgent" for cls in agent_cls_list)
print("success")
"""
    run_string_as_driver(test_code)
コード例 #2
0
ファイル: agent.py プロジェクト: xiaorancs/ray
 def __init__(self,
              redis_address,
              redis_password=None,
              temp_dir=None,
              log_dir=None,
              node_manager_port=None,
              object_store_name=None,
              raylet_name=None):
     """Initialize the DashboardAgent object."""
     self._agent_cls_list = dashboard_utils.get_all_modules(
         dashboard_utils.DashboardAgentModule)
     ip, port = redis_address.split(":")
     # Public attributes are accessible for all agent modules.
     self.redis_address = (ip, int(port))
     self.redis_password = redis_password
     self.temp_dir = temp_dir
     self.log_dir = log_dir
     self.node_manager_port = node_manager_port
     self.object_store_name = object_store_name
     self.raylet_name = raylet_name
     self.ip = ray.services.get_node_ip_address()
     self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
     listen_address = "[::]:0"
     logger.info("Dashboard agent listen at: %s", listen_address)
     self.port = self.server.add_insecure_port(listen_address)
     self.aioredis_client = None
     self.aiogrpc_raylet_channel = aiogrpc.insecure_channel("{}:{}".format(
         self.ip, self.node_manager_port))
     self.http_session = aiohttp.ClientSession(
         loop=asyncio.get_event_loop())
コード例 #3
0
ファイル: head.py プロジェクト: nikitavemuri/ray
 def _load_modules(self):
     """Load dashboard head modules."""
     modules = []
     head_cls_list = dashboard_utils.get_all_modules(
         dashboard_utils.DashboardHeadModule)
     for cls in head_cls_list:
         logger.info("Loading %s: %s",
                     dashboard_utils.DashboardHeadModule.__name__, cls)
         c = cls(self)
         dashboard_utils.ClassMethodRouteTable.bind(c)
         modules.append(c)
     logger.info("Loaded %d modules.", len(modules))
     return modules
コード例 #4
0
ファイル: agent.py プロジェクト: zhengpw/ray
 def _load_modules(self):
     """Load dashboard agent modules."""
     modules = []
     agent_cls_list = dashboard_utils.get_all_modules(
         dashboard_utils.DashboardAgentModule)
     for cls in agent_cls_list:
         logger.info("Loading %s: %s",
                     dashboard_utils.DashboardAgentModule.__name__, cls)
         c = cls(self)
         dashboard_utils.ClassMethodRouteTable.bind(c)
         modules.append(c)
     logger.info("Loaded {} modules.".format(len(modules)))
     return modules
コード例 #5
0
ファイル: head.py プロジェクト: xiaorancs/ray
 def __init__(self, redis_address, redis_password):
     # Scan and import head modules for collecting http routes.
     self._head_cls_list = dashboard_utils.get_all_modules(
         dashboard_utils.DashboardHeadModule)
     ip, port = redis_address.split(":")
     # NodeInfoGcsService
     self._gcs_node_info_stub = None
     self._gcs_rpc_error_counter = 0
     # Public attributes are accessible for all head modules.
     self.redis_address = (ip, int(port))
     self.redis_password = redis_password
     self.aioredis_client = None
     self.aiogrpc_gcs_channel = None
     self.http_session = aiohttp.ClientSession(
         loop=asyncio.get_event_loop())
     self.ip = ray.services.get_node_ip_address()
コード例 #6
0
ファイル: test_dashboard.py プロジェクト: zzmcdc/ray
def test_class_method_route_table(enable_test_module):
    head_cls_list = dashboard_utils.get_all_modules(
        dashboard_utils.DashboardHeadModule)
    agent_cls_list = dashboard_utils.get_all_modules(
        dashboard_utils.DashboardAgentModule)
    test_head_cls = None
    for cls in head_cls_list:
        if cls.__name__ == "TestHead":
            test_head_cls = cls
            break
    assert test_head_cls is not None
    test_agent_cls = None
    for cls in agent_cls_list:
        if cls.__name__ == "TestAgent":
            test_agent_cls = cls
            break
    assert test_agent_cls is not None

    def _has_route(route, method, path):
        if isinstance(route, aiohttp.web.RouteDef):
            if route.method == method and route.path == path:
                return True
        return False

    def _has_static(route, path, prefix):
        if isinstance(route, aiohttp.web.StaticDef):
            if route.path == path and route.prefix == prefix:
                return True
        return False

    all_routes = dashboard_utils.ClassMethodRouteTable.routes()
    assert any(_has_route(r, "HEAD", "/test/route_head") for r in all_routes)
    assert any(_has_route(r, "GET", "/test/route_get") for r in all_routes)
    assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
    assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
    assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
    assert any(
        _has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
    assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)

    # Test bind()
    bound_routes = dashboard_utils.ClassMethodRouteTable.bound_routes()
    assert len(bound_routes) == 0
    dashboard_utils.ClassMethodRouteTable.bind(
        test_agent_cls.__new__(test_agent_cls))
    bound_routes = dashboard_utils.ClassMethodRouteTable.bound_routes()
    assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
    assert all(not _has_route(r, "PUT", "/test/route_put")
               for r in bound_routes)

    # Static def should be in bound routes.
    routes.static("/test/route_static", "/path")
    bound_routes = dashboard_utils.ClassMethodRouteTable.bound_routes()
    assert any(
        _has_static(r, "/path", "/test/route_static") for r in bound_routes)

    # Test duplicated routes should raise exception.
    try:

        @routes.get("/test/route_get")
        def _duplicated_route(req):
            pass

        raise Exception("Duplicated routes should raise exception.")
    except Exception as ex:
        message = str(ex)
        assert "/test/route_get" in message
        assert "test_head.py" in message

    # Test exception in handler
    post_handler = None
    for r in bound_routes:
        if _has_route(r, "POST", "/test/route_post"):
            post_handler = r.handler
            break
    assert post_handler is not None

    loop = asyncio.get_event_loop()
    r = loop.run_until_complete(post_handler())
    assert r.status == 200
    resp = json.loads(r.body)
    assert resp["result"] is False
    assert "Traceback" in resp["msg"]