Exemplo n.º 1
0
def wavelet_inv(Yl, Yh, biort='near_sym_b_bp', qshift='qshift_b_bp',
                data_format="nhwc"):
    """ Perform an nlevel inverse dtcwt on the input data.

    Parameters
    ----------
    Yl : :py:class:`tf.Tensor`
        Real tensor of shape (batch, h, w) or (batch, h, w, c) holding the
        lowpass input. If the shape has a channel dimension, then c inverse
        dtcwt's will be performed (the other inputs need to also match this
        shape).
    Yh : list(:py:class:`tf.Tensor`)
        A list of length nlevels. Each entry has the high pass for the scales.
        Shape has to match Yl, with a 6 on the end.
    biort : str
        Which biorthogonal filters to use. 'near_sym_b_bp' are my favourite, as
        they have 45° and 135° filters with the same period as the others.
    qshift : str
        Which quarter shift filters to use. These should match up with the
        biorthogonal used. 'qshift_b_bp' are my favourite for the same reason.
    data_format : str
        An optional string of the form "nchw" or "nhwc" (for 4D data), or "nhw"
        or "hwn" (for 3D data). This specifies the data format of the input.
        E.g. If format is "nchw" (the default), then data is in the form [batch,
        channels, h, w]. If the format is "nhwc", then the data is in the form
        [batch, h, w, c].

    Returns
    -------
    X : :py:class:`tf.Tensor`
        An input of size [batch, h', w'], where h' and w' will be larger than
        h and w by a factor of 2**nlevels
    """

    with tf.variable_scope('wavelet_inv'):
        Yh = _dtcwt_correct_phases(Yh, inv=True)
        transform = Transform2d(biort=biort, qshift=qshift)
        pyramid = Pyramid(Yl, Yh)
        X = transform.inverse_channels(pyramid, data_format=data_format)

    return X
Exemplo n.º 2
0
    def forward(self, X, nlevels=3, include_scale=False):
        """ Perform a forward transform on an image.

        Can provide the forward transform with either an np array (naive
        usage), or a tensorflow variable or placeholder (designed usage). To
        transform batches of images, use the :py:meth:`forward_channels` method.

        :param ndarray X: Input image which you wish to transform. Can be a
            numpy array, tensorflow Variable or tensorflow placeholder. See
            comments below.
        :param int nlevels: Number of levels of the dtcwt transform to
            calculate.
        :param bool include_scale: Whether or not to return the lowpass results
            at each scale of the transform, or only at the highest scale (as is
            custom for multi-resolution analysis)

        :returns: A :py:class:`dtcwt.tf.Pyramid` object

        .. note::

            If a numpy array is provided, the forward function will create a
            tensorflow variable to hold the input image, and then create the
            graph of the right size to match the input, and then feed the
            input into the graph and evaluate it.  This operation will
            return a :py:class:`Pyramid` object similar to how running
            the numpy version would.

        .. codeauthor:: Fergal Cotter <*****@*****.**>, Feb 2017
        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
        .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
        """

        # Check if a numpy array was provided
        numpy = False
        try:
            dtype = X.dtype
        except AttributeError:
            X = asfarray(X)
            dtype = X.dtype

        if dtype in np_dtypes:
            numpy = True
            X = np.atleast_2d(X)
            X = tf.Variable(X, dtype=tf.float32, trainable=False)

        if X.dtype not in tf_dtypes:
            raise ValueError('I cannot handle the variable you have ' +
                             'provided of type ' + str(X.dtype) + '. ' +
                             'Inputs should be a numpy or tf array')

        X_shape = tuple(X.get_shape().as_list())
        if len(X_shape) == 2:
            # Need to make it a batch for tensorflow
            X = tf.expand_dims(X, axis=0)
        elif len(X_shape) >= 3:
            raise ValueError(
                'The entered variable has too many ' +
                'dimensions - ' + str(X_shape) + '. For batches of ' +
                'images with multiple channels (i.e. 3 or 4 dimensions), ' +
                'please either enter each channel separately, or use ' +
                'the forward_channels method.')

        X_shape = tuple(X.get_shape().as_list())
        original_size = X_shape[1:]
        size = '{}x{}'.format(original_size[0], original_size[1])
        name = 'dtcwt_fwd_{}'.format(size)
        with tf.variable_scope(name):
            Yl, Yh, Yscale = self._forward_ops(X, nlevels)

        Yl = Yl[0]
        Yh = tuple(x[0] for x in Yh)
        Yscale = tuple(x[0] for x in Yscale)

        if include_scale:
            return Pyramid(Yl, Yh, Yscale, numpy)
        else:
            return Pyramid(Yl, Yh, None, numpy)
