示例#1
0
    def _build_tf_graph(self):
        niter = self._get_niter()

        # Create argument placeholders with same defaults as those used at graph construction time
        padmodh = tf.compat.v1.placeholder_with_default(DEFAULT_PAD_MODE, (), name='pad_mode')
        smodeh = tf.compat.v1.placeholder_with_default(DEFAULT_START_MODE, (), name='start_mode')
        padminh = tf.compat.v1.placeholder_with_default(tf.zeros(self.n_dims, dtype=tf.int32), self.n_dims, name='pad_min')

        # Data and kernel should have shapes (z, height, width)
        dataph = tf.compat.v1.placeholder(self.dtype, shape=[None] * self.n_dims, name='data')
        kernph = tf.compat.v1.placeholder(self.dtype, shape=[None] * self.n_dims, name='kernel')
        datah, kernh = self._wrap_input(dataph), self._wrap_input(kernph)

        # Add assertion operations to validate padding mode, start mode, and data/kernel dimensions
        flag_pad_mode = tf.stack([tf.equal(padmodh, OPM_LOG2), tf.equal(padmodh, OPM_2357), tf.equal(padmodh, OPM_NONE)], axis=0)
        assert_pad_mode = tf.compat.v1.assert_greater(
                tf.reduce_sum(tf.cast(flag_pad_mode, tf.int32)), 0,
                message='Pad mode not valid', data=[padmodh])

        flag_start_mode = tf.stack([tf.equal(smodeh, SMODE_CONSTANT), tf.equal(smodeh, SMODE_INPUT)], axis=0)
        assert_start_mode = tf.compat.v1.assert_greater(
                tf.reduce_sum(tf.cast(flag_start_mode, tf.int32)), 0,
                message='Start mode not valid', data=[smodeh])

        flag_shapes = tf.shape(datah) - tf.shape(kernh)
        assert_shapes = tf.compat.v1.assert_greater_equal(
                tf.reduce_sum(flag_shapes), 0,
                message='Data shape must be >= kernel shape', data=[tf.shape(datah), tf.shape(kernh)])

        with tf.control_dependencies([assert_pad_mode, assert_start_mode, assert_shapes]):

            # If configured to do so, expand dimensions of data array to power of 2 or
            # prime factor multiples (after adding a minimum padding as well, if given)
            # to avoid use of Bluestein algorithm in favor of significantly faster Cooley-Tukey FFT
            pad_shape = tf.shape(datah) + padminh
            datat = tf.cond(tf.equal(padmodh, OPM_2357),
                lambda: pad_around_center(datah, optimize_dims(pad_shape, OPM_2357), mode=self.pad_fill),
                lambda: tf.cond(tf.equal(padmodh, OPM_LOG2),
                lambda: pad_around_center(datah, optimize_dims(pad_shape, OPM_LOG2), mode=self.pad_fill),
                lambda: pad_around_center(datah, pad_shape, mode=self.pad_fill)
            ))

            # Pad kernel (with zeros only) to equal dimensions of data tensor and run "circular"
            # transformation as this algorithm is based on circular convolutions and the results
            # will have half spaces swapped otherwise
            kernt = tf.cast(ifftshift(pad_around_center(kernh, tf.shape(datat))), self.fft_dtype)

        # Infer available TF FFT functions based on predefined number of data dimensions
        # TODO: Find a way to determine dimensionality of images separately from batch dimension and
        # update the rank used to get fft fns excluding batch dim
        fft_fwd, fft_rev = fft_utils_tf.get_fft_tf_fns(min(self.n_dims, 3), real_domain_only=self.real_domain_fft)

        # Determine intermediate kernel representation necessary based on domain specified to
        # carry out computations
        kern_fft = fft_fwd(kernt)
        if self.real_domain_fft:
            kern_fft_conj = fft_fwd(tf.reverse(kernt, axis=tf.range(0, self.n_dims)))
        else:
            kern_fft_conj = tf.math.conj(kern_fft)

        # Initialize resulting deconvolved image -- there are several sensible choices for this like the
        # original image or constant arrays, but some experiments show this to be better, and other
        # implementations doing the same are "Basic Matlab" and "Scikit-Image" (see class notes for links)
        decon = tf.cond(
            tf.equal(smodeh, SMODE_CONSTANT),
            lambda: tf.identity(.5 * tf.ones_like(datat, dtype=self.dtype), name='deconvolution'),
            # Multiplication used here to avoid https://github.com/tensorflow/tensorflow/issues/11186
            lambda: tf.identity(datat * tf.ones_like(datat, dtype=self.dtype), name='deconvolution')
        )

        def cond(i, decon):
            return i <= niter

        def conv(data, kernel_fft):
            return tf.math.real(fft_rev(fft_fwd(tf.cast(data, self.fft_dtype)) * kernel_fft))

        def body(i, decon):
            # Richardson-Lucy Iteration - logic taken largely from a combination of
            # the scikit-image (real domain) and DeconvolutionLab2 implementations (complex domain)
            conv1 = conv(decon, kern_fft)

            # High-pass filter to avoid division by very small numbers (see DeconvolutionLab2)
            blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1')

            conv2 = conv(blur1, kern_fft_conj)

            # Positivity constraint on result for iteration
            decon = tf.maximum(decon * conv2, 0.)

            # If given an "observer", pass the current image restoration and iteration counter to it
            if self.observer_fn is not None:
                # Remove any cropping that may have been added as this is usually not desirable in observers
                decon_crop = unpad_around_center(decon, tf.shape(datah))
                _, i, decon = tf_observer([decon_crop, i, decon], self.observer_fn)

            return i + 1, decon

        result = tf.while_loop(cond, body, [1, decon], parallel_iterations=1)[1]

        # Crop off any padding that may have been added to reach more efficient dimension sizes
        result = unpad_around_center(result, tf.shape(datah))

        # Wrap output in configured post-processing functions (if any)
        result = tf.identity(self._wrap_output(result, {'data': datah, 'kernel': kernh}), name='result')

        inputs = {
            'niter': niter, 'data': dataph, 'kernel': kernph,
            'pad_mode': padmodh, 'pad_min': padminh, 'start_mode': smodeh
        }
        outputs = {
            'result': result,
            'data_shape': tf.shape(datah), 'kern_shape': tf.shape(kernh),
            'pad_shape': pad_shape, 'pad_mode': padmodh,
            'datat_shape': tf.shape(datat),
            'pad_min': padminh, 'start_mode': smodeh,
        }

        return inputs, outputs
