示例#1
0
class LeaderTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source, start_time, end_time,
                 online_training):
        super(LeaderTrainerMaster, self).__init__(application_id, None,
                                                  online_training)
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(data_source, ETCD_NAME,
                                                    ETCD_BASE_DIR, ETCD_ADDR)
        self._start_time = start_time
        self._end_time = end_time

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        for block_id, block_item in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).items():
            if block_id not in checkpoint:
                logging.debug('load data block id %s path %s', block_id,
                              block_item.data_block_fpath)
                self._data_block_queue.put(block_item)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        data_blocks_resp = None
        if not self._data_block_queue.empty():
            data_blocks_resp = self._data_block_queue.get()
        return data_blocks_resp
示例#2
0
class LeaderTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source,
                 start_time, end_time, online_training,
                 shuffle_data_block, epoch_num):
        super(LeaderTrainerMaster, self).__init__(application_id,
                                                  None, online_training)
        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(
            data_source, db_database, db_base_dir, db_addr,
                db_username, db_password, kvstore_use_mock)
        self._start_time = start_time
        self._end_time = end_time
        self._epoch_num = epoch_num
        self._shuffle_data_block = shuffle_data_block
        self._visited_data_blocks = set()
        self._lock = threading.Lock()
        if online_training:
            assert self._epoch_num == 1 and not self._shuffle_data_block, \
                "epoch_num must be 1 and shuffle_data_block must be False " \
                "online_training is set"
        assert self._epoch_num >= 1, \
                "epoch_num {} must >= 1".format(self._epoch_num)

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        data_block_reps = [
            dbr for dbr in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).values()
            if dbr.block_id not in checkpoint and
               dbr.block_id not in self._visited_data_blocks]

        self._visited_data_blocks.update([i.block_id for i in data_block_reps])

        if self._online_training:
            data_block_reps.sort(key=lambda x: x.data_block_index)
        for rnd in range(self._epoch_num):
            if self._shuffle_data_block:
                random.shuffle(data_block_reps)
            for dbr in data_block_reps:
                logging.debug('epoch round-%d: add data block id %s path %s',
                              rnd, dbr.block_id, dbr.data_block_fpath)
                self._data_block_queue.put(dbr)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        with self._lock:
            if self._data_block_queue.empty() and self._online_training:
                self._load_data()

            if self._data_block_queue.empty():
                return None

            data_blocks_resp = self._data_block_queue.get()
            with self._checkpoint_mutex:
                self._allocated_data_blockids.add(data_blocks_resp.block_id)
            return data_blocks_resp
示例#3
0
    def test_trainer_master(self):
        db1 = DataBlock('1', 'data_path1', 'meta_path1')
        db2 = DataBlock('2', 'data_path2', 'meta_path2')
        db_queue = DataBlockQueue()
        db_queue.put(db1)
        db_queue.put(db2)

        self.assertEqual(db_queue.get(), db1)
        self.assertEqual(db_queue.get(), db2)
示例#4
0
class LeaderTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source, start_time, end_time,
                 online_training, shuffle_data_block, epoch_num):
        super(LeaderTrainerMaster, self).__init__(application_id, None,
                                                  online_training)
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(data_source, db_database,
                                                    db_base_dir, db_addr,
                                                    db_username, db_password)
        self._start_time = start_time
        self._end_time = end_time
        self._epoch_num = epoch_num
        self._shuffle_data_block = shuffle_data_block
        if online_training:
            self._epoch_num = 1
            self._shuffle_data_block = False
            logging.warning("Online Training will ignore args "\
                            "epoch and shuffle data block")
        assert self._epoch_num >= 1, \
                "epoch_num {} must >= 1".format(self._epoch_num)

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        visitor = self._data_block_visitor
        data_block_reps = [
            dbr for dbr in visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).values()
            if dbr.block_id not in checkpoint
        ]
        if self._online_training:
            data_block_reps.sort(key=dbr.data_block_index)
        for rnd in range(self._epoch_num):
            if self._shuffle_data_block:
                random.shuffle(data_block_reps)
            for dbr in data_block_reps:
                logging.debug('epoch round-%d: add data block id %s path %s',
                              rnd, dbr.block_id, dbr.data_block_fpath)
                self._data_block_queue.put(dbr)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        data_blocks_resp = None
        if not self._data_block_queue.empty():
            data_blocks_resp = self._data_block_queue.get()
        return data_blocks_resp
