Ejemplo n.º 1
0
    def __init__(self, server_config) -> None:
        _100_MB = 100 * 1024 * 1024
        self._tikTorchClassifier = None
        self._train_model = None
        self._shutdown_sent = False

        addr, port = socket.gethostbyname(
            server_config.address), server_config.port
        conn_conf = ConnConf(addr, port, timeout=20)

        if server_config.autostart:
            if addr == "127.0.0.1":
                self.launcher = LocalServerLauncher(conn_conf,
                                                    path=server_config.path)
            else:
                self.launcher = RemoteSSHServerLauncher(
                    conn_conf,
                    cred=SSHCred(server_config.username,
                                 key_path=server_config.ssh_key),
                    path=server_config.path,
                )
        else:
            self.launcher = _NullLauncher()

        self.launcher.start()

        self._chan = grpc.insecure_channel(
            f"{addr}:{port}",
            options=[("grpc.max_send_message_length", _100_MB),
                     ("grpc.max_receive_message_length", _100_MB)],
        )
        self._tikTorchClient = inference_pb2_grpc.InferenceStub(self._chan)
        self._devices = [d.id for d in server_config.devices if d.enabled]
Ejemplo n.º 2
0
def test_start_remote_server(srv_port):
    host, ssh_port = os.getenv(SSH_HOST_VAR), os.getenv(SSH_PORT_VAR, 22)
    user, pwd = os.getenv(SSH_USER_VAR), os.getenv(SSH_PWD_VAR)
    key = os.getenv(SSH_KEY_VAR)

    if not all([host, ssh_port, user, key or pwd]):
        pytest.skip(
            "To run this test specify "
            f"{SSH_HOST_VAR}, {SSH_USER_VAR} {SSH_PWD_VAR} or {SSH_KEY_VAR} and optionaly {SSH_PORT_VAR}"
        )

    conn_conf = ConnConf(socket.gethostbyname(host), srv_port, timeout=20)
    cred = SSHCred(user=user, password=pwd, key_path=key)
    launcher = RemoteSSHServerLauncher(conn_conf, cred=cred)

    client = client_factory(conn_conf)
    try:
        launcher.start()

        assert launcher.is_server_running()

        assert client.ping()
    finally:
        launcher.stop()

    assert not launcher.is_server_running()
Ejemplo n.º 3
0
def test_start_local_server(srv_port):
    conn_conf = ConnConf("127.0.0.1", srv_port, timeout=5)
    launcher = LocalServerLauncher(conn_conf)
    launcher.start()

    assert launcher.is_server_running()

    client = client_factory(conn_conf)

    assert client.ping()

    launcher.stop()
Ejemplo n.º 4
0
def _fetch_devices(config: types.ServerConfig):
    try:
        port = config.port
        if config.autostart:
            # in order not to block address for real server todo: remove port hack
            port = str(int(config.port) - 20)

        addr = socket.gethostbyname(config.address)
        conn_conf = ConnConf(addr, port, timeout=10)

        if config.autostart:
            if addr == "127.0.0.1":
                launcher = LocalServerLauncher(conn_conf, path=config.path)
            else:
                launcher = RemoteSSHServerLauncher(
                    conn_conf, cred=SSHCred(user=config.username, key_path=config.ssh_key), path=config.path
                )
        else:
            launcher = _NullLauncher()

        try:
            launcher.start()
            with grpc.insecure_channel(f"{addr}:{port}") as chan:
                client = inference_pb2_grpc.InferenceStub(chan)
                resp = client.ListDevices(inference_pb2.Empty())
                return [(d.id, d.id) for d in resp.devices]
        except Exception as e:
            logger.exception('Failed to fetch devices')
            raise
        finally:
            try:
                launcher.stop()
            except Exception:
                pass

    except Exception as e:
        logger.error(e)
        raise

    return []
Ejemplo n.º 5
0
    def ensure_connection(self, config):
        if self._connection:
            return self._connection

        _100_MB = 100 * 1024 * 1024
        server_config = config
        addr, port = socket.gethostbyname(
            server_config.address), server_config.port
        conn_conf = ConnConf(addr, port, timeout=20)

        if server_config.autostart:
            if addr == "127.0.0.1":
                self.launcher = LocalServerLauncher(conn_conf,
                                                    path=server_config.path)
            else:
                self.launcher = RemoteSSHServerLauncher(
                    conn_conf,
                    cred=SSHCred(server_config.username,
                                 key_path=server_config.ssh_key),
                    path=server_config.path,
                )
        else:
            self.launcher = _NullLauncher()

        self.launcher.start()

        self._chan = grpc.insecure_channel(
            f"{addr}:{port}",
            options=[("grpc.max_send_message_length", _100_MB),
                     ("grpc.max_receive_message_length", _100_MB)],
        )
        client = inference_pb2_grpc.InferenceStub(self._chan)
        upload_client = data_store_pb2_grpc.DataStoreStub(self._chan)
        self._devices = [d.id for d in server_config.devices if d.enabled]
        self._connection = Connection(client, upload_client)
        return self._connection