def _build_tf_graph(self): niter = self._get_niter() # Create argument placeholders with same defaults as those used at graph construction time padmodh = tf.compat.v1.placeholder_with_default(DEFAULT_PAD_MODE, (), name='pad_mode') smodeh = tf.compat.v1.placeholder_with_default(DEFAULT_START_MODE, (), name='start_mode') padminh = tf.compat.v1.placeholder_with_default(tf.zeros(self.n_dims, dtype=tf.int32), self.n_dims, name='pad_min') # Data and kernel should have shapes (z, height, width) dataph = tf.compat.v1.placeholder(self.dtype, shape=[None] * self.n_dims, name='data') kernph = tf.compat.v1.placeholder(self.dtype, shape=[None] * self.n_dims, name='kernel') datah, kernh = self._wrap_input(dataph), self._wrap_input(kernph) # Add assertion operations to validate padding mode, start mode, and data/kernel dimensions flag_pad_mode = tf.stack([tf.equal(padmodh, OPM_LOG2), tf.equal(padmodh, OPM_2357), tf.equal(padmodh, OPM_NONE)], axis=0) assert_pad_mode = tf.compat.v1.assert_greater( tf.reduce_sum(tf.cast(flag_pad_mode, tf.int32)), 0, message='Pad mode not valid', data=[padmodh]) flag_start_mode = tf.stack([tf.equal(smodeh, SMODE_CONSTANT), tf.equal(smodeh, SMODE_INPUT)], axis=0) assert_start_mode = tf.compat.v1.assert_greater( tf.reduce_sum(tf.cast(flag_start_mode, tf.int32)), 0, message='Start mode not valid', data=[smodeh]) flag_shapes = tf.shape(datah) - tf.shape(kernh) assert_shapes = tf.compat.v1.assert_greater_equal( tf.reduce_sum(flag_shapes), 0, message='Data shape must be >= kernel shape', data=[tf.shape(datah), tf.shape(kernh)]) with tf.control_dependencies([assert_pad_mode, assert_start_mode, assert_shapes]): # If configured to do so, expand dimensions of data array to power of 2 or # prime factor multiples (after adding a minimum padding as well, if given) # to avoid use of Bluestein algorithm in favor of significantly faster Cooley-Tukey FFT pad_shape = tf.shape(datah) + padminh datat = tf.cond(tf.equal(padmodh, OPM_2357), lambda: pad_around_center(datah, optimize_dims(pad_shape, OPM_2357), mode=self.pad_fill), lambda: tf.cond(tf.equal(padmodh, OPM_LOG2), lambda: pad_around_center(datah, optimize_dims(pad_shape, OPM_LOG2), mode=self.pad_fill), lambda: pad_around_center(datah, pad_shape, mode=self.pad_fill) )) # Pad kernel (with zeros only) to equal dimensions of data tensor and run "circular" # transformation as this algorithm is based on circular convolutions and the results # will have half spaces swapped otherwise kernt = tf.cast(ifftshift(pad_around_center(kernh, tf.shape(datat))), self.fft_dtype) # Infer available TF FFT functions based on predefined number of data dimensions # TODO: Find a way to determine dimensionality of images separately from batch dimension and # update the rank used to get fft fns excluding batch dim fft_fwd, fft_rev = fft_utils_tf.get_fft_tf_fns(min(self.n_dims, 3), real_domain_only=self.real_domain_fft) # Determine intermediate kernel representation necessary based on domain specified to # carry out computations kern_fft = fft_fwd(kernt) if self.real_domain_fft: kern_fft_conj = fft_fwd(tf.reverse(kernt, axis=tf.range(0, self.n_dims))) else: kern_fft_conj = tf.math.conj(kern_fft) # Initialize resulting deconvolved image -- there are several sensible choices for this like the # original image or constant arrays, but some experiments show this to be better, and other # implementations doing the same are "Basic Matlab" and "Scikit-Image" (see class notes for links) decon = tf.cond( tf.equal(smodeh, SMODE_CONSTANT), lambda: tf.identity(.5 * tf.ones_like(datat, dtype=self.dtype), name='deconvolution'), # Multiplication used here to avoid https://github.com/tensorflow/tensorflow/issues/11186 lambda: tf.identity(datat * tf.ones_like(datat, dtype=self.dtype), name='deconvolution') ) def cond(i, decon): return i <= niter def conv(data, kernel_fft): return tf.math.real(fft_rev(fft_fwd(tf.cast(data, self.fft_dtype)) * kernel_fft)) def body(i, decon): # Richardson-Lucy Iteration - logic taken largely from a combination of # the scikit-image (real domain) and DeconvolutionLab2 implementations (complex domain) conv1 = conv(decon, kern_fft) # High-pass filter to avoid division by very small numbers (see DeconvolutionLab2) blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1') conv2 = conv(blur1, kern_fft_conj) # Positivity constraint on result for iteration decon = tf.maximum(decon * conv2, 0.) # If given an "observer", pass the current image restoration and iteration counter to it if self.observer_fn is not None: # Remove any cropping that may have been added as this is usually not desirable in observers decon_crop = unpad_around_center(decon, tf.shape(datah)) _, i, decon = tf_observer([decon_crop, i, decon], self.observer_fn) return i + 1, decon result = tf.while_loop(cond, body, [1, decon], parallel_iterations=1)[1] # Crop off any padding that may have been added to reach more efficient dimension sizes result = unpad_around_center(result, tf.shape(datah)) # Wrap output in configured post-processing functions (if any) result = tf.identity(self._wrap_output(result, {'data': datah, 'kernel': kernh}), name='result') inputs = { 'niter': niter, 'data': dataph, 'kernel': kernph, 'pad_mode': padmodh, 'pad_min': padminh, 'start_mode': smodeh } outputs = { 'result': result, 'data_shape': tf.shape(datah), 'kern_shape': tf.shape(kernh), 'pad_shape': pad_shape, 'pad_mode': padmodh, 'datat_shape': tf.shape(datat), 'pad_min': padminh, 'start_mode': smodeh, } return inputs, outputs
def _build_tf_graph(self): niter = self._get_niter() # Create argument placeholders with same defaults as those used at graph construction time padmodh = tf.compat.v1.placeholder_with_default(DEFAULT_PAD_MODE, (), name='pad_mode') smodeh = tf.compat.v1.placeholder_with_default(DEFAULT_START_MODE, (), name='start_mode') padminh = tf.compat.v1.placeholder_with_default(tf.zeros( self.n_dims, dtype=tf.int32), self.n_dims, name='pad_min') # Data and kernel should have shapes (z, height, width) dataph = tf.compat.v1.placeholder(self.dtype, shape=[None] * self.n_dims, name='data') kernph = tf.compat.v1.placeholder(self.dtype, shape=[None] * self.n_dims, name='kernel') datah, kernh = self._wrap_input(dataph), self._wrap_input(kernph) # Add assertion operations to validate padding mode, start mode, and data/kernel dimensions flag_pad_mode = tf.stack([ tf.equal(padmodh, OPM_LOG2), tf.equal(padmodh, OPM_2357), tf.equal(padmodh, OPM_NONE) ], axis=0) assert_pad_mode = tf.compat.v1.assert_greater( tf.reduce_sum(tf.cast(flag_pad_mode, tf.int32)), 0, message='Pad mode not valid', data=[padmodh]) flag_start_mode = tf.stack( [tf.equal(smodeh, SMODE_CONSTANT), tf.equal(smodeh, SMODE_INPUT)], axis=0) assert_start_mode = tf.compat.v1.assert_greater( tf.reduce_sum(tf.cast(flag_start_mode, tf.int32)), 0, message='Start mode not valid', data=[smodeh]) flag_shapes = tf.shape(datah) - tf.shape(kernh) assert_shapes = tf.compat.v1.assert_greater_equal( tf.reduce_sum(flag_shapes), 0, message='Data shape must be >= kernel shape', data=[tf.shape(datah), tf.shape(kernh)]) with tf.control_dependencies( [assert_pad_mode, assert_start_mode, assert_shapes]): # If configured to do so, expand dimensions of data array to power of 2 or # prime factor multiples (after adding a minimum padding as well, if given) # to avoid use of Bluestein algorithm in favor of significantly faster Cooley-Tukey FFT pad_shape = tf.shape(datah) + padminh datat = tf.cond( tf.equal(padmodh, OPM_2357), lambda: pad_around_center(datah, optimize_dims(pad_shape, OPM_2357), mode=self.pad_fill), lambda: tf.cond( tf.equal(padmodh, OPM_LOG2), lambda: pad_around_center( datah, optimize_dims(pad_shape, OPM_LOG2), mode=self.pad_fill), lambda: pad_around_center( datah, pad_shape, mode=self.pad_fill))) # Pad kernel (with zeros only) to equal dimensions of data tensor and run "circular" # transformation as this algorithm is based on circular convolutions and the results # will have half spaces swapped otherwise kernt = tf.cast( ifftshift(pad_around_center(kernh, tf.shape(datat))), self.fft_dtype) # Infer available TF FFT functions based on predefined number of data dimensions # TODO: Find a way to determine dimensionality of images separately from batch dimension and # update the rank used to get fft fns excluding batch dim fft_fwd, fft_rev = fft_utils_tf.get_fft_tf_fns( min(self.n_dims, 3), real_domain_only=self.real_domain_fft) # Determine intermediate kernel representation necessary based on domain specified to # carry out computations kern_fft = fft_fwd(kernt) if self.real_domain_fft: kern_fft_conj = fft_fwd( tf.reverse(kernt, axis=tf.range(0, self.n_dims))) else: kern_fft_conj = tf.math.conj(kern_fft) # Initialize resulting deconvolved image -- there are several sensible choices for this like the # original image or constant arrays, but some experiments show this to be better, and other # implementations doing the same are "Basic Matlab" and "Scikit-Image" (see class notes for links) decon = tf.cond( tf.equal(smodeh, SMODE_CONSTANT), lambda: tf.identity(.5 * tf.ones_like(datat, dtype=self.dtype), name='deconvolution'), # Multiplication used here to avoid https://github.com/tensorflow/tensorflow/issues/11186 lambda: tf.identity(datat * tf.ones_like(datat, dtype=self.dtype), name='deconvolution')) def cond(i, decon): return i <= niter def conv(inputData, kernel_fft): return tf.math.real( fft_rev( fft_fwd(tf.cast(inputData, self.fft_dtype)) * kernel_fft)) def gaussian_kernel( size: int, mean: float, std: float, ): """Makes 3D gaussian Kernel for convolution.""" d = tf.distributions.Normal(mean, std) vals = d.prob( tf.range(start=-size, limit=size + 1, dtype=tf.float32)) gauss_kernel = tf.einsum('i,j,k->ijk', vals, vals, vals) # return the kernel normalised to sum =1 return gauss_kernel / tf.reduce_sum(gauss_kernel) gaussKernel = gaussian_kernel(9, 1.0, 7.0) # Expand dimensions of `gauss_kernel` for `tf.nn.conv3d` signature. gaussKernel = gaussKernel[:, :, :, tf.newaxis, tf.newaxis, tf.newaxis] def body( i, decon, ): '''# Richardson-Lucy Iteration - logic taken largely from a combination of # the scikit-image (real domain) and DeconvolutionLab2 implementations (complex domain) # conv1 is the current model blurred with the PSF conv1 = conv(decon, kern_fft) # High-pass filter to avoid division by very small numbers (see DeconvolutionLab2) blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1') # conv2 is the blurred model convolved with the flipped PSF conv2 = conv(blur1, kern_fft_conj) # Positivity constraint on result for iteration decon = tf.maximum(decon * conv2, 0.) ''' # Gold algorithm, ratio method, simpler then RL, doesnt use flipped OTF # conv1 is the current model blurred with the PSF conv1 = conv(decon, kern_fft) # High-pass filter to avoid division by very small numbers (see DeconvolutionLab2)? # we wont do it here as we will use the delta parameter in denom and numerrator of division to get blur2 # as per Stephan Ludwig et al 2019 # should normalise blur2 and decon each time because numbers get big and we risk overflow when multiplying in next step conv1norm = conv1 / (tf.math.reduce_sum(conv1)) datatNorm = datat / (tf.math.reduce_sum(datat)) # this value seems to work well fo rthe images that are normalised to sum of 1 deltaParam = 1e-4 ratio = (datatNorm + deltaParam) / (conv1norm + deltaParam) #blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1') #ratioNorm = ratio / (tf.math.reduce_sum(ratio)) #deconNorm = decon / (tf.math.reduce_sum(decon)) # decon is the normalised blurred model multiplied by the model # Positivity constraint on result for iteration decon = tf.maximum(decon * ratio, 0.) # Smooth the intermediate result image with Gaussian of sigma 1 every 5th iteration # to control noise buildup that Gold method is succeptible to. # Use tf.nn.conv3d to convolve a Gaussian kernel with an image: # Make Gaussian Kernel with desired specs using gaussian_kernel function defined above if i % 5 == 0: # Convolve decon with gauss kernel. tf.nn.conv3d(decon, filter=gaussKernel, strides=[1, 1, 1, 1, 1], padding="SAME") # normalise the result so the sum of the data is 1 decon = decon / (tf.math.reduce_sum(decon)) # TODO - Smoothing every 5 iterations with gaussian or wiener. # TODO rescale back to input data sum intensity - probably need to adjust deltaParam too. # If given an "observer", pass the current image restoration and iteration counter to it if self.observer_fn is not None: # Remove any cropping that may have been added as this is usually not desirable in observers decon_crop = unpad_around_center(decon, tf.shape(datah)) # normalise the result so the sum of the data is 1 decon_crop = decon_crop / (tf.math.reduce_sum(decon_crop)) # we can use these captured observed tensors to evaluate eg convergence # in eg. the observer function used. _, i, decon, conv1 = tf_observer([decon_crop, i, decon, conv1], self.observer_fn) return i + 1, decon result = tf.while_loop(cond, body, [1, decon], parallel_iterations=1)[1] # Crop off any padding that may have been added to reach more efficient dimension sizes result = unpad_around_center(result, tf.shape(datah)) # Wrap output in configured post-processing functions (if any) result = tf.identity(self._wrap_output(result, { 'data': datah, 'kernel': kernh }), name='result') inputs = { 'niter': niter, 'data': dataph, 'kernel': kernph, 'pad_mode': padmodh, 'pad_min': padminh, 'start_mode': smodeh } outputs = { 'result': result, 'data_shape': tf.shape(datah), 'kern_shape': tf.shape(kernh), 'pad_shape': pad_shape, 'pad_mode': padmodh, 'datat_shape': tf.shape(datat), 'pad_min': padminh, 'start_mode': smodeh, } return inputs, outputs