Exemplo n.º 3
0
    def forward_channels(self, X, data_format, nlevels=3,
                         include_scale=False):
        """ Perform a forward transform on an image with multiple channels.

        Will perform the DTCWT independently on each channel.

        :param X: Input image which you wish to transform.
        :param int nlevels: Number of levels of the dtcwt transform to
            calculate.
        :param bool include_scale: Whether or not to return the lowpass results
            at each scale of the transform, or only at the highest scale (as is
            custom for multiresolution analysis)
        :param str data_format: An optional string of the form:
            "nhw" (or "chw"), "hwn" (or "hwc"), "nchw" or "nhwc". Note that for
            these strings, 'n' is used to indicate where the batch dimension is,
            'c' is used to indicate where the image channels are, 'h' is used to
            indicate where the row dimension is, and 'c' is used to indicate
            where the columns are. If the data_format is:

                - "nhw" : the input will be interpreted as a batch of 2D images,
                  with the batch dimension as the first.
                - "chw" : will function exactly the same as "nhw" but is offered
                  to indicate the input is a 2D image with channels.
                - "hwn" : the input will be interpreted as a batch of 2D images
                  with the batch dimension as the last.
                - "hwc" : will function exatly the same as "hwc" but is offered
                  to indicate the input is a 2D image with channels.
                - "nchw" : the input is a batch of images with channel dimension
                  as the second dimension. Batch dimension is first.
                - "nhwc" : the input is a batch of images with channel dimension
                  as the last dimension. Batch dimension is first.

        :returns: A :py:class:`dtcwt.tf.Pyramid` object

        .. codeauthor:: Fergal Cotter <*****@*****.**>, Feb 2017
        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
        .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
        """
        data_format = data_format.lower()
        formats_3d = ("nhw", "chw", "hwn", "hwc")
        formats_4d = ("nchw", "nhwc")
        formats = formats_3d + formats_4d
        if data_format not in formats:
            raise ValueError('The data format must be one of: {}'.
                             format(formats))

        try:
            dtype = X.dtype
        except AttributeError:
            X = asfarray(X)
            dtype = X.dtype

        numpy = False
        if dtype in np_dtypes:
            numpy = True
            X = np.atleast_2d(X)
            X = tf.Variable(X, dtype=tf.float32, trainable=False)

        if X.dtype not in tf_dtypes:
            raise ValueError('I cannot handle the variable you have ' +
                             'provided of type ' + str(X.dtype) + '. ' +
                             'Inputs should be a numpy or tf array.')

        X_shape = X.get_shape().as_list()
        if not ((len(X_shape) == 3 and data_format in formats_3d) or
                (len(X_shape) == 4 and data_format in formats_4d)):
            raise ValueError(
                'The entered variable has incorrect shape - ' +
                str(X_shape) + ' for the specified data_format ' +
                data_format + '.')

        # Reshape the inputs to all be 3d inputs of shape (batch, h, w)
        if data_format in formats_4d:
            # Move all of the channels into the batch dimension for the
            # input.  This may involve transposing, depending on the data
            # format
            with tf.variable_scope('ch_to_batch'):
                s = X.get_shape().as_list()[1:]
                size = '{}x{}'.format(s[0], s[1])
                name = 'dtcwt_fwd_{}'.format(size)
                if data_format == 'nhwc':
                    nch = s[2]
                    X = tf.transpose(X, perm=[0, 3, 1, 2])
                    X = tf.reshape(X, [-1, s[0], s[1]])
                else:
                    nch = s[0]
                    X = tf.reshape(X, [-1, s[1], s[2]])
        elif data_format == "hwn" or data_format == "hwc":
            s = X.get_shape().as_list()[:2]
            size = '{}x{}'.format(s[0], s[1])
            name = 'dtcwt_fwd_{}'.format(size)
            with tf.variable_scope('ch_to_start'):
                X = tf.transpose(X, perm=[2,0,1])
        else:
            s = X.get_shape().as_list()[1:3]
            size = '{}x{}'.format(s[0], s[1])
            name = 'dtcwt_fwd_{}'.format(size)

        # Do the dtcwt, now with a 3 dimensional input
        with tf.variable_scope(name):
            Yl, Yh, Yscale = self._forward_ops(X, nlevels)

        # Reshape it all again to match the input
        if data_format in formats_4d:
            # Put the channels back into their correct positions
            with tf.variable_scope('batch_to_ch'):
                # Reshape Yl
                s = Yl.get_shape().as_list()[1:]
                Yl = tf.reshape(Yl, [-1, nch, s[0], s[1]], name='Yl_reshape')
                if data_format == 'nhwc':
                    Yl = tf.transpose(Yl, [0, 2, 3, 1], name='Yl_ch_to_end')

                # Reshape Yh
                with tf.variable_scope('Yh'):
                    Yh_new = [None,] * nlevels
                    for i in range(nlevels):
                        s = Yh[i].get_shape().as_list()[1:]
                        Yh_new[i] = tf.reshape(
                            Yh[i], [-1, nch, s[0], s[1], s[2]],
                            name='scale{}_reshape'.format(i))
                        if data_format == 'nhwc':
                            Yh_new[i] = tf.transpose(
                                Yh_new[i], [0, 2, 3, 1, 4],
                                name='scale{}_ch_to_end'.format(i))
                    Yh = tuple(Yh_new)

                # Reshape Yscale
                if include_scale:
                    with tf.variable_scope('Yscale'):
                        Yscale_new = [None,] * nlevels
                        for i in range(nlevels):
                            s = Yscale[i].get_shape().as_list()[1:]
                            Yscale_new[i] = tf.reshape(
                                Yscale[i], [-1, nch, s[0], s[1]],
                                name='scale{}_reshape'.format(i))
                            if data_format == 'nhwc':
                                Yscale_new[i] = tf.transpose(
                                    Yscale_new[i], [0, 2, 3, 1],
                                    name='scale{}_ch_to_end'.format(i))
                        Yscale = tuple(Yscale_new)

        elif data_format == "hwn" or data_format == "hwc":
            with tf.variable_scope('ch_to_end'):
                Yl = tf.transpose(Yl, perm=[1,2,0], name='Yl')
                Yh = tuple(
                    tf.transpose(x, [1, 2, 0, 3], name='Yh{}'.format(i))
                    for i,x in enumerate(Yh))
                if include_scale:
                    Yscale = tuple(
                        tf.transpose(x, [1, 2, 0], name='Yscale{}'.format(i))
                        for i,x in enumerate(Yscale))

        # Return the pyramid
        if include_scale:
            return Pyramid(Yl, Yh, Yscale, numpy)
        else:
            return Pyramid(Yl, Yh, None, numpy)
