Ejemplo n.º 1
0
    def recv_local(self):
        """Receive local cmd."""
        while True:
            recv_data = self.recv_local_q.recv()
            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)

            if cmd in [self.send_local_q.keys()]:
                self.send_local_q[cmd].send(recv_data)
                logging.debug("recv: {} with cmd-{}".format(type(recv_data["data"]), cmd))
            else:
                _t1 = time.time()
                broker_id = get_msg_info(recv_data, "broker_id")
                _cmd = get_msg_info(recv_data, "cmd")
                logging.debug("master recv:{} with cmd:'{}' to broker_id: <{}>".format(
                    type(recv_data["data"]), _cmd, broker_id))
                # self.metric.append(debug=time.time() - _t1)

                if broker_id == -1:
                    for slave, node_info in zip(self.send_slave, self.node_config_list):
                        slave.send(recv_data)
                else:
                    self.send_slave[broker_id].send(recv_data)
                self.metric.append(send=time.time() - _t1)
Ejemplo n.º 2
0
    def record_reward(self, train_data):
        """Record reward in train."""
        broker_id = get_msg_info(train_data, 'broker_id')
        explorer_id = get_msg_info(train_data, 'explorer_id')
        agent_id = get_msg_info(train_data, 'agent_id')
        key = (broker_id, explorer_id, agent_id)

        self.alg.dist_model_policy.add_processed_ctr_info(key)
        data_dict = get_msg_data(train_data)

        # update multi agent train reward without done flag
        if self.alg.alg_name in ("ppo_share_weights", ):
            self.actual_step += len(data_dict["done"])
            self.logger.record(
                step=self.actual_step,
                train_reward=np.sum(data_dict["reward"]),
                train_count=self.train_count,
            )
            return
        elif self.alg.alg_name in ("QMixAlg", ):  # fixme: unify the record op
            self.actual_step += np.sum(data_dict["filled"])
            self.won_in_episodes.append(data_dict.pop("battle_won"))
            self.logger.update(
                explore_won_rate=np.nanmean(self.won_in_episodes))

            self.logger.record(
                step=self.actual_step,
                train_reward=np.sum(data_dict["reward"]),
                train_count=self.train_count,
            )
            return

        if key not in self.actor_reward.keys():
            self.actor_reward[key] = 0.0
            self.actor_trajectory[key] = 0

        data_length = len(data_dict["done"])  # fetch the train data length
        for data_index in range(data_length):
            reward = data_dict["reward"][data_index]
            done = data_dict["done"][data_index]
            info = data_dict["info"][data_index]
            self.actual_step += 1
            if isinstance(info, dict):
                self.actor_reward[key] += info.get("eval_reward", reward)
                self.actor_trajectory[key] += 1
                done = info.get("real_done", done)

            if done:
                self.logger.record(
                    step=self.actual_step,
                    train_count=self.train_count,
                    train_reward=self.actor_reward[key],
                )

                logging.debug("{} epi reward-{}. with len-{}".format(
                    key, self.actor_reward[key], self.actor_trajectory[key]))

                self.actor_reward[key] = 0.0
                self.actor_trajectory[key] = 0
Ejemplo n.º 3
0
    def recv_local(self):
        """Receive local cmd."""
        # split the case between single receive queue and pbt
        if len(self.recv_local_q) == 1:
            single_stub, = self.recv_local_q.values()
        else:
            single_stub = None

        kwargs = init_main_broker_debug_kwargs()
        yield_func = self._yield_local_msg()

        while True:
            if single_stub:
                recv_data = single_stub.recv(block=True)
            else:
                recv_data = next(yield_func)

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)
                break

            if cmd in self.send_local_q.keys():
                # print(self.send_local_q.keys())
                self.send_local_q[cmd].send(recv_data)
                # logging.debug("recv: {} with cmd-{}".format(
                #     recv_data["data"], cmd))
            else:
                _t1 = time.time()
                broker_id = get_msg_info(recv_data, "broker_id")
                _cmd = get_msg_info(recv_data, "cmd")
                # logging.debug("ctr recv:{} with cmd:'{}' to broker_id: <{}>".format(
                #     type(recv_data["data"]), _cmd, broker_id))
                # self.metric.append(debug=time.time() - _t1)

                if broker_id == -1:
                    for broker_item, node_info in zip(self.send_broker,
                                                      self.node_config_list):
                        broker_item.send(recv_data['ctr_info'],
                                         recv_data['data'])
                else:
                    self.send_broker[broker_id].send(recv_data['ctr_info'],
                                                     recv_data['data'])
                self.metric.append(send=time.time() - _t1)
                debug_within_interval(**kwargs)
