Ejemplo n.º 1
0
        def blk_matvec(x, blks):
            nargins = [[blk.shape[-1] for blk in blkrow] for blkrow in blks]
            nargouts = [[blk.shape[0] for blk in blkrow] for blkrow in blks]
            nargin = sum(nargins[0])
            nargout = sum([out[0] for out in nargouts])
            nx = len(x)
            self.logger.debug('Multiplying with a vector of size %d' % nx)
            self.logger.debug('nargin=%d, nargout=%d' % (nargin, nargout))
            if nx != nargin:
                raise ShapeError('Multiplying with vector of wrong shape.')

            result_type = np.result_type(self.dtype, x.dtype)
            y = np.zeros(nargout, dtype=result_type)

            nblk_row = len(blks)
            nblk_col = len(blks[0])

            row_start = col_start = 0
            for row in range(nblk_row):
                row_end = row_start + nargouts[row][0]
                yout = y[row_start:row_end]
                for col in range(nblk_col):
                    col_end = col_start + nargins[0][col]
                    xin = x[col_start:col_end]
                    B = blks[row][col]
                    yout[:] += B * xin
                    col_start = col_end
                row_start = row_end
                col_start = 0

            return y
Ejemplo n.º 2
0
        def blk_matvec(x, blks):
            nx = len(x)
            nargins = [blk.shape[-1] for blk in blocks]
            nargin = sum(nargins)
            nargouts = [blk.shape[0] for blk in blocks]
            nargout = sum(nargouts)
            self.logger.debug('Multiplying with a vector of size %d' % nx)
            self.logger.debug('nargin=%d, nargout=%d' % (nargin, nargout))
            if nx != nargin:
                raise ShapeError('Multiplying with vector of wrong shape.')

            result_type = np.result_type(self.dtype, x.dtype)
            y = np.empty(nargout, dtype=result_type)

            nblks = len(blks)

            row_start = col_start = 0
            for blk in range(nblks):
                row_end = row_start + nargouts[blk]
                yout = y[row_start:row_end]

                col_end = col_start + nargins[blk]
                xin = x[col_start:col_end]

                B = blks[blk]
                yout[:] = B * xin

                col_start = col_end
                row_start = row_end

            return y
Ejemplo n.º 3
0
 def update_block(self, row, column, new_block):
     if new_block.shape != self._blocks[row][column].shape:
         raise ShapeError(
             'The new block should have the same shape as the block you are trying to replace'
         )
     self._blocks[row][column] = new_block
     if self.symmetric:
         self._blocks[column][row] = new_block.T
     return
Ejemplo n.º 4
0
    def __init__(self, blocks, symmetric=False, **kwargs):
        # If building a symmetric operator, fill in the blanks.
        # They're just references to existing objects.
        if symmetric:
            nrow = len(blocks)
            ncol = len(blocks[0])
            if nrow != ncol:
                raise ShapeError('Inconsistent shape.')

            for block_row in blocks:
                if not block_row[0].symmetric:
                    raise ValueError('Blocks on diagonal must be symmetric.')

            self._blocks = blocks[:]
            for i in range(1, nrow):
                for j in range(i - 1, -1, -1):
                    self._blocks[i].insert(0, self._blocks[j][i].T)

        else:
            self._blocks = blocks

        log = kwargs.get('logger', null_log)
        log.debug('Building new BlockLinearOperator')

        nargins = [[blk.shape[-1] for blk in row] for row in self._blocks]
        log.debug('nargins = ' + repr(nargins))
        nargins_by_row = [nargin[0] for nargin in nargins]
        if min(nargins_by_row) != max(nargins_by_row):
            raise ShapeError('Inconsistent block shapes')

        nargouts = [[blk.shape[0] for blk in row] for row in self._blocks]
        log.debug('nargouts = ' + repr(nargouts))
        for row in nargouts:
            if min(row) != max(row):
                raise ShapeError('Inconsistent block shapes')

        nargin = sum(nargins[0])
        nargout = sum([out[0] for out in nargouts])

        # Create blocks of transpose operator.
        blocksT = map(lambda *row: [blk.T for blk in row], *self._blocks)

        def blk_matvec(x, blks):
            nargins = [[blk.shape[-1] for blk in blkrow] for blkrow in blks]
            nargouts = [[blk.shape[0] for blk in blkrow] for blkrow in blks]
            nargin = sum(nargins[0])
            nargout = sum([out[0] for out in nargouts])
            nx = len(x)
            self.logger.debug('Multiplying with a vector of size %d' % nx)
            self.logger.debug('nargin=%d, nargout=%d' % (nargin, nargout))
            if nx != nargin:
                raise ShapeError('Multiplying with vector of wrong shape.')

            result_type = np.result_type(self.dtype, x.dtype)
            y = np.zeros(nargout, dtype=result_type)

            nblk_row = len(blks)
            nblk_col = len(blks[0])

            row_start = col_start = 0
            for row in range(nblk_row):
                row_end = row_start + nargouts[row][0]
                yout = y[row_start:row_end]
                for col in range(nblk_col):
                    col_end = col_start + nargins[0][col]
                    xin = x[col_start:col_end]
                    B = blks[row][col]
                    yout[:] += B * xin
                    col_start = col_end
                row_start = row_end
                col_start = 0

            return y

        flat_blocks = list(itertools.chain(*blocks))
        blk_dtypes = [blk.dtype for blk in flat_blocks]
        op_dtype = np.result_type(*blk_dtypes)

        super(BlockLinearOperator,
              self).__init__(nargin,
                             nargout,
                             symmetric=symmetric,
                             matvec=lambda x: blk_matvec(x, self._blocks),
                             matvec_transp=lambda x: blk_matvec(x, blocksT),
                             dtype=op_dtype)

        self.T._blocks = blocksT