def main(_): league_mgr_apis = LeagueMgrAPIs(FLAGS.league_mgr_addr) with open(FLAGS.model_path, 'rb') as f: model = pickle.load(f) if not isinstance(model, Model): model = Model(model, None, None) if not model.is_freezed(): model.freeze() if FLAGS.model_key: model.key = FLAGS.model_key league_mgr_apis.request_add_model(model)
class BaseActor(metaclass=ABCMeta): def __init__(self, league_mgr_addr, model_pool_addrs, learner_addr=None, verbose=0, log_interval_steps=51): ip, hostname = get_ip_hostname() self._actor_id = hostname + '@' + ip + ':' + str(uuid.uuid1())[:8] self._learner_id = None self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) if learner_addr: self._learner_apis = LearnerAPIs(learner_addr) self._learner_id = self._learner_apis.request_learner_id() self._log_interval_steps = log_interval_steps logger.configure(dir=None, format_strs=['stdout']) logger.set_level(verbose) self.task = None self._steps = 0 def run(self): """Run an infinite loop that rollouts the trajectories for each episode.""" while True: self.task = self._request_task() # one task for one episode outcome, info = self._rollout_an_episode() self._finish_task(self.task, outcome, info) @abstractmethod def _rollout_an_episode(self): pass def _request_task(self): """Request the task for this actor.""" logger.log('entering _request_task', 'steps: {}'.format(self._steps), level=logger.DEBUG + 5) task = self._league_mgr_apis.request_actor_task( self._actor_id, self._learner_id) logger.log('leaving _request_task', level=logger.DEBUG + 5) return task def _finish_task(self, task, outcome, info=None): """Do stuff (e.g., send match result) when task finishes.""" info = info or {} logger.log('entering _finish_task', 'steps: {}'.format(self._steps), level=logger.DEBUG + 5) match_result = MatchResult(task.model_key1, task.model_key2, outcome, info) self._league_mgr_apis.notify_actor_task_end(self._actor_id, match_result) logger.log('leaving _finish_task', level=logger.DEBUG + 5)
def test_learner_task(self): league_client = LeagueMgrAPIs(league_mgr_addr="localhost:11007") learner_id = str(uuid.uuid1()) task = league_client.request_learner_task(learner_id=learner_id) self.assertTrue(isinstance(task, LearnerTask)) query_task = league_client.query_learner_task(learner_id=learner_id) self.assertEqual(task.model_key, query_task.model_key) self.assertEqual(task.parent_model_key, query_task.parent_model_key) league_client.notify_learner_task_begin(learner_id=learner_id, learner_task=task) league_client.notify_learner_task_end(learner_id=learner_id)
def test_actor_task(self): actor_id = str(uuid.uuid1()) learner_id = str(uuid.uuid1()) league_client = LeagueMgrAPIs(league_mgr_addr="localhost:11007") learner_task = league_client.request_learner_task(learner_id=learner_id) league_client.notify_learner_task_begin(learner_id=learner_id, learner_task=learner_task) model_client = ModelPoolAPIs(model_pool_addrs=["localhost:11001:11006"]) hyperparam = MutableHyperparam() model_client.push_model(None, hyperparam, str(uuid.uuid1())) task = league_client.request_actor_task(actor_id=actor_id, learner_id=learner_id) self.assertTrue(isinstance(task, ActorTask)) league_client.notify_actor_task_begin(actor_id=actor_id) league_client.notify_actor_task_end( actor_id=actor_id, match_result=MatchResult(task.model_key1, task.model_key2, 1))
def test_checkpoint(self): league_client = LeagueMgrAPIs(league_mgr_addr="localhost:11007") model_client1 = ModelPoolAPIs(model_pool_addrs=["localhost:11001:11006"]) hyperparam = MutableHyperparam() model_key1 = str(uuid.uuid1()) model_key2 = str(uuid.uuid1()) model_client1.push_model("model_data1", hyperparam, model_key1) model_client1.push_model("model_data2", hyperparam, model_key2) time.sleep(4) league_client.request_add_model( Model("model_data1", hyperparam, model_key1)) model_client1.push_model("model_data3", hyperparam, model_key2) time.sleep(3) checkpoints = [filename for filename in os.listdir("./checkpoints") if filename.startswith("checkpoint")] self.assertTrue(len(checkpoints) > 0) checkpoint_dir = os.path.join("./checkpoints", checkpoints[-1]) league_process = Process( target=lambda: LeagueMgr( port="11008", model_pool_addrs=["localhost:11011:11016"], mutable_hyperparam_type='MutableHyperparam', restore_checkpoint_dir=checkpoint_dir).run()) league_process.start() model_client2 = ModelPoolAPIs(model_pool_addrs=["localhost:11011:11016"]) time.sleep(2) keys = model_client2.pull_keys() self.assertTrue(model_key1 in keys) self.assertTrue(model_key2 in keys) model1 = model_client1.pull_model(model_key1) model2 = model_client2.pull_model(model_key1) self.assertEqual(model1.model, model2.model) self.assertEqual(model1.key, model2.key) self.assertEqual(model1.createtime, model2.createtime) model1 = model_client1.pull_model(model_key2) model2 = model_client2.pull_model(model_key2) self.assertEqual(model1.model, model2.model) self.assertEqual(model1.key, model2.key) self.assertEqual(model1.createtime, model2.createtime) league_process.terminate()
def __init__(self, league_mgr_addr, model_pool_addrs, learner_addr=None, verbose=0, log_interval_steps=51): ip, hostname = get_ip_hostname() self._actor_id = hostname + '@' + ip + ':' + str(uuid.uuid1())[:8] self._learner_id = None self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) if learner_addr: self._learner_apis = LearnerAPIs(learner_addr) self._learner_id = self._learner_apis.request_learner_id() self._log_interval_steps = log_interval_steps logger.configure(dir=None, format_strs=['stdout']) logger.set_level(verbose) self.task = None self._steps = 0
def __init__(self, league_mgr_addr, model_pool_addrs, learner_ports, learner_id=''): if learner_id: self._learner_id = learner_id else: self._learner_id = str(uuid.uuid1()) self._zmq_context = zmq.Context() self._rep_socket = self._zmq_context.socket(zmq.REP) self._rep_socket.bind("tcp://*:%s" % learner_ports[0]) self._pull_socket = self._zmq_context.socket(zmq.PULL) self._pull_socket.setsockopt(zmq.RCVHWM, 1) self._pull_socket.bind("tcp://*:%s" % learner_ports[1]) self._message_thread = Thread(target=self._message_worker) self._message_thread.daemon = True self._message_thread.start() self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) self.task = None self.model_key = None self.last_model_key = None self._lrn_period_count = 0 # learning period count self._pull_lock = Lock()
def __init__(self, league_mgr_addr, model_pool_addrs, port, ds, batch_size, ob_space, ac_space, policy, outputs=['a'], policy_config={}, gpu_id=0, compress=True, batch_worker_num=4, update_model_seconds=60, learner_id=None, log_seconds=60, model_key="", task_attr='model_key', **kwargs): self._update_model_seconds = update_model_seconds self._log_seconds = log_seconds self._learner_id = learner_id self._task_attr = task_attr.split('.') if model_key: # If model_key is given, this indicates the infserver works # for a fixed model inference self._league_mgr_apis = None self.is_rl = False self.model_key = model_key else: # If model_key is absent, this indicates an infserver # that performs varying policy inference, and model_key will be # assigned by querying league_mgr self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self.is_rl = True self.model_key = None self.model = None self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) assert hasattr(policy, 'net_config_cls') assert hasattr(policy, 'net_build_fun') # bookkeeping self.ob_space = ob_space self.ob_space = ac_space self.batch_size = batch_size self._ac_structure = tp_utils.template_structure_from_gym_space( ac_space) self.outputs = outputs # build the net policy_config = {} if policy_config is None else policy_config policy_config['batch_size'] = batch_size use_gpu = (gpu_id >= 0) self.data_server = InferDataServer( port=port, batch_size=batch_size, ds=ds, batch_worker_num=batch_worker_num, use_gpu=use_gpu, compress=compress, ) config = tf.ConfigProto(allow_soft_placement=True) if use_gpu: config.gpu_options.visible_device_list = str(gpu_id) config.gpu_options.allow_growth = True if 'use_xla' in policy_config and policy_config['use_xla']: config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 self._sess = tf.Session(config=config) self.nc = policy.net_config_cls(ob_space, ac_space, **policy_config) self.net_out = policy.net_build_fun(self.data_server._batch_input, self.nc, scope='Inf_server') # saving/loading ops self.params = self.net_out.vars.all_vars self.params_ph = [ tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.params ] self.params_assign_ops = [ p.assign(np_p) for p, np_p in zip(self.params, self.params_ph) ] # initialize the net params tf.global_variables_initializer().run(session=self._sess) self.setup_fetches(outputs) self.id_and_fetches = [self.data_server._batch_data_id, self.fetches] self._update_model()
class InfServer(object): def __init__(self, league_mgr_addr, model_pool_addrs, port, ds, batch_size, ob_space, ac_space, policy, outputs=['a'], policy_config={}, gpu_id=0, compress=True, batch_worker_num=4, update_model_seconds=60, learner_id=None, log_seconds=60, model_key="", task_attr='model_key', **kwargs): self._update_model_seconds = update_model_seconds self._log_seconds = log_seconds self._learner_id = learner_id self._task_attr = task_attr.split('.') if model_key: # If model_key is given, this indicates the infserver works # for a fixed model inference self._league_mgr_apis = None self.is_rl = False self.model_key = model_key else: # If model_key is absent, this indicates an infserver # that performs varying policy inference, and model_key will be # assigned by querying league_mgr self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self.is_rl = True self.model_key = None self.model = None self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) assert hasattr(policy, 'net_config_cls') assert hasattr(policy, 'net_build_fun') # bookkeeping self.ob_space = ob_space self.ob_space = ac_space self.batch_size = batch_size self._ac_structure = tp_utils.template_structure_from_gym_space( ac_space) self.outputs = outputs # build the net policy_config = {} if policy_config is None else policy_config policy_config['batch_size'] = batch_size use_gpu = (gpu_id >= 0) self.data_server = InferDataServer( port=port, batch_size=batch_size, ds=ds, batch_worker_num=batch_worker_num, use_gpu=use_gpu, compress=compress, ) config = tf.ConfigProto(allow_soft_placement=True) if use_gpu: config.gpu_options.visible_device_list = str(gpu_id) config.gpu_options.allow_growth = True if 'use_xla' in policy_config and policy_config['use_xla']: config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 self._sess = tf.Session(config=config) self.nc = policy.net_config_cls(ob_space, ac_space, **policy_config) self.net_out = policy.net_build_fun(self.data_server._batch_input, self.nc, scope='Inf_server') # saving/loading ops self.params = self.net_out.vars.all_vars self.params_ph = [ tf.placeholder(p.dtype, shape=p.get_shape()) for p in self.params ] self.params_assign_ops = [ p.assign(np_p) for p, np_p in zip(self.params, self.params_ph) ] # initialize the net params tf.global_variables_initializer().run(session=self._sess) self.setup_fetches(outputs) self.id_and_fetches = [self.data_server._batch_data_id, self.fetches] self._update_model() def load_model(self, loaded_params): self._sess.run( self.params_assign_ops[:len(loaded_params)], feed_dict={p: v for p, v in zip(self.params_ph, loaded_params)}) def setup_fetches(self, outputs): def split_batch(template, tf_structure): split_flatten = zip(*[ tf.split(t, self.batch_size) for t in nest.flatten_up_to(template, tf_structure) ]) return [ nest.pack_sequence_as(template, flatten) for flatten in split_flatten ] if self.nc.use_self_fed_heads: a = nest.map_structure_up_to(self._ac_structure, lambda head: head.sam, self.net_out.self_fed_heads) neglogp = nest.map_structure_up_to(self._ac_structure, lambda head: head.neglogp, self.net_out.self_fed_heads) flatparam = nest.map_structure_up_to(self._ac_structure, lambda head: head.flatparam, self.net_out.self_fed_heads) self.all_outputs = { 'a': split_batch(self._ac_structure, a), 'neglogp': split_batch(self._ac_structure, neglogp), 'flatparam': split_batch(self._ac_structure, flatparam), 'v': tf.split(self.net_out.value_head, self.batch_size) if self.net_out.value_head is not None else [[]] * self.batch_size, 'state': tf.split(self.net_out.S, self.batch_size) if self.net_out.S is not None else [[]] * self.batch_size } else: flatparam = nest.map_structure_up_to(self._ac_structure, lambda head: head.flatparam, self.net_out.outer_fed_heads) self.all_outputs = { 'flatparam': split_batch(self._ac_structure, flatparam), 'state': tf.split(self.net_out.S, self.batch_size) if self.net_out.S is not None else [[]] * self.batch_size } if self.nc.use_lstm and 'state' not in outputs: outputs.append('state') self.fetches = [ dict(zip(outputs, pred)) for pred in zip(*[self.all_outputs[o] for o in outputs]) ] def _update_model(self): if self.is_rl: # if (self.model_key is None or # (self.model is not None and self.model.is_freezed())): self._query_task() if self._should_update_model(self.model, self.model_key): self.model = self._model_pool_apis.pull_model(self.model_key) self.load_model(self.model.model) def _query_task(self): assert self.is_rl, '_query_task can be use in RL!' task = self._league_mgr_apis.query_learner_task(self._learner_id) while task is None: print('Learner has not request task! wait...') time.sleep(5) task = self._league_mgr_apis.query_learner_task(self._learner_id) self.last_model_key = self.model_key self.model_key = task for attr in self._task_attr: self.model_key = getattr(self.model_key, attr) return task def _should_update_model(self, model, model_key): if model is None or model_key != model.key: return True elif model.is_freezed(): return False else: return self._model_pool_apis.pull_attr( 'updatetime', model_key) > model.updatetime def run(self): while not self.data_server.ready: time.sleep(10) print('Waiting at least {} actors to ' 'connect ...'.format(self.batch_size), flush=True) last_update_time = time.time() last_log_time = last_update_time batch_num = 0 last_log_batch_num = 0 pid = os.getpid() while True: # input is pre-fetched in self.data_server data_ids, outputs = self._sess.run(self.id_and_fetches, {}) self.data_server.response(data_ids, outputs) batch_num += 1 t0 = time.time() if t0 > last_update_time + self._update_model_seconds: self._update_model() last_update_time = t0 t0 = time.time() if t0 > last_log_time + self._log_seconds: cost = t0 - last_log_time sam_num = self.batch_size * (batch_num - last_log_batch_num) print( 'Process {} predicts {} samples costs {} seconds, fps {}'. format(pid, sam_num, cost, sam_num / cost), flush=True) last_log_batch_num = batch_num last_log_time = t0
class BaseLearner(object): """Base learner class. Define the basic workflow for a learner.""" def __init__(self, league_mgr_addr, model_pool_addrs, learner_ports, learner_id=''): if learner_id: self._learner_id = learner_id else: self._learner_id = str(uuid.uuid1()) self._zmq_context = zmq.Context() self._rep_socket = self._zmq_context.socket(zmq.REP) self._rep_socket.bind("tcp://*:%s" % learner_ports[0]) self._pull_socket = self._zmq_context.socket(zmq.PULL) self._pull_socket.setsockopt(zmq.RCVHWM, 1) self._pull_socket.bind("tcp://*:%s" % learner_ports[1]) self._message_thread = Thread(target=self._message_worker) self._message_thread.daemon = True self._message_thread.start() self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) self.task = None self.model_key = None self.last_model_key = None self._lrn_period_count = 0 # learning period count self._pull_lock = Lock() def run(self): while True: self.task = self._request_task() self._init_task() self._train() self._finish_task() self._lrn_period_count += 1 @abstractmethod def _train(self, **kwargs): pass @abstractmethod def _init_task(self): pass def _request_task(self): task = self._league_mgr_apis.request_learner_task(self._learner_id) self.last_model_key = self.model_key self.model_key = task.model_key # lazy freeze the model of last lp, then actors will stop the last lp. if self.last_model_key and self.model_key != self.last_model_key: self._model_pool_apis.freeze_model(self.last_model_key) return task def _query_task(self): task = self._league_mgr_apis.query_learner_task(self._learner_id) if task is not None: self.last_model_key = self.model_key self.model_key = task.model_key return task def _finish_task(self): self._notify_task_end() def _pull_data(self): self._pull_lock.acquire() data = self._pull_socket.recv(copy=False) self._pull_lock.release() return pickle.loads(data) def _message_worker(self): while True: msg = self._rep_socket.recv_string() if msg == 'learner_id': self._rep_socket.send_pyobj(self._learner_id) else: raise RuntimeError("message not recognized") def _notify_task_begin(self, task): self._league_mgr_apis.notify_learner_task_begin(self._learner_id, task) def _notify_task_end(self): self._league_mgr_apis.notify_learner_task_end(self._learner_id)