Ejemplo n.º 1
0
    def __init__(
        self,
        train_q,
        alg,
        lock,
        model_path,
        model_q,
        s3_path,
        max_step,
        stats_deliver,
        eval_adapter=None,
    ):
        self.train_q = train_q
        self.alg = alg
        self.lock = lock
        self.model_path = model_path
        self.model_q = model_q
        self.actor_reward = dict()
        self.actor_trajectory = dict()
        self.rewards = []
        self.s3_path = s3_path
        self.max_step = max_step
        self.actual_step = 0
        self.won_in_episodes = deque(maxlen=256)
        self.train_count = 0

        self.stats_deliver = stats_deliver
        self.e_adapter = eval_adapter

        self.logger = Logger(os.path.dirname(model_path))
        self._metric = TimerRecorder("leaner_model",
                                     maxlen=50,
                                     fields=("fix_weight", "send"))
Ejemplo n.º 2
0
    def __init__(self, node_config_list):
        self.node_config_list = node_config_list
        self.node_num = len(node_config_list)

        self.recv_broker = UniComm("CommByZmq", type="PULL")
        self.send_broker = [
            UniComm("CommByZmq", type="PUSH") for _i in range(self.node_num)
        ]

        self.port_info = {
            "recv": ast.literal_eval(self.recv_broker.info),
            "send": [ast.literal_eval(_s.info) for _s in self.send_broker]
        }

        port_info_h = pprint.pformat(
            self.port_info,
            indent=0,
            width=1,
        )
        logging.info("Init Broker server info:\n{}\n".format(port_info_h))

        self.recv_local_q = dict()  # UniComm("LocalMsg")
        self.send_local_q = dict()

        self.data_manager = Manager()
        self._data_store = dict()

        self._main_task = list()
        self.metric = TimerRecorder("Controller",
                                    maxlen=50,
                                    fields=("send", "recv"))
        self.stats = BrokerStats()

        if DebugConf.trace:
            tracemalloc.start()
Ejemplo n.º 3
0
    def __init__(self, ip_addr, broker_id, push_port, pull_port):
        self.broker_id = broker_id
        self.send_controller_q = UniComm("CommByZmq",
                                         type="PUSH",
                                         addr=ip_addr,
                                         port=push_port)

        self.recv_controller_q = UniComm("CommByZmq",
                                         type="PULL",
                                         addr=ip_addr,
                                         port=pull_port)

        self.recv_explorer_q_ready = False
        # record the information between explorer and learner
        #     {"learner_id": UniComm("ShareByPlasma")}
        # add {"default_eval": UniComm("ShareByPlasma")}
        self.explorer_share_qs = {"EVAL0": None}

        # {"recv_id": receive_count}  --> {("recv_id", "explorer_id"): count}
        self.explorer_stats = defaultdict(int)

        self.send_explorer_q = dict()
        self.explore_process = dict()
        self.processes_suspend = 0
        logging.info("init broker with id-{}".format(self.broker_id))
        self._metric = TimerRecorder("broker", maxlen=50, fields=("send", ))
        # Note: need check it if add explorer dynamic
        # buf size vary with env_num&algorithm
        # ~4M, impala atari model
        self._buf = ShareBuf(live=0, size=400000000, max_keep=94, start=True)
Ejemplo n.º 4
0
    def __init__(
        self,
        train_q,
        alg,
        lock,
        model_path,
        model_q,
        s3_path,
        max_step,
        max_episode,
        stats_deliver,
        eval_adapter=None,
        **kwargs,
    ):
        self.train_q = train_q
        self.alg = alg
        self.lock = lock
        self.model_path = model_path
        self.model_q = model_q
        self.actor_reward = defaultdict(float)
        self.actor_trajectory = defaultdict(int)
        self.rewards = []
        self.s3_path = s3_path
        self.max_step = max_step
        self.actual_step = 0
        self.name = kwargs.get('name', 'T0')

        self.max_episode = max_episode
        self.elapsed_episode = 0  # off policy, elapsed_episode > train count

        self.won_in_episodes = deque(maxlen=256)
        self.train_count = 0

        self.stats_deliver = stats_deliver
        self.e_adapter = eval_adapter

        self.logger = Logger(os.path.dirname(model_path))
        self._metric = TimerRecorder("leaner_model",
                                     maxlen=50,
                                     fields=("fix_weight", "send"))

        self._log_interval = kwargs["log_interval"]

        self._explorer_ids = None
        self._pbt_aid = None
        self._train_data_counter = defaultdict(int)
Ejemplo n.º 5
0
    def __init__(self, ip_addr, broker_id, start_port):
        self.broker_id = broker_id
        train_port, predict_port = get_port(start_port)

        self.send_master_q = UniComm(
            "CommByZmq", type="PUSH", addr=ip_addr, port=train_port
        )

        self.recv_master_q = UniComm(
            "CommByZmq", type="PULL", addr=ip_addr, port=predict_port + broker_id
        )

        self.recv_explorer_q = UniComm("ShareByPlasma")
        self.send_explorer_q = dict()
        self.explore_process = dict()
        self.processes_suspend = 0
        logging.info("init broker slave with id-{}".format(self.broker_id))
        self._metric = TimerRecorder("broker_slave", maxlen=50, fields=("send",))
        # Note: need check it if add explorer dynamic
        self._buf = ShareBuf(live=0, start=True)
Ejemplo n.º 6
0
    def __init__(self, node_config_list, start_port=None):
        self.node_config_list = node_config_list
        self.node_num = len(node_config_list)
        comm_conf = None
        if not start_port:
            comm_conf = CommConf()
            start_port = comm_conf.get_start_port()
        self.start_port = start_port
        logging.info("master broker init on port: {}".format(start_port))
        self.comm_conf = comm_conf

        recv_port, send_port = get_port(start_port)
        self.recv_slave = UniComm("CommByZmq", type="PULL", port=recv_port)
        self.send_slave = [
            UniComm("CommByZmq", type="PUSH", port=send_port + i)
            for i in range(self.node_num)
        ]

        self.recv_local_q = UniComm("LocalMsg")
        self.send_local_q = dict()

        self.main_task = None
        self.metric = TimerRecorder("master", maxlen=50, fields=("send", "recv"))
