コード例 #1
0
ファイル: transform2d.py プロジェクト: rjw57/dtcwt
def q2c(y):
    """
    Convert from quads in y to complex numbers in z.
    """

    j2 = (np.sqrt(0.5) * np.array([1, 1j])).astype(appropriate_complex_type_for(y))

    # Arrange pixels from the corners of the quads into
    # 2 subimages of alternate real and imag pixels.
    #  a----b
    #  |    |
    #  |    |
    #  c----d

    # Combine (a,b) and (d,c) to form two complex subimages.
    p = y[0::2, 0::2]*j2[0] + y[0::2, 1::2]*j2[1] # p = (a + jb) / sqrt(2)
    q = y[1::2, 1::2]*j2[0] - y[1::2, 0::2]*j2[1] # q = (d - jc) / sqrt(2)

    # Form the 2 highpasses in z.
    z = np.dstack((p-q,p+q))

    return z
コード例 #2
0
def q2c(y):
    """
    Convert from quads in y to complex numbers in z.
    """

    j2 = (np.sqrt(0.5) * np.array([1, 1j])).astype(
        appropriate_complex_type_for(y))

    # Arrange pixels from the corners of the quads into
    # 2 subimages of alternate real and imag pixels.
    #  a----b
    #  |    |
    #  |    |
    #  c----d

    # Combine (a,b) and (d,c) to form two complex subimages.
    p = y[0::2, 0::2] * j2[0] + y[0::2, 1::2] * j2[1]  # p = (a + jb) / sqrt(2)
    q = y[1::2, 1::2] * j2[0] - y[1::2, 0::2] * j2[1]  # q = (d - jc) / sqrt(2)

    # Form the 2 highpasses in z.
    z = np.dstack((p - q, p + q))

    return z
コード例 #3
0
ファイル: transform2d.py プロジェクト: rjw57/dtcwt
    def forward(self, X, nlevels=3, include_scale=False):
        """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.

        :param X: 2D real array
        :param nlevels: Number of levels of wavelet decomposition

        :returns: A :py:class:`dtcwt.Pyramid` compatible object representing the transform-domain signal

        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
        .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001

        """
        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.biort) == 4:
            h0o, g0o, h1o, g1o = self.biort
        elif len(self.biort) == 6:
            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.qshift) == 8:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
        elif len(self.qshift) == 12:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        X = np.atleast_2d(asfarray(X))
        original_size = X.shape

        if len(X.shape) >= 3:
            raise ValueError('The entered image is {0}, which is invalid '.
                             format('x'.join(list(str(s) for s in X.shape))) +
                             'for the 2D transform in a numpy backend. ' +
                             'Please enter each image slice separately.')

        # The next few lines of code check to see if the image is odd in size, if so an extra ...
        # row/column will be added to the bottom/right of the image
        initial_row_extend = 0  #initialise
        initial_col_extend = 0
        if original_size[0] % 2 != 0:
            # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
            X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
            initial_row_extend = 1

        if original_size[1] % 2 != 0:
            # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
            X = np.hstack((X, X[:,[-1]]))
            initial_col_extend = 1

        extended_size = X.shape

        if nlevels == 0:
            if include_scale:
                return Pyramid(X, (), ())
            else:
                return Pyramid(X, ())

        # initialise
        Yh = [None,] * nlevels
        if include_scale:
            # this is only required if the user specifies a third output component.
            Yscale = [None,] * nlevels

        complex_dtype = appropriate_complex_type_for(X)

        if nlevels >= 1:
            # Do odd top-level filters on cols.
            Lo = colfilter(X,h0o).T
            Hi = colfilter(X,h1o).T
            if len(self.biort) >= 6:
                Ba = colfilter(X,h2o).T

            # Do odd top-level filters on rows.
            LoLo = colfilter(Lo,h0o).T
            Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6), dtype=complex_dtype)
            Yh[0][:,:,0:6:5] = q2c(colfilter(Hi,h0o).T)     # Horizontal pair
            Yh[0][:,:,2:4:1] = q2c(colfilter(Lo,h1o).T)     # Vertical pair
            if len(self.biort) >= 6:
                Yh[0][:,:,1:5:3] = q2c(colfilter(Ba,h2o).T)     # Diagonal pair
            else:
                Yh[0][:,:,1:5:3] = q2c(colfilter(Hi,h1o).T)     # Diagonal pair

            if include_scale:
                Yscale[0] = LoLo

        for level in xrange(1, nlevels):
            row_size, col_size = LoLo.shape
            if row_size % 4 != 0:
                # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
                LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))

            if col_size % 4 != 0:
                # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
                LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))

            # Do even Qshift filters on rows.
            Lo = coldfilt(LoLo,h0b,h0a).T
            Hi = coldfilt(LoLo,h1b,h1a).T
            if len(self.qshift) >= 12:
                Ba = coldfilt(LoLo,h2b,h2a).T

            # Do even Qshift filters on columns.
            LoLo = coldfilt(Lo,h0b,h0a).T

            Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
            Yh[level][:,:,0:6:5] = q2c(coldfilt(Hi,h0b,h0a).T)  # Horizontal
            Yh[level][:,:,2:4:1] = q2c(coldfilt(Lo,h1b,h1a).T)  # Vertical
            if len(self.qshift) >= 12:
                Yh[level][:,:,1:5:3] = q2c(coldfilt(Ba,h2b,h2a).T)  # Diagonal
            else:
                Yh[level][:,:,1:5:3] = q2c(coldfilt(Hi,h1b,h1a).T)  # Diagonal

            if include_scale:
                Yscale[level] = LoLo

        Yl = LoLo

        if initial_row_extend == 1 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row and rightmost column have been duplicated, prior to decomposition.')

        if initial_row_extend == 1 and initial_col_extend == 0:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row has been duplicated, prior to decomposition.')

        if initial_row_extend == 0 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The rightmost column has been duplicated, prior to decomposition.')

        if include_scale:
            return Pyramid(Yl, tuple(Yh), tuple(Yscale))
        else:
            return Pyramid(Yl, tuple(Yh))
