def _inverse_ops(self, Yl, Yh, gain_mask=None): """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D reconstruction. :param Yl: The lowpass output from a forward transform. Should be a tensorflow variable. :param Yh: The tuple of highpass outputs from a forward transform. Should be tensorflow variables. :param gain_mask: Gain to be applied to each subband. :returns: A tf.Variable holding the output The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are zero-indexed. .. codeauthor:: Fergal Cotter <*****@*****.**>, Feb 2017 .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013 .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002 .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002 """ a = len(Yh) # No of levels. if gain_mask is None: gain_mask = np.ones((6, a)) # Default gain_mask. gain_mask = np.array(gain_mask) # If biort has 6 elements instead of 4, then it's a modified # rotationally symmetric wavelet # FIXME: there's probably a nicer way to do this if len(self.biort) == 4: h0o, g0o, h1o, g1o = self.biort elif len(self.biort) == 6: h0o, g0o, h1o, g1o, h2o, g2o = self.biort else: raise ValueError('Biort wavelet must have 6 or 4 components.') # If qshift has 12 elements instead of 8, then it's a modified # rotationally symmetric wavelet # FIXME: there's probably a nicer way to do this if len(self.qshift) == 8: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift elif len(self.qshift) == 12: h0a, h0b, g0a, g0b, h1a, h1b, \ g1a, g1b, h2a, h2b, g2a, g2b = self.qshift else: raise ValueError('Qshift wavelet must have 12 or 8 components.') current_level = a Z = Yl # This ensures that for level 1 we never do the following while current_level >= 2: lh = c2q(Yh[current_level - 1][:, :, :, 0:6:5], gain_mask[[0, 5], current_level - 1]) hl = c2q(Yh[current_level - 1][:, :, :, 2:4:1], gain_mask[[2, 3], current_level - 1]) hh = c2q(Yh[current_level - 1][:, :, :, 1:5:3], gain_mask[[1, 4], current_level - 1]) # Do even Qshift filters on columns. y1 = colifilt(Z, g0b, g0a) + colifilt(lh, g1b, g1a) if len(self.qshift) >= 12: y2 = colifilt(hl, g0b, g0a) y2bp = colifilt(hh, g2b, g2a) # Do even Qshift filters on rows. y1T = tf.transpose(y1, perm=[0, 2, 1]) y2T = tf.transpose(y2, perm=[0, 2, 1]) y2bpT = tf.transpose(y2bp, perm=[0, 2, 1]) Z = tf.transpose( colifilt(y1T, g0b, g0a) + colifilt(y2T, g1b, g1a) + colifilt(y2bpT, g2b, g2a), perm=[0, 2, 1]) else: y2 = colifilt(hl, g0b, g0a) + colifilt(hh, g1b, g1a) # Do even Qshift filters on rows. y1T = tf.transpose(y1, perm=[0, 2, 1]) y2T = tf.transpose(y2, perm=[0, 2, 1]) Z = tf.transpose( colifilt(y1T, g0b, g0a) + colifilt(y2T, g1b, g1a), perm=[0, 2, 1]) # Check size of Z and crop as required Z_r, Z_c = Z.get_shape().as_list()[1:3] S_r, S_c = Yh[current_level - 2].get_shape().as_list()[1:3] # check to see if this result needs to be cropped for the rows if Z_r != S_r * 2: Z = Z[:, 1:-1, :] # check to see if this result needs to be cropped for the cols if Z_c != S_c * 2: Z = Z[:, :, 1:-1] # Assert that the size matches at this stage Z_r, Z_c = Z.get_shape().as_list()[1:3] if Z_r != S_r * 2 or Z_c != S_c * 2: raise ValueError( 'Sizes of highpasses {}x{} are not '.format(Z_r, Z_c) + 'compatible with {}x{} from next level'.format(S_r, S_c)) current_level = current_level - 1 if current_level == 1: lh = c2q(Yh[current_level - 1][:, :, :, 0:6:5], gain_mask[[0, 5], current_level - 1]) hl = c2q(Yh[current_level - 1][:, :, :, 2:4:1], gain_mask[[2, 3], current_level - 1]) hh = c2q(Yh[current_level - 1][:, :, :, 1:5:3], gain_mask[[1, 4], current_level - 1]) # Do odd top-level filters on columns. y1 = colfilter(Z, g0o) + colfilter(lh, g1o) if len(self.biort) >= 6: y2 = colfilter(hl, g0o) y2bp = colfilter(hh, g2o) # Do odd top-level filters on rows. Z = rowfilter(y1, g0o) + rowfilter(y2, g1o) + \ rowfilter(y2bp, g2o) else: y2 = colfilter(hl, g0o) + colfilter(hh, g1o) # Do odd top-level filters on rows. Z = rowfilter(y1, g0o) + rowfilter(y2, g1o) return Z
def _inverse_ops(self, Yl, Yh, gain_mask=None): """Perform an *n*-level dual-tree complex wavelet (DTCWT) 1D reconstruction. :param Yl: The lowpass output from a forward transform. Should be a tensorflow variable. :param Yh: The tuple of highpass outputs from a forward transform. Should be tensorflow variables. :param gain_mask: Gain to be applied to each subband. :returns: A tf.Variable holding the output The *l*-th element of *gain_mask* is gain for wavelet subband at level l. If gain_mask[l] == 0, no computation is performed for band *l*. Default *gain_mask* is all ones. Note that *l* is 0-indexed. .. codeauthor:: Fergal Cotter <*****@*****.**>, Sep 2017 .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013 .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002 .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002 """ # Which wavelets are to be used? biort = self.biort qshift = self.qshift a = len(Yh) # No of levels. if gain_mask is None: gain_mask = np.ones(a) # Default gain_mask. gain_mask = np.array(gain_mask) # Try to load coefficients if biort is a string parameter try: h0o, g0o, h1o, g1o = _biort(biort) except TypeError: h0o, g0o, h1o, g1o = biort # Try to load coefficients if qshift is a string parameter try: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) except TypeError: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift level = a-1 # No of levels = no of rows in L. if level < 0: # if there are no levels in the input, just return the Yl value return Yl # Reconstruct levels 2 and above in reverse order. Lo = Yl while level >= 1: Hi = c2q1d(Yh[level]*gain_mask[level]) Lo = colifilt(Lo, g0b, g0a) + colifilt(Hi, g1b, g1a) # If Lo is not the same length as the next Therefore we have to clip # Lo so it is the same height as the next Yh. Yh => t1 was extended. Lo_shape = Lo.get_shape().as_list() next_shape = Yh[level-1].get_shape().as_list() if Lo_shape[1] != 2 * next_shape[1]: Lo = Lo[:,1:-1] Lo_shape = Lo.get_shape().as_list() # Check the row shapes across the entire matrix if (np.any(np.asanyarray(Lo_shape[1:]) != np.asanyarray(next_shape[1:] * np.array((2,1))))): raise ValueError('Yh sizes are not valid for DTWAVEIFM') level -= 1 # Reconstruct level 1. if level == 0: Hi = c2q1d(Yh[level]*gain_mask[level]) Z = colfilter(Lo,g0o) + colfilter(Hi,g1o) return Z
def _forward_ops(self, X, nlevels=3): """ Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*. :param X: 3D real array of size [batch, h, w] :param nlevels: Number of levels of wavelet decomposition :param extended: True if a singleton dimension was added at the beginning of the input. Signal to remove afterwards. :returns: A tuple of Yl, Yh, Yscale """ # If biort has 6 elements instead of 4, then it's a modified # rotationally symmetric wavelet # FIXME: there's probably a nicer way to do this if len(self.biort) == 4: h0o, g0o, h1o, g1o = self.biort elif len(self.biort) == 6: h0o, g0o, h1o, g1o, h2o, g2o = self.biort else: raise ValueError('Biort wavelet must have 6 or 4 components.') # If qshift has 12 elements instead of 8, then it's a modified # rotationally symmetric wavelet # FIXME: there's probably a nicer way to do this if len(self.qshift) == 8: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift elif len(self.qshift) == 12: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10] else: raise ValueError('Qshift wavelet must have 12 or 8 components.') # Check the shape and form of the input if X.dtype not in tf_dtypes: raise ValueError('X needs to be a tf variable or placeholder') original_size = X.get_shape().as_list()[1:] if len(original_size) >= 3: raise ValueError( """The entered variable has too many dimensions {}. If the final dimension are colour channels, please enter each channel separately.""".format(original_size)) # ############################ Resize ################################# # The next few lines of code check to see if the image is odd in size, # if so an extra ... row/column will be added to the bottom/right of the # image initial_row_extend = 0 initial_col_extend = 0 # If the row count of X is not divisible by 2 then we need to # extend X by adding a row at the bottom if original_size[0] % 2 != 0: bottom_row = tf.slice(X, [0, original_size[0] - 1, 0], [-1, 1, -1]) X = tf.concat([X, bottom_row], axis=1) initial_row_extend = 1 # If the col count of X is not divisible by 2 then we need to # extend X by adding a col to the right if original_size[1] % 2 != 0: right_col = tf.slice(X, [0, 0, original_size[1] - 1], [-1, -1, 1]) X = tf.concat([X, right_col], axis=2) initial_col_extend = 1 extended_size = X.get_shape().as_list()[1:3] if nlevels == 0: return X, (), () # ########################### Initialise ############################### Yh = [None, ] * nlevels # This is only required if the user specifies a third output # component. Yscale = [None, ] * nlevels # ############################ Level 1 ################################# # Uses the biorthogonal filters if nlevels >= 1: # Do odd top-level filters on cols. Lo = colfilter(X, h0o) Hi = colfilter(X, h1o) if len(self.biort) >= 6: Ba = colfilter(X, h2o) # Do odd top-level filters on rows. LoLo = rowfilter(Lo, h0o) LoLo_shape = LoLo.get_shape().as_list()[1:] # Horizontal wavelet pair (15 & 165 degrees) horiz = q2c(rowfilter(Hi, h0o)) # Vertical wavelet pair (75 & 105 degrees) vertic = q2c(rowfilter(Lo, h1o)) # Diagonal wavelet pair (45 & 135 degrees) if len(self.biort) >= 6: diag = q2c(rowfilter(Ba, h2o)) else: diag = q2c(rowfilter(Hi, h1o)) # Pack all 6 tensors into one Yh[0] = tf.stack( [horiz[0], diag[0], vertic[0], vertic[1], diag[1], horiz[1]], axis=3) Yscale[0] = LoLo # ############################ Level 2+ ################################ # Uses the qshift filters for level in xrange(1, nlevels): row_size, col_size = LoLo_shape[0], LoLo_shape[1] # If the row count of LoLo is not divisible by 4 (it will be # divisible by 2), add 2 extra rows to make it so if row_size % 4 != 0: LoLo = tf.pad(LoLo, [[0, 0], [1, 1], [0, 0]], 'SYMMETRIC') # If the col count of LoLo is not divisible by 4 (it will be # divisible by 2), add 2 extra cols to make it so if col_size % 4 != 0: LoLo = tf.pad(LoLo, [[0, 0], [0, 0], [1, 1]], 'SYMMETRIC') # Do even Qshift filters on cols. Lo = coldfilt(LoLo, h0b, h0a) Hi = coldfilt(LoLo, h1b, h1a) if len(self.qshift) >= 12: Ba = coldfilt(LoLo, h2b, h2a) # Do even Qshift filters on rows. LoLo = rowdfilt(Lo, h0b, h0a) LoLo_shape = LoLo.get_shape().as_list()[1:3] # Horizontal wavelet pair (15 & 165 degrees) horiz = q2c(rowdfilt(Hi, h0b, h0a)) # Vertical wavelet pair (75 & 105 degrees) vertic = q2c(rowdfilt(Lo, h1b, h1a)) # Diagonal wavelet pair (45 & 135 degrees) if len(self.qshift) >= 12: diag = q2c(rowdfilt(Ba, h2b, h2a)) else: diag = q2c(rowdfilt(Hi, h1b, h1a)) # Pack all 6 tensors into one Yh[level] = tf.stack( [horiz[0], diag[0], vertic[0], vertic[1], diag[1], horiz[1]], axis=3) Yscale[level] = LoLo Yl = LoLo if initial_row_extend == 1 and initial_col_extend == 1: logging.warn('The image entered is now a {0} NOT a {1}.'.format( 'x'.join(list(str(s) for s in extended_size)), 'x'.join(list(str(s) for s in original_size)))) logging.warn( """The bottom row and rightmost column have been duplicated, prior to decomposition.""") if initial_row_extend == 1 and initial_col_extend == 0: logging.warn('The image entered is now a {0} NOT a {1}.'.format( 'x'.join(list(str(s) for s in extended_size)), 'x'.join(list(str(s) for s in original_size)))) logging.warn( 'The bottom row has been duplicated, prior to decomposition.') if initial_row_extend == 0 and initial_col_extend == 1: logging.warn('The image entered is now a {0} NOT a {1}.'.format( 'x'.join(list(str(s) for s in extended_size)), 'x'.join(list(str(s) for s in original_size)))) logging.warn( """The rightmost column has been duplicated, prior to decomposition.""") return Yl, tuple(Yh), tuple(Yscale)
def _forward_ops(self, X, nlevels=3): """ Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*. For column inputs, we still need the input shape to be 3D, but with 1 as the last dimension. :param X: 3D real array of size [batch, h, w] :param nlevels: Number of levels of wavelet decomposition :param extended: True if a singleton dimension was added at the beginning of the input. Signal to remove afterwards. :returns: A tuple of Yl, Yh, Yscale """ biort = self.biort qshift = self.qshift # Try to load coefficients if biort is a string parameter try: h0o, g0o, h1o, g1o = _biort(biort) except TypeError: h0o, g0o, h1o, g1o = biort # Try to load coefficients if qshift is a string parameter try: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) except TypeError: h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = qshift # Check the shape and form of the input if X.dtype not in tf_dtypes: raise ValueError('X needs to be a tf variable or placeholder') original_size = X.get_shape().as_list()[1:] # ############################ Resize ################################# # The next few lines of code check to see if the image is odd in size, # if so an extra ... row/column will be added to the bottom/right of the # image # initial_row_extend = 0 # initial_col_extend = 0 # If the row count of X is not divisible by 2 then we need to # extend X by adding a row at the bottom if original_size[0] % 2 != 0: # X = tf.pad(X, [[0, 0], [0, 1], [0, 0]], 'SYMMETRIC') raise ValueError('Size of input X must be a multiple of 2') # extended_size = X.get_shape().as_list()[1:] if nlevels == 0: return X, (), () # ########################### Initialise ############################### Yh = [None, ] * nlevels # This is only required if the user specifies a third output # component. Yscale = [None, ] * nlevels # ############################ Level 1 ################################# # Uses the biorthogonal filters if nlevels >= 1: # Do odd top-level filters on cols. Hi = colfilter(X, h1o) Lo = colfilter(X, h0o) # Convert Hi to complex form by taking alternate rows Yh[0] = tf.cast(Hi[:,::2,:], tf.complex64) + \ 1j*tf.cast(Hi[:,1::2,:], tf.complex64) Yscale[0] = Lo # ############################ Level 2+ ################################ # Uses the qshift filters for level in xrange(1, nlevels): # If the row count of Lo is not divisible by 4 (it will be # divisible by 2), add 2 extra rows to make it so if Lo.get_shape().as_list()[1] % 4 != 0: Lo = tf.pad(Lo, [[0, 0], [1, 1], [0, 0]], 'SYMMETRIC') # Do even Qshift filters on cols. Hi = coldfilt(Lo, h1b, h1a) Lo = coldfilt(Lo, h0b, h0a) # Convert Hi to complex form by taking alternate rows Yh[level] = tf.cast(Hi[:,::2,:], tf.complex64) + \ 1j * tf.cast(Hi[:,1::2,:], tf.complex64) Yscale[level] = Lo Yl = Lo return Yl, tuple(Yh), tuple(Yscale)