示例#5
0
class LeaderTrainerMaster(object):
    def __init__(self, application_id, data_source,
                 start_time, end_time, online_training,
                 shuffle_data_block, epoch_num):
        self._application_id = application_id
        self._online_training = online_training
        self._checkpoint_mutex = threading.Lock()
        self._allocated_data_blockids = None
        self._status_mutex = threading.Lock()
        self._status = tm_pb.MasterStatus.CREATED

        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_queue = DataBlockQueue()
        self._data_block_visitor = DataBlockVisitor(
            data_source, db_database, db_base_dir, db_addr,
                db_username, db_password, kvstore_use_mock)
        self._start_time = start_time
        self._end_time = end_time
        self._epoch_num = epoch_num
        self._shuffle_data_block = shuffle_data_block
        self._visited_data_blocks = set()
        self._lock = threading.Lock()
        if online_training:
            assert self._epoch_num == 1 and not self._shuffle_data_block, \
                "epoch_num must be 1 and shuffle_data_block must be False " \
                "online_training is set"
        assert self._epoch_num >= 1, \
                "epoch_num {} must >= 1".format(self._epoch_num)

    def run(self, listen_port):
        self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        tm_grpc.add_TrainerMasterServiceServicer_to_server(
            TrainerMasterServer(self._data_block_response,
                                self._get_checkpoint_fn,
                                self._restore_checkpoint_fn), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
        self._server.start()
        logging.info('Trainer Master Server start on port[%d].', listen_port)
        self._transfer_status(tm_pb.MasterStatus.CREATED,
                              tm_pb.MasterStatus.INITIALING)
        self._server.wait_for_termination()

    def _transfer_status(self, frm, to, callback_fn=lambda *args: True):
        with self._status_mutex:
            if self._status == frm:
                self._status = to
                return callback_fn()
            logging.warning("%s invalid status transfer, from %d to %d, "
                          "while status is %d", self.__class__.__name__,
                          frm, to, self._status)
            self._status = tm_pb.MasterStatus.ERROR
        return False

    def _check_status(self, callback_fn):
        with self._status_mutex:
            return callback_fn(self._status)
        raise ValueError("unreachable")

    def _get_checkpoint_fn(self, request):
        assert request.application_id == self._application_id, \
                "Application id not matched"
        response = tm_pb.GetDataBlockCheckpointResponse()
        ckpt_not_ready_fn = lambda status: status not in \
                (tm_pb.MasterStatus.RUNNING, tm_pb.MasterStatus.FINISHED)
        if self._check_status(ckpt_not_ready_fn):
            response.status.code = common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT
            response.status.error_message = \
                    "master is not ready for querying daya checkpoint"
            return response
        response.status.code = common_pb.STATUS_SUCCESS
        response.status.error_message = 'success'
        response.block_ids.extend(list(self._allocated_data_blockids))
        return response

    def _restore_checkpoint_fn(self, request):
        assert request.application_id == self._application_id,\
                "Application id not matched: %s vs %s"%(
                    request.application_id, self._application_id)
        response = tm_pb.RestoreDataBlockCheckpointResponse()
        no_need_restore_fn = lambda status: status in (\
                                            tm_pb.MasterStatus.RUNNING,\
                                            tm_pb.MasterStatus.FINISHED,\
                                            tm_pb.MasterStatus.ERROR)
        if self._check_status(no_need_restore_fn):
            logging.info("No need to restore %s", self.__class__.__name__)
            response.status.code = common_pb.STATUS_SUCCESS
            response.status.error_message = "success"
            return response

        # In case of race, load data before state transfering to RUNNING, and
        #   after filling data checkpoint
        with self._checkpoint_mutex:
            self._allocated_data_blockids = set(request.block_ids)
        self._load_data()

        trans_ok = self._transfer_status(tm_pb.MasterStatus.INITIALING,
                             tm_pb.MasterStatus.RUNNING)
        if not trans_ok:
            response.status.code = common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT
            response.status.error_message = \
                    "must sync data checkpoint before alloc"
            return response

        response.status.code = common_pb.STATUS_SUCCESS
        response.status.error_message = "success"
        return response

    def _get_checkpoint(self):
        return self._allocated_data_blockids

    def _data_block_response(self, request):
        response = tm_pb.DataBlockResponse()
        def status_check_fn(status):
            response = tm_pb.DataBlockResponse()
            if status in (tm_pb.MasterStatus.FINISHED, \
                    tm_pb.MasterStatus.ERROR):
                response.status.code = common_pb.STATUS_DATA_FINISHED
                response.status.error_message = 'datablock finished'
                return response
            if status != tm_pb.MasterStatus.RUNNING:
                response.status.code = \
                       common_pb.STATUS_WAIT_FOR_SYNCING_CHECKPOINT
                response.status.error_message = \
                        "must sync data checkpoint before alloc"
                return response
            #only if status is RUNNING
            return True

        ready = self._check_status(status_check_fn)
        if ready is not True:
            return ready
        data_block = self._alloc_data_block(block_id=request.block_id)
        if data_block:
            logging.debug("%s allocated worker_%d with block id %s",
                          self.__class__.__name__,
                          request.worker_rank,
                          data_block.block_id)
            response.status.code = common_pb.STATUS_SUCCESS
            response.status.error_message = 'success'
            response.data_block_info.data_path = \
                str(data_block.data_block_fpath)
            response.data_block_info.meta_path = ''
            response.data_block_info.block_id = str(data_block.block_id)
        elif self._online_training:
            logging.debug("%s allocated worker_%d with empty data block. "\
                          "wait for new data block since online traning",
                          self.__class__.__name__, request.worker_rank)
            response.status.code = common_pb.STATUS_NO_MORE_DATA
            response.status.error_message = 'please wait for datablock ready'
        else:
            logging.debug("%s allocated worker_%d with empty data block. "\
                          "exit running since since batch traning",
                          self.__class__.__name__, request.worker_rank)
            response.status.code = common_pb.STATUS_DATA_FINISHED
            response.status.error_message = 'datablock finished'
        if response.status.code == common_pb.STATUS_DATA_FINISHED:
            self._transfer_status(tm_pb.MasterStatus.RUNNING,
                                  tm_pb.MasterStatus.FINISHED)
        return response

    def _load_data(self):
        checkpoint = self._get_checkpoint()
        # pylint: disable=line-too-long
        logging.info("load_data, checkpoint: %s", checkpoint)
        data_block_reps = [
            dbr for dbr in self._data_block_visitor.LoadDataBlockRepByTimeFrame(
                self._start_time, self._end_time).values()
            if dbr.block_id not in checkpoint and
               dbr.block_id not in self._visited_data_blocks]

        self._visited_data_blocks.update([i.block_id for i in data_block_reps])

        if self._online_training:
            data_block_reps.sort(key=lambda x: x.data_block_index)
        for rnd in range(self._epoch_num):
            if self._shuffle_data_block:
                random.shuffle(data_block_reps)
            for dbr in data_block_reps:
                logging.debug('epoch round-%d: add data block id %s path %s',
                              rnd, dbr.block_id, dbr.data_block_fpath)
                self._data_block_queue.put(dbr)

    def _alloc_data_block(self, block_id=None):
        # block_id is unused in leader role
        with self._lock:
            if self._data_block_queue.empty() and self._online_training:
                logging.info("Load data when queue empty and online training")
                self._load_data()

            if self._data_block_queue.empty():
                logging.info("Allocate when data_block_queue is empty")
                return None

            data_blocks_resp = self._data_block_queue.get()
            with self._checkpoint_mutex:
                self._allocated_data_blockids.add(data_blocks_resp.block_id)
            return data_blocks_resp