Exemple #1
0
    def _inverse_ops(self, Yl, Yh, gain_mask=None):
        """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
        reconstruction.

        :param Yl: The lowpass output from a forward transform. Should be a
            tensorflow variable.
        :param Yh: The tuple of highpass outputs from a forward transform.
            Should be tensorflow variables.
        :param gain_mask: Gain to be applied to each subband.

        :returns: A tf.Variable holding the output

        The (*d*, *l*)-th element of *gain_mask* is gain for subband with
        direction *d* at level *l*. If gain_mask[d,l] == 0, no computation is
        performed for band (d,l). Default *gain_mask* is all ones. Note that
        both *d* and *l* are zero-indexed.

        .. codeauthor:: Fergal Cotter <*****@*****.**>, Feb 2017
        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
        .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002

        """
        a = len(Yh)  # No of levels.

        if gain_mask is None:
            gain_mask = np.ones((6, a))  # Default gain_mask.

        gain_mask = np.array(gain_mask)

        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.biort) == 4:
            h0o, g0o, h1o, g1o = self.biort
        elif len(self.biort) == 6:
            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.qshift) == 8:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
        elif len(self.qshift) == 12:
            h0a, h0b, g0a, g0b, h1a, h1b, \
                g1a, g1b, h2a, h2b, g2a, g2b = self.qshift
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        current_level = a
        Z = Yl

        # This ensures that for level 1 we never do the following
        while current_level >= 2:
            lh = c2q(Yh[current_level - 1][:, :, :, 0:6:5],
                     gain_mask[[0, 5],
                     current_level - 1])
            hl = c2q(Yh[current_level - 1][:, :, :, 2:4:1],
                     gain_mask[[2, 3],
                     current_level - 1])
            hh = c2q(Yh[current_level - 1][:, :, :, 1:5:3],
                     gain_mask[[1, 4],
                     current_level - 1])

            # Do even Qshift filters on columns.
            y1 = colifilt(Z, g0b, g0a) + colifilt(lh, g1b, g1a)

            if len(self.qshift) >= 12:
                y2 = colifilt(hl, g0b, g0a)
                y2bp = colifilt(hh, g2b, g2a)

                # Do even Qshift filters on rows.
                y1T = tf.transpose(y1, perm=[0, 2, 1])
                y2T = tf.transpose(y2, perm=[0, 2, 1])
                y2bpT = tf.transpose(y2bp, perm=[0, 2, 1])
                Z = tf.transpose(
                    colifilt(y1T, g0b, g0a) +
                    colifilt(y2T, g1b, g1a) +
                    colifilt(y2bpT, g2b, g2a),
                    perm=[0, 2, 1])
            else:
                y2 = colifilt(hl, g0b, g0a) + colifilt(hh, g1b, g1a)

                # Do even Qshift filters on rows.
                y1T = tf.transpose(y1, perm=[0, 2, 1])
                y2T = tf.transpose(y2, perm=[0, 2, 1])
                Z = tf.transpose(
                    colifilt(y1T, g0b, g0a) +
                    colifilt(y2T, g1b, g1a),
                    perm=[0, 2, 1])

            # Check size of Z and crop as required
            Z_r, Z_c = Z.get_shape().as_list()[1:3]
            S_r, S_c = Yh[current_level - 2].get_shape().as_list()[1:3]
            # check to see if this result needs to be cropped for the rows
            if Z_r != S_r * 2:
                Z = Z[:, 1:-1, :]
            # check to see if this result needs to be cropped for the cols
            if Z_c != S_c * 2:
                Z = Z[:, :, 1:-1]

            # Assert that the size matches at this stage
            Z_r, Z_c = Z.get_shape().as_list()[1:3]
            if Z_r != S_r * 2 or Z_c != S_c * 2:
                raise ValueError(
                    'Sizes of highpasses {}x{} are not '.format(Z_r, Z_c) +
                    'compatible with {}x{} from next level'.format(S_r, S_c))

            current_level = current_level - 1

        if current_level == 1:
            lh = c2q(Yh[current_level - 1][:, :, :, 0:6:5],
                     gain_mask[[0, 5],
                     current_level - 1])
            hl = c2q(Yh[current_level - 1][:, :, :, 2:4:1],
                     gain_mask[[2, 3],
                     current_level - 1])
            hh = c2q(Yh[current_level - 1][:, :, :, 1:5:3],
                     gain_mask[[1, 4],
                     current_level - 1])

            # Do odd top-level filters on columns.
            y1 = colfilter(Z, g0o) + colfilter(lh, g1o)

            if len(self.biort) >= 6:
                y2 = colfilter(hl, g0o)
                y2bp = colfilter(hh, g2o)

                # Do odd top-level filters on rows.
                Z = rowfilter(y1, g0o) + rowfilter(y2, g1o) + \
                    rowfilter(y2bp, g2o)
            else:
                y2 = colfilter(hl, g0o) + colfilter(hh, g1o)

                # Do odd top-level filters on rows.
                Z = rowfilter(y1, g0o) + rowfilter(y2, g1o)

        return Z
