예제 #1
0
class FollowerTrainerMaster(TrainerMaster):
    def __init__(self, application_id, data_source, start_time, end_time,
                 online_training):
        super(FollowerTrainerMaster, self).__init__(application_id, None,
                                                    online_training)
        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_visitor = DataBlockVisitor(data_source, db_database,
                                                    db_base_dir, db_addr,
                                                    db_username, db_password,
                                                    kvstore_use_mock)
        self._start_time = start_time
        self._end_time = end_time

    def _load_data(self):
        pass

    def _alloc_data_block(self, block_id=None):
        logging.debug("FollowerTrainerMaster is getting block %s", block_id)
        if not block_id:
            raise Exception('follower tm need block_id to alloc.')

        return self._data_block_visitor.LoadDataBlockRepByBlockId(block_id)
예제 #2
0
class FollowerTrainerMaster(object):
    def __init__(self, application_id, data_source, online_training):
        self._application_id = application_id
        self._online_training = online_training
        kvstore_use_mock = os.environ.get('KVSTORE_USE_MOCK', "off") == "on"
        self._data_block_visitor = DataBlockVisitor(data_source, kvstore_type,
                                                    kvstore_use_mock)

    def run(self, listen_port):
        self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        tm_grpc.add_TrainerMasterServiceServicer_to_server(
            TrainerMasterServer(self._data_block_response,
                                self._get_checkpoint_fn,
                                self._restore_checkpoint_fn), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
        self._server.start()
        logging.info('Trainer Master Server start on port[%d].', listen_port)
        self._server.wait_for_termination()

    def _get_checkpoint_fn(self, request):
        response = tm_pb.GetDataBlockCheckpointResponse()
        response.status.code = common_pb.STATUS_SUCCESS
        response.status.error_message = 'success'
        logging.info("Follower _get_checkpoint_fn, do nothing")
        return response

    def _restore_checkpoint_fn(self, request):
        response = tm_pb.RestoreDataBlockCheckpointResponse()
        response.status.code = common_pb.STATUS_SUCCESS
        response.status.error_message = "success"
        logging.info("Follower _restore_checkpoint_fn, do nothing")
        return response

    def _alloc_data_block(self, block_id=None):
        logging.info("FollowerTrainerMaster is getting block %s", block_id)
        if not block_id:
            raise Exception('follower tm need block_id to alloc.')
        return self._data_block_visitor.LoadDataBlockRepByBlockId(block_id)

    def _data_block_response(self, request):
        response = tm_pb.DataBlockResponse()
        data_block = self._alloc_data_block(block_id=request.block_id)
        if data_block:
            logging.info("%s allocated worker_%d with block id %s",
                         self.__class__.__name__, request.worker_rank,
                         data_block.block_id)
            response.status.code = common_pb.STATUS_SUCCESS
            response.status.error_message = 'success'
            response.data_block_info.data_path = \
                str(data_block.data_block_fpath)
            response.data_block_info.meta_path = ''
            response.data_block_info.block_id = str(data_block.block_id)
        elif self._online_training:
            logging.info("%s allocated worker_%d with empty data block. "\
                          "wait for new data block since online traning",
                          self.__class__.__name__, request.worker_rank)
            response.status.code = common_pb.STATUS_NO_MORE_DATA
            response.status.error_message = 'please wait for datablock ready'
        else:
            logging.info("%s allocated worker_%d with empty data block. "\
                          "exit running since since batch traning",
                          self.__class__.__name__, request.worker_rank)
            response.status.code = common_pb.STATUS_DATA_FINISHED
            response.status.error_message = 'datablock finished'
        return response