def _write_worker(self): msg = self._sub_socket.recv_string() if msg == 'model': model = self._sub_socket.recv_pyobj() if model.key in self._model_pool: model.createtime = self._model_pool[model.key].createtime if model.createtime is None: model.createtime = model.updatetime self._model_pool[model.key] = model logger.log(now() + 'on msg write model.', 'key: {},'.format(model.key), 'create time: {}'.format(model.createtime)) elif msg == 'freeze': # freeze one model key = self._sub_socket.recv_string() if key in self._model_pool: self._model_pool[key].freeze() logger.log(now() + 'on msg write freeze', 'key: {}'.format(key)) elif msg == 'learner_meta': # store learner meta data key = self._sub_socket.recv_string() learner_meta = self._sub_socket.recv_pyobj() self._learner_meta[key] = learner_meta logger.log(now() + 'on msg write learner_meta', 'key: {}'.format(key)) else: raise RuntimeError("message {} not recognized".format(msg))
def _save_model_checkpoint(self, checkpoint_root, checkpoint_name): checkpoint_dir = os.path.join(checkpoint_root, checkpoint_name) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) logger.log(now() + 'Pulling updatetime') updatetime_dict = self._model_pool_apis.pull_all_attr('updatetime') logger.log( now() + 'Done pulling updatetime, no.={}'.format(len(updatetime_dict))) filenames = [] for model_key, updatetime in updatetime_dict.items(): filename = "%s_%s.model" % (model_key, updatetime) filepath = os.path.join(checkpoint_root, filename) filenames.append(filename + '\n') if not os.path.isfile(filepath): logger.log(now() + 'Pulling model {}'.format(model_key)) model = self._model_pool_apis.pull_model(model_key) logger.log(now() + 'Done pulling model {}'.format(model_key)) assert model_key == model.key with open(filepath, 'wb') as f: pickle.dump(model, f) if self._save_learner_meta: learner_meta = self._model_pool_apis.pull_learner_meta( model_key) pickle.dump(learner_meta, f) logger.log(now() + 'Saved model to {}'.format(f.name)) filelistpath = os.path.join(checkpoint_dir, 'filename.list') with open(filelistpath, 'w') as f: f.writelines(filenames) with open(os.path.join(checkpoint_dir, '.ready'), 'w') as f: f.write('ready') f.flush()
def _restore_checkpoint(self, checkpoint_dir): super(LeagueMgr, self)._restore_checkpoint(checkpoint_dir) logger.log('{}loading league-mgr from {}'.format(now(), checkpoint_dir)) # 3. self.game_mgr.load(checkpoint_dir) # 2. self._hyper_mgr.load(checkpoint_dir) # 1. filepath = os.path.join(checkpoint_dir, 'learner_task_table') with open(filepath, 'rb') as f: self._learner_task_table = pickle.load(f) logger.log('{}done loading league-mgr'.format(now()))
def _save_checkpoint(self, checkpoint_root, checkpoint_name): checkpoint_dir = os.path.join(checkpoint_root, checkpoint_name) logger.log('{}saving league-mgr to {}'.format(now(), checkpoint_dir)) super(LeagueMgr, self)._save_checkpoint(checkpoint_root, checkpoint_name) # 1. filepath = os.path.join(checkpoint_dir, 'learner_task_table') with open(filepath, 'wb') as f: pickle.dump(self._learner_task_table, f) # 2. self._hyper_mgr.save(checkpoint_dir) # 3. self.game_mgr.save(checkpoint_dir) logger.log('{}done saving league-mgr'.format(now()))
def _read_worker(self): while True: msg = self._rep_socket.recv_string() if msg == 'model': # get one Model key = self._rep_socket.recv_string() self._rep_socket.send_pyobj( ModelPoolErroMsg('Key {} not exits'.format(key)) if key not in self._model_pool else self._model_pool[key]) logger.log(now() + 'on msg read model,', 'key: {}'.format(key)) elif msg == 'keys': # get all the model keys self._rep_socket.send_pyobj(list(self._model_pool.keys())) logger.log(now() + 'on msg read keys,') elif msg == 'all_attr': # get the attr for all models. return {key: attr} attr = self._rep_socket.recv_string() self._rep_socket.send_pyobj({ k: (ModelPoolErroMsg('Attribute not exits') if not hasattr(v, attr) else getattr(v, attr)) for k, v in self._model_pool.items() }) logger.log(now() + 'on msg read all_attr,') elif msg == 'learner_meta': # get learner meta data key = self._rep_socket.recv_string() self._rep_socket.send_pyobj( ModelPoolErroMsg('Key not exits') if key not in self._model_pool else None if key not in self._learner_meta else self._learner_meta[key]) logger.log(now() + 'on msg read learner_meta') else: # on pull_attr(), get the attr for one model # TODO(pengsun): too tricky, should use strict msg definition attr = msg key = self._rep_socket.recv_string() self._rep_socket.send_pyobj( ModelPoolErroMsg('Key not exits') if key not in self._model_pool else ( ModelPoolErroMsg('Attribute not exits') if not hasattr(self._model_pool[key], attr) else getattr( self._model_pool[key], attr))) logger.log(now() + 'on msg read {}'.format(msg))