Ejemplo n.º 1
0
class PredictThread(object):
    """Predict Worker for async algorithm."""
    def __init__(self, thread_id, alg, request_q, reply_q, stats_deliver,
                 lock):
        self.alg = alg
        self.thread_id = thread_id
        self.request_q = request_q
        self.reply_q = reply_q
        self.lock = lock

        self.stats_deliver = stats_deliver
        self._report_period = 200

        self._stats = PredictStats()

    def predict(self):
        """Predict action."""
        while True:

            start_t0 = time()
            data = self.request_q.recv()
            state = get_msg_data(data)
            self._stats.obs_wait_time += time() - start_t0

            start_t1 = time()
            with self.lock:
                action = self.alg.predict(state)
            self._stats.inference_time += time() - start_t1

            set_msg_info(data, cmd="predict_reply")
            set_msg_data(data, action)

            # logging.debug("msg to explore: ", data)
            self.reply_q.send(data)

            self._stats.iters += 1
            if self._stats.iters > self._report_period:
                _report = self._stats.get()
                self.stats_deliver.send(_report, block=True)
Ejemplo n.º 2
0
class Predictor(object):
    """Predict Worker for async algorithm."""
    def __init__(self, predictor_id, config_info, request_q, reply_q,
                 predictor_name):
        self.config_info = deepcopy(config_info)
        self.predictor_id = predictor_id
        self.request_q = request_q
        self.reply_q = reply_q

        self._report_period = 200

        self._stats = PredictStats()
        self.predictor_name = predictor_name

    def process(self):
        """Predict action."""
        while True:
            start_t0 = time()
            ctr_info, data = self.request_q.recv()
            recv_data = {'ctr_info': ctr_info, 'data': data}
            self._stats.obs_wait_time += time() - start_t0

            cmd = ctr_info.get('sub_cmd', 'predict')
            # if 'predict' in cmd:
            #     cmd = 'predict'
            # else:
            #     print("sync model")

            if cmd in self.process_fn.keys():
                proc_fn = self.process_fn[cmd]
                proc_fn(recv_data)
            else:
                raise KeyError("invalid cmd: {}".format(ctr_info['cmd']))

    def sync_weights(self, recv_data):
        model_weights = recv_data['data']
        self.alg.set_weights(model_weights)

    def predict(self, recv_data):
        start_t1 = time()
        state = get_msg_data(recv_data)
        broker_id = get_msg_info(recv_data, 'broker_id')
        explorer_id = get_msg_info(recv_data, 'explorer_id')
        action = self.alg.predict(state)
        self._stats.inference_time += time() - start_t1

        reply_data = message(action,
                             cmd="predict_reply",
                             broker_id=broker_id,
                             explorer_id=explorer_id)
        self.reply_q.put(reply_data)

        self._stats.iters += 1
        if self._stats.iters > self._report_period:
            _report = self._stats.get()
            reply_data = message(_report,
                                 cmd="stats_msg{}".format(self.predictor_name))
            self.reply_q.put(reply_data)

    def start(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = str(-1)
        alg_para = self.config_info.get('alg_para')
        setproctitle.setproctitle("xt_predictor")

        self.alg = alg_builder(**alg_para)

        self.process_fn = {
            'sync_weights': self.sync_weights,
            'predict': self.predict
        }

        #start msg process
        self.process()