def __init__(self, server_config) -> None: _100_MB = 100 * 1024 * 1024 self._tikTorchClassifier = None self._train_model = None self._shutdown_sent = False addr, port = socket.gethostbyname( server_config.address), server_config.port conn_conf = ConnConf(addr, port, timeout=20) if server_config.autostart: if addr == "127.0.0.1": self.launcher = LocalServerLauncher(conn_conf, path=server_config.path) else: self.launcher = RemoteSSHServerLauncher( conn_conf, cred=SSHCred(server_config.username, key_path=server_config.ssh_key), path=server_config.path, ) else: self.launcher = _NullLauncher() self.launcher.start() self._chan = grpc.insecure_channel( f"{addr}:{port}", options=[("grpc.max_send_message_length", _100_MB), ("grpc.max_receive_message_length", _100_MB)], ) self._tikTorchClient = inference_pb2_grpc.InferenceStub(self._chan) self._devices = [d.id for d in server_config.devices if d.enabled]
def test_start_remote_server(srv_port): host, ssh_port = os.getenv(SSH_HOST_VAR), os.getenv(SSH_PORT_VAR, 22) user, pwd = os.getenv(SSH_USER_VAR), os.getenv(SSH_PWD_VAR) key = os.getenv(SSH_KEY_VAR) if not all([host, ssh_port, user, key or pwd]): pytest.skip( "To run this test specify " f"{SSH_HOST_VAR}, {SSH_USER_VAR} {SSH_PWD_VAR} or {SSH_KEY_VAR} and optionaly {SSH_PORT_VAR}" ) conn_conf = ConnConf(socket.gethostbyname(host), srv_port, timeout=20) cred = SSHCred(user=user, password=pwd, key_path=key) launcher = RemoteSSHServerLauncher(conn_conf, cred=cred) client = client_factory(conn_conf) try: launcher.start() assert launcher.is_server_running() assert client.ping() finally: launcher.stop() assert not launcher.is_server_running()
def test_start_local_server(srv_port): conn_conf = ConnConf("127.0.0.1", srv_port, timeout=5) launcher = LocalServerLauncher(conn_conf) launcher.start() assert launcher.is_server_running() client = client_factory(conn_conf) assert client.ping() launcher.stop()
def _fetch_devices(config: types.ServerConfig): try: port = config.port if config.autostart: # in order not to block address for real server todo: remove port hack port = str(int(config.port) - 20) addr = socket.gethostbyname(config.address) conn_conf = ConnConf(addr, port, timeout=10) if config.autostart: if addr == "127.0.0.1": launcher = LocalServerLauncher(conn_conf, path=config.path) else: launcher = RemoteSSHServerLauncher( conn_conf, cred=SSHCred(user=config.username, key_path=config.ssh_key), path=config.path ) else: launcher = _NullLauncher() try: launcher.start() with grpc.insecure_channel(f"{addr}:{port}") as chan: client = inference_pb2_grpc.InferenceStub(chan) resp = client.ListDevices(inference_pb2.Empty()) return [(d.id, d.id) for d in resp.devices] except Exception as e: logger.exception('Failed to fetch devices') raise finally: try: launcher.stop() except Exception: pass except Exception as e: logger.error(e) raise return []
def ensure_connection(self, config): if self._connection: return self._connection _100_MB = 100 * 1024 * 1024 server_config = config addr, port = socket.gethostbyname( server_config.address), server_config.port conn_conf = ConnConf(addr, port, timeout=20) if server_config.autostart: if addr == "127.0.0.1": self.launcher = LocalServerLauncher(conn_conf, path=server_config.path) else: self.launcher = RemoteSSHServerLauncher( conn_conf, cred=SSHCred(server_config.username, key_path=server_config.ssh_key), path=server_config.path, ) else: self.launcher = _NullLauncher() self.launcher.start() self._chan = grpc.insecure_channel( f"{addr}:{port}", options=[("grpc.max_send_message_length", _100_MB), ("grpc.max_receive_message_length", _100_MB)], ) client = inference_pb2_grpc.InferenceStub(self._chan) upload_client = data_store_pb2_grpc.DataStoreStub(self._chan) self._devices = [d.id for d in server_config.devices if d.enabled] self._connection = Connection(client, upload_client) return self._connection