コード例 #1
0
class FollowerTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source, start_time, end_time,
                 online_training):
        super(FollowerTrainerMaster, self).__init__(application_id, None,
                                                    online_training)
        self._data_block_set = DataBlockSet()
        self._data_block_visitor = DataBlockVisitor(data_source, ETCD_NAME,
                                                    ETCD_BASE_DIR, ETCD_ADDR)
        self._start_time = start_time
        self._end_time = end_time

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).items():
            if block_id not in checkpoint:
                logging.debug('load data block id %s path %s', block_id,
                              block_item.data_block_fpath)
                self._data_block_set.add(block_item)
        logging.debug("FollowerTrainerMaster: get all block %s",
                      self._data_block_set)

    def _alloc_data_block(self, block_id=None):
        logging.debug("FollowerTrainerMaster is getting block %s", block_id)
        if not block_id:
            raise Exception('follower tm need block_id to alloc.')
        return self._data_block_set.get(block_id)
コード例 #2
0
class LeaderTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source, start_time, end_time,
                 online_training):
        super(LeaderTrainerMaster, self).__init__(application_id, None,
                                                  online_training)
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(data_source, ETCD_NAME,
                                                    ETCD_BASE_DIR, ETCD_ADDR)
        self._start_time = start_time
        self._end_time = end_time

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).items():
            if block_id not in checkpoint:
                logging.debug('load data block id %s path %s', block_id,
                              block_item.data_block_fpath)
                self._data_block_queue.put(block_item)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        data_blocks_resp = None
        if not self._data_block_queue.empty():
            data_blocks_resp = self._data_block_queue.get()
        return data_blocks_resp
コード例 #3
0
ファイル: follower_tm.py プロジェクト: zyenggook/fedlearner
class FollowerTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source,
                 start_time, end_time, online_training):
        super(FollowerTrainerMaster, self).__init__(application_id,
                                                    None, online_training)
        self._data_block_set = DataBlockSet()
        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_visitor = DataBlockVisitor(
            data_source, db_database, db_base_dir, db_addr,
                db_username, db_password, kvstore_use_mock)
        self._start_time = start_time
        self._end_time = end_time

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).items():
            if block_id not in checkpoint:
                logging.debug('load data block id %s path %s',
                              block_id, block_item.data_block_fpath)
                self._data_block_set.add(block_item)
        logging.debug("FollowerTrainerMaster: get all block %s",
                      self._data_block_set)

    def _alloc_data_block(self, block_id=None):
        logging.debug("FollowerTrainerMaster is getting block %s", block_id)
        if not block_id:
            raise Exception('follower tm need block_id to alloc.')
        return self._data_block_set.get(block_id)
コード例 #4
0
ファイル: leader_tm.py プロジェクト: zhangqixun/fedlearner
class LeaderTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source,
                 start_time, end_time, online_training,
                 shuffle_data_block, epoch_num):
        super(LeaderTrainerMaster, self).__init__(application_id,
                                                  None, online_training)
        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(
            data_source, db_database, db_base_dir, db_addr,
                db_username, db_password, kvstore_use_mock)
        self._start_time = start_time
        self._end_time = end_time
        self._epoch_num = epoch_num
        self._shuffle_data_block = shuffle_data_block
        self._visited_data_blocks = set()
        self._lock = threading.Lock()
        if online_training:
            assert self._epoch_num == 1 and not self._shuffle_data_block, \
                "epoch_num must be 1 and shuffle_data_block must be False " \
                "online_training is set"
        assert self._epoch_num >= 1, \
                "epoch_num {} must >= 1".format(self._epoch_num)

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        data_block_reps = [
            dbr for dbr in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).values()
            if dbr.block_id not in checkpoint and
               dbr.block_id not in self._visited_data_blocks]

        self._visited_data_blocks.update([i.block_id for i in data_block_reps])

        if self._online_training:
            data_block_reps.sort(key=lambda x: x.data_block_index)
        for rnd in range(self._epoch_num):
            if self._shuffle_data_block:
                random.shuffle(data_block_reps)
            for dbr in data_block_reps:
                logging.debug('epoch round-%d: add data block id %s path %s',
                              rnd, dbr.block_id, dbr.data_block_fpath)
                self._data_block_queue.put(dbr)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        with self._lock:
            if self._data_block_queue.empty() and self._online_training:
                self._load_data()

            if self._data_block_queue.empty():
                return None

            data_blocks_resp = self._data_block_queue.get()
            with self._checkpoint_mutex:
                self._allocated_data_blockids.add(data_blocks_resp.block_id)
            return data_blocks_resp