Ejemplo n.º 4
0
    def predict(self, recv_data):
        start_t1 = time()
        state = get_msg_data(recv_data)
        broker_id = get_msg_info(recv_data, 'broker_id')
        explorer_id = get_msg_info(recv_data, 'explorer_id')
        action = self.alg.predict(state)
        self._stats.inference_time += time() - start_t1

        reply_data = message(action,
                             cmd="predict_reply",
                             broker_id=broker_id,
                             explorer_id=explorer_id)
        self.reply_q.put(reply_data)

        self._stats.iters += 1
        if self._stats.iters > self._report_period:
            _report = self._stats.get()
            reply_data = message(_report,
                                 cmd="stats_msg{}".format(self.predictor_name))
            self.reply_q.put(reply_data)
Ejemplo n.º 5
0
    def transfer_to_agent(self):
        """Send train data to learner."""
        while True:
            data = self.recv_broker.get()
            cmd = get_msg_info(data, "cmd")
            if cmd == "close":
                logging.debug("enter explore close")
                self.close()
                continue

            data = get_msg_data(data)
            self.send_agent.send(data)
Ejemplo n.º 6
0
    def transfer_to_broker(self):
        """Send train data to learner."""
        while True:
            data = self.recv_agent.recv()
            info_cmd = get_msg_info(data, "cmd")

            new_cmd = info_cmd + self.learner_postfix
            set_msg_info(data,
                         broker_id=self.broker_id,
                         explorer_id=self.explorer_id,
                         cmd=new_cmd)

            self.send_broker.send(data)
Ejemplo n.º 7
0
    def transfer_to_broker(self):
        """Send train data to learner."""
        while True:
            data = self.recv_agent.recv()

            info_cmd = get_msg_info(data, "cmd")
            # print("info_cmd in explorer: ", info_cmd, data)
            data_type = "buf_reduce" if info_cmd == "buf_reduce" else "data"
            set_msg_info(data,
                         broker_id=self.broker_id,
                         explorer_id=self.explorer_id)

            self.send_broker.send(data, data_type=data_type)
Ejemplo n.º 8
0
    def recv_broker_slave(self):
        """Receive remote train data in sync mode."""
        while True:
            recv_data = self.recv_slave.recv_bytes()
            _t0 = time.time()
            recv_data = deserialize(lz4.frame.decompress(recv_data))
            self.metric.append(recv=time.time() - _t0)

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in []:
                pass
            else:
                send_cmd = self.send_local_q.get(cmd)
                if send_cmd:
                    send_cmd.send(recv_data)

            # report log
            self.metric.report_if_need()
Ejemplo n.º 9
0
    def start(self):
        """Run evaluator."""
        setproctitle.setproctitle("xt_evaluator")

        _ags = AgentGroup(self.env_para,
                          self.alg_para,
                          self.agent_para,
                          scene="evaluate")
        while True:
            recv_data = self.recv_broker.get()
            cmd = get_msg_info(recv_data, "cmd")
            logging.debug("evaluator get meg: {}".format(type(recv_data)))
            if cmd in ("close", ):
                break

            if cmd not in ["eval"]:
                print_immediately("eval get un-used data:{}".format(recv_data))
                continue

            # print_immediately("recv_data in evaluator: {}".format(
            #     [v.keys() for v in recv_data["data"].values()]))

            for train_count, weights in recv_data["data"].items():
                _ags.restore(weights, is_id=False)
                eval_data = _ags.evaluate(
                    self.bm_eval.get("episodes_per_eval", 1))

                # return each rewards for each agent
                record_item = tuple([
                    eval_data, {
                        "train_count": train_count,
                        "broker_id": self.broker_id,
                        "test_id": self.test_id
                    }
                ])
                print_immediately(
                    "collect eval results: {}".format(record_item))
                record_item = message(
                    record_item,
                    cmd="eval_return",
                    broker_id=self.broker_id,
                    test_id=self.test_id,
                )
                self.send_broker.send(record_item)
Ejemplo n.º 10
0
    def recv_broker_task(self):
        """Receive remote train data in sync mode."""
        while True:
            ctr_info, recv_data = self.recv_broker.recv_bytes()
            _t0 = time.time()
            ctr_info = deserialize(ctr_info)
            compress_flag = ctr_info.get('compress_flag', False)
            if compress_flag:
                recv_data = lz4.frame.decompress(recv_data)
            recv_data = deserialize(recv_data)
            recv_data = {'ctr_info': ctr_info, 'data': recv_data}
            self.metric.append(recv=time.time() - _t0)

            cmd = get_msg_info(recv_data, "cmd")
            send_cmd = self.send_local_q.get(cmd)
            if send_cmd:
                send_cmd.send(recv_data)
            else:
                logging.warning("invalid cmd: {}, with date: {}".format(
                    cmd, recv_data))

            # report log
            self.metric.report_if_need(field_sets=("send", "recv"))
