예제 #1
0
class DataBlockLoader(object):
    def __init__(self,
                 role,
                 bridge,
                 data_path,
                 ext,
                 worker_rank=0,
                 num_workers=1,
                 output_path=None):
        self._role = role
        self._bridge = bridge
        self._num_workers = num_workers
        self._worker_rank = worker_rank
        self._output_path = output_path

        self._tm_role = 'follower' if role == 'leader' else 'leader'

        if data_path:
            files = None
            if not tf.io.gfile.isdir(data_path):
                files = [os.path.basename(data_path)]
                data_path = os.path.dirname(data_path)
            self._trainer_master = LocalTrainerMasterClient(self._tm_role,
                                                            data_path,
                                                            files=files,
                                                            ext=ext)
        else:
            self._trainer_master = None

        self._count = 0
        if self._role == 'leader':
            self._block_queue = queue.Queue()
            self._bridge.register_data_block_handler(self._data_block_handler)
            self._bridge.start(self._bridge.new_iter_id())
            self._bridge.send(self._bridge.current_iter_id, 'barrier',
                              np.asarray([1]))
            self._bridge.commit()
        elif self._role == 'follower':
            self._bridge.start(self._bridge.new_iter_id())
            self._bridge.receive(self._bridge.current_iter_id, 'barrier')
            self._bridge.commit()

    def _data_block_handler(self, msg):
        logging.debug('DataBlock: recv "%s" at %d', msg.block_id, msg.count)
        assert self._count == msg.count
        if not msg.block_id:
            block = None
        elif self._trainer_master is not None:
            block = self._trainer_master.request_data_block(msg.block_id)
            return False
        else:
            block = DataBlockInfo(msg.block_id, None)
        self._count += 1
        self._block_queue.put(block)
        return True

    def _request_data_block(self):
        while True:
            for _ in range(self._worker_rank):
                self._trainer_master.request_data_block()
            block = self._trainer_master.request_data_block()
            for _ in range(self._num_workers - self._worker_rank - 1):
                self._trainer_master.request_data_block()
            if block is None or self._output_path is None or \
                    not tf.io.gfile.exists(os.path.join(
                        self._output_path, block.block_id) + '.output'):
                break
        return block

    def get_next_block(self):
        if self._role == 'local':
            return self._request_data_block()

        if self._tm_role == 'leader':
            while True:
                block = self._request_data_block()
                if block is not None:
                    if not self._bridge.load_data_block(
                            self._count, block.block_id):
                        continue
                else:
                    self._bridge.load_data_block(self._count, '')
                break
            self._count += 1
        else:
            block = self._block_queue.get()
        return block
예제 #2
0
class DataBlockLoader(object):
    def __init__(self, role, bridge, data_path, ext,
                 worker_rank=0, num_workers=1):
        self._role = role
        self._bridge = bridge
        self._num_workers = num_workers
        self._worker_rank = worker_rank

        self._tm_role = 'follower' if role == 'leader' else 'leader'

        if data_path:
            files = None
            if not tf.io.gfile.isdir(data_path):
                files = [os.path.basename(data_path)]
                data_path = os.path.dirname(data_path)
            self._trainer_master = LocalTrainerMasterClient(
                self._tm_role, data_path, files=files, ext=ext)
        else:
            self._trainer_master = None

        self._count = 0
        if self._role == 'leader':
            self._block_queue = queue.Queue()
            self._bridge.register_data_block_handler(self._data_block_handler)
            self._bridge.start(self._bridge.new_iter_id())
            self._bridge.send(
                self._bridge.current_iter_id, 'barrier', np.asarray([1]))
            self._bridge.commit()
        elif self._role == 'follower':
            self._bridge.start(self._bridge.new_iter_id())
            self._bridge.receive(self._bridge.current_iter_id, 'barrier')
            self._bridge.commit()

    def _data_block_handler(self, msg):
        logging.debug('DataBlock: recv "%s" at %d', msg.block_id, msg.count)
        assert self._count == msg.count
        if not msg.block_id:
            block = None
        elif self._trainer_master is not None:
            block = self._trainer_master.request_data_block(msg.block_id)
            if block is None:
                raise ValueError("Block %s not found" % msg.block_id)
        else:
            block = DataBlockInfo(msg.block_id, None)
        self._count += 1
        self._block_queue.put(block)

    def _request_data_block(self):
        for _ in range(self._worker_rank):
            self._trainer_master.request_data_block()
        block = self._trainer_master.request_data_block()
        for _ in range(self._num_workers - self._worker_rank - 1):
            self._trainer_master.request_data_block()
        return block

    def get_next_block(self):
        if self._role == 'local':
            return self._request_data_block()

        if self._tm_role == 'leader':
            while True:
                block = self._request_data_block()
                if block is not None:
                    try:
                        self._bridge.load_data_block(self._count,
                                                     block.block_id)
                    except Exception as e:  # pylint: disable=broad-except
                        logging.error('load data block %s with error: %s',
                                      block.block_id, repr(e))
                        continue
                else:
                    self._bridge.load_data_block(self._count, '')
                break
            self._count += 1
        else:
            block = self._block_queue.get()
        return block