Exemple #2
0
    def _inverse_ops(self, Yl, Yh, gain_mask=None):
        """Perform an *n*-level dual-tree complex wavelet (DTCWT) 1D
        reconstruction.

        :param Yl: The lowpass output from a forward transform. Should be a
            tensorflow variable.
        :param Yh: The tuple of highpass outputs from a forward transform.
            Should be tensorflow variables.
        :param gain_mask: Gain to be applied to each subband.

        :returns: A tf.Variable holding the output

        The *l*-th element of *gain_mask* is gain for wavelet subband at level
        l.  If gain_mask[l] == 0, no computation is performed for band *l*.
        Default *gain_mask* is all ones. Note that *l* is 0-indexed.

        .. codeauthor:: Fergal Cotter <*****@*****.**>, Sep 2017
        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
        .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002

        """
        # Which wavelets are to be used?
        biort = self.biort
        qshift = self.qshift
        a = len(Yh)  # No of levels.

        if gain_mask is None:
            gain_mask = np.ones(a)  # Default gain_mask.
        gain_mask = np.array(gain_mask)

        # Try to load coefficients if biort is a string parameter
        try:
            h0o, g0o, h1o, g1o = _biort(biort)
        except TypeError:
            h0o, g0o, h1o, g1o = biort

        # Try to load coefficients if qshift is a string parameter
        try:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
        except TypeError:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift

        level = a-1   # No of levels = no of rows in L.
        if level < 0:
            # if there are no levels in the input, just return the Yl value
            return Yl

        # Reconstruct levels 2 and above in reverse order.
        Lo = Yl
        while level >= 1:
            Hi = c2q1d(Yh[level]*gain_mask[level])
            Lo = colifilt(Lo, g0b, g0a) + colifilt(Hi, g1b, g1a)

            # If Lo is not the same length as the next Therefore we have to clip
            # Lo so it is the same height as the next Yh. Yh => t1 was extended.
            Lo_shape = Lo.get_shape().as_list()
            next_shape = Yh[level-1].get_shape().as_list()
            if Lo_shape[1] != 2 * next_shape[1]:
                Lo = Lo[:,1:-1]
                Lo_shape = Lo.get_shape().as_list()

            # Check the row shapes across the entire matrix
            if (np.any(np.asanyarray(Lo_shape[1:]) !=
                       np.asanyarray(next_shape[1:] * np.array((2,1))))):
                raise ValueError('Yh sizes are not valid for DTWAVEIFM')

            level -= 1

        # Reconstruct level 1.
        if level == 0:
            Hi = c2q1d(Yh[level]*gain_mask[level])
            Z = colfilter(Lo,g0o) + colfilter(Hi,g1o)

        return Z