Exemplo n.º 4
0
    def forward(self, X, nlevels=3, include_scale=False):
        """Perform a *n*-level DTCWT decompostion on a 1D column vector *X* (or on
        the columns of a matrix *X*).

        Can provide the forward transform with either an np array (naive usage),
        or a tensorflow variable or placeholder (designed usage). To transform
        batches of vectors, use the :py:meth:`forward_channels` method.

        :param X: 1D real array or 2D real array whose columns are to be
            transformed.
        :param nlevels: Number of levels of wavelet decomposition

        :returns: A :py:class:`dtcwt.tf.Pyramid` object representing the
            transform result.

        If *biort* or *qshift* are strings, they are used as an argument to the
        :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
        interpreted as tuples of vectors giving filter coefficients. In the
        *biort* case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case,
        this should be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).

        .. codeauthor:: Fergal Cotter <*****@*****.**>, Sep 2017
        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
        .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002

        """
        # Check if a numpy array was provided
        numpy = False
        try:
            dtype = X.dtype
        except AttributeError:
            X = asfarray(X)
            dtype = X.dtype

        if dtype in np_dtypes:
            numpy = True
            # Need this because colfilter and friends assumes input is 2d
            if len(X.shape) == 1:
                X = np.atleast_2d(X).T
            X = tf.Variable(X, dtype=tf.float32, trainable=False)
        elif dtype in tf_dtypes:
            if len(X.get_shape().as_list()) == 1:
                X = tf.expand_dims(X, axis=-1)
        else:
            raise ValueError('I cannot handle the variable you have ' +
                             'provided of type ' + str(X.dtype) + '. ' +
                             'Inputs should be a numpy or tf array')

        X_shape = tuple(X.get_shape().as_list())
        size = '{}'.format(X_shape[0])
        name = 'dtcwt_fwd_{}'.format(size)
        if len(X_shape) == 2:
            # Need to make it a batch for tensorflow
            X = tf.expand_dims(X, axis=0)
        elif len(X_shape) >= 3:
            raise ValueError(
                'The entered variable has too many ' +
                'dimensions - ' + str(X_shape) + '.')

        # Do the forward transform
        with tf.variable_scope(name):
            Yl, Yh, Yscale = self._forward_ops(X, nlevels)

        Yl = Yl[0]
        Yh = tuple(x[0] for x in Yh)
        Yscale = tuple(x[0] for x in Yscale)

        if include_scale:
            return Pyramid(Yl, Yh, Yscale, numpy)
        else:
            return Pyramid(Yl, Yh, None, numpy)