Ejemplo n.º 7
0
class TrainWorker(object):
    """TrainWorker Process manage the trajectory data set and optimizer."""
    def __init__(
        self,
        train_q,
        alg,
        lock,
        model_path,
        model_q,
        s3_path,
        max_step,
        stats_deliver,
        eval_adapter=None,
    ):
        self.train_q = train_q
        self.alg = alg
        self.lock = lock
        self.model_path = model_path
        self.model_q = model_q
        self.actor_reward = dict()
        self.actor_trajectory = dict()
        self.rewards = []
        self.s3_path = s3_path
        self.max_step = max_step
        self.actual_step = 0
        self.won_in_episodes = deque(maxlen=256)
        self.train_count = 0

        self.stats_deliver = stats_deliver
        self.e_adapter = eval_adapter

        self.logger = Logger(os.path.dirname(model_path))
        self._metric = TimerRecorder("leaner_model",
                                     maxlen=50,
                                     fields=("fix_weight", "send"))

    def _dist_policy(self, weight=None, save_index=-1, dist_cmd="explore"):
        """Distribute model tool."""
        ctr_info = self.alg.dist_model_policy.get_dist_info(save_index)

        if isinstance(ctr_info, dict):
            ctr_info = [ctr_info]

        for _ctr in ctr_info:
            to_send_data = message(weight, cmd=dist_cmd, **_ctr)
            self.model_q.send(to_send_data)

    def _handle_eval_process(self, loss):
        if self.e_adapter and self.e_adapter.if_eval(self.train_count):
            weights = self.alg.get_weights()
            self.e_adapter.to_eval(weights, self.train_count, self.actual_step,
                                   self.logger.elapsed_time,
                                   self.logger.train_reward, loss)
            eval_ret = self.e_adapter.fetch_eval_result()
            if eval_ret:
                logging.debug("eval stats: {}".format(eval_ret))
                self.stats_deliver.send({
                    "data": eval_ret,
                    "is_bm": True
                },
                                        block=True)

    def train(self):
        """Train model."""
        total_count = 0  # if on the off policy, total count > train count
        save_count = 0

        if not self.alg.async_flag:
            policy_weight = self.alg.get_weights()
            self._dist_policy(weight=policy_weight)

        while True:
            for _tf_val in range(self.alg.prepare_data_times):
                logging.debug("wait data for preparing-{}...".format(_tf_val))
                with self.logger.wait_sample_timer:
                    data = self.train_q.recv()
                with self.logger.prepare_data_timer:
                    data = bytes_to_str(data)
                    self.record_reward(data)
                    self.alg.prepare_data(data["data"],
                                          ctr_info=data["ctr_info"])
                logging.debug("Prepared data-{}.".format(_tf_val))
                # support sync model before

            if self.max_step and self.actual_step >= self.max_step:
                break

            total_count += 1
            if not self.alg.train_ready(total_count,
                                        dist_dummy_model=self._dist_policy):
                continue

            with self.lock, self.logger.train_timer:
                logging.debug("start train process-{}.".format(
                    self.train_count))
                loss = self.alg.train(episode_num=total_count)
                self.train_count += 1

            if type(loss) in (float, np.float64, np.float32, np.float16,
                              np.float):
                self.logger.record(train_loss=loss)

            # The requirement of distribute model is checkpoint ready.
            # if self.alg.checkpoint_ready(self.train_count):
            with self.lock:
                if self.alg.if_save(self.train_count):
                    _name = self.alg.save(self.model_path, self.train_count)
                    # logging.debug("to save model: {}".format(_name))

            self._handle_eval_process(loss)

            if not self.alg.async_flag and self.alg.checkpoint_ready(
                    self.train_count):

                _save_t1 = time()
                policy_weight = self.alg.get_weights()
                self._metric.append(fix_weight=time() - _save_t1)

                _dist_st = time()
                self._dist_policy(policy_weight, save_count)
                self._metric.append(send=time() - _dist_st)
                self._metric.report_if_need()

            save_count += 1
            if save_count % 5 == 1:
                self.stats_deliver.send(self.logger.get_new_info(), block=True)

    def record_reward(self, train_data):
        """Record reward in train."""
        broker_id = get_msg_info(train_data, 'broker_id')
        explorer_id = get_msg_info(train_data, 'explorer_id')
        agent_id = get_msg_info(train_data, 'agent_id')
        key = (broker_id, explorer_id, agent_id)

        self.alg.dist_model_policy.add_processed_ctr_info(key)
        data_dict = get_msg_data(train_data)

        # update multi agent train reward without done flag
        if self.alg.alg_name in ("ppo_share_weights", ):
            self.actual_step += len(data_dict["done"])
            self.logger.record(
                step=self.actual_step,
                train_reward=np.sum(data_dict["reward"]),
                train_count=self.train_count,
            )
            return
        elif self.alg.alg_name in ("QMixAlg", ):  # fixme: unify the record op
            self.actual_step += np.sum(data_dict["filled"])
            self.won_in_episodes.append(data_dict.pop("battle_won"))
            self.logger.update(
                explore_won_rate=np.nanmean(self.won_in_episodes))

            self.logger.record(
                step=self.actual_step,
                train_reward=np.sum(data_dict["reward"]),
                train_count=self.train_count,
            )
            return

        if key not in self.actor_reward.keys():
            self.actor_reward[key] = 0.0
            self.actor_trajectory[key] = 0

        data_length = len(data_dict["done"])  # fetch the train data length
        for data_index in range(data_length):
            reward = data_dict["reward"][data_index]
            done = data_dict["done"][data_index]
            info = data_dict["info"][data_index]
            self.actual_step += 1
            if isinstance(info, dict):
                self.actor_reward[key] += info.get("eval_reward", reward)
                self.actor_trajectory[key] += 1
                done = info.get("real_done", done)

            if done:
                self.logger.record(
                    step=self.actual_step,
                    train_count=self.train_count,
                    train_reward=self.actor_reward[key],
                )

                logging.debug("{} epi reward-{}. with len-{}".format(
                    key, self.actor_reward[key], self.actor_trajectory[key]))

                self.actor_reward[key] = 0.0
                self.actor_trajectory[key] = 0