Exemple #3
0
    def _forward_ops(self, X, nlevels=3):
        """ Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.

        :param X: 3D real array of size [batch, h, w]
        :param nlevels: Number of levels of wavelet decomposition
        :param extended: True if a singleton dimension was added at the
            beginning of the input. Signal to remove afterwards.

        :returns: A tuple of Yl, Yh, Yscale
        """

        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.biort) == 4:
            h0o, g0o, h1o, g1o = self.biort
        elif len(self.biort) == 6:
            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.qshift) == 8:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
        elif len(self.qshift) == 12:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        # Check the shape and form of the input
        if X.dtype not in tf_dtypes:
            raise ValueError('X needs to be a tf variable or placeholder')

        original_size = X.get_shape().as_list()[1:]

        if len(original_size) >= 3:
            raise ValueError(
                """The entered variable has too many dimensions {}. If
                the final dimension are colour channels, please enter each
                channel separately.""".format(original_size))

        # ############################ Resize #################################
        # The next few lines of code check to see if the image is odd in size,
        # if so an extra ... row/column will be added to the bottom/right of the
        # image
        initial_row_extend = 0
        initial_col_extend = 0
        # If the row count of X is not divisible by 2 then we need to
        # extend X by adding a row at the bottom
        if original_size[0] % 2 != 0:
            bottom_row = tf.slice(X, [0, original_size[0] - 1, 0], [-1, 1, -1])
            X = tf.concat([X, bottom_row], axis=1)
            initial_row_extend = 1

        # If the col count of X is not divisible by 2 then we need to
        # extend X by adding a col to the right
        if original_size[1] % 2 != 0:
            right_col = tf.slice(X, [0, 0, original_size[1] - 1], [-1, -1, 1])
            X = tf.concat([X, right_col], axis=2)
            initial_col_extend = 1

        extended_size = X.get_shape().as_list()[1:3]

        if nlevels == 0:
            return X, (), ()

        # ########################### Initialise ###############################
        Yh = [None, ] * nlevels
        # This is only required if the user specifies a third output
        # component.
        Yscale = [None, ] * nlevels

        # ############################ Level 1 #################################
        # Uses the biorthogonal filters
        if nlevels >= 1:
            # Do odd top-level filters on cols.
            Lo = colfilter(X, h0o)
            Hi = colfilter(X, h1o)
            if len(self.biort) >= 6:
                Ba = colfilter(X, h2o)

            # Do odd top-level filters on rows.
            LoLo = rowfilter(Lo, h0o)
            LoLo_shape = LoLo.get_shape().as_list()[1:]

            # Horizontal wavelet pair (15 & 165 degrees)
            horiz = q2c(rowfilter(Hi, h0o))

            # Vertical wavelet pair (75 & 105 degrees)
            vertic = q2c(rowfilter(Lo, h1o))

            # Diagonal wavelet pair (45 & 135 degrees)
            if len(self.biort) >= 6:
                diag = q2c(rowfilter(Ba, h2o))
            else:
                diag = q2c(rowfilter(Hi, h1o))

            # Pack all 6 tensors into one
            Yh[0] = tf.stack(
                [horiz[0], diag[0], vertic[0], vertic[1], diag[1], horiz[1]],
                axis=3)

            Yscale[0] = LoLo

        # ############################ Level 2+ ################################
        # Uses the qshift filters
        for level in xrange(1, nlevels):
            row_size, col_size = LoLo_shape[0], LoLo_shape[1]
            # If the row count of LoLo is not divisible by 4 (it will be
            # divisible by 2), add 2 extra rows to make it so
            if row_size % 4 != 0:
                LoLo = tf.pad(LoLo, [[0, 0], [1, 1], [0, 0]], 'SYMMETRIC')

            # If the col count of LoLo is not divisible by 4 (it will be
            # divisible by 2), add 2 extra cols to make it so
            if col_size % 4 != 0:
                LoLo = tf.pad(LoLo, [[0, 0], [0, 0], [1, 1]], 'SYMMETRIC')

            # Do even Qshift filters on cols.
            Lo = coldfilt(LoLo, h0b, h0a)
            Hi = coldfilt(LoLo, h1b, h1a)
            if len(self.qshift) >= 12:
                Ba = coldfilt(LoLo, h2b, h2a)

            # Do even Qshift filters on rows.
            LoLo = rowdfilt(Lo, h0b, h0a)
            LoLo_shape = LoLo.get_shape().as_list()[1:3]

            # Horizontal wavelet pair (15 & 165 degrees)
            horiz = q2c(rowdfilt(Hi, h0b, h0a))

            # Vertical wavelet pair (75 & 105 degrees)
            vertic = q2c(rowdfilt(Lo, h1b, h1a))

            # Diagonal wavelet pair (45 & 135 degrees)
            if len(self.qshift) >= 12:
                diag = q2c(rowdfilt(Ba, h2b, h2a))
            else:
                diag = q2c(rowdfilt(Hi, h1b, h1a))

            # Pack all 6 tensors into one
            Yh[level] = tf.stack(
                [horiz[0], diag[0], vertic[0], vertic[1], diag[1], horiz[1]],
                axis=3)

            Yscale[level] = LoLo

        Yl = LoLo

        if initial_row_extend == 1 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                """The bottom row and rightmost column have been duplicated,
                prior to decomposition.""")

        if initial_row_extend == 1 and initial_col_extend == 0:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row has been duplicated, prior to decomposition.')

        if initial_row_extend == 0 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                """The rightmost column has been duplicated, prior to
                decomposition.""")

        return Yl, tuple(Yh), tuple(Yscale)
