def mps_model_handler(self, idx): try: # 1. create backend model process conn_backend, conn_model = mp.Pipe() proc = mp.Process(target=model_process, args=(self.name, self.model_type, self.model_path, self.input_shm_queue, conn_model, self.input_info, self.output_info, idx, self.metric_q)) proc.start() # 2. create shared memory conn_backend.send(self.input_shm_name_set) output_shm_name = conn_backend.recv() output_shm = [] for shn_name, info in zip(output_shm_name, self.output_info): sh = ShmHandler(shn_name, info['max_shape'], info['dtype']) sh.load_shm() output_shm.append(sh) except: logger.error('mps_model_handler initialize error') logger.error(traceback.format_exc()) return def health_check(): while True: sleep(5) tag = {'model_handler_name': '{}_{}'.format(self.name, idx)} if proc.is_alive(): self.emit_metric({'model_handler_health_value': 1}, tag=tag) else: self.emit_metric({'model_handler_health_value': 0}, tag=tag) health_thread = threading.Thread(target=health_check, daemon=True) health_thread.start() # 3. inference while self.alive: start_ts = time() try: shm_idx, shapes, batch_index, batch_q_ts = \ self.batched_tensor_queue.get(timeout=1) except queue.Empty: continue except: logger.error('mps_model_handler error') logger.error(traceback.format_exc()) batch_output = [] try: model_start_ts = time() conn_backend.send((shm_idx, shapes)) shapes = conn_backend.recv() self.emit_metric( {'backend_forward_model_cost': time() - model_start_ts}) for shape, sh in zip(shapes, output_shm): shm_arr = sh.ndarray(shape) output_arr = np.empty(shape, shm_arr.dtype) output_arr[:] = shm_arr[:] batch_output.append(output_arr) fwd_cost = time() - start_ts self.emit_metric({'backend_forward_cost': fwd_cost}) if self.adapt: delta = fwd_cost / (0.5 + self.duplicate_num) - self.timeout if abs(delta) / self.timeout > 0.2: self.io_queue_lock.acquire() self.timeout = self.timeout * 0.8 + (self.timeout + delta) * 0.2 self.io_queue_lock.release() # print('forward cost : {}, timeout : {}'.format( # fwd_cost, self.timeout # )) except: logger.error('mps_model_handler error') logger.error(traceback.format_exc()) self.emit_metric({'mps_model_handler_error_counter': 1}) finally: self.output_tensor_queue.put((batch_output, batch_index)) # 4. clean conn_backend.send(EXIT_SIG) stat = conn_backend.recv() for sh in output_shm: sh.close() conn_backend.send(True) proc.join()
def model_process(model_name, model_type, model_path, shm_queue, conn, input_info, output_info, pid, metric_q): try: # 1. init model if model_type == 'mock': from SimpleDBI.mock_model import MockModel model = MockModel(model_name, model_path) elif model_type == 'torch': from SimpleDBI.torch_model import TorchModel model = TorchModel(model_name, model_path) elif model_type == 'tf': from SimpleDBI.tf_model import TFModel model = TFModel(model_name, model_path, input_info, output_info) elif model_type == 'tensorrt': from SimpleDBI.tensorrt_model import TensorRTModel model = TensorRTModel(model_name, model_path) elif model_type == 'onnx2trt': from SimpleDBI.onnx2trt_model import TensorRTModel model = TensorRTModel(model_name, model_path) else: logger.error('ERROR MODEL TYPE : {}'.format(model_type)) raise RuntimeError('ERROR MODEL TYPE : {}'.format(model_type)) # 2. create shared memoty # 2.1 create output shared memory output_shm_name = [] output_shm = [] for info in output_info: shm_name = gen_name(info['name']) sh = ShmHandler(shm_name, info['max_shape'], info['dtype']) sh.create_shm() output_shm_name.append(shm_name) output_shm.append(sh) # 2.2 load input shared memory input_shm_name_list = conn.recv() input_shm_list = [] for input_shm_name in input_shm_name_list: input_shm = [] for shm_name, info in zip(input_shm_name, input_info): sh = ShmHandler(shm_name, info['max_shape'], info['dtype']) sh.load_shm() input_shm.append(sh) input_shm_list.append(input_shm) conn.send(output_shm_name) except: logger.error('model_process initialize error') logger.error(traceback.format_exc()) return logger.info('model_process <{}> initialize done'.format(model_name)) tags = {'model': '{}_{}'.format(model_name, pid)} # 3. inference while True: value = conn.recv() if value == EXIT_SIG: break shm_idx, input_shapes = value inputs = [] output_shapes = [] try: ts = time() # 3.1 load input input_shm = input_shm_list[shm_idx] for shape, sh in zip(input_shapes, input_shm): shm_arr = sh.ndarray(shape) inputs.append(shm_arr) # 3.2 forward outputs = model.forward(*inputs) # 3.3 write output for output, sh in zip(outputs, output_shm): shape = output.shape shm_arr = sh.ndarray(shape) shm_arr[:] = output[:] output_shapes.append(shape) if metric_q is not None: metric_q.put({ "tags": tags, "fields": { 'model_proc_cost': time() - ts }, }) except: logger.error('model_process runtime error') logger.error(traceback.format_exc()) finally: conn.send(output_shapes) shm_queue.put(shm_idx) # send shared memory to avalible queue # 4. clean try: for input_shm in input_shm_list: for sh in input_shm: sh.close() conn.send(True) stat = conn.recv() assert stat for sh in output_shm: sh.close() conn.close() except: logger.error('model_process destructor error') logger.error(traceback.format_exc()) logger.error('Model process exit.')