Ejemplo n.º 8
0
class TrainWorker(object):
    """TrainWorker Process manage the trajectory data set and optimizer."""
    def __init__(
        self,
        train_q,
        alg,
        lock,
        model_path,
        model_q,
        s3_path,
        max_step,
        max_episode,
        stats_deliver,
        eval_adapter=None,
        **kwargs,
    ):
        self.train_q = train_q
        self.alg = alg
        self.lock = lock
        self.model_path = model_path
        self.model_q = model_q
        self.actor_reward = defaultdict(float)
        self.actor_trajectory = defaultdict(int)
        self.rewards = []
        self.s3_path = s3_path
        self.max_step = max_step
        self.actual_step = 0
        self.name = kwargs.get('name', 'T0')

        self.max_episode = max_episode
        self.elapsed_episode = 0  # off policy, elapsed_episode > train count

        self.won_in_episodes = deque(maxlen=256)
        self.train_count = 0

        self.stats_deliver = stats_deliver
        self.e_adapter = eval_adapter

        self.logger = Logger(os.path.dirname(model_path))
        self._metric = TimerRecorder("leaner_model",
                                     maxlen=50,
                                     fields=("fix_weight", "send"))

        self._log_interval = kwargs["log_interval"]

        self._explorer_ids = None
        self._pbt_aid = None
        self._train_data_counter = defaultdict(int)

    @property
    def explorer_ids(self):
        return self._explorer_ids

    @explorer_ids.setter
    def explorer_ids(self, val):
        self._explorer_ids = val

    @property
    def pbt_aid(self):
        return self._pbt_aid

    @pbt_aid.setter
    def pbt_aid(self, val):
        self._pbt_aid = val

    def _dist_policy(self, weight=None, save_index=-1, dist_cmd="explore"):
        """Distribute model tool."""
        explorer_set = self.explorer_ids

        ctr_info = self.alg.dist_model_policy.get_dist_info(
            save_index, explorer_set)

        if isinstance(ctr_info, dict):
            ctr_info = [ctr_info]

        for _ctr in ctr_info:
            to_send_data = message(weight, cmd=dist_cmd, **_ctr)
            self.model_q.send(to_send_data)

    def _handle_eval_process(self, loss):
        if not self.e_adapter:
            return

        if self.e_adapter.if_eval(self.train_count):
            weights = self.alg.get_weights()
            self.e_adapter.to_eval(weights, self.train_count, self.actual_step,
                                   self.logger.elapsed_time,
                                   self.logger.train_reward, loss)
        elif not self.e_adapter.eval_result_empty:
            eval_ret = self.e_adapter.fetch_eval_result()
            if eval_ret:
                logging.debug("eval stats: {}".format(eval_ret))
                self.stats_deliver.send({
                    "data": eval_ret,
                    "is_bm": True
                },
                                        block=True)

    def _meet_stop(self):
        if self.max_step and self.actual_step > self.max_step:
            return True

        # Under pbt set, the max_episode need set into pbt_config
        # Owing to the reset of episode count after each pbt.exploit
        if self.max_episode and self.elapsed_episode > self.max_episode:
            return True

        return False

    def train(self):
        """Train model."""
        if not self.alg.async_flag:
            policy_weight = self.alg.get_weights()
            self._dist_policy(weight=policy_weight)

        loss = 0
        while True:
            for _tf_val in range(self.alg.prepare_data_times):
                # logging.debug("wait data for preparing-{}...".format(_tf_val))
                with self.logger.wait_sample_timer:
                    data = self.train_q.recv()
                with self.logger.prepare_data_timer:
                    data = bytes_to_str(data)
                    self.record_reward(data)
                    self.alg.prepare_data(data["data"],
                                          ctr_info=data["ctr_info"])

                # dqn series algorithm will count the 'SARSA' as one episode.
                # and, episodic count will used for train ready flag.
                # each pbt exploit.step will reset the episodic count.
                self.elapsed_episode += 1
                # logging.debug("Prepared data-{}.".format(_tf_val))
                # support sync model before

            # run pbt if need.
            if self.pbt_aid:
                if self.pbt_aid.meet_stop(self.elapsed_episode):
                    break

                cur_info = dict(
                    episodic_reward_mean=self.logger.train_reward_avg,
                    elapsed_step=self.actual_step,
                    elapsed_episode=self.elapsed_episode)
                new_alg = self.pbt_aid.step(cur_info, cur_alg=self.alg)

                if new_alg:  # re-assign algorithm if need
                    self.alg = new_alg
                    if not self.alg.async_flag:
                        policy_weight = self.alg.get_weights()
                        self._dist_policy(weight=policy_weight)
                        continue

            if self._meet_stop():
                self.stats_deliver.send(self.logger.get_new_info(), block=True)
                break

            if not self.alg.train_ready(self.elapsed_episode,
                                        dist_dummy_model=self._dist_policy):
                continue

            with self.lock, self.logger.train_timer:
                # logging.debug("start train process-{}.".format(self.train_count))
                loss = self.alg.train(episode_num=self.elapsed_episode)

            if type(loss) in (float, np.float64, np.float32, np.float16,
                              np.float):
                self.logger.record(train_loss=loss)

            with self.lock:
                if self.alg.if_save(self.train_count):
                    _name = self.alg.save(self.model_path, self.train_count)
                    # logging.debug("to save model: {}".format(_name))

            self._handle_eval_process(loss)

            # The requirement of distribute model is checkpoint ready.
            if not self.alg.async_flag and self.alg.checkpoint_ready(
                    self.train_count):
                _save_t1 = time()
                policy_weight = self.alg.get_weights()
                self._metric.append(fix_weight=time() - _save_t1)

                _dist_st = time()
                self._dist_policy(policy_weight, self.train_count)
                self._metric.append(send=time() - _dist_st)
                self._metric.report_if_need()
            else:
                if self.alg.checkpoint_ready(self.train_count):
                    policy_weight = self.alg.get_weights()
                    weight_msg = message(policy_weight,
                                         cmd="predict{}".format(self.name),
                                         sub_cmd='sync_weights')
                    self.model_q.send(weight_msg)

            if self.train_count % self._log_interval == 0:
                self.stats_deliver.send(self.logger.get_new_info(), block=True)

            self.train_count += 1

    def record_reward(self, train_data):
        """Record reward in train."""
        broker_id = get_msg_info(train_data, 'broker_id')
        explorer_id = get_msg_info(train_data, 'explorer_id')
        agent_id = get_msg_info(train_data, 'agent_id')
        key = (broker_id, explorer_id, agent_id)
        # key = learner_stats_id(train_data["ctr_info"])
        # record the train_data received
        self._train_data_counter[key] += 1

        self.alg.dist_model_policy.add_processed_ctr_info(key)
        data_dict = get_msg_data(train_data)

        # update multi agent train reward without done flag
        if self.alg.alg_name in ("QMixAlg", ):  # fixme: unify the record op
            self.actual_step += np.sum(data_dict["filled"])
            self.won_in_episodes.append(data_dict.pop("battle_won"))
            self.logger.update(
                explore_won_rate=np.nanmean(self.won_in_episodes))

            self.logger.record(
                step=self.actual_step,
                train_reward=np.sum(data_dict["reward"]),
                train_count=self.train_count,
            )
            return
        elif self.alg.alg_config['api_type'] == "unified":
            self.actual_step += len(data_dict["done"])
            self.logger.record(
                step=self.actual_step,
                train_reward=np.sum(data_dict["reward"]),
                train_count=self.train_count,
            )
            return

        data_length = len(data_dict["done"])  # fetch the train data length
        for data_index in range(data_length):
            reward = data_dict["reward"][data_index]
            done = data_dict["done"][data_index]
            info = data_dict["info"][data_index]
            self.actual_step += 1
            if isinstance(info, dict):
                self.actor_reward[key] += info.get("eval_reward", reward)
                self.actor_trajectory[key] += 1
                done = info.get("real_done", done)

            if done:
                self.logger.record(
                    step=self.actual_step,
                    train_count=self.train_count,
                    train_reward=self.actor_reward[key],
                    trajectory_length=self.actor_trajectory[key],
                )
                # logging.debug("{} epi reward-{}. with len-{}".format(
                #     key, self.actor_reward[key], self.actor_trajectory[key]))
                self.actor_reward[key] = 0.0
                self.actor_trajectory[key] = 0
