def start_explore(self): """Start explore process.""" signal.signal(signal.SIGINT, signal.SIG_IGN) os.environ["CUDA_VISIBLE_DEVICES"] = str(-1) explored_times = 0 report_stats_interval = 20 last_report_index = -999 try: self.rl_agent = AgentGroup(self.env_para, self.alg_para, self.agent_para, self.send_agent, self.recv_agent, self._buf) explore_time = self.agent_para.get("agent_config", {}).get("sync_model_interval", 1) logging.info( "AgentGroup start to explore with sync interval-{}".format( explore_time)) while True: stats = self.rl_agent.explore(explore_time) explored_times += explore_time if self.explorer_id < 1: logging.debug("explore-{} ran {} times".format( self.explorer_id, explored_times)) if explored_times - last_report_index > report_stats_interval: stats_msg = message(stats, cmd="stats_msg") self.recv_agent.send(stats_msg) last_report_index = explored_times except BaseException as ex: logging.exception(ex) os._exit(4)
def stop(self): """Stop all system.""" close_cmd = message(None, cmd="close") for _learner_id, recv_q in self.recv_local_q.items(): if 'predict' in _learner_id: continue else: recv_q.send(close_cmd) time.sleep(0.1)
def get_trajectory(self, last_pred=None): """Get trajectory""" # Need copy, when run with explore time > 1, # if not, will clear trajectory before sent. # trajectory = message(self.trajectory.copy()) trajectory = message(deepcopy(self.trajectory)) set_msg_info(trajectory, agent_id=self.id) return trajectory
def send_create_evaluator_msg(self, broker_id, test_id): """Create evaluator.""" config = deepcopy(self.config_info) config.update({"test_id": test_id}) create_cmd = message(config, cmd="create_evaluator", broker_id=broker_id) self.send_broker.send(create_cmd)
def put_test_model(self, model_weights): """Send test model.""" key = self.get_avail_node() ctr_info = {"cmd": "eval", "broker_id": key[0], "test_id": key[1]} eval_cmd = message(model_weights, **ctr_info) self.send_broker.send(eval_cmd) logging.debug("put evaluate model: {}".format(type(model_weights))) self.used_node[key] += 1
def _dist_policy(self, weight=None, save_index=-1, dist_cmd="explore"): """Distribute model tool.""" ctr_info = self.alg.dist_model_policy.get_dist_info(save_index) if isinstance(ctr_info, dict): ctr_info = [ctr_info] for _ctr in ctr_info: to_send_data = message(weight, cmd=dist_cmd, **_ctr) self.model_q.send(to_send_data)
def predict(self, recv_data): start_t1 = time() state = get_msg_data(recv_data) broker_id = get_msg_info(recv_data, 'broker_id') explorer_id = get_msg_info(recv_data, 'explorer_id') action = self.alg.predict(state) self._stats.inference_time += time() - start_t1 reply_data = message(action, cmd="predict_reply", broker_id=broker_id, explorer_id=explorer_id) self.reply_q.put(reply_data) self._stats.iters += 1 if self._stats.iters > self._report_period: _report = self._stats.get() reply_data = message(_report, cmd="stats_msg{}".format(self.predictor_name)) self.reply_q.put(reply_data)
def get_trajectory(self, last_pred=None): for _data_key in ("cur_state", "logit", "action"): self.trajectory[_data_key] = np.asarray(self.trajectory[_data_key]) self.trajectory["action"].astype(np.int32) # self.trajectory["cur_state"].astype(np.int32) # print(self.trajectory) trajectory = message(self.trajectory) set_msg_info(trajectory, agent_id=self.id) return trajectory
def get_trajectory(self, last_pred=None): for env_id in range(self.vector_env_size): for _data_key in ("cur_state", "logit", "action", "reward", "done", "info"): self.trajectory[_data_key].extend( self.sample_vector[env_id][_data_key]) # merge data into env_num * seq_len for _data_key in self.trajectory: self.trajectory[_data_key] = np.stack(self.trajectory[_data_key]) self.trajectory["action"].astype(np.int32) trajectory = message(self.trajectory.copy()) set_msg_info(trajectory, agent_id=self.id) return trajectory
def start(self): """Run evaluator.""" setproctitle.setproctitle("xt_evaluator") _ags = AgentGroup(self.env_para, self.alg_para, self.agent_para, scene="evaluate") while True: recv_data = self.recv_broker.get() cmd = get_msg_info(recv_data, "cmd") logging.debug("evaluator get meg: {}".format(type(recv_data))) if cmd in ("close", ): break if cmd not in ["eval"]: print_immediately("eval get un-used data:{}".format(recv_data)) continue # print_immediately("recv_data in evaluator: {}".format( # [v.keys() for v in recv_data["data"].values()])) for train_count, weights in recv_data["data"].items(): _ags.restore(weights, is_id=False) eval_data = _ags.evaluate( self.bm_eval.get("episodes_per_eval", 1)) # return each rewards for each agent record_item = tuple([ eval_data, { "train_count": train_count, "broker_id": self.broker_id, "test_id": self.test_id } ]) print_immediately( "collect eval results: {}".format(record_item)) record_item = message( record_item, cmd="eval_return", broker_id=self.broker_id, test_id=self.test_id, ) self.send_broker.send(record_item)
def start_explore(self): """Start explore process.""" signal.signal(signal.SIGINT, signal.SIG_IGN) os.environ["CUDA_VISIBLE_DEVICES"] = str(-1) explored_times = 0 try: self.rl_agent = AgentGroup(self.env_para, self.alg_para, self.agent_para, self.send_agent, self.recv_agent, self._buf) explore_time = self.agent_para.get("agent_config", {}).get("sync_model_interval", 1) logging.info("explorer-{} start with sync interval-{}".format( self.explorer_id, explore_time)) while True: model_type = self.rl_agent.update_model() stats = self.rl_agent.explore(explore_time) explored_times += explore_time if explored_times % self.report_stats_interval == self.explorer_id \ or explored_times == explore_time: stats_msg = message(stats, cmd="stats_msg", broker_id=self.broker_id, explorer_id=self.explorer_id) self.recv_agent.send(stats_msg) if self.explorer_id < 1: logging.debug( "EXP{} ran {} ts, restore {} ts, last type:{}". format(self.explorer_id, explored_times, self.rl_agent.restore_count, model_type)) except BaseException as ex: logging.exception(ex) os._exit(4)
def setup_explorer(broker_master, config_info, env_id): config = deepcopy(config_info) config["env_para"].update({"env_id": env_id}) msg = message(config, cmd="create_explorer") broker_master.recv_local_q.send(msg)
def get_trajectory(self): trajectory = message(self.trajectory.copy()) set_msg_info(trajectory, agent_id=self.id) return trajectory
def get_trajectory(self): transition = self.batch.data.transition_data transition.update(self._info.copy()) # record win rate within train trajectory = message(transition) set_msg_info(trajectory, agent_id=self.id) return trajectory
def setup_explorer(controller_recv_stub, config_info, env_id): config = deepcopy(config_info) config["env_para"].update({"env_id": env_id}) msg = message(config, cmd="create_explorer") controller_recv_stub.send(msg)
def train(self): """Train model.""" if not self.alg.async_flag: policy_weight = self.alg.get_weights() self._dist_policy(weight=policy_weight) loss = 0 while True: for _tf_val in range(self.alg.prepare_data_times): # logging.debug("wait data for preparing-{}...".format(_tf_val)) with self.logger.wait_sample_timer: data = self.train_q.recv() with self.logger.prepare_data_timer: data = bytes_to_str(data) self.record_reward(data) self.alg.prepare_data(data["data"], ctr_info=data["ctr_info"]) # dqn series algorithm will count the 'SARSA' as one episode. # and, episodic count will used for train ready flag. # each pbt exploit.step will reset the episodic count. self.elapsed_episode += 1 # logging.debug("Prepared data-{}.".format(_tf_val)) # support sync model before # run pbt if need. if self.pbt_aid: if self.pbt_aid.meet_stop(self.elapsed_episode): break cur_info = dict( episodic_reward_mean=self.logger.train_reward_avg, elapsed_step=self.actual_step, elapsed_episode=self.elapsed_episode) new_alg = self.pbt_aid.step(cur_info, cur_alg=self.alg) if new_alg: # re-assign algorithm if need self.alg = new_alg if not self.alg.async_flag: policy_weight = self.alg.get_weights() self._dist_policy(weight=policy_weight) continue if self._meet_stop(): self.stats_deliver.send(self.logger.get_new_info(), block=True) break if not self.alg.train_ready(self.elapsed_episode, dist_dummy_model=self._dist_policy): continue with self.lock, self.logger.train_timer: # logging.debug("start train process-{}.".format(self.train_count)) loss = self.alg.train(episode_num=self.elapsed_episode) if type(loss) in (float, np.float64, np.float32, np.float16, np.float): self.logger.record(train_loss=loss) with self.lock: if self.alg.if_save(self.train_count): _name = self.alg.save(self.model_path, self.train_count) # logging.debug("to save model: {}".format(_name)) self._handle_eval_process(loss) # The requirement of distribute model is checkpoint ready. if not self.alg.async_flag and self.alg.checkpoint_ready( self.train_count): _save_t1 = time() policy_weight = self.alg.get_weights() self._metric.append(fix_weight=time() - _save_t1) _dist_st = time() self._dist_policy(policy_weight, self.train_count) self._metric.append(send=time() - _dist_st) self._metric.report_if_need() else: if self.alg.checkpoint_ready(self.train_count): policy_weight = self.alg.get_weights() weight_msg = message(policy_weight, cmd="predict{}".format(self.name), sub_cmd='sync_weights') self.model_q.send(weight_msg) if self.train_count % self._log_interval == 0: self.stats_deliver.send(self.logger.get_new_info(), block=True) self.train_count += 1
def restore(self, weights, is_id=True): """ Restore the weights for all the agents. {"agent_id": {"prefix": "actor", "name":"YOUR/PATH/TO/MODEL/FILE.h5"}} First, find the prefix, Second, find name of the model file. :param weights: :param is_id: :return: """ self.restore_count += 1 # fixme: remove model name file, and make sense to multi-agent. # print("try to get weights:", weights) if is_id: model_weights = self.buf_stub.get(weights) reduce_msg = message(weights, cmd="buf_reduce") self.send_explorer.send(reduce_msg) else: model_weights = {"data": weights} for _ag in self.agents: # weights as dict data, deliver model by weighs # dict, would be useful to multi-agent. if isinstance(weights, (dict, bytes)): # buffer may return weights with None if not model_weights["data"]: # None, dummy model. logging.debug("not data in dict, continue!") continue _ag.alg.restore(model_weights=model_weights["data"]) if self.env_id < 1: logging.debug("ag-{} restore weights t-{}".format( _ag.id, self.restore_count)) continue # 0, default, without weights map, agents will share the same weights if not self.alg_weights_map: logging.debug( "without weights map, use the first weights as default") model_name = weights[0] # 1, use weight prefix elif self.alg_weights_map[_ag.id].get("prefix"): weight_prefix = self.alg_weights_map[_ag.id].get("prefix") model_candid = [ _item for _item in weights if os.path.basename(_item).startswith(weight_prefix) ] model_name = model_candid[0] if len(model_candid) > 0 else None # 2, use model name else: model_name = self.alg_weights_map[_ag.id].get("name") assert model_name is not None, "NO model weight for: {}".format( _ag.id) # restore model with agent.alg.restore() logging.debug("agent-{} trying to load model: {}".format( _ag.id, model_name)) _ag.alg.restore(model_name)
def stop(self): """Stop all system.""" close_cmd = message(None, cmd="close") self.recv_local_q.send(close_cmd)