Esempio n. 1
0
def test_start_remote_server(srv_port):
    host, ssh_port = os.getenv(SSH_HOST_VAR), os.getenv(SSH_PORT_VAR, 22)
    user, pwd = os.getenv(SSH_USER_VAR), os.getenv(SSH_PWD_VAR)
    key = os.getenv(SSH_KEY_VAR)

    if not all([host, ssh_port, user, key or pwd]):
        pytest.skip(
            "To run this test specify "
            f"{SSH_HOST_VAR}, {SSH_USER_VAR} {SSH_PWD_VAR} or {SSH_KEY_VAR} and optionaly {SSH_PORT_VAR}"
        )

    conn_conf = ConnConf(socket.gethostbyname(host), srv_port, timeout=20)
    cred = SSHCred(user=user, password=pwd, key_path=key)
    launcher = RemoteSSHServerLauncher(conn_conf, cred=cred)

    client = client_factory(conn_conf)
    try:
        launcher.start()

        assert launcher.is_server_running()

        assert client.ping()
    finally:
        launcher.stop()

    assert not launcher.is_server_running()
Esempio n. 2
0
class TikTorchLazyflowClassifierFactory:
    # The version is used to determine compatibility of pickled classifier factories.
    # You must bump this if any instance members are added/removed/renamed.
    VERSION = 1

    def create_model_session(self, model_str: bytes, devices: List[str]):
        session = self._tikTorchClient.CreateModelSession(
            inference_pb2.CreateModelSessionRequest(
                model_blob=inference_pb2.Blob(content=model_str),
                deviceIds=devices))
        return ModelSession(session, self)

    def __init__(self, server_config) -> None:
        _100_MB = 100 * 1024 * 1024
        self._tikTorchClassifier = None
        self._train_model = None
        self._shutdown_sent = False

        addr, port = socket.gethostbyname(
            server_config.address), server_config.port
        conn_conf = ConnConf(addr, port, timeout=20)

        if server_config.autostart:
            if addr == "127.0.0.1":
                self.launcher = LocalServerLauncher(conn_conf,
                                                    path=server_config.path)
            else:
                self.launcher = RemoteSSHServerLauncher(
                    conn_conf,
                    cred=SSHCred(server_config.username,
                                 key_path=server_config.ssh_key),
                    path=server_config.path,
                )
        else:
            self.launcher = _NullLauncher()

        self.launcher.start()

        self._chan = grpc.insecure_channel(
            f"{addr}:{port}",
            options=[("grpc.max_send_message_length", _100_MB),
                     ("grpc.max_receive_message_length", _100_MB)],
        )
        self._tikTorchClient = inference_pb2_grpc.InferenceStub(self._chan)
        self._devices = [d.id for d in server_config.devices if d.enabled]

    def shutdown(self):
        self._shutdown_sent = True
        self.launcher.stop()

    @property
    def tikTorchClient(self):
        return self._tikTorchClient

    @property
    def description(self):
        if self.tikTorchClient:
            return "TikTorch classifier (client available)"
        else:
            return "TikTorch classifier (client missing)"

    def __eq__(self, other):
        return isinstance(other, type(self))

    def __ne__(self, other):
        return not self.__eq__(other)

    def __del__(self):
        if not self._shutdown_sent:
            try:
                self.launcher.stop()
            except AttributeError:
                pass
Esempio n. 3
0
class TiktorchConnectionFactory(_base.IConnectionFactory):
    def ensure_connection(self, config):
        if self._connection:
            return self._connection

        _100_MB = 100 * 1024 * 1024
        server_config = config
        addr, port = socket.gethostbyname(
            server_config.address), server_config.port
        conn_conf = ConnConf(addr, port, timeout=20)

        if server_config.autostart:
            if addr == "127.0.0.1":
                self.launcher = LocalServerLauncher(conn_conf,
                                                    path=server_config.path)
            else:
                self.launcher = RemoteSSHServerLauncher(
                    conn_conf,
                    cred=SSHCred(server_config.username,
                                 key_path=server_config.ssh_key),
                    path=server_config.path,
                )
        else:
            self.launcher = _NullLauncher()

        self.launcher.start()

        self._chan = grpc.insecure_channel(
            f"{addr}:{port}",
            options=[("grpc.max_send_message_length", _100_MB),
                     ("grpc.max_receive_message_length", _100_MB)],
        )
        client = inference_pb2_grpc.InferenceStub(self._chan)
        upload_client = data_store_pb2_grpc.DataStoreStub(self._chan)
        self._devices = [d.id for d in server_config.devices if d.enabled]
        self._connection = Connection(client, upload_client)
        return self._connection

    def __init__(self) -> None:
        self._tikTorchClassifier = None
        self._train_model = None
        self._shutdown_sent = False
        self._connection = None

    def shutdown(self):
        self._shutdown_sent = True
        self.launcher.stop()

    @property
    def tikTorchClient(self):
        return self._tikTorchClient

    @property
    def description(self):
        if self.tikTorchClient:
            return "TikTorch classifier (client available)"
        else:
            return "TikTorch classifier (client missing)"

    def __eq__(self, other):
        return isinstance(other, type(self))

    def __ne__(self, other):
        return not self.__eq__(other)

    def __del__(self):
        if not self._shutdown_sent:
            try:
                self.launcher.stop()
            except AttributeError:
                pass