Ejemplo n.º 9
0
class Controller(object):
    """Controller Manage Broker within Learner."""
    def __init__(self, node_config_list):
        self.node_config_list = node_config_list
        self.node_num = len(node_config_list)

        self.recv_broker = UniComm("CommByZmq", type="PULL")
        self.send_broker = [
            UniComm("CommByZmq", type="PUSH") for _i in range(self.node_num)
        ]

        self.port_info = {
            "recv": ast.literal_eval(self.recv_broker.info),
            "send": [ast.literal_eval(_s.info) for _s in self.send_broker]
        }

        port_info_h = pprint.pformat(
            self.port_info,
            indent=0,
            width=1,
        )
        logging.info("Init Broker server info:\n{}\n".format(port_info_h))

        self.recv_local_q = dict()  # UniComm("LocalMsg")
        self.send_local_q = dict()

        self.data_manager = Manager()
        self._data_store = dict()

        self._main_task = list()
        self.metric = TimerRecorder("Controller",
                                    maxlen=50,
                                    fields=("send", "recv"))
        self.stats = BrokerStats()

        if DebugConf.trace:
            tracemalloc.start()

    def start_data_transfer(self):
        """Start transfer data and other thread."""
        data_transfer_thread = threading.Thread(target=self.recv_broker_task)
        data_transfer_thread.setDaemon(True)
        data_transfer_thread.start()

        data_transfer_thread = threading.Thread(target=self.recv_local)
        data_transfer_thread.setDaemon(True)
        data_transfer_thread.start()

        # alloc_thread = threading.Thread(target=self.alloc_actor)
        # alloc_thread.setDaemon(True)
        # alloc_thread.start()

    @property
    def tasks(self):
        return self._main_task

    def recv_broker_task(self):
        """Receive remote train data in sync mode."""
        while True:
            ctr_info, recv_data = self.recv_broker.recv_bytes()
            _t0 = time.time()
            ctr_info = deserialize(ctr_info)
            compress_flag = ctr_info.get('compress_flag', False)
            if compress_flag:
                recv_data = lz4.frame.decompress(recv_data)
            recv_data = deserialize(recv_data)
            recv_data = {'ctr_info': ctr_info, 'data': recv_data}
            self.metric.append(recv=time.time() - _t0)

            cmd = get_msg_info(recv_data, "cmd")
            send_cmd = self.send_local_q.get(cmd)
            if send_cmd:
                send_cmd.send(recv_data)
            else:
                logging.warning("invalid cmd: {}, with date: {}".format(
                    cmd, recv_data))

            # report log
            self.metric.report_if_need(field_sets=("send", "recv"))

    def _yield_local_msg(self):
        """Yield local msg received."""
        while True:
            # polling with whole learner with no wait.
            for cmd, recv_q in self.recv_local_q.items():
                if 'predict' in cmd:
                    try:
                        recv_data = recv_q.get(block=False)
                    except:
                        recv_data = None
                else:
                    recv_data = recv_q.recv(block=False)
                if recv_data:
                    yield recv_data
                else:
                    time.sleep(0.002)

    def recv_local(self):
        """Receive local cmd."""
        # split the case between single receive queue and pbt
        if len(self.recv_local_q) == 1:
            single_stub, = self.recv_local_q.values()
        else:
            single_stub = None

        kwargs = init_main_broker_debug_kwargs()
        yield_func = self._yield_local_msg()

        while True:
            if single_stub:
                recv_data = single_stub.recv(block=True)
            else:
                recv_data = next(yield_func)

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)
                break

            if cmd in self.send_local_q.keys():
                # print(self.send_local_q.keys())
                self.send_local_q[cmd].send(recv_data)
                # logging.debug("recv: {} with cmd-{}".format(
                #     recv_data["data"], cmd))
            else:
                _t1 = time.time()
                broker_id = get_msg_info(recv_data, "broker_id")
                _cmd = get_msg_info(recv_data, "cmd")
                # logging.debug("ctr recv:{} with cmd:'{}' to broker_id: <{}>".format(
                #     type(recv_data["data"]), _cmd, broker_id))
                # self.metric.append(debug=time.time() - _t1)

                if broker_id == -1:
                    for broker_item, node_info in zip(self.send_broker,
                                                      self.node_config_list):
                        broker_item.send(recv_data['ctr_info'],
                                         recv_data['data'])
                else:
                    self.send_broker[broker_id].send(recv_data['ctr_info'],
                                                     recv_data['data'])
                self.metric.append(send=time.time() - _t1)
                debug_within_interval(**kwargs)

    def add_task(self, learner_obj):
        """Add learner task into Broker."""
        self._main_task.append(learner_obj)

    def register(self, cmd, direction, comm_q=None):
        """Register cmd vary with type.

        :param cmd: name to register.
        :type cmd: str
         :param direction: direction relate to broker.
        :type direction: str
        :return:  stub of the local queue with registered.
        :rtype: option
        """
        if not comm_q:
            comm_q = UniComm("LocalMsg")

        if direction == "send":
            self.send_local_q.update({cmd: comm_q})
            return self.send_local_q[cmd]
        elif direction == "recv":
            self.recv_local_q.update({cmd: comm_q})
            return self.recv_local_q[cmd]
        elif direction == "store":
            self._data_store.update({cmd: self.data_manager.dict()})
            return self._data_store[cmd]
        else:
            raise KeyError("invalid register key: {}".format(direction))

    def alloc_actor(self):
        while True:
            time.sleep(10)
            if not self.send_local_q.get("train"):
                continue

            train_list = self.send_local_q["train"].comm.data_list
            if len(train_list) > 200:
                self.send_alloc_msg("decrease")
            elif len(train_list) < 10:
                self.send_alloc_msg("increase")

    def send_alloc_msg(self, actor_status):
        alloc_cmd = {
            "ctr_info": {
                "cmd": actor_status,
                "actor_id": -1,
                "explorer_id": -1
            }
        }
        for q in self.send_broker:
            q.send(alloc_cmd['ctr_info'], alloc_cmd['data'])

    def close(self, close_cmd):
        for broker_item in self.send_broker:
            broker_item.send(close_cmd['ctr_info'], close_cmd['data'])

        # close ctx may mismatch the socket, use the os.exit last.
        # self.recv_broker.close()
        # for _send_stub in self.send_broker:
        #     _send_stub.close()

    def start(self):
        """Start all system."""
        setproctitle.setproctitle("xt_main")
        self.start_data_transfer()

    def tasks_loop(self):
        """
        Create the tasks_loop after ready the messy setup works.

        The foreground task of Controller.
        :return:
        """
        if not self._main_task:
            logging.fatal("Without learning process ready!")

        train_thread = [
            threading.Thread(target=task.main_loop) for task in self.tasks
        ]
        for task in train_thread:
            task.start()
            self.stats.add_relation_task(task)

        # check broker stats.
        self.stats.loop()

        # wait to job end.
        for task in train_thread:
            task.join()

    def stop(self):
        """Stop all system."""
        close_cmd = message(None, cmd="close")
        for _learner_id, recv_q in self.recv_local_q.items():
            if 'predict' in _learner_id:
                continue
            else:
                recv_q.send(close_cmd)
        time.sleep(0.1)