Exemplo n.º 5
0
    def forward_channels(self, X, nlevels=3, include_scale=False):
        """Perform a *n*-level DTCWT decompostion on a 3D array *X*.

        Can provide the forward transform with either an np array (naive usage),
        or a tensorflow variable or placeholder (designed usage).

        :param X: 3D real array. Batch of matrices whose columns are to be
            transformed (i.e. the second dimension).
        :param nlevels: Number of levels of wavelet decomposition

        :returns: A :py:class:`dtcwt.tf.Pyramid` object representing the
            transform result.

        If *biort* or *qshift* are strings, they are used as an argument to the
        :py:func:`biort` or :py:func:`qshift` functions. Otherwise, they are
        interpreted as tuples of vectors giving filter coefficients. In the
        *biort* case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case,
        this should be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).

        .. codeauthor:: Fergal Cotter <*****@*****.**>, Sep 2017
        .. codeauthor:: Rich Wareham <*****@*****.**>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
        .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002

        """
        # Check if a numpy array was provided
        numpy = False
        try:
            dtype = X.dtype
        except AttributeError:
            X = asfarray(X)
            dtype = X.dtype

        if dtype in np_dtypes:
            numpy = True
            if len(X.shape) != 3:
                raise ValueError(
                    'Incorrect input shape for the forward_channels ' +
                    'method ' + str(X.shape) + '. For Inputs of 1 or 2 ' +
                    'dimensions, use the forward method.')
            # Need this because colfilter and friends assumes input is 2d
            X = tf.Variable(X, dtype=tf.float32, trainable=False)
        elif dtype in tf_dtypes:
            X_shape = X.get_shape().as_list()
            if len(X.get_shape().as_list()) != 3:
                raise ValueError(
                    'Incorrect input shape for the forward_channels ' +
                    'method ' + str(X_shape) + '. For Inputs of 1 or 2 ' +
                    'dimensions, use the forward method.')
        else:
            raise ValueError('I cannot handle the variable you have ' +
                             'provided of type ' + str(X.dtype) + '. ' +
                             'Inputs should be a numpy or tf array')

        X_shape = tuple(X.get_shape().as_list())
        size = '{}'.format(X_shape[1])
        name = 'dtcwt_fwd_{}'.format(size)

        # Do the forward transform
        with tf.variable_scope(name):
            Yl, Yh, Yscale = self._forward_ops(X, nlevels)

        if include_scale:
            return Pyramid(Yl, Yh, Yscale, numpy)
        else:
            return Pyramid(Yl, Yh, None, numpy)