示例#1
0
 def test_learner_task(self):
   league_client = LeagueMgrAPIs(league_mgr_addr="localhost:11007")
   learner_id = str(uuid.uuid1())
   task = league_client.request_learner_task(learner_id=learner_id)
   self.assertTrue(isinstance(task, LearnerTask))
   query_task = league_client.query_learner_task(learner_id=learner_id)
   self.assertEqual(task.model_key, query_task.model_key)
   self.assertEqual(task.parent_model_key, query_task.parent_model_key)
   league_client.notify_learner_task_begin(learner_id=learner_id,
                                           learner_task=task)
   league_client.notify_learner_task_end(learner_id=learner_id)
示例#2
0
class InfServer(object):
    def __init__(self,
                 league_mgr_addr,
                 model_pool_addrs,
                 port,
                 ds,
                 batch_size,
                 ob_space,
                 ac_space,
                 policy,
                 outputs=['a'],
                 policy_config={},
                 gpu_id=0,
                 compress=True,
                 batch_worker_num=4,
                 update_model_seconds=60,
                 learner_id=None,
                 log_seconds=60,
                 model_key="",
                 task_attr='model_key',
                 **kwargs):
        self._update_model_seconds = update_model_seconds
        self._log_seconds = log_seconds
        self._learner_id = learner_id
        self._task_attr = task_attr.split('.')
        if model_key:
            # If model_key is given, this indicates the infserver works
            # for a fixed model inference
            self._league_mgr_apis = None
            self.is_rl = False
            self.model_key = model_key
        else:
            # If model_key is absent, this indicates an infserver
            # that performs varying policy inference, and model_key will be
            # assigned by querying league_mgr
            self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr)
            self.is_rl = True
            self.model_key = None
        self.model = None
        self._model_pool_apis = ModelPoolAPIs(model_pool_addrs)
        assert hasattr(policy, 'net_config_cls')
        assert hasattr(policy, 'net_build_fun')
        # bookkeeping
        self.ob_space = ob_space
        self.ob_space = ac_space
        self.batch_size = batch_size
        self._ac_structure = tp_utils.template_structure_from_gym_space(
            ac_space)
        self.outputs = outputs
        # build the net
        policy_config = {} if policy_config is None else policy_config
        policy_config['batch_size'] = batch_size
        use_gpu = (gpu_id >= 0)
        self.data_server = InferDataServer(
            port=port,
            batch_size=batch_size,
            ds=ds,
            batch_worker_num=batch_worker_num,
            use_gpu=use_gpu,
            compress=compress,
        )
        config = tf.ConfigProto(allow_soft_placement=True)
        if use_gpu:
            config.gpu_options.visible_device_list = str(gpu_id)
            config.gpu_options.allow_growth = True
            if 'use_xla' in policy_config and policy_config['use_xla']:
                config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        self._sess = tf.Session(config=config)
        self.nc = policy.net_config_cls(ob_space, ac_space, **policy_config)
        self.net_out = policy.net_build_fun(self.data_server._batch_input,
                                            self.nc,
                                            scope='Inf_server')
        # saving/loading ops
        self.params = self.net_out.vars.all_vars
        self.params_ph = [
            tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.params
        ]
        self.params_assign_ops = [
            p.assign(np_p) for p, np_p in zip(self.params, self.params_ph)
        ]
        # initialize the net params
        tf.global_variables_initializer().run(session=self._sess)
        self.setup_fetches(outputs)
        self.id_and_fetches = [self.data_server._batch_data_id, self.fetches]
        self._update_model()

    def load_model(self, loaded_params):
        self._sess.run(
            self.params_assign_ops[:len(loaded_params)],
            feed_dict={p: v
                       for p, v in zip(self.params_ph, loaded_params)})

    def setup_fetches(self, outputs):
        def split_batch(template, tf_structure):
            split_flatten = zip(*[
                tf.split(t, self.batch_size)
                for t in nest.flatten_up_to(template, tf_structure)
            ])
            return [
                nest.pack_sequence_as(template, flatten)
                for flatten in split_flatten
            ]

        if self.nc.use_self_fed_heads:
            a = nest.map_structure_up_to(self._ac_structure,
                                         lambda head: head.sam,
                                         self.net_out.self_fed_heads)
            neglogp = nest.map_structure_up_to(self._ac_structure,
                                               lambda head: head.neglogp,
                                               self.net_out.self_fed_heads)
            flatparam = nest.map_structure_up_to(self._ac_structure,
                                                 lambda head: head.flatparam,
                                                 self.net_out.self_fed_heads)
            self.all_outputs = {
                'a':
                split_batch(self._ac_structure, a),
                'neglogp':
                split_batch(self._ac_structure, neglogp),
                'flatparam':
                split_batch(self._ac_structure, flatparam),
                'v':
                tf.split(self.net_out.value_head, self.batch_size)
                if self.net_out.value_head is not None else [[]] *
                self.batch_size,
                'state':
                tf.split(self.net_out.S, self.batch_size)
                if self.net_out.S is not None else [[]] * self.batch_size
            }
        else:
            flatparam = nest.map_structure_up_to(self._ac_structure,
                                                 lambda head: head.flatparam,
                                                 self.net_out.outer_fed_heads)
            self.all_outputs = {
                'flatparam':
                split_batch(self._ac_structure, flatparam),
                'state':
                tf.split(self.net_out.S, self.batch_size)
                if self.net_out.S is not None else [[]] * self.batch_size
            }
        if self.nc.use_lstm and 'state' not in outputs:
            outputs.append('state')
        self.fetches = [
            dict(zip(outputs, pred))
            for pred in zip(*[self.all_outputs[o] for o in outputs])
        ]

    def _update_model(self):
        if self.is_rl:
            # if (self.model_key is None or
            #     (self.model is not None and self.model.is_freezed())):
            self._query_task()
        if self._should_update_model(self.model, self.model_key):
            self.model = self._model_pool_apis.pull_model(self.model_key)
            self.load_model(self.model.model)

    def _query_task(self):
        assert self.is_rl, '_query_task can be use in RL!'
        task = self._league_mgr_apis.query_learner_task(self._learner_id)
        while task is None:
            print('Learner has not request task! wait...')
            time.sleep(5)
            task = self._league_mgr_apis.query_learner_task(self._learner_id)
        self.last_model_key = self.model_key
        self.model_key = task
        for attr in self._task_attr:
            self.model_key = getattr(self.model_key, attr)
        return task

    def _should_update_model(self, model, model_key):
        if model is None or model_key != model.key:
            return True
        elif model.is_freezed():
            return False
        else:
            return self._model_pool_apis.pull_attr(
                'updatetime', model_key) > model.updatetime

    def run(self):
        while not self.data_server.ready:
            time.sleep(10)
            print('Waiting at least {} actors to '
                  'connect ...'.format(self.batch_size),
                  flush=True)
        last_update_time = time.time()
        last_log_time = last_update_time
        batch_num = 0
        last_log_batch_num = 0
        pid = os.getpid()
        while True:
            # input is pre-fetched in self.data_server
            data_ids, outputs = self._sess.run(self.id_and_fetches, {})
            self.data_server.response(data_ids, outputs)
            batch_num += 1
            t0 = time.time()
            if t0 > last_update_time + self._update_model_seconds:
                self._update_model()
                last_update_time = t0
            t0 = time.time()
            if t0 > last_log_time + self._log_seconds:
                cost = t0 - last_log_time
                sam_num = self.batch_size * (batch_num - last_log_batch_num)
                print(
                    'Process {} predicts {} samples costs {} seconds, fps {}'.
                    format(pid, sam_num, cost, sam_num / cost),
                    flush=True)
                last_log_batch_num = batch_num
                last_log_time = t0