Ejemplo n.º 10
0
class Broker(object):
    """Broker manage the Broker within Explorer of each node."""
    def __init__(self, ip_addr, broker_id, push_port, pull_port):
        self.broker_id = broker_id
        self.send_controller_q = UniComm("CommByZmq",
                                         type="PUSH",
                                         addr=ip_addr,
                                         port=push_port)

        self.recv_controller_q = UniComm("CommByZmq",
                                         type="PULL",
                                         addr=ip_addr,
                                         port=pull_port)

        self.recv_explorer_q_ready = False
        # record the information between explorer and learner
        #     {"learner_id": UniComm("ShareByPlasma")}
        # add {"default_eval": UniComm("ShareByPlasma")}
        self.explorer_share_qs = {"EVAL0": None}

        # {"recv_id": receive_count}  --> {("recv_id", "explorer_id"): count}
        self.explorer_stats = defaultdict(int)

        self.send_explorer_q = dict()
        self.explore_process = dict()
        self.processes_suspend = 0
        logging.info("init broker with id-{}".format(self.broker_id))
        self._metric = TimerRecorder("broker", maxlen=50, fields=("send", ))
        # Note: need check it if add explorer dynamic
        # buf size vary with env_num&algorithm
        # ~4M, impala atari model
        self._buf = ShareBuf(live=0, size=400000000, max_keep=94, start=True)

    def start_data_transfer(self):
        """Start transfer data and other thread."""
        data_transfer_thread = threading.Thread(
            target=self.recv_controller_task)
        data_transfer_thread.start()

        data_transfer_thread = threading.Thread(target=self.recv_explorer)
        data_transfer_thread.start()

    def _setup_share_qs_firstly(self, config_info):
        """Setup only once time."""
        if self.recv_explorer_q_ready:
            return

        _use_pbt, pbt_size, env_num, _ = get_pbt_set(config_info)
        plasma_size = config_info.get("plasma_size", 100000000)

        for i in range(pbt_size):
            plasma_path = "/tmp/plasma{}T{}".format(os.getpid(), i)
            self.explorer_share_qs["T{}".format(i)] = UniComm("ShareByPlasma",
                                                              size=plasma_size,
                                                              path=plasma_path)

            # print("broker pid:", os.getpid())
            # self.explorer_share_qs["T{}".format(i)].comm.connect()

        # if pbt, eval process will share single server
        # else, share with T0
        if not _use_pbt:
            self.explorer_share_qs["EVAL0"] = self.explorer_share_qs["T0"]
        else:  # use pbt, server will been set within create_evaluator
            self.explorer_share_qs["EVAL0"] = None

        self.recv_explorer_q_ready = True

    def recv_controller_task(self):
        """Recv remote train data in sync mode."""
        while True:
            # recv, data will deserialize with pyarrow default
            # recv_data = self.recv_controller_q.recv()
            ctr_info, data = self.recv_controller_q.recv_bytes()
            recv_data = {
                'ctr_info': deserialize(ctr_info),
                'data': deserialize(data)
            }

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)

            if cmd in ["create_explorer"]:
                config_set = recv_data["data"]
                # setup plasma only one times!
                self._setup_share_qs_firstly(config_set)

                config_set.update({"share_path": self._buf.get_path()})
                self.create_explorer(config_set)

                # update the buffer live attribute, explorer num as default.
                # self._buf.plus_one_live()
                continue
            if cmd in ["create_evaluator"]:
                # evaluator share single plasma.
                config_set = recv_data["data"]

                if not self.explorer_share_qs["EVAL0"]:
                    use_pbt, _, _, _ = get_pbt_set(config_set)
                    if use_pbt:
                        plasma_size = config_set.get("plasma_size", 100000000)
                        plasma_path = "/tmp/plasma{}EVAL0".format(os.getpid())
                        self.explorer_share_qs["EVAL0"] = UniComm(
                            "ShareByPlasma",
                            size=plasma_size,
                            path=plasma_path)

                config_set.update({"share_path": self._buf.get_path()})
                logging.debug(
                    "create evaluator with config:{}".format(config_set))

                self.create_evaluator(config_set)
                # self._buf.plus_one_live()
                continue

            if cmd in ["increase", "decrease"]:
                self.alloc(cmd)
                continue

            if cmd in ("eval", ):  # fixme: merge into explore
                test_id = get_msg_info(recv_data, "test_id")
                self.send_explorer_q[test_id].put(recv_data)
                continue

            # last job, distribute weights/model_name from controller
            explorer_id = get_msg_info(recv_data, "explorer_id")
            if not isinstance(explorer_id, list):
                explorer_id = [explorer_id]

            _t0 = time.time()

            if cmd in ("explore", ):  # todo: could mv to first priority.
                # here, only handle explore weights
                buf_id = self._buf.put(data)
                # replace weight with id
                recv_data.update({"data": buf_id})
            # predict_reply
            # e.g, {'ctr_info': {'broker_id': 0, 'explorer_id': 4, 'agent_id': -1,
            # 'cmd': 'predict_reply'}, 'data': 0}

            for _eid in explorer_id:
                if _eid > -1:
                    self.send_explorer_q[_eid].put(recv_data)
                elif _eid > -2:  # -1 # whole explorer, contains evaluator!
                    for qid, send_q in self.send_explorer_q.items():
                        if isinstance(qid, str) and "test" in qid:
                            # logging.info("continue test: ", qid, send_q)
                            continue

                        send_q.put(recv_data)
                else:
                    raise KeyError("invalid explore id: {}".format(_eid))

            self._metric.append(send=time.time() - _t0)
            self._metric.report_if_need()

    @staticmethod
    def _handle_data(ctr_info, data, explorer_stub, broker_stub):
        object_id = ctr_info['object_id']
        ctr_info_data = ctr_info['ctr_info_data']
        broker_stub.send_bytes(ctr_info_data, data)
        explorer_stub.delete(object_id)

    def _step_explorer_msg(self, use_single_stub):
        """Yield local msg received."""

        if use_single_stub:
            # whole explorer share single plasma
            recv_id = "T0"
            single_stub = self.explorer_share_qs[recv_id]
        else:
            single_stub, recv_id = None, None

        while True:
            if use_single_stub:
                ctr_info, data = single_stub.recv_bytes(block=True)
                self._handle_data(ctr_info, data, single_stub,
                                  self.send_controller_q)
                yield recv_id, ctr_info
            else:
                # polling with whole learner with no wait.
                for recv_id, recv_q in self.explorer_share_qs.items():
                    if not recv_q:  # handle eval dummy q
                        continue

                    ctr_info, data = recv_q.recv_bytes(block=False)
                    if not ctr_info:  # this receive q without ready!
                        time.sleep(0.002)
                        continue

                    self._handle_data(ctr_info, data, recv_q,
                                      self.send_controller_q)
                    yield recv_id, ctr_info

    def recv_explorer(self):
        """Recv explorer cmd."""
        while not self.recv_explorer_q_ready:
            # wait explorer share buffer ready
            time.sleep(0.05)
        # whole explorer share single plasma
        use_single_flag = True if len(self.explorer_share_qs) == 2 else False
        yield_func = self._step_explorer_msg(use_single_flag)

        while True:
            recv_id, _info = next(yield_func)
            _id = stats_id(_info)
            self.explorer_stats[_id] += 1
            debug_within_interval(logs=dict(self.explorer_stats),
                                  interval=DebugConf.interval_s,
                                  human_able=True)

    def create_explorer(self, config_info):
        """Create explorer."""
        env_para = config_info.get("env_para")
        env_num = config_info.get("env_num")
        speedup = config_info.get("speedup", True)
        start_core = config_info.get("start_core", 1)
        env_id = env_para.get("env_id")  # used for explorer id.

        ref_learner_id = config_info.get("learner_postfix")

        send_explorer = Queue()
        explorer = Explorer(
            config_info,
            self.broker_id,
            recv_broker=send_explorer,
            send_broker=self.explorer_share_qs[ref_learner_id],
        )

        p = Process(target=explorer.start)
        p.start()

        cpu_count = psutil.cpu_count()
        if speedup and cpu_count > (env_num + start_core):
            _p = psutil.Process(p.pid)
            _p.cpu_affinity([start_core + env_id])

        self.send_explorer_q.update({env_id: send_explorer})
        self.explore_process.update({env_id: p})

    def create_evaluator(self, config_info):
        """Create evaluator."""
        test_id = config_info.get("test_id")
        send_evaluator = Queue()

        evaluator = Evaluator(
            config_info,
            self.broker_id,
            recv_broker=send_evaluator,
            send_broker=self.explorer_share_qs["EVAL0"],
        )
        p = Process(target=evaluator.start)
        p.start()

        speedup = config_info.get("speedup", False)
        start_core = config_info.get("start_core", 1)
        eval_num = config_info.get("benchmark",
                                   {}).get("eval", {}).get("evaluator_num", 1)
        env_num = config_info.get("env_num")

        core_set = env_num + start_core
        cpu_count = psutil.cpu_count()
        if speedup and cpu_count > (env_num + eval_num + start_core):
            _p = psutil.Process(p.pid)
            _p.cpu_affinity([core_set])
            core_set += 1

        self.send_explorer_q.update({test_id: send_evaluator})
        self.explore_process.update({test_id: p})

    def alloc(self, actor_status):
        """Monitor system and adjust resource."""
        p_id = [_p.pid for _, _p in self.explore_process.items()]
        p = [psutil.Process(_pid) for _pid in p_id]

        if actor_status == "decrease":
            if self.processes_suspend < len(p):
                p[self.processes_suspend].suspend()
                self.processes_suspend += 1
        elif actor_status == "increase":
            if self.processes_suspend >= 1:
                p[self.processes_suspend - 1].resume()
                self.processes_suspend -= 1
            else:
                pass
        elif actor_status == "reset":
            # resume all processes suspend
            for _, resume_p in enumerate(p):
                resume_p.resume()

    def close(self, close_cmd):
        """Close broker."""
        for _, send_q in self.send_explorer_q.items():
            send_q.put(close_cmd)
        time.sleep(2)

        for _, p in self.explore_process.items():
            if p.exitcode is None:
                p.terminate()

        # self.send_controller_q.close()
        # self.recv_controller_q.close()

        os.system("pkill plasma -g " + str(os.getpgid(0)))
        os._exit(0)

    def start(self):
        """Start all system."""
        setproctitle.setproctitle("xt_broker")
        self.start_data_transfer()