Exemple #4
0
    def _forward_ops(self, X, nlevels=3):
        """ Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.

        For column inputs, we still need the input shape to be 3D, but with 1 as
        the last dimension.

        :param X: 3D real array of size [batch, h, w]
        :param nlevels: Number of levels of wavelet decomposition
        :param extended: True if a singleton dimension was added at the
            beginning of the input. Signal to remove afterwards.

        :returns: A tuple of Yl, Yh, Yscale
        """
        biort = self.biort
        qshift = self.qshift

        # Try to load coefficients if biort is a string parameter
        try:
            h0o, g0o, h1o, g1o = _biort(biort)
        except TypeError:
            h0o, g0o, h1o, g1o = biort

        # Try to load coefficients if qshift is a string parameter
        try:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
        except TypeError:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift

        # Check the shape and form of the input
        if X.dtype not in tf_dtypes:
            raise ValueError('X needs to be a tf variable or placeholder')

        original_size = X.get_shape().as_list()[1:]

        # ############################ Resize #################################
        # The next few lines of code check to see if the image is odd in size,
        # if so an extra ... row/column will be added to the bottom/right of the
        # image
        #  initial_row_extend = 0
        #  initial_col_extend = 0
        # If the row count of X is not divisible by 2 then we need to
        # extend X by adding a row at the bottom
        if original_size[0] % 2 != 0:
            #  X = tf.pad(X, [[0, 0], [0, 1], [0, 0]], 'SYMMETRIC')
            raise ValueError('Size of input X must be a multiple of 2')

        #  extended_size = X.get_shape().as_list()[1:]

        if nlevels == 0:
            return X, (), ()

        # ########################### Initialise ###############################
        Yh = [None, ] * nlevels
        # This is only required if the user specifies a third output
        # component.
        Yscale = [None, ] * nlevels

        # ############################ Level 1 #################################
        # Uses the biorthogonal filters
        if nlevels >= 1:
            # Do odd top-level filters on cols.
            Hi = colfilter(X, h1o)
            Lo = colfilter(X, h0o)

            # Convert Hi to complex form by taking alternate rows
            Yh[0] = tf.cast(Hi[:,::2,:], tf.complex64) + \
                1j*tf.cast(Hi[:,1::2,:], tf.complex64)
            Yscale[0] = Lo

        # ############################ Level 2+ ################################
        # Uses the qshift filters
        for level in xrange(1, nlevels):
            # If the row count of Lo is not divisible by 4 (it will be
            # divisible by 2), add 2 extra rows to make it so
            if Lo.get_shape().as_list()[1] % 4 != 0:
                Lo = tf.pad(Lo, [[0, 0], [1, 1], [0, 0]], 'SYMMETRIC')

            # Do even Qshift filters on cols.
            Hi = coldfilt(Lo, h1b, h1a)
            Lo = coldfilt(Lo, h0b, h0a)

            # Convert Hi to complex form by taking alternate rows
            Yh[level] = tf.cast(Hi[:,::2,:], tf.complex64) + \
                1j * tf.cast(Hi[:,1::2,:], tf.complex64)
            Yscale[level] = Lo

        Yl = Lo

        return Yl, tuple(Yh), tuple(Yscale)