def test_learner_task(self): league_client = LeagueMgrAPIs(league_mgr_addr="localhost:11007") learner_id = str(uuid.uuid1()) task = league_client.request_learner_task(learner_id=learner_id) self.assertTrue(isinstance(task, LearnerTask)) query_task = league_client.query_learner_task(learner_id=learner_id) self.assertEqual(task.model_key, query_task.model_key) self.assertEqual(task.parent_model_key, query_task.parent_model_key) league_client.notify_learner_task_begin(learner_id=learner_id, learner_task=task) league_client.notify_learner_task_end(learner_id=learner_id)
def test_actor_task(self): actor_id = str(uuid.uuid1()) learner_id = str(uuid.uuid1()) league_client = LeagueMgrAPIs(league_mgr_addr="localhost:11007") learner_task = league_client.request_learner_task(learner_id=learner_id) league_client.notify_learner_task_begin(learner_id=learner_id, learner_task=learner_task) model_client = ModelPoolAPIs(model_pool_addrs=["localhost:11001:11006"]) hyperparam = MutableHyperparam() model_client.push_model(None, hyperparam, str(uuid.uuid1())) task = league_client.request_actor_task(actor_id=actor_id, learner_id=learner_id) self.assertTrue(isinstance(task, ActorTask)) league_client.notify_actor_task_begin(actor_id=actor_id) league_client.notify_actor_task_end( actor_id=actor_id, match_result=MatchResult(task.model_key1, task.model_key2, 1))
class BaseLearner(object): """Base learner class. Define the basic workflow for a learner.""" def __init__(self, league_mgr_addr, model_pool_addrs, learner_ports, learner_id=''): if learner_id: self._learner_id = learner_id else: self._learner_id = str(uuid.uuid1()) self._zmq_context = zmq.Context() self._rep_socket = self._zmq_context.socket(zmq.REP) self._rep_socket.bind("tcp://*:%s" % learner_ports[0]) self._pull_socket = self._zmq_context.socket(zmq.PULL) self._pull_socket.setsockopt(zmq.RCVHWM, 1) self._pull_socket.bind("tcp://*:%s" % learner_ports[1]) self._message_thread = Thread(target=self._message_worker) self._message_thread.daemon = True self._message_thread.start() self._league_mgr_apis = LeagueMgrAPIs(league_mgr_addr) self._model_pool_apis = ModelPoolAPIs(model_pool_addrs) self.task = None self.model_key = None self.last_model_key = None self._lrn_period_count = 0 # learning period count self._pull_lock = Lock() def run(self): while True: self.task = self._request_task() self._init_task() self._train() self._finish_task() self._lrn_period_count += 1 @abstractmethod def _train(self, **kwargs): pass @abstractmethod def _init_task(self): pass def _request_task(self): task = self._league_mgr_apis.request_learner_task(self._learner_id) self.last_model_key = self.model_key self.model_key = task.model_key # lazy freeze the model of last lp, then actors will stop the last lp. if self.last_model_key and self.model_key != self.last_model_key: self._model_pool_apis.freeze_model(self.last_model_key) return task def _query_task(self): task = self._league_mgr_apis.query_learner_task(self._learner_id) if task is not None: self.last_model_key = self.model_key self.model_key = task.model_key return task def _finish_task(self): self._notify_task_end() def _pull_data(self): self._pull_lock.acquire() data = self._pull_socket.recv(copy=False) self._pull_lock.release() return pickle.loads(data) def _message_worker(self): while True: msg = self._rep_socket.recv_string() if msg == 'learner_id': self._rep_socket.send_pyobj(self._learner_id) else: raise RuntimeError("message not recognized") def _notify_task_begin(self, task): self._league_mgr_apis.notify_learner_task_begin(self._learner_id, task) def _notify_task_end(self): self._league_mgr_apis.notify_learner_task_end(self._learner_id)