Ejemplo n.º 11
0
class BrokerMaster(object):
    """BrokerMaster Manage Broker within Learner."""

    def __init__(self, node_config_list, start_port=None):
        self.node_config_list = node_config_list
        self.node_num = len(node_config_list)
        comm_conf = None
        if not start_port:
            comm_conf = CommConf()
            start_port = comm_conf.get_start_port()
        self.start_port = start_port
        logging.info("master broker init on port: {}".format(start_port))
        self.comm_conf = comm_conf

        recv_port, send_port = get_port(start_port)
        self.recv_slave = UniComm("CommByZmq", type="PULL", port=recv_port)
        self.send_slave = [
            UniComm("CommByZmq", type="PUSH", port=send_port + i)
            for i in range(self.node_num)
        ]

        self.recv_local_q = UniComm("LocalMsg")
        self.send_local_q = dict()

        self.main_task = None
        self.metric = TimerRecorder("master", maxlen=50, fields=("send", "recv"))

    def start_data_transfer(self):
        """Start transfer data and other thread."""
        data_transfer_thread = threading.Thread(target=self.recv_broker_slave)
        data_transfer_thread.setDaemon(True)
        data_transfer_thread.start()

        data_transfer_thread = threading.Thread(target=self.recv_local)
        data_transfer_thread.setDaemon(True)
        data_transfer_thread.start()

        # alloc_thread = threading.Thread(target=self.alloc_actor)
        # alloc_thread.setDaemon(True)
        # alloc_thread.start()

    def recv_broker_slave(self):
        """Receive remote train data in sync mode."""
        while True:
            recv_data = self.recv_slave.recv_bytes()
            _t0 = time.time()
            recv_data = deserialize(lz4.frame.decompress(recv_data))
            self.metric.append(recv=time.time() - _t0)

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in []:
                pass
            else:
                send_cmd = self.send_local_q.get(cmd)
                if send_cmd:
                    send_cmd.send(recv_data)

            # report log
            self.metric.report_if_need()

    def recv_local(self):
        """Receive local cmd."""
        while True:
            recv_data = self.recv_local_q.recv()
            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)

            if cmd in [self.send_local_q.keys()]:
                self.send_local_q[cmd].send(recv_data)
                logging.debug("recv: {} with cmd-{}".format(type(recv_data["data"]), cmd))
            else:
                _t1 = time.time()
                broker_id = get_msg_info(recv_data, "broker_id")
                _cmd = get_msg_info(recv_data, "cmd")
                logging.debug("master recv:{} with cmd:'{}' to broker_id: <{}>".format(
                    type(recv_data["data"]), _cmd, broker_id))
                # self.metric.append(debug=time.time() - _t1)

                if broker_id == -1:
                    for slave, node_info in zip(self.send_slave, self.node_config_list):
                        slave.send(recv_data)
                else:
                    self.send_slave[broker_id].send(recv_data)
                self.metric.append(send=time.time() - _t1)

    def register(self, cmd):
        self.send_local_q.update({cmd: UniComm("LocalMsg")})
        return self.send_local_q[cmd]

    def alloc_actor(self):
        while True:
            time.sleep(10)
            if not self.send_local_q.get("train"):
                continue

            train_list = self.send_local_q["train"].comm.data_list
            if len(train_list) > 200:
                self.send_alloc_msg("decrease")
            elif len(train_list) < 10:
                self.send_alloc_msg("increase")

    def send_alloc_msg(self, actor_status):
        alloc_cmd = {
            "ctr_info": {"cmd": actor_status, "actor_id": -1, "explorer_id": -1}
        }
        for q in self.send_slave:
            q.send(alloc_cmd)

    def close(self, close_cmd):
        for slave in self.send_slave:
            slave.send(close_cmd)

        time.sleep(1)
        try:
            self.comm_conf.release_start_port(self.start_port)
        except BaseException:
            pass

        os._exit(0)

    def start(self):
        """Start all system."""
        self.start_data_transfer()

    def main_loop(self):
        """
        Create the main_loop after ready the messy setup works.

        The foreground task of broker master.
        :return:
        """
        if not self.main_task:
            logging.fatal("learning process isn't ready!")
        self.main_task.main_loop()

    def stop(self):
        """Stop all system."""
        close_cmd = message(None, cmd="close")
        self.recv_local_q.send(close_cmd)
