def __init__( self, train_q, alg, lock, model_path, model_q, s3_path, max_step, stats_deliver, eval_adapter=None, ): self.train_q = train_q self.alg = alg self.lock = lock self.model_path = model_path self.model_q = model_q self.actor_reward = dict() self.actor_trajectory = dict() self.rewards = [] self.s3_path = s3_path self.max_step = max_step self.actual_step = 0 self.won_in_episodes = deque(maxlen=256) self.train_count = 0 self.stats_deliver = stats_deliver self.e_adapter = eval_adapter self.logger = Logger(os.path.dirname(model_path)) self._metric = TimerRecorder("leaner_model", maxlen=50, fields=("fix_weight", "send"))
def __init__(self, node_config_list): self.node_config_list = node_config_list self.node_num = len(node_config_list) self.recv_broker = UniComm("CommByZmq", type="PULL") self.send_broker = [ UniComm("CommByZmq", type="PUSH") for _i in range(self.node_num) ] self.port_info = { "recv": ast.literal_eval(self.recv_broker.info), "send": [ast.literal_eval(_s.info) for _s in self.send_broker] } port_info_h = pprint.pformat( self.port_info, indent=0, width=1, ) logging.info("Init Broker server info:\n{}\n".format(port_info_h)) self.recv_local_q = dict() # UniComm("LocalMsg") self.send_local_q = dict() self.data_manager = Manager() self._data_store = dict() self._main_task = list() self.metric = TimerRecorder("Controller", maxlen=50, fields=("send", "recv")) self.stats = BrokerStats() if DebugConf.trace: tracemalloc.start()
def __init__(self, ip_addr, broker_id, push_port, pull_port): self.broker_id = broker_id self.send_controller_q = UniComm("CommByZmq", type="PUSH", addr=ip_addr, port=push_port) self.recv_controller_q = UniComm("CommByZmq", type="PULL", addr=ip_addr, port=pull_port) self.recv_explorer_q_ready = False # record the information between explorer and learner # {"learner_id": UniComm("ShareByPlasma")} # add {"default_eval": UniComm("ShareByPlasma")} self.explorer_share_qs = {"EVAL0": None} # {"recv_id": receive_count} --> {("recv_id", "explorer_id"): count} self.explorer_stats = defaultdict(int) self.send_explorer_q = dict() self.explore_process = dict() self.processes_suspend = 0 logging.info("init broker with id-{}".format(self.broker_id)) self._metric = TimerRecorder("broker", maxlen=50, fields=("send", )) # Note: need check it if add explorer dynamic # buf size vary with env_num&algorithm # ~4M, impala atari model self._buf = ShareBuf(live=0, size=400000000, max_keep=94, start=True)
def __init__( self, train_q, alg, lock, model_path, model_q, s3_path, max_step, max_episode, stats_deliver, eval_adapter=None, **kwargs, ): self.train_q = train_q self.alg = alg self.lock = lock self.model_path = model_path self.model_q = model_q self.actor_reward = defaultdict(float) self.actor_trajectory = defaultdict(int) self.rewards = [] self.s3_path = s3_path self.max_step = max_step self.actual_step = 0 self.name = kwargs.get('name', 'T0') self.max_episode = max_episode self.elapsed_episode = 0 # off policy, elapsed_episode > train count self.won_in_episodes = deque(maxlen=256) self.train_count = 0 self.stats_deliver = stats_deliver self.e_adapter = eval_adapter self.logger = Logger(os.path.dirname(model_path)) self._metric = TimerRecorder("leaner_model", maxlen=50, fields=("fix_weight", "send")) self._log_interval = kwargs["log_interval"] self._explorer_ids = None self._pbt_aid = None self._train_data_counter = defaultdict(int)
def __init__(self, ip_addr, broker_id, start_port): self.broker_id = broker_id train_port, predict_port = get_port(start_port) self.send_master_q = UniComm( "CommByZmq", type="PUSH", addr=ip_addr, port=train_port ) self.recv_master_q = UniComm( "CommByZmq", type="PULL", addr=ip_addr, port=predict_port + broker_id ) self.recv_explorer_q = UniComm("ShareByPlasma") self.send_explorer_q = dict() self.explore_process = dict() self.processes_suspend = 0 logging.info("init broker slave with id-{}".format(self.broker_id)) self._metric = TimerRecorder("broker_slave", maxlen=50, fields=("send",)) # Note: need check it if add explorer dynamic self._buf = ShareBuf(live=0, start=True)
def __init__(self, node_config_list, start_port=None): self.node_config_list = node_config_list self.node_num = len(node_config_list) comm_conf = None if not start_port: comm_conf = CommConf() start_port = comm_conf.get_start_port() self.start_port = start_port logging.info("master broker init on port: {}".format(start_port)) self.comm_conf = comm_conf recv_port, send_port = get_port(start_port) self.recv_slave = UniComm("CommByZmq", type="PULL", port=recv_port) self.send_slave = [ UniComm("CommByZmq", type="PUSH", port=send_port + i) for i in range(self.node_num) ] self.recv_local_q = UniComm("LocalMsg") self.send_local_q = dict() self.main_task = None self.metric = TimerRecorder("master", maxlen=50, fields=("send", "recv"))
class TrainWorker(object): """TrainWorker Process manage the trajectory data set and optimizer.""" def __init__( self, train_q, alg, lock, model_path, model_q, s3_path, max_step, stats_deliver, eval_adapter=None, ): self.train_q = train_q self.alg = alg self.lock = lock self.model_path = model_path self.model_q = model_q self.actor_reward = dict() self.actor_trajectory = dict() self.rewards = [] self.s3_path = s3_path self.max_step = max_step self.actual_step = 0 self.won_in_episodes = deque(maxlen=256) self.train_count = 0 self.stats_deliver = stats_deliver self.e_adapter = eval_adapter self.logger = Logger(os.path.dirname(model_path)) self._metric = TimerRecorder("leaner_model", maxlen=50, fields=("fix_weight", "send")) def _dist_policy(self, weight=None, save_index=-1, dist_cmd="explore"): """Distribute model tool.""" ctr_info = self.alg.dist_model_policy.get_dist_info(save_index) if isinstance(ctr_info, dict): ctr_info = [ctr_info] for _ctr in ctr_info: to_send_data = message(weight, cmd=dist_cmd, **_ctr) self.model_q.send(to_send_data) def _handle_eval_process(self, loss): if self.e_adapter and self.e_adapter.if_eval(self.train_count): weights = self.alg.get_weights() self.e_adapter.to_eval(weights, self.train_count, self.actual_step, self.logger.elapsed_time, self.logger.train_reward, loss) eval_ret = self.e_adapter.fetch_eval_result() if eval_ret: logging.debug("eval stats: {}".format(eval_ret)) self.stats_deliver.send({ "data": eval_ret, "is_bm": True }, block=True) def train(self): """Train model.""" total_count = 0 # if on the off policy, total count > train count save_count = 0 if not self.alg.async_flag: policy_weight = self.alg.get_weights() self._dist_policy(weight=policy_weight) while True: for _tf_val in range(self.alg.prepare_data_times): logging.debug("wait data for preparing-{}...".format(_tf_val)) with self.logger.wait_sample_timer: data = self.train_q.recv() with self.logger.prepare_data_timer: data = bytes_to_str(data) self.record_reward(data) self.alg.prepare_data(data["data"], ctr_info=data["ctr_info"]) logging.debug("Prepared data-{}.".format(_tf_val)) # support sync model before if self.max_step and self.actual_step >= self.max_step: break total_count += 1 if not self.alg.train_ready(total_count, dist_dummy_model=self._dist_policy): continue with self.lock, self.logger.train_timer: logging.debug("start train process-{}.".format( self.train_count)) loss = self.alg.train(episode_num=total_count) self.train_count += 1 if type(loss) in (float, np.float64, np.float32, np.float16, np.float): self.logger.record(train_loss=loss) # The requirement of distribute model is checkpoint ready. # if self.alg.checkpoint_ready(self.train_count): with self.lock: if self.alg.if_save(self.train_count): _name = self.alg.save(self.model_path, self.train_count) # logging.debug("to save model: {}".format(_name)) self._handle_eval_process(loss) if not self.alg.async_flag and self.alg.checkpoint_ready( self.train_count): _save_t1 = time() policy_weight = self.alg.get_weights() self._metric.append(fix_weight=time() - _save_t1) _dist_st = time() self._dist_policy(policy_weight, save_count) self._metric.append(send=time() - _dist_st) self._metric.report_if_need() save_count += 1 if save_count % 5 == 1: self.stats_deliver.send(self.logger.get_new_info(), block=True) def record_reward(self, train_data): """Record reward in train.""" broker_id = get_msg_info(train_data, 'broker_id') explorer_id = get_msg_info(train_data, 'explorer_id') agent_id = get_msg_info(train_data, 'agent_id') key = (broker_id, explorer_id, agent_id) self.alg.dist_model_policy.add_processed_ctr_info(key) data_dict = get_msg_data(train_data) # update multi agent train reward without done flag if self.alg.alg_name in ("ppo_share_weights", ): self.actual_step += len(data_dict["done"]) self.logger.record( step=self.actual_step, train_reward=np.sum(data_dict["reward"]), train_count=self.train_count, ) return elif self.alg.alg_name in ("QMixAlg", ): # fixme: unify the record op self.actual_step += np.sum(data_dict["filled"]) self.won_in_episodes.append(data_dict.pop("battle_won")) self.logger.update( explore_won_rate=np.nanmean(self.won_in_episodes)) self.logger.record( step=self.actual_step, train_reward=np.sum(data_dict["reward"]), train_count=self.train_count, ) return if key not in self.actor_reward.keys(): self.actor_reward[key] = 0.0 self.actor_trajectory[key] = 0 data_length = len(data_dict["done"]) # fetch the train data length for data_index in range(data_length): reward = data_dict["reward"][data_index] done = data_dict["done"][data_index] info = data_dict["info"][data_index] self.actual_step += 1 if isinstance(info, dict): self.actor_reward[key] += info.get("eval_reward", reward) self.actor_trajectory[key] += 1 done = info.get("real_done", done) if done: self.logger.record( step=self.actual_step, train_count=self.train_count, train_reward=self.actor_reward[key], ) logging.debug("{} epi reward-{}. with len-{}".format( key, self.actor_reward[key], self.actor_trajectory[key])) self.actor_reward[key] = 0.0 self.actor_trajectory[key] = 0
class TrainWorker(object): """TrainWorker Process manage the trajectory data set and optimizer.""" def __init__( self, train_q, alg, lock, model_path, model_q, s3_path, max_step, max_episode, stats_deliver, eval_adapter=None, **kwargs, ): self.train_q = train_q self.alg = alg self.lock = lock self.model_path = model_path self.model_q = model_q self.actor_reward = defaultdict(float) self.actor_trajectory = defaultdict(int) self.rewards = [] self.s3_path = s3_path self.max_step = max_step self.actual_step = 0 self.name = kwargs.get('name', 'T0') self.max_episode = max_episode self.elapsed_episode = 0 # off policy, elapsed_episode > train count self.won_in_episodes = deque(maxlen=256) self.train_count = 0 self.stats_deliver = stats_deliver self.e_adapter = eval_adapter self.logger = Logger(os.path.dirname(model_path)) self._metric = TimerRecorder("leaner_model", maxlen=50, fields=("fix_weight", "send")) self._log_interval = kwargs["log_interval"] self._explorer_ids = None self._pbt_aid = None self._train_data_counter = defaultdict(int) @property def explorer_ids(self): return self._explorer_ids @explorer_ids.setter def explorer_ids(self, val): self._explorer_ids = val @property def pbt_aid(self): return self._pbt_aid @pbt_aid.setter def pbt_aid(self, val): self._pbt_aid = val def _dist_policy(self, weight=None, save_index=-1, dist_cmd="explore"): """Distribute model tool.""" explorer_set = self.explorer_ids ctr_info = self.alg.dist_model_policy.get_dist_info( save_index, explorer_set) if isinstance(ctr_info, dict): ctr_info = [ctr_info] for _ctr in ctr_info: to_send_data = message(weight, cmd=dist_cmd, **_ctr) self.model_q.send(to_send_data) def _handle_eval_process(self, loss): if not self.e_adapter: return if self.e_adapter.if_eval(self.train_count): weights = self.alg.get_weights() self.e_adapter.to_eval(weights, self.train_count, self.actual_step, self.logger.elapsed_time, self.logger.train_reward, loss) elif not self.e_adapter.eval_result_empty: eval_ret = self.e_adapter.fetch_eval_result() if eval_ret: logging.debug("eval stats: {}".format(eval_ret)) self.stats_deliver.send({ "data": eval_ret, "is_bm": True }, block=True) def _meet_stop(self): if self.max_step and self.actual_step > self.max_step: return True # Under pbt set, the max_episode need set into pbt_config # Owing to the reset of episode count after each pbt.exploit if self.max_episode and self.elapsed_episode > self.max_episode: return True return False def train(self): """Train model.""" if not self.alg.async_flag: policy_weight = self.alg.get_weights() self._dist_policy(weight=policy_weight) loss = 0 while True: for _tf_val in range(self.alg.prepare_data_times): # logging.debug("wait data for preparing-{}...".format(_tf_val)) with self.logger.wait_sample_timer: data = self.train_q.recv() with self.logger.prepare_data_timer: data = bytes_to_str(data) self.record_reward(data) self.alg.prepare_data(data["data"], ctr_info=data["ctr_info"]) # dqn series algorithm will count the 'SARSA' as one episode. # and, episodic count will used for train ready flag. # each pbt exploit.step will reset the episodic count. self.elapsed_episode += 1 # logging.debug("Prepared data-{}.".format(_tf_val)) # support sync model before # run pbt if need. if self.pbt_aid: if self.pbt_aid.meet_stop(self.elapsed_episode): break cur_info = dict( episodic_reward_mean=self.logger.train_reward_avg, elapsed_step=self.actual_step, elapsed_episode=self.elapsed_episode) new_alg = self.pbt_aid.step(cur_info, cur_alg=self.alg) if new_alg: # re-assign algorithm if need self.alg = new_alg if not self.alg.async_flag: policy_weight = self.alg.get_weights() self._dist_policy(weight=policy_weight) continue if self._meet_stop(): self.stats_deliver.send(self.logger.get_new_info(), block=True) break if not self.alg.train_ready(self.elapsed_episode, dist_dummy_model=self._dist_policy): continue with self.lock, self.logger.train_timer: # logging.debug("start train process-{}.".format(self.train_count)) loss = self.alg.train(episode_num=self.elapsed_episode) if type(loss) in (float, np.float64, np.float32, np.float16, np.float): self.logger.record(train_loss=loss) with self.lock: if self.alg.if_save(self.train_count): _name = self.alg.save(self.model_path, self.train_count) # logging.debug("to save model: {}".format(_name)) self._handle_eval_process(loss) # The requirement of distribute model is checkpoint ready. if not self.alg.async_flag and self.alg.checkpoint_ready( self.train_count): _save_t1 = time() policy_weight = self.alg.get_weights() self._metric.append(fix_weight=time() - _save_t1) _dist_st = time() self._dist_policy(policy_weight, self.train_count) self._metric.append(send=time() - _dist_st) self._metric.report_if_need() else: if self.alg.checkpoint_ready(self.train_count): policy_weight = self.alg.get_weights() weight_msg = message(policy_weight, cmd="predict{}".format(self.name), sub_cmd='sync_weights') self.model_q.send(weight_msg) if self.train_count % self._log_interval == 0: self.stats_deliver.send(self.logger.get_new_info(), block=True) self.train_count += 1 def record_reward(self, train_data): """Record reward in train.""" broker_id = get_msg_info(train_data, 'broker_id') explorer_id = get_msg_info(train_data, 'explorer_id') agent_id = get_msg_info(train_data, 'agent_id') key = (broker_id, explorer_id, agent_id) # key = learner_stats_id(train_data["ctr_info"]) # record the train_data received self._train_data_counter[key] += 1 self.alg.dist_model_policy.add_processed_ctr_info(key) data_dict = get_msg_data(train_data) # update multi agent train reward without done flag if self.alg.alg_name in ("QMixAlg", ): # fixme: unify the record op self.actual_step += np.sum(data_dict["filled"]) self.won_in_episodes.append(data_dict.pop("battle_won")) self.logger.update( explore_won_rate=np.nanmean(self.won_in_episodes)) self.logger.record( step=self.actual_step, train_reward=np.sum(data_dict["reward"]), train_count=self.train_count, ) return elif self.alg.alg_config['api_type'] == "unified": self.actual_step += len(data_dict["done"]) self.logger.record( step=self.actual_step, train_reward=np.sum(data_dict["reward"]), train_count=self.train_count, ) return data_length = len(data_dict["done"]) # fetch the train data length for data_index in range(data_length): reward = data_dict["reward"][data_index] done = data_dict["done"][data_index] info = data_dict["info"][data_index] self.actual_step += 1 if isinstance(info, dict): self.actor_reward[key] += info.get("eval_reward", reward) self.actor_trajectory[key] += 1 done = info.get("real_done", done) if done: self.logger.record( step=self.actual_step, train_count=self.train_count, train_reward=self.actor_reward[key], trajectory_length=self.actor_trajectory[key], ) # logging.debug("{} epi reward-{}. with len-{}".format( # key, self.actor_reward[key], self.actor_trajectory[key])) self.actor_reward[key] = 0.0 self.actor_trajectory[key] = 0
class Controller(object): """Controller Manage Broker within Learner.""" def __init__(self, node_config_list): self.node_config_list = node_config_list self.node_num = len(node_config_list) self.recv_broker = UniComm("CommByZmq", type="PULL") self.send_broker = [ UniComm("CommByZmq", type="PUSH") for _i in range(self.node_num) ] self.port_info = { "recv": ast.literal_eval(self.recv_broker.info), "send": [ast.literal_eval(_s.info) for _s in self.send_broker] } port_info_h = pprint.pformat( self.port_info, indent=0, width=1, ) logging.info("Init Broker server info:\n{}\n".format(port_info_h)) self.recv_local_q = dict() # UniComm("LocalMsg") self.send_local_q = dict() self.data_manager = Manager() self._data_store = dict() self._main_task = list() self.metric = TimerRecorder("Controller", maxlen=50, fields=("send", "recv")) self.stats = BrokerStats() if DebugConf.trace: tracemalloc.start() def start_data_transfer(self): """Start transfer data and other thread.""" data_transfer_thread = threading.Thread(target=self.recv_broker_task) data_transfer_thread.setDaemon(True) data_transfer_thread.start() data_transfer_thread = threading.Thread(target=self.recv_local) data_transfer_thread.setDaemon(True) data_transfer_thread.start() # alloc_thread = threading.Thread(target=self.alloc_actor) # alloc_thread.setDaemon(True) # alloc_thread.start() @property def tasks(self): return self._main_task def recv_broker_task(self): """Receive remote train data in sync mode.""" while True: ctr_info, recv_data = self.recv_broker.recv_bytes() _t0 = time.time() ctr_info = deserialize(ctr_info) compress_flag = ctr_info.get('compress_flag', False) if compress_flag: recv_data = lz4.frame.decompress(recv_data) recv_data = deserialize(recv_data) recv_data = {'ctr_info': ctr_info, 'data': recv_data} self.metric.append(recv=time.time() - _t0) cmd = get_msg_info(recv_data, "cmd") send_cmd = self.send_local_q.get(cmd) if send_cmd: send_cmd.send(recv_data) else: logging.warning("invalid cmd: {}, with date: {}".format( cmd, recv_data)) # report log self.metric.report_if_need(field_sets=("send", "recv")) def _yield_local_msg(self): """Yield local msg received.""" while True: # polling with whole learner with no wait. for cmd, recv_q in self.recv_local_q.items(): if 'predict' in cmd: try: recv_data = recv_q.get(block=False) except: recv_data = None else: recv_data = recv_q.recv(block=False) if recv_data: yield recv_data else: time.sleep(0.002) def recv_local(self): """Receive local cmd.""" # split the case between single receive queue and pbt if len(self.recv_local_q) == 1: single_stub, = self.recv_local_q.values() else: single_stub = None kwargs = init_main_broker_debug_kwargs() yield_func = self._yield_local_msg() while True: if single_stub: recv_data = single_stub.recv(block=True) else: recv_data = next(yield_func) cmd = get_msg_info(recv_data, "cmd") if cmd in ["close"]: self.close(recv_data) break if cmd in self.send_local_q.keys(): # print(self.send_local_q.keys()) self.send_local_q[cmd].send(recv_data) # logging.debug("recv: {} with cmd-{}".format( # recv_data["data"], cmd)) else: _t1 = time.time() broker_id = get_msg_info(recv_data, "broker_id") _cmd = get_msg_info(recv_data, "cmd") # logging.debug("ctr recv:{} with cmd:'{}' to broker_id: <{}>".format( # type(recv_data["data"]), _cmd, broker_id)) # self.metric.append(debug=time.time() - _t1) if broker_id == -1: for broker_item, node_info in zip(self.send_broker, self.node_config_list): broker_item.send(recv_data['ctr_info'], recv_data['data']) else: self.send_broker[broker_id].send(recv_data['ctr_info'], recv_data['data']) self.metric.append(send=time.time() - _t1) debug_within_interval(**kwargs) def add_task(self, learner_obj): """Add learner task into Broker.""" self._main_task.append(learner_obj) def register(self, cmd, direction, comm_q=None): """Register cmd vary with type. :param cmd: name to register. :type cmd: str :param direction: direction relate to broker. :type direction: str :return: stub of the local queue with registered. :rtype: option """ if not comm_q: comm_q = UniComm("LocalMsg") if direction == "send": self.send_local_q.update({cmd: comm_q}) return self.send_local_q[cmd] elif direction == "recv": self.recv_local_q.update({cmd: comm_q}) return self.recv_local_q[cmd] elif direction == "store": self._data_store.update({cmd: self.data_manager.dict()}) return self._data_store[cmd] else: raise KeyError("invalid register key: {}".format(direction)) def alloc_actor(self): while True: time.sleep(10) if not self.send_local_q.get("train"): continue train_list = self.send_local_q["train"].comm.data_list if len(train_list) > 200: self.send_alloc_msg("decrease") elif len(train_list) < 10: self.send_alloc_msg("increase") def send_alloc_msg(self, actor_status): alloc_cmd = { "ctr_info": { "cmd": actor_status, "actor_id": -1, "explorer_id": -1 } } for q in self.send_broker: q.send(alloc_cmd['ctr_info'], alloc_cmd['data']) def close(self, close_cmd): for broker_item in self.send_broker: broker_item.send(close_cmd['ctr_info'], close_cmd['data']) # close ctx may mismatch the socket, use the os.exit last. # self.recv_broker.close() # for _send_stub in self.send_broker: # _send_stub.close() def start(self): """Start all system.""" setproctitle.setproctitle("xt_main") self.start_data_transfer() def tasks_loop(self): """ Create the tasks_loop after ready the messy setup works. The foreground task of Controller. :return: """ if not self._main_task: logging.fatal("Without learning process ready!") train_thread = [ threading.Thread(target=task.main_loop) for task in self.tasks ] for task in train_thread: task.start() self.stats.add_relation_task(task) # check broker stats. self.stats.loop() # wait to job end. for task in train_thread: task.join() def stop(self): """Stop all system.""" close_cmd = message(None, cmd="close") for _learner_id, recv_q in self.recv_local_q.items(): if 'predict' in _learner_id: continue else: recv_q.send(close_cmd) time.sleep(0.1)
class Broker(object): """Broker manage the Broker within Explorer of each node.""" def __init__(self, ip_addr, broker_id, push_port, pull_port): self.broker_id = broker_id self.send_controller_q = UniComm("CommByZmq", type="PUSH", addr=ip_addr, port=push_port) self.recv_controller_q = UniComm("CommByZmq", type="PULL", addr=ip_addr, port=pull_port) self.recv_explorer_q_ready = False # record the information between explorer and learner # {"learner_id": UniComm("ShareByPlasma")} # add {"default_eval": UniComm("ShareByPlasma")} self.explorer_share_qs = {"EVAL0": None} # {"recv_id": receive_count} --> {("recv_id", "explorer_id"): count} self.explorer_stats = defaultdict(int) self.send_explorer_q = dict() self.explore_process = dict() self.processes_suspend = 0 logging.info("init broker with id-{}".format(self.broker_id)) self._metric = TimerRecorder("broker", maxlen=50, fields=("send", )) # Note: need check it if add explorer dynamic # buf size vary with env_num&algorithm # ~4M, impala atari model self._buf = ShareBuf(live=0, size=400000000, max_keep=94, start=True) def start_data_transfer(self): """Start transfer data and other thread.""" data_transfer_thread = threading.Thread( target=self.recv_controller_task) data_transfer_thread.start() data_transfer_thread = threading.Thread(target=self.recv_explorer) data_transfer_thread.start() def _setup_share_qs_firstly(self, config_info): """Setup only once time.""" if self.recv_explorer_q_ready: return _use_pbt, pbt_size, env_num, _ = get_pbt_set(config_info) plasma_size = config_info.get("plasma_size", 100000000) for i in range(pbt_size): plasma_path = "/tmp/plasma{}T{}".format(os.getpid(), i) self.explorer_share_qs["T{}".format(i)] = UniComm("ShareByPlasma", size=plasma_size, path=plasma_path) # print("broker pid:", os.getpid()) # self.explorer_share_qs["T{}".format(i)].comm.connect() # if pbt, eval process will share single server # else, share with T0 if not _use_pbt: self.explorer_share_qs["EVAL0"] = self.explorer_share_qs["T0"] else: # use pbt, server will been set within create_evaluator self.explorer_share_qs["EVAL0"] = None self.recv_explorer_q_ready = True def recv_controller_task(self): """Recv remote train data in sync mode.""" while True: # recv, data will deserialize with pyarrow default # recv_data = self.recv_controller_q.recv() ctr_info, data = self.recv_controller_q.recv_bytes() recv_data = { 'ctr_info': deserialize(ctr_info), 'data': deserialize(data) } cmd = get_msg_info(recv_data, "cmd") if cmd in ["close"]: self.close(recv_data) if cmd in ["create_explorer"]: config_set = recv_data["data"] # setup plasma only one times! self._setup_share_qs_firstly(config_set) config_set.update({"share_path": self._buf.get_path()}) self.create_explorer(config_set) # update the buffer live attribute, explorer num as default. # self._buf.plus_one_live() continue if cmd in ["create_evaluator"]: # evaluator share single plasma. config_set = recv_data["data"] if not self.explorer_share_qs["EVAL0"]: use_pbt, _, _, _ = get_pbt_set(config_set) if use_pbt: plasma_size = config_set.get("plasma_size", 100000000) plasma_path = "/tmp/plasma{}EVAL0".format(os.getpid()) self.explorer_share_qs["EVAL0"] = UniComm( "ShareByPlasma", size=plasma_size, path=plasma_path) config_set.update({"share_path": self._buf.get_path()}) logging.debug( "create evaluator with config:{}".format(config_set)) self.create_evaluator(config_set) # self._buf.plus_one_live() continue if cmd in ["increase", "decrease"]: self.alloc(cmd) continue if cmd in ("eval", ): # fixme: merge into explore test_id = get_msg_info(recv_data, "test_id") self.send_explorer_q[test_id].put(recv_data) continue # last job, distribute weights/model_name from controller explorer_id = get_msg_info(recv_data, "explorer_id") if not isinstance(explorer_id, list): explorer_id = [explorer_id] _t0 = time.time() if cmd in ("explore", ): # todo: could mv to first priority. # here, only handle explore weights buf_id = self._buf.put(data) # replace weight with id recv_data.update({"data": buf_id}) # predict_reply # e.g, {'ctr_info': {'broker_id': 0, 'explorer_id': 4, 'agent_id': -1, # 'cmd': 'predict_reply'}, 'data': 0} for _eid in explorer_id: if _eid > -1: self.send_explorer_q[_eid].put(recv_data) elif _eid > -2: # -1 # whole explorer, contains evaluator! for qid, send_q in self.send_explorer_q.items(): if isinstance(qid, str) and "test" in qid: # logging.info("continue test: ", qid, send_q) continue send_q.put(recv_data) else: raise KeyError("invalid explore id: {}".format(_eid)) self._metric.append(send=time.time() - _t0) self._metric.report_if_need() @staticmethod def _handle_data(ctr_info, data, explorer_stub, broker_stub): object_id = ctr_info['object_id'] ctr_info_data = ctr_info['ctr_info_data'] broker_stub.send_bytes(ctr_info_data, data) explorer_stub.delete(object_id) def _step_explorer_msg(self, use_single_stub): """Yield local msg received.""" if use_single_stub: # whole explorer share single plasma recv_id = "T0" single_stub = self.explorer_share_qs[recv_id] else: single_stub, recv_id = None, None while True: if use_single_stub: ctr_info, data = single_stub.recv_bytes(block=True) self._handle_data(ctr_info, data, single_stub, self.send_controller_q) yield recv_id, ctr_info else: # polling with whole learner with no wait. for recv_id, recv_q in self.explorer_share_qs.items(): if not recv_q: # handle eval dummy q continue ctr_info, data = recv_q.recv_bytes(block=False) if not ctr_info: # this receive q without ready! time.sleep(0.002) continue self._handle_data(ctr_info, data, recv_q, self.send_controller_q) yield recv_id, ctr_info def recv_explorer(self): """Recv explorer cmd.""" while not self.recv_explorer_q_ready: # wait explorer share buffer ready time.sleep(0.05) # whole explorer share single plasma use_single_flag = True if len(self.explorer_share_qs) == 2 else False yield_func = self._step_explorer_msg(use_single_flag) while True: recv_id, _info = next(yield_func) _id = stats_id(_info) self.explorer_stats[_id] += 1 debug_within_interval(logs=dict(self.explorer_stats), interval=DebugConf.interval_s, human_able=True) def create_explorer(self, config_info): """Create explorer.""" env_para = config_info.get("env_para") env_num = config_info.get("env_num") speedup = config_info.get("speedup", True) start_core = config_info.get("start_core", 1) env_id = env_para.get("env_id") # used for explorer id. ref_learner_id = config_info.get("learner_postfix") send_explorer = Queue() explorer = Explorer( config_info, self.broker_id, recv_broker=send_explorer, send_broker=self.explorer_share_qs[ref_learner_id], ) p = Process(target=explorer.start) p.start() cpu_count = psutil.cpu_count() if speedup and cpu_count > (env_num + start_core): _p = psutil.Process(p.pid) _p.cpu_affinity([start_core + env_id]) self.send_explorer_q.update({env_id: send_explorer}) self.explore_process.update({env_id: p}) def create_evaluator(self, config_info): """Create evaluator.""" test_id = config_info.get("test_id") send_evaluator = Queue() evaluator = Evaluator( config_info, self.broker_id, recv_broker=send_evaluator, send_broker=self.explorer_share_qs["EVAL0"], ) p = Process(target=evaluator.start) p.start() speedup = config_info.get("speedup", False) start_core = config_info.get("start_core", 1) eval_num = config_info.get("benchmark", {}).get("eval", {}).get("evaluator_num", 1) env_num = config_info.get("env_num") core_set = env_num + start_core cpu_count = psutil.cpu_count() if speedup and cpu_count > (env_num + eval_num + start_core): _p = psutil.Process(p.pid) _p.cpu_affinity([core_set]) core_set += 1 self.send_explorer_q.update({test_id: send_evaluator}) self.explore_process.update({test_id: p}) def alloc(self, actor_status): """Monitor system and adjust resource.""" p_id = [_p.pid for _, _p in self.explore_process.items()] p = [psutil.Process(_pid) for _pid in p_id] if actor_status == "decrease": if self.processes_suspend < len(p): p[self.processes_suspend].suspend() self.processes_suspend += 1 elif actor_status == "increase": if self.processes_suspend >= 1: p[self.processes_suspend - 1].resume() self.processes_suspend -= 1 else: pass elif actor_status == "reset": # resume all processes suspend for _, resume_p in enumerate(p): resume_p.resume() def close(self, close_cmd): """Close broker.""" for _, send_q in self.send_explorer_q.items(): send_q.put(close_cmd) time.sleep(2) for _, p in self.explore_process.items(): if p.exitcode is None: p.terminate() # self.send_controller_q.close() # self.recv_controller_q.close() os.system("pkill plasma -g " + str(os.getpgid(0))) os._exit(0) def start(self): """Start all system.""" setproctitle.setproctitle("xt_broker") self.start_data_transfer()
class BrokerMaster(object): """BrokerMaster Manage Broker within Learner.""" def __init__(self, node_config_list, start_port=None): self.node_config_list = node_config_list self.node_num = len(node_config_list) comm_conf = None if not start_port: comm_conf = CommConf() start_port = comm_conf.get_start_port() self.start_port = start_port logging.info("master broker init on port: {}".format(start_port)) self.comm_conf = comm_conf recv_port, send_port = get_port(start_port) self.recv_slave = UniComm("CommByZmq", type="PULL", port=recv_port) self.send_slave = [ UniComm("CommByZmq", type="PUSH", port=send_port + i) for i in range(self.node_num) ] self.recv_local_q = UniComm("LocalMsg") self.send_local_q = dict() self.main_task = None self.metric = TimerRecorder("master", maxlen=50, fields=("send", "recv")) def start_data_transfer(self): """Start transfer data and other thread.""" data_transfer_thread = threading.Thread(target=self.recv_broker_slave) data_transfer_thread.setDaemon(True) data_transfer_thread.start() data_transfer_thread = threading.Thread(target=self.recv_local) data_transfer_thread.setDaemon(True) data_transfer_thread.start() # alloc_thread = threading.Thread(target=self.alloc_actor) # alloc_thread.setDaemon(True) # alloc_thread.start() def recv_broker_slave(self): """Receive remote train data in sync mode.""" while True: recv_data = self.recv_slave.recv_bytes() _t0 = time.time() recv_data = deserialize(lz4.frame.decompress(recv_data)) self.metric.append(recv=time.time() - _t0) cmd = get_msg_info(recv_data, "cmd") if cmd in []: pass else: send_cmd = self.send_local_q.get(cmd) if send_cmd: send_cmd.send(recv_data) # report log self.metric.report_if_need() def recv_local(self): """Receive local cmd.""" while True: recv_data = self.recv_local_q.recv() cmd = get_msg_info(recv_data, "cmd") if cmd in ["close"]: self.close(recv_data) if cmd in [self.send_local_q.keys()]: self.send_local_q[cmd].send(recv_data) logging.debug("recv: {} with cmd-{}".format(type(recv_data["data"]), cmd)) else: _t1 = time.time() broker_id = get_msg_info(recv_data, "broker_id") _cmd = get_msg_info(recv_data, "cmd") logging.debug("master recv:{} with cmd:'{}' to broker_id: <{}>".format( type(recv_data["data"]), _cmd, broker_id)) # self.metric.append(debug=time.time() - _t1) if broker_id == -1: for slave, node_info in zip(self.send_slave, self.node_config_list): slave.send(recv_data) else: self.send_slave[broker_id].send(recv_data) self.metric.append(send=time.time() - _t1) def register(self, cmd): self.send_local_q.update({cmd: UniComm("LocalMsg")}) return self.send_local_q[cmd] def alloc_actor(self): while True: time.sleep(10) if not self.send_local_q.get("train"): continue train_list = self.send_local_q["train"].comm.data_list if len(train_list) > 200: self.send_alloc_msg("decrease") elif len(train_list) < 10: self.send_alloc_msg("increase") def send_alloc_msg(self, actor_status): alloc_cmd = { "ctr_info": {"cmd": actor_status, "actor_id": -1, "explorer_id": -1} } for q in self.send_slave: q.send(alloc_cmd) def close(self, close_cmd): for slave in self.send_slave: slave.send(close_cmd) time.sleep(1) try: self.comm_conf.release_start_port(self.start_port) except BaseException: pass os._exit(0) def start(self): """Start all system.""" self.start_data_transfer() def main_loop(self): """ Create the main_loop after ready the messy setup works. The foreground task of broker master. :return: """ if not self.main_task: logging.fatal("learning process isn't ready!") self.main_task.main_loop() def stop(self): """Stop all system.""" close_cmd = message(None, cmd="close") self.recv_local_q.send(close_cmd)
class BrokerSlave(object): """BrokerSlave manage the Broker within Explorer of each node.""" def __init__(self, ip_addr, broker_id, start_port): self.broker_id = broker_id train_port, predict_port = get_port(start_port) self.send_master_q = UniComm( "CommByZmq", type="PUSH", addr=ip_addr, port=train_port ) self.recv_master_q = UniComm( "CommByZmq", type="PULL", addr=ip_addr, port=predict_port + broker_id ) self.recv_explorer_q = UniComm("ShareByPlasma") self.send_explorer_q = dict() self.explore_process = dict() self.processes_suspend = 0 logging.info("init broker slave with id-{}".format(self.broker_id)) self._metric = TimerRecorder("broker_slave", maxlen=50, fields=("send",)) # Note: need check it if add explorer dynamic self._buf = ShareBuf(live=0, start=True) def start_data_transfer(self): """Start transfer data and other thread.""" data_transfer_thread = threading.Thread(target=self.recv_master) data_transfer_thread.start() data_transfer_thread = threading.Thread(target=self.recv_explorer) data_transfer_thread.start() def recv_master(self): """Recv remote train data in sync mode.""" while True: # recv, data will deserialize with pyarrow default # recv_data = self.recv_master_q.recv() recv_bytes = self.recv_master_q.recv_bytes() recv_data = deserialize(recv_bytes) cmd = get_msg_info(recv_data, "cmd") if cmd in ["close"]: self.close(recv_data) if cmd in ["create_explorer"]: config_set = recv_data["data"] config_set.update({"share_path": self._buf.get_path()}) self.create_explorer(config_set) # update the buffer live attribute, explorer num as default. self._buf.plus_one_live() continue if cmd in ["create_evaluator"]: config_set = recv_data["data"] config_set.update({"share_path": self._buf.get_path()}) logging.debug("create evaluator with config:{}".format(config_set)) self.create_evaluator(config_set) self._buf.plus_one_live() continue if cmd in ["increase", "decrease"]: self.alloc(cmd) continue if cmd in ("eval",): # fixme: merge into explore test_id = get_msg_info(recv_data, "test_id") self.send_explorer_q[test_id].put(recv_data) continue # last job, distribute weights/model_name from broker master explorer_id = get_msg_info(recv_data, "explorer_id") if not isinstance(explorer_id, list): explorer_id = [explorer_id] _t0 = time.time() if "explore" in cmd: # here, only handle explore weights buf_id = self._buf.put(recv_bytes) # replace weight with id recv_data.update({"data": buf_id}) # predict_reply for _eid in explorer_id: if _eid > -1: self.send_explorer_q[_eid].put(recv_data) elif _eid > -2: # -1 # whole explorer, contains evaluator! for qid, send_q in self.send_explorer_q.items(): if isinstance(qid, str) and "test" in qid: # logging.info("continue test: ", qid, send_q) continue send_q.put(recv_data) else: raise KeyError("invalid explore id: {}".format(_eid)) self._metric.append(send=time.time() - _t0) self._metric.report_if_need() def recv_explorer(self): """Recv explorer cmd.""" while True: data, object_info = self.recv_explorer_q.recv_bytes() object_id, data_type = object_info if data_type == "data": self.send_master_q.send_bytes(data) elif data_type == "buf_reduce": recv_data = deserialize(data) to_reduce_id = get_msg_data(recv_data) self._buf.reduce_once(to_reduce_id) else: raise KeyError("un-known data type: {}".format(data_type)) self.recv_explorer_q.delete(object_id) def create_explorer(self, config_info): """Create explorer.""" env_para = config_info.get("env_para") env_id = env_para.get("env_id") send_explorer = Queue() explorer = Explorer( config_info, self.broker_id, recv_broker=send_explorer, send_broker=self.recv_explorer_q, ) p = Process(target=explorer.start) p.daemon = True p.start() self.send_explorer_q.update({env_id: send_explorer}) self.explore_process.update({env_id: p}) def create_evaluator(self, config_info): """Create evaluator.""" test_id = config_info.get("test_id") send_evaluator = Queue() evaluator = Evaluator( config_info, self.broker_id, recv_broker=send_evaluator, send_broker=self.recv_explorer_q, ) p = Process(target=evaluator.start) p.daemon = True p.start() self.send_explorer_q.update({test_id: send_evaluator}) self.explore_process.update({test_id: p}) def alloc(self, actor_status): """Monitor system and adjust resource.""" p_id = [_p.pid for _, _p in self.explore_process.items()] p = [psutil.Process(_pid) for _pid in p_id] if actor_status == "decrease": if self.processes_suspend < len(p): p[self.processes_suspend].suspend() self.processes_suspend += 1 elif actor_status == "increase": if self.processes_suspend >= 1: p[self.processes_suspend - 1].resume() self.processes_suspend -= 1 else: pass elif actor_status == "reset": # resume all processes suspend for _, resume_p in enumerate(p): resume_p.resume() def close(self, close_cmd): """Close broker.""" for _, send_q in self.send_explorer_q.items(): send_q.put(close_cmd) time.sleep(5) for _, p in self.explore_process.items(): if p.exitcode is None: p.terminate() os.system("pkill plasma -g " + str(os.getpgid(0))) os._exit(0) def start(self): """Start all system.""" self.start_data_transfer()