def __init__(self, addr: str = 'localhost', port: int = 1708, bindaddr: str = 'localhost', bindport: int = 0, worker_n: int = 28, logger: WorkerLogger = None): super().__init__() if logger is None: self.logger = PrintLogger() else: self.logger = logger sock = socket.socket() sock.connect((addr, port)) b = sock.recv(1) if b != b'\x00': sock = ssl.wrap_socket(sock) self.socket = FormatSocket(sock) self.pool = WorkerPoolServer(bindaddr, bindport, logger=logger) self.pool.start() workerinfo = HostInformation() workerinfo.worker_info.n_qubits = worker_n workerinfo.worker_info.address = self.pool.addr workerinfo.worker_info.port = self.pool.port self.socket.send(workerinfo.SerializeToString()) self.serverapi = SocketServerBackend(self.socket) self.addr = addr self.port = port self.running = True
def make_connection(self, job_id: str, myinputstart: int, myinputend: int, myoutputstart: int, myoutputend: int, partner: WorkerPartner): self.logger("Connecting to: {}".format(str(partner))) wp = WorkerPartner() wp.job_id = job_id wp.state_index_start = myinputstart wp.state_index_end = myinputend wp.output_index_start = myoutputstart wp.output_index_end = myoutputend sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((partner.addr, partner.port)) sock = FormatSocket(sock) sock.send(wp.SerializeToString()) fullkey = (job_id, partner.state_index_start, partner.state_index_end, partner.output_index_start, partner.output_index_end) inputkey = (job_id, partner.state_index_start, partner.state_index_end) outputkey = (job_id, partner.output_index_start, partner.output_index_end) with self.workerlock: self.workers[fullkey] = sock if inputkey not in self.inputrange_workers: self.inputrange_workers[inputkey] = [] self.inputrange_workers[inputkey].append( (sock, partner.output_index_start, partner.output_index_end)) if outputkey not in self.outputrange_workers: self.outputrange_workers[outputkey] = [] self.outputrange_workers[outputkey].append( (sock, partner.state_index_start, partner.state_index_end))
def run(self): self.logger.starting_server() self.sock.listen(5) self.logger.accepting_connections() while True: sock, _ = self.sock.accept() sock = FormatSocket(sock) workersetup = WorkerPartner.FromString(sock.recv()) self.logger.accepted_connection() fullkey = (workersetup.job_id, workersetup.state_index_start, workersetup.state_index_end, workersetup.output_index_start, workersetup.output_index_end) inputkey = (workersetup.job_id, workersetup.state_index_start, workersetup.state_index_end) outputkey = (workersetup.job_id, workersetup.output_index_start, workersetup.output_index_end) with self.workerlock: if inputkey not in self.inputrange_workers: self.inputrange_workers[inputkey] = [] if outputkey not in self.outputrange_workers: self.outputrange_workers[outputkey] = [] self.workers[fullkey] = sock self.inputrange_workers[inputkey].append( (sock, workersetup.output_index_start, workersetup.output_index_end)) self.outputrange_workers[outputkey].append( (sock, workersetup.state_index_start, workersetup.state_index_end))
class WorkerRunner(Thread): def __init__(self, addr: str = 'localhost', port: int = 1708, bindaddr: str = 'localhost', bindport: int = 0, worker_n: int = 28, logger: WorkerLogger = None): super().__init__() if logger is None: self.logger = PrintLogger() else: self.logger = logger sock = socket.socket() sock.connect((addr, port)) b = sock.recv(1) if b != b'\x00': sock = ssl.wrap_socket(sock) self.socket = FormatSocket(sock) self.pool = WorkerPoolServer(bindaddr, bindport, logger=logger) self.pool.start() workerinfo = HostInformation() workerinfo.worker_info.n_qubits = worker_n workerinfo.worker_info.address = self.pool.addr workerinfo.worker_info.port = self.pool.port self.socket.send(workerinfo.SerializeToString()) self.serverapi = SocketServerBackend(self.socket) self.addr = addr self.port = port self.running = True def run(self): while self.running: self.logger.waiting_for_setup() cmd = WorkerCommand.FromString(self.socket.recv()) if cmd.HasField('setup'): setup = cmd.setup self.logger.accepted_setup(setup) worker = WorkerInstance(self.serverapi, self.pool, setup, logger=self.logger) worker.run() elif cmd.HasField('close'): self.running = False
def __init__(self, addr: str = 'localhost', port: int = 6060): super().__init__() self.addr = addr self.port = port self.worker_id = str(uuid.uuid4()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((self.addr, self.port)) self.sock = FormatSocket(sock) host_info = LoggerHostInfo() host_info.worker.worker_id = self.worker_id self.sock.send(host_info.SerializeToString()) self.proto_arena = WorkerLog()
def __init__(self, n, server_host: str = 'localhost', server_port: int = 1708): super().__init__(n, None) self.control_server_addr = (server_host, server_port) sock = socket.socket() sock.connect(self.control_server_addr) # First byte is whether to use ssl or not, all remaining communication is proto based. b = sock.recv(1) if b != b'\x00': sock = ssl.wrap_socket(sock) self.socket = FormatSocket(sock) # Introduce yourself. host_info = HostInformation() host_info.client_info.name = 'backend' self.socket.send(host_info.SerializeToString())
class MonitorWorkerLogger(WorkerLogger): def __init__(self, addr: str = 'localhost', port: int = 6060): super().__init__() self.addr = addr self.port = port self.worker_id = str(uuid.uuid4()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((self.addr, self.port)) self.sock = FormatSocket(sock) host_info = LoggerHostInfo() host_info.worker.worker_id = self.worker_id self.sock.send(host_info.SerializeToString()) self.proto_arena = WorkerLog() def send(self): self.proto_arena.worker_id = self.worker_id self.sock.send(self.proto_arena.SerializeToString()) self.proto_arena.Clear() def log_string(self, s, **kwargs): self.proto_arena.string_log = s self.send() def log_error(self, s, **kwargs): self.proto_arena.string_error = s self.send() def starting_server(self): self.log_string("Starting server") def accepting_connections(self): self.log_string("Accepting connections.") def accepted_connection(self): self.log_string("Accepted connection") def waiting_for_setup(self): self.log_string("Waiting for setup.") def accepted_setup(self, setup: WorkerSetup): self.log_string("Setup: {}".format(setup)) def making_state(self, handle: str, input_start: int, input_end: int, output_start: int, output_end: int): self.proto_arena.set_job.job_id = handle self.proto_arena.set_job.input_start = input_start self.proto_arena.set_job.input_end = input_end self.proto_arena.set_job.output_start = output_start self.proto_arena.set_job.output_end = output_end self.send() def closing_state(self, handle: str): self.proto_arena.clear_job = handle self.send() def waiting_for_operation(self, handle: str): self.log_string("Waiting for operation for job {}".format(handle)) def running_operation(self, handle: str, op: WorkerOperation): self.proto_arena.running_op.handle = handle self.proto_arena.running_op.op = op.SerializeToString() self.send() def done_running_operation(self, handle: str, op: WorkerOperation): self.proto_arena.done_with_op.handle = handle self.send() def sending_state(self, handle: str): self.proto_arena.sending_state = handle self.send() def receiving_state(self, handle: str): self.proto_arena.receiving_state = handle self.send()
class MonitorServerLogger(ServerLogger): def __init__(self, addr: str = 'localhost', port: int = 6060): super().__init__() self.addr = addr self.port = port self.server_id = str(uuid.uuid4()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((self.addr, self.port)) self.sock = FormatSocket(sock) host_info = LoggerHostInfo() host_info.manager.manager_id = self.server_id self.sock.send(host_info.SerializeToString()) self.proto_arena = ManagerLog() def send(self): self.proto_arena.manager_id = self.server_id self.sock.send(self.proto_arena.SerializeToString()) self.proto_arena.Clear() def log_string(self, s, **kwargs): self.proto_arena.string_log = s self.send() def log_error(self, s, **kwargs): self.proto_arena.string_error = s self.send() def starting_server(self): self.log_string("Starting server") def accepting_connections(self): self.log_string("Accepting connections.") def accepted_connection(self, ssl: bool = False): self.log_string( "Accepted connection (SSL: {})".format("ON" if ssl else "OFF")) def received_worker(self, host_info: HostInformation): self.log_string("Received worker: {}".format(host_info)) def received_client(self, host_info: HostInformation): self.log_string("Received client: {}".format(host_info)) def waiting_for_setup(self): self.log_string("Waiting for setup.") def making_state(self, handle: str, n: int): self.proto_arena.set_job.job_id = handle self.proto_arena.set_job.n = n self.send() def closing_state(self, handle: str): self.proto_arena.clear_job = handle self.send() def waiting_for_operation(self, handle: str): self.log_string("Waiting for operation for job {}".format(handle)) def running_operation(self, handle: str, op: WorkerOperation): self.proto_arena.running_op.handle = handle self.proto_arena.running_op.op = op.SerializeToString() self.send() def done_running_operation(self, handle: str, op: WorkerOperation): self.proto_arena.done_with_op.handle = handle self.send() def allocating_workers(self, handle: str, n: int): self.log_string("Allocating {} worker(s) for {}.".format(n, handle)) def returning_workers(self, handle: str, n: int): self.log_string("Returning {} worker(s) from {} to pool.".format( n, handle))
class DistributedBackend(StateType): def __init__(self, n, server_host: str = 'localhost', server_port: int = 1708): super().__init__(n, None) self.control_server_addr = (server_host, server_port) sock = socket.socket() sock.connect(self.control_server_addr) # First byte is whether to use ssl or not, all remaining communication is proto based. b = sock.recv(1) if b != b'\x00': sock = ssl.wrap_socket(sock) self.socket = FormatSocket(sock) # Introduce yourself. host_info = HostInformation() host_info.client_info.name = 'backend' self.socket.send(host_info.SerializeToString()) @staticmethod def make_state(n: int, index_groups: Sequence[Sequence[int]], feed_list: Sequence[InitialState], statetype: type = numpy.complex128) -> StateType: distbackend = DistributedBackend(n) setup_message = StateSetup() setup_message.n = n for index_group, initial_state in zip(index_groups, feed_list): pb_state = setup_message.states.add() pb_state.indices.CopyFrom(indices_to_pbindices(index_group)) if type(initial_state) == int: pb_state.index = initial_state else: pb_state.vector.CopyFrom(vec_to_pbvec(initial_state)) distbackend.socket.send(setup_message.SerializeToString()) resp = StateHandle.FromString(distbackend.socket.recv()) if resp.HasField('error_message'): raise Exception(resp.error_message) else: distbackend.state = resp.state_handle return distbackend def kronselect_dot(self, mats: Mapping[IndexType, MatrixType], input_offset: int = 0, output_offset: int = 0) -> None: workerop = WorkerOperation() for indices in mats: mat = mats[indices] matop_to_pbmatop(indices, mat, workerop.kronprod.matrices.add()) workerop.job_id = self.state self.socket.send(workerop.SerializeToString()) conf = WorkerConfirm.FromString(self.socket.recv()) if conf.HasField('error_message'): self.close() raise Exception(conf.error_message) elif conf.job_id != self.state: self.close() raise Exception("Server miscommunication: {} != {}".format( self.state, conf.job_id)) def func_apply(self, reg1_indices: IndexType, reg2_indices: IndexType, func: Callable[[int], int], input_offset: int = 0, output_offset: int = 0) -> None: raise NotImplemented( "Function application not yet supported in distributed backends.") def measure(self, indices: Sequence[int], measured: Optional[int] = None, measured_prob: Optional[float] = None, input_offset: int = 0, output_offset: int = 0): workerop = WorkerOperation() workerop.job_id = self.state workerop.measure.reduce = False indices_to_pbindices(indices, workerop.measure.indices) self.socket.send(workerop.SerializeToString()) conf = WorkerConfirm.FromString(self.socket.recv()) if conf.HasField('error_message'): raise Exception(conf.error_message) else: return conf.measure_result.measured_bits, conf.measure_result.measured_prob def reduce_measure(self, indices: IndexType, measured: Optional[int] = None, measured_prob: Optional[float] = None, input_offset: int = 0, output_offset: int = 0) -> Tuple[int, float]: workerop = WorkerOperation() workerop.job_id = self.state workerop.measure.reduce = True indices_to_pbindices(indices, workerop.measure.indices) self.socket.send(workerop.SerializeToString()) conf = WorkerConfirm.FromString(self.socket.recv()) if conf.HasField('error_message'): raise Exception(conf.error_message) else: return conf.measure_result.measured_bits, conf.measure_result.measured_prob def soft_measure(self, indices: Sequence[int], measured: Optional[int] = None, input_offset: int = 0): raise NotImplemented( "Soft measurement not yet supported in distributed backends.") def measure_probabilities( self, indices: IndexType, top_k: int = 0) -> Tuple[Sequence[int], Sequence[float]]: if not top_k: top_k = pow(2, len(indices)) workerop = WorkerOperation() workerop.job_id = self.state workerop.measure.top_k = top_k indices_to_pbindices(indices, workerop.measure.indices) self.socket.send(workerop.SerializeToString()) conf = WorkerConfirm.FromString(self.socket.recv()) if conf.HasField('error_message'): raise Exception(conf.error_message) else: return list(conf.measure_result.top_k_indices.index), list( conf.measure_result.top_k_probs) def close(self): close_op = WorkerOperation() close_op.close = True self.socket.send(close_op.SerializeToString()) self.socket.close()