Ejemplo n.º 12
0
class BrokerSlave(object):
    """BrokerSlave manage the Broker within Explorer of each node."""

    def __init__(self, ip_addr, broker_id, start_port):
        self.broker_id = broker_id
        train_port, predict_port = get_port(start_port)

        self.send_master_q = UniComm(
            "CommByZmq", type="PUSH", addr=ip_addr, port=train_port
        )

        self.recv_master_q = UniComm(
            "CommByZmq", type="PULL", addr=ip_addr, port=predict_port + broker_id
        )

        self.recv_explorer_q = UniComm("ShareByPlasma")
        self.send_explorer_q = dict()
        self.explore_process = dict()
        self.processes_suspend = 0
        logging.info("init broker slave with id-{}".format(self.broker_id))
        self._metric = TimerRecorder("broker_slave", maxlen=50, fields=("send",))
        # Note: need check it if add explorer dynamic
        self._buf = ShareBuf(live=0, start=True)

    def start_data_transfer(self):
        """Start transfer data and other thread."""
        data_transfer_thread = threading.Thread(target=self.recv_master)
        data_transfer_thread.start()

        data_transfer_thread = threading.Thread(target=self.recv_explorer)
        data_transfer_thread.start()

    def recv_master(self):
        """Recv remote train data in sync mode."""
        while True:
            # recv, data will deserialize with pyarrow default
            # recv_data = self.recv_master_q.recv()
            recv_bytes = self.recv_master_q.recv_bytes()
            recv_data = deserialize(recv_bytes)

            cmd = get_msg_info(recv_data, "cmd")
            if cmd in ["close"]:
                self.close(recv_data)

            if cmd in ["create_explorer"]:
                config_set = recv_data["data"]
                config_set.update({"share_path": self._buf.get_path()})
                self.create_explorer(config_set)

                # update the buffer live attribute, explorer num as default.
                self._buf.plus_one_live()
                continue
            if cmd in ["create_evaluator"]:
                config_set = recv_data["data"]
                config_set.update({"share_path": self._buf.get_path()})
                logging.debug("create evaluator with config:{}".format(config_set))

                self.create_evaluator(config_set)
                self._buf.plus_one_live()
                continue

            if cmd in ["increase", "decrease"]:
                self.alloc(cmd)
                continue

            if cmd in ("eval",):  # fixme: merge into explore
                test_id = get_msg_info(recv_data, "test_id")
                self.send_explorer_q[test_id].put(recv_data)
                continue

            # last job, distribute weights/model_name from broker master
            explorer_id = get_msg_info(recv_data, "explorer_id")
            if not isinstance(explorer_id, list):
                explorer_id = [explorer_id]

            _t0 = time.time()
            if "explore" in cmd:
                # here, only handle explore weights
                buf_id = self._buf.put(recv_bytes)
                # replace weight with id
                recv_data.update({"data": buf_id})

            # predict_reply

            for _eid in explorer_id:
                if _eid > -1:
                    self.send_explorer_q[_eid].put(recv_data)
                elif _eid > -2:  # -1 # whole explorer, contains evaluator!
                    for qid, send_q in self.send_explorer_q.items():
                        if isinstance(qid, str) and "test" in qid:
                            # logging.info("continue test: ", qid, send_q)
                            continue

                        send_q.put(recv_data)
                else:
                    raise KeyError("invalid explore id: {}".format(_eid))

            self._metric.append(send=time.time() - _t0)
            self._metric.report_if_need()

    def recv_explorer(self):
        """Recv explorer cmd."""
        while True:
            data, object_info = self.recv_explorer_q.recv_bytes()
            object_id, data_type = object_info
            if data_type == "data":
                self.send_master_q.send_bytes(data)
            elif data_type == "buf_reduce":
                recv_data = deserialize(data)
                to_reduce_id = get_msg_data(recv_data)
                self._buf.reduce_once(to_reduce_id)
            else:
                raise KeyError("un-known data type: {}".format(data_type))
            self.recv_explorer_q.delete(object_id)

    def create_explorer(self, config_info):
        """Create explorer."""
        env_para = config_info.get("env_para")
        env_id = env_para.get("env_id")
        send_explorer = Queue()

        explorer = Explorer(
            config_info,
            self.broker_id,
            recv_broker=send_explorer,
            send_broker=self.recv_explorer_q,
        )

        p = Process(target=explorer.start)
        p.daemon = True
        p.start()

        self.send_explorer_q.update({env_id: send_explorer})
        self.explore_process.update({env_id: p})

    def create_evaluator(self, config_info):
        """Create evaluator."""
        test_id = config_info.get("test_id")
        send_evaluator = Queue()

        evaluator = Evaluator(
            config_info,
            self.broker_id,
            recv_broker=send_evaluator,
            send_broker=self.recv_explorer_q,
        )
        p = Process(target=evaluator.start)
        p.daemon = True
        p.start()

        self.send_explorer_q.update({test_id: send_evaluator})
        self.explore_process.update({test_id: p})

    def alloc(self, actor_status):
        """Monitor system and adjust resource."""
        p_id = [_p.pid for _, _p in self.explore_process.items()]
        p = [psutil.Process(_pid) for _pid in p_id]

        if actor_status == "decrease":
            if self.processes_suspend < len(p):
                p[self.processes_suspend].suspend()
                self.processes_suspend += 1
        elif actor_status == "increase":
            if self.processes_suspend >= 1:
                p[self.processes_suspend - 1].resume()
                self.processes_suspend -= 1
            else:
                pass
        elif actor_status == "reset":
            # resume all processes suspend
            for _, resume_p in enumerate(p):
                resume_p.resume()

    def close(self, close_cmd):
        """Close broker."""
        for _, send_q in self.send_explorer_q.items():
            send_q.put(close_cmd)
        time.sleep(5)

        for _, p in self.explore_process.items():
            if p.exitcode is None:
                p.terminate()

        os.system("pkill plasma -g " + str(os.getpgid(0)))
        os._exit(0)

    def start(self):
        """Start all system."""
        self.start_data_transfer()