示例#3
0
class BaseLearner(object):
  """Base learner class.

  Define the basic workflow for a learner."""
  def __init__(self, league_mgr_addr, model_pool_addrs, learner_ports,
               learner_id=''):
    if learner_id: self._learner_id = learner_id
    else: self._learner_id = str(uuid.uuid1())

    self._zmq_context = zmq.Context()
    self._rep_socket = self._zmq_context.socket(zmq.REP)
    self._rep_socket.bind("tcp://*:%s" % learner_ports[0])
    self._pull_socket = self._zmq_context.socket(zmq.PULL)
    self._pull_socket.setsockopt(zmq.RCVHWM, 1)
    self._pull_socket.bind("tcp://*:%s" % learner_ports[1])
    self._message_thread = Thread(target=self._message_worker)
    self._message_thread.daemon = True
    self._message_thread.start()
    self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr)
    self._model_pool_apis = ModelPoolAPIs(model_pool_addrs)

    self.task = None
    self.model_key = None
    self.last_model_key = None
    self._lrn_period_count = 0  # learning period count
    self._pull_lock = Lock()

  def run(self):
    while True:
      self.task = self._request_task()
      self._init_task()
      self._train()
      self._finish_task()
      self._lrn_period_count += 1

  @abstractmethod
  def _train(self, **kwargs):
    pass

  @abstractmethod
  def _init_task(self):
    pass

  def _request_task(self):
    task = self._league_mgr_apis.request_learner_task(self._learner_id)
    self.last_model_key = self.model_key
    self.model_key = task.model_key
    # lazy freeze the model of last lp, then actors will stop the last lp.
    if self.last_model_key and self.model_key != self.last_model_key:
      self._model_pool_apis.freeze_model(self.last_model_key)
    return task

  def _query_task(self):
    task = self._league_mgr_apis.query_learner_task(self._learner_id)
    if task is not None:
      self.last_model_key = self.model_key
      self.model_key = task.model_key
    return task

  def _finish_task(self):
    self._notify_task_end()

  def _pull_data(self):
    self._pull_lock.acquire()
    data = self._pull_socket.recv(copy=False)
    self._pull_lock.release()
    return pickle.loads(data)

  def _message_worker(self):
    while True:
      msg = self._rep_socket.recv_string()
      if msg == 'learner_id':
        self._rep_socket.send_pyobj(self._learner_id)
      else:
        raise RuntimeError("message not recognized")

  def _notify_task_begin(self, task):
    self._league_mgr_apis.notify_learner_task_begin(self._learner_id, task)

  def _notify_task_end(self):
    self._league_mgr_apis.notify_learner_task_end(self._learner_id)