コード例 #4
0
    def forward(self, X, nlevels=3, include_scale=False):
        """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.

        :param X: 2D real array
        :param nlevels: Number of levels of wavelet decomposition

        :returns: A :py:class:`dtcwt.Pyramid` compatible object representing the transform-domain signal

        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
        .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001

        """
        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.biort) == 4:
            h0o, g0o, h1o, g1o = self.biort
        elif len(self.biort) == 6:
            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.qshift) == 8:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
        elif len(self.qshift) == 12:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        X = np.atleast_2d(asfarray(X))
        original_size = X.shape

        if len(X.shape) >= 3:
            raise ValueError(
                'The entered image is {0}, please enter each image slice separately.'
                .format('x'.join(list(str(s) for s in X.shape))))

        # The next few lines of code check to see if the image is odd in size, if so an extra ...
        # row/column will be added to the bottom/right of the image
        initial_row_extend = 0  #initialise
        initial_col_extend = 0
        if original_size[0] % 2 != 0:
            # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
            X = np.vstack(
                (X, X[[-1], :]
                 ))  # Any further extension will be done in due course.
            initial_row_extend = 1

        if original_size[1] % 2 != 0:
            # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
            X = np.hstack((X, X[:, [-1]]))
            initial_col_extend = 1

        extended_size = X.shape

        if nlevels == 0:
            if include_scale:
                return Pyramid(X, (), ())
            else:
                return Pyramid(X, ())

        # initialise
        Yh = [
            None,
        ] * nlevels
        if include_scale:
            # this is only required if the user specifies a third output component.
            Yscale = [
                None,
            ] * nlevels

        complex_dtype = appropriate_complex_type_for(X)

        if nlevels >= 1:
            # Do odd top-level filters on cols.
            Lo = colfilter(X, h0o).T
            Hi = colfilter(X, h1o).T
            if len(self.biort) >= 6:
                Ba = colfilter(X, h2o).T

            # Do odd top-level filters on rows.
            LoLo = colfilter(Lo, h0o).T
            Yh[0] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6),
                             dtype=complex_dtype)
            Yh[0][:, :, 0:6:5] = q2c(colfilter(Hi, h0o).T)  # Horizontal pair
            Yh[0][:, :, 2:4:1] = q2c(colfilter(Lo, h1o).T)  # Vertical pair
            if len(self.biort) >= 6:
                Yh[0][:, :, 1:5:3] = q2c(colfilter(Ba, h2o).T)  # Diagonal pair
            else:
                Yh[0][:, :, 1:5:3] = q2c(colfilter(Hi, h1o).T)  # Diagonal pair

            if include_scale:
                Yscale[0] = LoLo

        for level in xrange(1, nlevels):
            row_size, col_size = LoLo.shape
            if row_size % 4 != 0:
                # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
                LoLo = np.vstack((LoLo[:1, :], LoLo, LoLo[-1:, :]))

            if col_size % 4 != 0:
                # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
                LoLo = np.hstack((LoLo[:, :1], LoLo, LoLo[:, -1:]))

            # Do even Qshift filters on rows.
            Lo = coldfilt(LoLo, h0b, h0a).T
            Hi = coldfilt(LoLo, h1b, h1a).T
            if len(self.qshift) >= 12:
                Ba = coldfilt(LoLo, h2b, h2a).T

            # Do even Qshift filters on columns.
            LoLo = coldfilt(Lo, h0b, h0a).T

            Yh[level] = np.zeros((LoLo.shape[0] >> 1, LoLo.shape[1] >> 1, 6),
                                 dtype=complex_dtype)
            Yh[level][:, :, 0:6:5] = q2c(coldfilt(Hi, h0b,
                                                  h0a).T)  # Horizontal
            Yh[level][:, :, 2:4:1] = q2c(coldfilt(Lo, h1b, h1a).T)  # Vertical
            if len(self.qshift) >= 12:
                Yh[level][:, :, 1:5:3] = q2c(coldfilt(Ba, h2b,
                                                      h2a).T)  # Diagonal
            else:
                Yh[level][:, :, 1:5:3] = q2c(coldfilt(Hi, h1b,
                                                      h1a).T)  # Diagonal

            if include_scale:
                Yscale[level] = LoLo

        Yl = LoLo

        if initial_row_extend == 1 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row and rightmost column have been duplicated, prior to decomposition.'
            )

        if initial_row_extend == 1 and initial_col_extend == 0:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row has been duplicated, prior to decomposition.')

        if initial_row_extend == 0 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The rightmost column has been duplicated, prior to decomposition.'
            )

        if include_scale:
            return Pyramid(Yl, tuple(Yh), tuple(Yscale))
        else:
            return Pyramid(Yl, tuple(Yh))