Ejemplo n.º 11
0
    def recv_controller_task(self):
        """Recv remote train data in sync mode."""
        while True:
            # recv, data will deserialize with pyarrow default
            # recv_data = self.recv_controller_q.recv()
            ctr_info, data = self.recv_controller_q.recv_bytes()
            recv_data = {
                'ctr_info': deserialize(ctr_info),
                'data': deserialize(data)
            }

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)

            if cmd in ["create_explorer"]:
                config_set = recv_data["data"]
                # setup plasma only one times!
                self._setup_share_qs_firstly(config_set)

                config_set.update({"share_path": self._buf.get_path()})
                self.create_explorer(config_set)

                # update the buffer live attribute, explorer num as default.
                # self._buf.plus_one_live()
                continue
            if cmd in ["create_evaluator"]:
                # evaluator share single plasma.
                config_set = recv_data["data"]

                if not self.explorer_share_qs["EVAL0"]:
                    use_pbt, _, _, _ = get_pbt_set(config_set)
                    if use_pbt:
                        plasma_size = config_set.get("plasma_size", 100000000)
                        plasma_path = "/tmp/plasma{}EVAL0".format(os.getpid())
                        self.explorer_share_qs["EVAL0"] = UniComm(
                            "ShareByPlasma",
                            size=plasma_size,
                            path=plasma_path)

                config_set.update({"share_path": self._buf.get_path()})
                logging.debug(
                    "create evaluator with config:{}".format(config_set))

                self.create_evaluator(config_set)
                # self._buf.plus_one_live()
                continue

            if cmd in ["increase", "decrease"]:
                self.alloc(cmd)
                continue

            if cmd in ("eval", ):  # fixme: merge into explore
                test_id = get_msg_info(recv_data, "test_id")
                self.send_explorer_q[test_id].put(recv_data)
                continue

            # last job, distribute weights/model_name from controller
            explorer_id = get_msg_info(recv_data, "explorer_id")
            if not isinstance(explorer_id, list):
                explorer_id = [explorer_id]

            _t0 = time.time()

            if cmd in ("explore", ):  # todo: could mv to first priority.
                # here, only handle explore weights
                buf_id = self._buf.put(data)
                # replace weight with id
                recv_data.update({"data": buf_id})
            # predict_reply
            # e.g, {'ctr_info': {'broker_id': 0, 'explorer_id': 4, 'agent_id': -1,
            # 'cmd': 'predict_reply'}, 'data': 0}

            for _eid in explorer_id:
                if _eid > -1:
                    self.send_explorer_q[_eid].put(recv_data)
                elif _eid > -2:  # -1 # whole explorer, contains evaluator!
                    for qid, send_q in self.send_explorer_q.items():
                        if isinstance(qid, str) and "test" in qid:
                            # logging.info("continue test: ", qid, send_q)
                            continue

                        send_q.put(recv_data)
                else:
                    raise KeyError("invalid explore id: {}".format(_eid))

            self._metric.append(send=time.time() - _t0)
            self._metric.report_if_need()
Ejemplo n.º 12
0
    def recv_master(self):
        """Recv remote train data in sync mode."""
        while True:
            # recv, data will deserialize with pyarrow default
            # recv_data = self.recv_master_q.recv()
            recv_bytes = self.recv_master_q.recv_bytes()
            recv_data = deserialize(recv_bytes)

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)

            if cmd in ["create_explorer"]:
                config_set = recv_data["data"]
                config_set.update({"share_path": self._buf.get_path()})
                self.create_explorer(config_set)

                # update the buffer live attribute, explorer num as default.
                self._buf.plus_one_live()
                continue
            if cmd in ["create_evaluator"]:
                config_set = recv_data["data"]
                config_set.update({"share_path": self._buf.get_path()})
                logging.debug("create evaluator with config:{}".format(config_set))

                self.create_evaluator(config_set)
                self._buf.plus_one_live()
                continue

            if cmd in ["increase", "decrease"]:
                self.alloc(cmd)
                continue

            if cmd in ("eval",):  # fixme: merge into explore
                test_id = get_msg_info(recv_data, "test_id")
                self.send_explorer_q[test_id].put(recv_data)
                continue

            # last job, distribute weights/model_name from broker master
            explorer_id = get_msg_info(recv_data, "explorer_id")
            if not isinstance(explorer_id, list):
                explorer_id = [explorer_id]

            _t0 = time.time()
            if "explore" in cmd:
                # here, only handle explore weights
                buf_id = self._buf.put(recv_bytes)
                # replace weight with id
                recv_data.update({"data": buf_id})

            # predict_reply

            for _eid in explorer_id:
                if _eid > -1:
                    self.send_explorer_q[_eid].put(recv_data)
                elif _eid > -2:  # -1 # whole explorer, contains evaluator!
                    for qid, send_q in self.send_explorer_q.items():
                        if isinstance(qid, str) and "test" in qid:
                            # logging.info("continue test: ", qid, send_q)
                            continue

                        send_q.put(recv_data)
                else:
                    raise KeyError("invalid explore id: {}".format(_eid))

            self._metric.append(send=time.time() - _t0)
            self._metric.report_if_need()