コード例 #5
0
    def __init__(self,
                 role,
                 path,
                 files=None,
                 ext='.tfrecord',
                 start_time=None,
                 end_time=None,
                 from_data_source=False,
                 skip_datablock_checkpoint=False,
                 epoch_num=1):
        self._role = role
        self._path = path
        self._block_queue = []
        self._block_map = {}
        self._allocated_data_blockids = set()
        self._status = tm_pb.MasterStatus.CREATED
        if from_data_source:
            data_block_visitor = DataBlockVisitor(path, db_database,
                                                  db_base_dir, db_addr,
                                                  db_username, db_password)
            # pylint: disable=line-too-long
            for block_id, block_item in data_block_visitor.LoadDataBlockRepByTimeFrame(
                    start_time, end_time).items():
                self._block_queue.append(block_item)
                self._block_map[block_id] = block_item
        else:
            if files is None:
                files = []
                for dirname, _, filenames in tf.io.gfile.walk(path):
                    for filename in filenames:
                        _, fileext = os.path.splitext(filename)
                        if ext and fileext != ext:
                            continue
                        subdirname = os.path.relpath(dirname, path)
                        files.append(os.path.join(subdirname, filename))
            files.sort()

            # Hack way for supporting multiple epochs
            blocks = []
            for filename in files:
                block_id, _ = os.path.splitext(os.path.basename(filename))
                fullname = os.path.join(path, filename)
                block = DataBlockInfo(block_id, fullname)
                blocks.append(block)
            self._block_map = {block.block_id: block for block in blocks}
            for rnd in range(epoch_num):
                for block in blocks:
                    self._block_queue.append(block)

        self._status = tm_pb.MasterStatus.INITIALING
        if skip_datablock_checkpoint:
            self._status = tm_pb.MasterStatus.RUNNING
コード例 #6
0
    def __init__(self,
                 role,
                 path,
                 files=None,
                 ext='.tfrecord',
                 start_time=None,
                 end_time=None,
                 from_data_source=False):
        self._role = role
        self._path = path
        self._block_queue = []
        self._block_map = {}
        if from_data_source:
            data_block_visitor = DataBlockVisitor(path, db_database,
                                                  db_base_dir, db_addr,
                                                  db_username, db_password)
            # pylint: disable=line-too-long
            for block_id, block_item in data_block_visitor.LoadDataBlockRepByTimeFrame(
                    start_time, end_time).items():
                self._block_queue.append(block_item)
                self._block_map[block_id] = block_item
        else:
            if files is None:
                files = []
                for dirname, _, filenames in tf.io.gfile.walk(path):
                    for filename in filenames:
                        _, fileext = os.path.splitext(filename)
                        if ext and fileext != ext:
                            continue
                        subdirname = os.path.relpath(dirname, path)
                        files.append(os.path.join(subdirname, filename))
            files.sort()

            block_map = {}
            for filename in files:
                block_id, _ = os.path.splitext(os.path.basename(filename))
                assert block_id not in block_map, \
                    "Duplicate file names: %s and %s"%(
                        filename, block_map[block_id])
                block_map[block_id] = filename
                fullname = os.path.join(path, filename)
                block = DataBlockInfo(block_id, fullname)
                self._block_queue.append(block)
                self._block_map[block_id] = block