示例#2
0
    def _build_tf_graph(self):
        niter = self._get_niter()

        # Create argument placeholders with same defaults as those used at graph construction time
        padmodh = tf.compat.v1.placeholder_with_default(DEFAULT_PAD_MODE, (),
                                                        name='pad_mode')
        smodeh = tf.compat.v1.placeholder_with_default(DEFAULT_START_MODE, (),
                                                       name='start_mode')
        padminh = tf.compat.v1.placeholder_with_default(tf.zeros(
            self.n_dims, dtype=tf.int32),
                                                        self.n_dims,
                                                        name='pad_min')

        # Data and kernel should have shapes (z, height, width)
        dataph = tf.compat.v1.placeholder(self.dtype,
                                          shape=[None] * self.n_dims,
                                          name='data')
        kernph = tf.compat.v1.placeholder(self.dtype,
                                          shape=[None] * self.n_dims,
                                          name='kernel')
        datah, kernh = self._wrap_input(dataph), self._wrap_input(kernph)

        # Add assertion operations to validate padding mode, start mode, and data/kernel dimensions
        flag_pad_mode = tf.stack([
            tf.equal(padmodh, OPM_LOG2),
            tf.equal(padmodh, OPM_2357),
            tf.equal(padmodh, OPM_NONE)
        ],
                                 axis=0)
        assert_pad_mode = tf.compat.v1.assert_greater(
            tf.reduce_sum(tf.cast(flag_pad_mode, tf.int32)),
            0,
            message='Pad mode not valid',
            data=[padmodh])

        flag_start_mode = tf.stack(
            [tf.equal(smodeh, SMODE_CONSTANT),
             tf.equal(smodeh, SMODE_INPUT)],
            axis=0)
        assert_start_mode = tf.compat.v1.assert_greater(
            tf.reduce_sum(tf.cast(flag_start_mode, tf.int32)),
            0,
            message='Start mode not valid',
            data=[smodeh])

        flag_shapes = tf.shape(datah) - tf.shape(kernh)
        assert_shapes = tf.compat.v1.assert_greater_equal(
            tf.reduce_sum(flag_shapes),
            0,
            message='Data shape must be >= kernel shape',
            data=[tf.shape(datah), tf.shape(kernh)])

        with tf.control_dependencies(
            [assert_pad_mode, assert_start_mode, assert_shapes]):

            # If configured to do so, expand dimensions of data array to power of 2 or
            # prime factor multiples (after adding a minimum padding as well, if given)
            # to avoid use of Bluestein algorithm in favor of significantly faster Cooley-Tukey FFT
            pad_shape = tf.shape(datah) + padminh
            datat = tf.cond(
                tf.equal(padmodh, OPM_2357),
                lambda: pad_around_center(datah,
                                          optimize_dims(pad_shape, OPM_2357),
                                          mode=self.pad_fill),
                lambda: tf.cond(
                    tf.equal(padmodh, OPM_LOG2), lambda: pad_around_center(
                        datah,
                        optimize_dims(pad_shape, OPM_LOG2),
                        mode=self.pad_fill), lambda: pad_around_center(
                            datah, pad_shape, mode=self.pad_fill)))

            # Pad kernel (with zeros only) to equal dimensions of data tensor and run "circular"
            # transformation as this algorithm is based on circular convolutions and the results
            # will have half spaces swapped otherwise
            kernt = tf.cast(
                ifftshift(pad_around_center(kernh, tf.shape(datat))),
                self.fft_dtype)

        # Infer available TF FFT functions based on predefined number of data dimensions
        # TODO: Find a way to determine dimensionality of images separately from batch dimension and
        # update the rank used to get fft fns excluding batch dim
        fft_fwd, fft_rev = fft_utils_tf.get_fft_tf_fns(
            min(self.n_dims, 3), real_domain_only=self.real_domain_fft)

        # Determine intermediate kernel representation necessary based on domain specified to
        # carry out computations
        kern_fft = fft_fwd(kernt)
        if self.real_domain_fft:
            kern_fft_conj = fft_fwd(
                tf.reverse(kernt, axis=tf.range(0, self.n_dims)))
        else:
            kern_fft_conj = tf.math.conj(kern_fft)

        # Initialize resulting deconvolved image -- there are several sensible choices for this like the
        # original image or constant arrays, but some experiments show this to be better, and other
        # implementations doing the same are "Basic Matlab" and "Scikit-Image" (see class notes for links)
        decon = tf.cond(
            tf.equal(smodeh, SMODE_CONSTANT),
            lambda: tf.identity(.5 * tf.ones_like(datat, dtype=self.dtype),
                                name='deconvolution'),
            # Multiplication used here to avoid https://github.com/tensorflow/tensorflow/issues/11186
            lambda: tf.identity(datat * tf.ones_like(datat, dtype=self.dtype),
                                name='deconvolution'))

        def cond(i, decon):
            return i <= niter

        def conv(inputData, kernel_fft):
            return tf.math.real(
                fft_rev(
                    fft_fwd(tf.cast(inputData, self.fft_dtype)) * kernel_fft))

        def gaussian_kernel(
            size: int,
            mean: float,
            std: float,
        ):
            """Makes 3D gaussian Kernel for convolution."""

            d = tf.distributions.Normal(mean, std)

            vals = d.prob(
                tf.range(start=-size, limit=size + 1, dtype=tf.float32))

            gauss_kernel = tf.einsum('i,j,k->ijk', vals, vals, vals)
            # return the kernel normalised to sum =1
            return gauss_kernel / tf.reduce_sum(gauss_kernel)

        gaussKernel = gaussian_kernel(9, 1.0, 7.0)
        # Expand dimensions of `gauss_kernel` for `tf.nn.conv3d` signature.
        gaussKernel = gaussKernel[:, :, :, tf.newaxis, tf.newaxis, tf.newaxis]

        def body(
            i,
            decon,
        ):
            '''# Richardson-Lucy Iteration - logic taken largely from a combination of
            # the scikit-image (real domain) and DeconvolutionLab2 implementations (complex domain)
            # conv1 is the current model blurred with the PSF
            conv1 = conv(decon, kern_fft)

            # High-pass filter to avoid division by very small numbers (see DeconvolutionLab2)
			blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1')

            # conv2 is the blurred model convolved with the flipped PSF
            conv2 = conv(blur1, kern_fft_conj)

            # Positivity constraint on result for iteration
			decon = tf.maximum(decon * conv2, 0.)
            '''

            # Gold algorithm, ratio method, simpler then RL, doesnt use flipped OTF
            # conv1 is the current model blurred with the PSF
            conv1 = conv(decon, kern_fft)

            # High-pass filter to avoid division by very small numbers (see DeconvolutionLab2)?
            # we wont do it here as we will use the delta parameter in denom and numerrator of division to get blur2
            # as per Stephan Ludwig et al 2019
            # should normalise blur2 and decon each time because numbers get big and we risk overflow when multiplying in next step
            conv1norm = conv1 / (tf.math.reduce_sum(conv1))
            datatNorm = datat / (tf.math.reduce_sum(datat))
            # this value seems to work well fo rthe images that are normalised to sum of 1
            deltaParam = 1e-4
            ratio = (datatNorm + deltaParam) / (conv1norm + deltaParam)
            #blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1')
            #ratioNorm = ratio / (tf.math.reduce_sum(ratio))
            #deconNorm = decon / (tf.math.reduce_sum(decon))
            # decon is the  normalised blurred model multiplied by the model
            # Positivity constraint on result for iteration
            decon = tf.maximum(decon * ratio, 0.)
            # Smooth the intermediate result image with Gaussian of sigma 1 every 5th iteration
            # to control noise buildup that Gold method is succeptible to.
            # Use tf.nn.conv3d to convolve a Gaussian kernel with an image:
            # Make Gaussian Kernel with desired specs using gaussian_kernel function defined above
            if i % 5 == 0:
                # Convolve decon with gauss kernel.
                tf.nn.conv3d(decon,
                             filter=gaussKernel,
                             strides=[1, 1, 1, 1, 1],
                             padding="SAME")
            # normalise the result so the sum of the data is 1
            decon = decon / (tf.math.reduce_sum(decon))

            # TODO - Smoothing every 5 iterations with gaussian or wiener.
            # TODO rescale back to input data sum intensity -  probably need to adjust deltaParam too.

            # If given an "observer", pass the current image restoration and iteration counter to it
            if self.observer_fn is not None:
                # Remove any cropping that may have been added as this is usually not desirable in observers
                decon_crop = unpad_around_center(decon, tf.shape(datah))
                # normalise the result so the sum of the data is 1
                decon_crop = decon_crop / (tf.math.reduce_sum(decon_crop))
                # we can use these captured observed tensors to evaluate eg convergence
                # in eg. the observer function used.
                _, i, decon, conv1 = tf_observer([decon_crop, i, decon, conv1],
                                                 self.observer_fn)

            return i + 1, decon

        result = tf.while_loop(cond, body, [1, decon],
                               parallel_iterations=1)[1]

        # Crop off any padding that may have been added to reach more efficient dimension sizes
        result = unpad_around_center(result, tf.shape(datah))

        # Wrap output in configured post-processing functions (if any)
        result = tf.identity(self._wrap_output(result, {
            'data': datah,
            'kernel': kernh
        }),
                             name='result')

        inputs = {
            'niter': niter,
            'data': dataph,
            'kernel': kernph,
            'pad_mode': padmodh,
            'pad_min': padminh,
            'start_mode': smodeh
        }
        outputs = {
            'result': result,
            'data_shape': tf.shape(datah),
            'kern_shape': tf.shape(kernh),
            'pad_shape': pad_shape,
            'pad_mode': padmodh,
            'datat_shape': tf.shape(datat),
            'pad_min': padminh,
            'start_mode': smodeh,
        }

        return inputs, outputs