def __init__(self, predictor_id, config_info, request_q, reply_q, predictor_name): self.config_info = deepcopy(config_info) self.predictor_id = predictor_id self.request_q = request_q self.reply_q = reply_q self._report_period = 200 self._stats = PredictStats() self.predictor_name = predictor_name
def __init__(self, thread_id, alg, request_q, reply_q, stats_deliver, lock): self.alg = alg self.thread_id = thread_id self.request_q = request_q self.reply_q = reply_q self.lock = lock self.stats_deliver = stats_deliver self._report_period = 200 self._stats = PredictStats()
class PredictThread(object): """Predict Worker for async algorithm.""" def __init__(self, thread_id, alg, request_q, reply_q, stats_deliver, lock): self.alg = alg self.thread_id = thread_id self.request_q = request_q self.reply_q = reply_q self.lock = lock self.stats_deliver = stats_deliver self._report_period = 200 self._stats = PredictStats() def predict(self): """Predict action.""" while True: start_t0 = time() data = self.request_q.recv() state = get_msg_data(data) self._stats.obs_wait_time += time() - start_t0 start_t1 = time() with self.lock: action = self.alg.predict(state) self._stats.inference_time += time() - start_t1 set_msg_info(data, cmd="predict_reply") set_msg_data(data, action) # logging.debug("msg to explore: ", data) self.reply_q.send(data) self._stats.iters += 1 if self._stats.iters > self._report_period: _report = self._stats.get() self.stats_deliver.send(_report, block=True)
class Predictor(object): """Predict Worker for async algorithm.""" def __init__(self, predictor_id, config_info, request_q, reply_q, predictor_name): self.config_info = deepcopy(config_info) self.predictor_id = predictor_id self.request_q = request_q self.reply_q = reply_q self._report_period = 200 self._stats = PredictStats() self.predictor_name = predictor_name def process(self): """Predict action.""" while True: start_t0 = time() ctr_info, data = self.request_q.recv() recv_data = {'ctr_info': ctr_info, 'data': data} self._stats.obs_wait_time += time() - start_t0 cmd = ctr_info.get('sub_cmd', 'predict') # if 'predict' in cmd: # cmd = 'predict' # else: # print("sync model") if cmd in self.process_fn.keys(): proc_fn = self.process_fn[cmd] proc_fn(recv_data) else: raise KeyError("invalid cmd: {}".format(ctr_info['cmd'])) def sync_weights(self, recv_data): model_weights = recv_data['data'] self.alg.set_weights(model_weights) def predict(self, recv_data): start_t1 = time() state = get_msg_data(recv_data) broker_id = get_msg_info(recv_data, 'broker_id') explorer_id = get_msg_info(recv_data, 'explorer_id') action = self.alg.predict(state) self._stats.inference_time += time() - start_t1 reply_data = message(action, cmd="predict_reply", broker_id=broker_id, explorer_id=explorer_id) self.reply_q.put(reply_data) self._stats.iters += 1 if self._stats.iters > self._report_period: _report = self._stats.get() reply_data = message(_report, cmd="stats_msg{}".format(self.predictor_name)) self.reply_q.put(reply_data) def start(self): os.environ["CUDA_VISIBLE_DEVICES"] = str(-1) alg_para = self.config_info.get('alg_para') setproctitle.setproctitle("xt_predictor") self.alg = alg_builder(**alg_para) self.process_fn = { 'sync_weights': self.sync_weights, 'predict': self.predict } #start msg process self.process()