class FollowerTrainerMaster(TrainerMaster): def __init__(self, application_id, data_source, start_time, end_time, online_training): super(FollowerTrainerMaster, self).__init__(application_id, None, online_training) self._data_block_set = DataBlockSet() self._data_block_visitor = DataBlockVisitor(data_source, ETCD_NAME, ETCD_BASE_DIR, ETCD_ADDR) self._start_time = start_time self._end_time = end_time def _load_data(self): checkpoint = self._get_checkpoint() # pylint: disable=line-too-long for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame( self._start_time, self._end_time).items(): if block_id not in checkpoint: logging.debug('load data block id %s path %s', block_id, block_item.data_block_fpath) self._data_block_set.add(block_item) logging.debug("FollowerTrainerMaster: get all block %s", self._data_block_set) def _alloc_data_block(self, block_id=None): logging.debug("FollowerTrainerMaster is getting block %s", block_id) if not block_id: raise Exception('follower tm need block_id to alloc.') return self._data_block_set.get(block_id)
class LeaderTrainerMaster(TrainerMaster): def __init__(self, application_id, data_source, start_time, end_time, online_training): super(LeaderTrainerMaster, self).__init__(application_id, None, online_training) self._data_block_queue = DataBlockQueue() self._data_block_visitor = DataBlockVisitor(data_source, ETCD_NAME, ETCD_BASE_DIR, ETCD_ADDR) self._start_time = start_time self._end_time = end_time def _load_data(self): checkpoint = self._get_checkpoint() # pylint: disable=line-too-long for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame( self._start_time, self._end_time).items(): if block_id not in checkpoint: logging.debug('load data block id %s path %s', block_id, block_item.data_block_fpath) self._data_block_queue.put(block_item) def _alloc_data_block(self, block_id=None): # block_id is unused in leader role data_blocks_resp = None if not self._data_block_queue.empty(): data_blocks_resp = self._data_block_queue.get() return data_blocks_resp
class FollowerTrainerMaster(TrainerMaster): def __init__(self, application_id, data_source, start_time, end_time, online_training): super(FollowerTrainerMaster, self).__init__(application_id, None, online_training) self._data_block_set = DataBlockSet() kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on" self._data_block_visitor = DataBlockVisitor( data_source, db_database, db_base_dir, db_addr, db_username, db_password, kvstore_use_mock) self._start_time = start_time self._end_time = end_time def _load_data(self): checkpoint = self._get_checkpoint() # pylint: disable=line-too-long for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame( self._start_time, self._end_time).items(): if block_id not in checkpoint: logging.debug('load data block id %s path %s', block_id, block_item.data_block_fpath) self._data_block_set.add(block_item) logging.debug("FollowerTrainerMaster: get all block %s", self._data_block_set) def _alloc_data_block(self, block_id=None): logging.debug("FollowerTrainerMaster is getting block %s", block_id) if not block_id: raise Exception('follower tm need block_id to alloc.') return self._data_block_set.get(block_id)
class LeaderTrainerMaster(TrainerMaster): def __init__(self, application_id, data_source, start_time, end_time, online_training, shuffle_data_block, epoch_num): super(LeaderTrainerMaster, self).__init__(application_id, None, online_training) kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on" self._data_block_queue = DataBlockQueue() self._data_block_visitor = DataBlockVisitor( data_source, db_database, db_base_dir, db_addr, db_username, db_password, kvstore_use_mock) self._start_time = start_time self._end_time = end_time self._epoch_num = epoch_num self._shuffle_data_block = shuffle_data_block self._visited_data_blocks = set() self._lock = threading.Lock() if online_training: assert self._epoch_num == 1 and not self._shuffle_data_block, \ "epoch_num must be 1 and shuffle_data_block must be False " \ "online_training is set" assert self._epoch_num >= 1, \ "epoch_num {} must >= 1".format(self._epoch_num) def _load_data(self): checkpoint = self._get_checkpoint() # pylint: disable=line-too-long data_block_reps = [ dbr for dbr in self._data_block_visitor.LoadDataBlockRepByTimeFrame( self._start_time, self._end_time).values() if dbr.block_id not in checkpoint and dbr.block_id not in self._visited_data_blocks] self._visited_data_blocks.update([i.block_id for i in data_block_reps]) if self._online_training: data_block_reps.sort(key=lambda x: x.data_block_index) for rnd in range(self._epoch_num): if self._shuffle_data_block: random.shuffle(data_block_reps) for dbr in data_block_reps: logging.debug('epoch round-%d: add data block id %s path %s', rnd, dbr.block_id, dbr.data_block_fpath) self._data_block_queue.put(dbr) def _alloc_data_block(self, block_id=None): # block_id is unused in leader role with self._lock: if self._data_block_queue.empty() and self._online_training: self._load_data() if self._data_block_queue.empty(): return None data_blocks_resp = self._data_block_queue.get() with self._checkpoint_mutex: self._allocated_data_blockids.add(data_blocks_resp.block_id) return data_blocks_resp
def __init__(self, role, path, files=None, ext='.tfrecord', start_time=None, end_time=None, from_data_source=False, skip_datablock_checkpoint=False, epoch_num=1): self._role = role self._path = path self._block_queue = [] self._block_map = {} self._allocated_data_blockids = set() self._status = tm_pb.MasterStatus.CREATED if from_data_source: data_block_visitor = DataBlockVisitor(path, db_database, db_base_dir, db_addr, db_username, db_password) # pylint: disable=line-too-long for block_id, block_item in data_block_visitor.LoadDataBlockRepByTimeFrame( start_time, end_time).items(): self._block_queue.append(block_item) self._block_map[block_id] = block_item else: if files is None: files = [] for dirname, _, filenames in tf.io.gfile.walk(path): for filename in filenames: _, fileext = os.path.splitext(filename) if ext and fileext != ext: continue subdirname = os.path.relpath(dirname, path) files.append(os.path.join(subdirname, filename)) files.sort() # Hack way for supporting multiple epochs blocks = [] for filename in files: block_id, _ = os.path.splitext(os.path.basename(filename)) fullname = os.path.join(path, filename) block = DataBlockInfo(block_id, fullname) blocks.append(block) self._block_map = {block.block_id: block for block in blocks} for rnd in range(epoch_num): for block in blocks: self._block_queue.append(block) self._status = tm_pb.MasterStatus.INITIALING if skip_datablock_checkpoint: self._status = tm_pb.MasterStatus.RUNNING
def __init__(self, role, path, files=None, ext='.tfrecord', start_time=None, end_time=None, from_data_source=False): self._role = role self._path = path self._block_queue = [] self._block_map = {} if from_data_source: data_block_visitor = DataBlockVisitor(path, db_database, db_base_dir, db_addr, db_username, db_password) # pylint: disable=line-too-long for block_id, block_item in data_block_visitor.LoadDataBlockRepByTimeFrame( start_time, end_time).items(): self._block_queue.append(block_item) self._block_map[block_id] = block_item else: if files is None: files = [] for dirname, _, filenames in tf.io.gfile.walk(path): for filename in filenames: _, fileext = os.path.splitext(filename) if ext and fileext != ext: continue subdirname = os.path.relpath(dirname, path) files.append(os.path.join(subdirname, filename)) files.sort() block_map = {} for filename in files: block_id, _ = os.path.splitext(os.path.basename(filename)) assert block_id not in block_map, \ "Duplicate file names: %s and %s"%( filename, block_map[block_id]) block_map[block_id] = filename fullname = os.path.join(path, filename) block = DataBlockInfo(block_id, fullname) self._block_queue.append(block) self._block_map[block_id] = block
def __init__(self, role, path, files=None, ext='.tfrecord', start_time=None, end_time=None, from_data_source=False): self._role = role self._path = path self._block_queue = [] self._block_map = {} if from_data_source: data_block_visitor = DataBlockVisitor(path, ETCD_NAME, ETCD_BASE_DIR, ETCD_ADDR) # pylint: disable=line-too-long for block_id, block_item in data_block_visitor.LoadDataBlockRepByTimeFrame( start_time, end_time).items(): self._block_queue.append(block_item) self._block_map[block_id] = block_item else: if files is None: files = [] for filename in os.listdir(path): fullname = os.path.join(path, filename) if not os.path.isfile(fullname): continue _, fileext = os.path.splitext(filename) if ext and fileext != ext: continue files.append(filename) files.sort() for filename in files: block_id, _ = os.path.splitext(filename) fullname = os.path.join(path, filename) block = DataBlockInfo(block_id, fullname) self._block_queue.append(block) self._block_map[block_id] = block
class LeaderTrainerMaster(object): def __init__(self, application_id, data_source, start_time, end_time, online_training, shuffle_data_block, epoch_num): self._application_id = application_id self._online_training = online_training self._checkpoint_mutex = threading.Lock() self._allocated_data_blockids = None self._status_mutex = threading.Lock() self._status = tm_pb.MasterStatus.CREATED kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on" self._data_block_queue = DataBlockQueue() self._data_block_visitor = DataBlockVisitor( data_source, db_database, db_base_dir, db_addr, db_username, db_password, kvstore_use_mock) self._start_time = start_time self._end_time = end_time self._epoch_num = epoch_num self._shuffle_data_block = shuffle_data_block self._visited_data_blocks = set() self._lock = threading.Lock() if online_training: assert self._epoch_num == 1 and not self._shuffle_data_block, \ "epoch_num must be 1 and shuffle_data_block must be False " \ "online_training is set" assert self._epoch_num >= 1, \ "epoch_num {} must >= 1".format(self._epoch_num) def run(self, listen_port): self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) tm_grpc.add_TrainerMasterServiceServicer_to_server( TrainerMasterServer(self._data_block_response, self._get_checkpoint_fn, self._restore_checkpoint_fn), self._server) self._server.add_insecure_port('[::]:%d' % listen_port) self._server.start() logging.info('Trainer Master Server start on port[%d].', listen_port) self._transfer_status(tm_pb.MasterStatus.CREATED, tm_pb.MasterStatus.INITIALING) self._server.wait_for_termination() def _transfer_status(self, frm, to, callback_fn=lambda *args: True): with self._status_mutex: if self._status == frm: self._status = to return callback_fn() logging.warning("%s invalid status transfer, from %d to %d, " "while status is %d", self.__class__.__name__, frm, to, self._status) self._status = tm_pb.MasterStatus.ERROR return False def _check_status(self, callback_fn): with self._status_mutex: return callback_fn(self._status) raise ValueError("unreachable") def _get_checkpoint_fn(self, request): assert request.application_id == self._application_id, \ "Application id not matched" response = tm_pb.GetDataBlockCheckpointResponse() ckpt_not_ready_fn = lambda status: status not in \ (tm_pb.MasterStatus.RUNNING, tm_pb.MasterStatus.FINISHED) if self._check_status(ckpt_not_ready_fn): response.status.code = common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT response.status.error_message = \ "master is not ready for querying daya checkpoint" return response response.status.code = common_pb.STATUS_SUCCESS response.status.error_message = 'success' response.block_ids.extend(list(self._allocated_data_blockids)) return response def _restore_checkpoint_fn(self, request): assert request.application_id == self._application_id,\ "Application id not matched: %s vs %s"%( request.application_id, self._application_id) response = tm_pb.RestoreDataBlockCheckpointResponse() no_need_restore_fn = lambda status: status in (\ tm_pb.MasterStatus.RUNNING,\ tm_pb.MasterStatus.FINISHED,\ tm_pb.MasterStatus.ERROR) if self._check_status(no_need_restore_fn): logging.info("No need to restore %s", self.__class__.__name__) response.status.code = common_pb.STATUS_SUCCESS response.status.error_message = "success" return response # In case of race, load data before state transfering to RUNNING, and # after filling data checkpoint with self._checkpoint_mutex: self._allocated_data_blockids = set(request.block_ids) self._load_data() trans_ok = self._transfer_status(tm_pb.MasterStatus.INITIALING, tm_pb.MasterStatus.RUNNING) if not trans_ok: response.status.code = common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT response.status.error_message = \ "must sync data checkpoint before alloc" return response response.status.code = common_pb.STATUS_SUCCESS response.status.error_message = "success" return response def _get_checkpoint(self): return self._allocated_data_blockids def _data_block_response(self, request): response = tm_pb.DataBlockResponse() def status_check_fn(status): response = tm_pb.DataBlockResponse() if status in (tm_pb.MasterStatus.FINISHED, \ tm_pb.MasterStatus.ERROR): response.status.code = common_pb.STATUS_DATA_FINISHED response.status.error_message = 'datablock finished' return response if status != tm_pb.MasterStatus.RUNNING: response.status.code = \ common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT response.status.error_message = \ "must sync data checkpoint before alloc" return response #only if status is RUNNING return True ready = self._check_status(status_check_fn) if ready is not True: return ready data_block = self._alloc_data_block(block_id=request.block_id) if data_block: logging.debug("%s allocated worker_%d with block id %s", self.__class__.__name__, request.worker_rank, data_block.block_id) response.status.code = common_pb.STATUS_SUCCESS response.status.error_message = 'success' response.data_block_info.data_path = \ str(data_block.data_block_fpath) response.data_block_info.meta_path = '' response.data_block_info.block_id = str(data_block.block_id) elif self._online_training: logging.debug("%s allocated worker_%d with empty data block. "\ "wait for new data block since online traning", self.__class__.__name__, request.worker_rank) response.status.code = common_pb.STATUS_NO_MORE_DATA response.status.error_message = 'please wait for datablock ready' else: logging.debug("%s allocated worker_%d with empty data block. "\ "exit running since since batch traning", self.__class__.__name__, request.worker_rank) response.status.code = common_pb.STATUS_DATA_FINISHED response.status.error_message = 'datablock finished' if response.status.code == common_pb.STATUS_DATA_FINISHED: self._transfer_status(tm_pb.MasterStatus.RUNNING, tm_pb.MasterStatus.FINISHED) return response def _load_data(self): checkpoint = self._get_checkpoint() # pylint: disable=line-too-long logging.info("load_data, checkpoint: %s", checkpoint) data_block_reps = [ dbr for dbr in self._data_block_visitor.LoadDataBlockRepByTimeFrame( self._start_time, self._end_time).values() if dbr.block_id not in checkpoint and dbr.block_id not in self._visited_data_blocks] self._visited_data_blocks.update([i.block_id for i in data_block_reps]) if self._online_training: data_block_reps.sort(key=lambda x: x.data_block_index) for rnd in range(self._epoch_num): if self._shuffle_data_block: random.shuffle(data_block_reps) for dbr in data_block_reps: logging.debug('epoch round-%d: add data block id %s path %s', rnd, dbr.block_id, dbr.data_block_fpath) self._data_block_queue.put(dbr) def _alloc_data_block(self, block_id=None): # block_id is unused in leader role with self._lock: if self._data_block_queue.empty() and self._online_training: logging.info("Load data when queue empty and online training") self._load_data() if self._data_block_queue.empty(): logging.info("Allocate when data_block_queue is empty") return None data_blocks_resp = self._data_block_queue.get() with self._checkpoint_mutex: self._allocated_data_blockids.add(data_blocks_resp.block_id) return data_blocks_resp