コード例 #7
0
    def __init__(self,
                 role,
                 path,
                 files=None,
                 ext='.tfrecord',
                 start_time=None,
                 end_time=None,
                 from_data_source=False):
        self._role = role
        self._path = path
        self._block_queue = []
        self._block_map = {}
        if from_data_source:
            data_block_visitor = DataBlockVisitor(path, ETCD_NAME,
                                                  ETCD_BASE_DIR, ETCD_ADDR)
            # pylint: disable=line-too-long
            for block_id, block_item in data_block_visitor.LoadDataBlockRepByTimeFrame(
                    start_time, end_time).items():
                self._block_queue.append(block_item)
                self._block_map[block_id] = block_item
        else:
            if files is None:
                files = []
                for filename in os.listdir(path):
                    fullname = os.path.join(path, filename)
                    if not os.path.isfile(fullname):
                        continue
                    _, fileext = os.path.splitext(filename)
                    if ext and fileext != ext:
                        continue
                    files.append(filename)
            files.sort()

            for filename in files:
                block_id, _ = os.path.splitext(filename)
                fullname = os.path.join(path, filename)
                block = DataBlockInfo(block_id, fullname)
                self._block_queue.append(block)
                self._block_map[block_id] = block
コード例 #8
0
ファイル: leader_tm.py プロジェクト: flyfoxCI/fedlearner
class LeaderTrainerMaster(object):
    def __init__(self, application_id, data_source,
                 start_time, end_time, online_training,
                 shuffle_data_block, epoch_num):
        self._application_id = application_id
        self._online_training = online_training
        self._checkpoint_mutex = threading.Lock()
        self._allocated_data_blockids = None
        self._status_mutex = threading.Lock()
        self._status = tm_pb.MasterStatus.CREATED

        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(
            data_source, db_database, db_base_dir, db_addr,
                db_username, db_password, kvstore_use_mock)
        self._start_time = start_time
        self._end_time = end_time
        self._epoch_num = epoch_num
        self._shuffle_data_block = shuffle_data_block
        self._visited_data_blocks = set()
        self._lock = threading.Lock()
        if online_training:
            assert self._epoch_num == 1 and not self._shuffle_data_block, \
                "epoch_num must be 1 and shuffle_data_block must be False " \
                "online_training is set"
        assert self._epoch_num >= 1, \
                "epoch_num {} must >= 1".format(self._epoch_num)

    def run(self, listen_port):
        self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        tm_grpc.add_TrainerMasterServiceServicer_to_server(
            TrainerMasterServer(self._data_block_response,
                                self._get_checkpoint_fn,
                                self._restore_checkpoint_fn), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
        self._server.start()
        logging.info('Trainer Master Server start on port[%d].', listen_port)
        self._transfer_status(tm_pb.MasterStatus.CREATED,
                              tm_pb.MasterStatus.INITIALING)
        self._server.wait_for_termination()

    def _transfer_status(self, frm, to, callback_fn=lambda *args: True):
        with self._status_mutex:
            if self._status == frm:
                self._status = to
                return callback_fn()
            logging.warning("%s invalid status transfer, from %d to %d, "
                          "while status is %d", self.__class__.__name__,
                          frm, to, self._status)
            self._status = tm_pb.MasterStatus.ERROR
        return False

    def _check_status(self, callback_fn):
        with self._status_mutex:
            return callback_fn(self._status)
        raise ValueError("unreachable")

    def _get_checkpoint_fn(self, request):
        assert request.application_id == self._application_id, \
                "Application id not matched"
        response = tm_pb.GetDataBlockCheckpointResponse()
        ckpt_not_ready_fn = lambda status: status not in \
                (tm_pb.MasterStatus.RUNNING, tm_pb.MasterStatus.FINISHED)
        if self._check_status(ckpt_not_ready_fn):
            response.status.code = common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT
            response.status.error_message = \
                    "master is not ready for querying daya checkpoint"
            return response
        response.status.code = common_pb.STATUS_SUCCESS
        response.status.error_message = 'success'
        response.block_ids.extend(list(self._allocated_data_blockids))
        return response

    def _restore_checkpoint_fn(self, request):
        assert request.application_id == self._application_id,\
                "Application id not matched: %s vs %s"%(
                    request.application_id, self._application_id)
        response = tm_pb.RestoreDataBlockCheckpointResponse()
        no_need_restore_fn = lambda status: status in (\
                                            tm_pb.MasterStatus.RUNNING,\
                                            tm_pb.MasterStatus.FINISHED,\
                                            tm_pb.MasterStatus.ERROR)
        if self._check_status(no_need_restore_fn):
            logging.info("No need to restore %s", self.__class__.__name__)
            response.status.code = common_pb.STATUS_SUCCESS
            response.status.error_message = "success"
            return response

        # In case of race, load data before state transfering to RUNNING, and
        #   after filling data checkpoint
        with self._checkpoint_mutex:
            self._allocated_data_blockids = set(request.block_ids)
        self._load_data()

        trans_ok = self._transfer_status(tm_pb.MasterStatus.INITIALING,
                             tm_pb.MasterStatus.RUNNING)
        if not trans_ok:
            response.status.code = common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT
            response.status.error_message = \
                    "must sync data checkpoint before alloc"
            return response

        response.status.code = common_pb.STATUS_SUCCESS
        response.status.error_message = "success"
        return response

    def _get_checkpoint(self):
        return self._allocated_data_blockids

    def _data_block_response(self, request):
        response = tm_pb.DataBlockResponse()
        def status_check_fn(status):
            response = tm_pb.DataBlockResponse()
            if status in (tm_pb.MasterStatus.FINISHED, \
                    tm_pb.MasterStatus.ERROR):
                response.status.code = common_pb.STATUS_DATA_FINISHED
                response.status.error_message = 'datablock finished'
                return response
            if status != tm_pb.MasterStatus.RUNNING:
                response.status.code = \
                       common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT
                response.status.error_message = \
                        "must sync data checkpoint before alloc"
                return response
            #only if status is RUNNING
            return True

        ready = self._check_status(status_check_fn)
        if ready is not True:
            return ready
        data_block = self._alloc_data_block(block_id=request.block_id)
        if data_block:
            logging.debug("%s allocated worker_%d with block id %s",
                          self.__class__.__name__,
                          request.worker_rank,
                          data_block.block_id)
            response.status.code = common_pb.STATUS_SUCCESS
            response.status.error_message = 'success'
            response.data_block_info.data_path = \
                str(data_block.data_block_fpath)
            response.data_block_info.meta_path = ''
            response.data_block_info.block_id = str(data_block.block_id)
        elif self._online_training:
            logging.debug("%s allocated worker_%d with empty data block. "\
                          "wait for new data block since online traning",
                          self.__class__.__name__, request.worker_rank)
            response.status.code = common_pb.STATUS_NO_MORE_DATA
            response.status.error_message = 'please wait for datablock ready'
        else:
            logging.debug("%s allocated worker_%d with empty data block. "\
                          "exit running since since batch traning",
                          self.__class__.__name__, request.worker_rank)
            response.status.code = common_pb.STATUS_DATA_FINISHED
            response.status.error_message = 'datablock finished'
        if response.status.code == common_pb.STATUS_DATA_FINISHED:
            self._transfer_status(tm_pb.MasterStatus.RUNNING,
                                  tm_pb.MasterStatus.FINISHED)
        return response

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        logging.info("load_data, checkpoint: %s", checkpoint)
        data_block_reps = [
            dbr for dbr in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).values()
            if dbr.block_id not in checkpoint and
               dbr.block_id not in self._visited_data_blocks]

        self._visited_data_blocks.update([i.block_id for i in data_block_reps])

        if self._online_training:
            data_block_reps.sort(key=lambda x: x.data_block_index)
        for rnd in range(self._epoch_num):
            if self._shuffle_data_block:
                random.shuffle(data_block_reps)
            for dbr in data_block_reps:
                logging.debug('epoch round-%d: add data block id %s path %s',
                              rnd, dbr.block_id, dbr.data_block_fpath)
                self._data_block_queue.put(dbr)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        with self._lock:
            if self._data_block_queue.empty() and self._online_training:
                logging.info("Load data when queue empty and online training")
                self._load_data()

            if self._data_block_queue.empty():
                logging.info("Allocate when data_block_queue is empty")
                return None

            data_blocks_resp = self._data_block_queue.get()
            with self._checkpoint_mutex:
                self._allocated_data_blockids.add(data_blocks_resp.block_id)
            return data_blocks_resp