def visualise_weights(self): try: self.model.visualise_weights() except AttributeError: lprint( f"Method 'visualise_weights()' unavailable, cannot visualise weights. " )
def _train_loop(self, data_loader, optimizer, loss_function): try: self.model.kernel_continuity_regularisation = False except AttributeError: lprint("Cannot deactivate kernel continuity regularisation") super()._train_loop(data_loader, optimizer, loss_function)
def _debug_allocation(self, info): if self.__debug_allocation: if self.backend == 'cupy': import cupy lprint( f"CUDA memory usage {info}: {cupy.get_default_memory_pool().used_bytes() / 1e6} MB" )
def __init__( self, max_epochs=1024, patience=None, patience_epsilon=0.0, learning_rate=0.02, batch_size=64, model_class=FeedForward, masking=True, masking_density=0.01, loss='l1', normaliser_type='percentile', balance_training_data=None, keep_ratio=1, max_voxels_for_training=4e6, monitor=None, use_cuda=True, device_index=0, ): """ Constructs an image translator using the pytorch deep learning library. :param normaliser_type: normaliser type :param balance_training_data: balance data ? (limits number training entries per target value histogram bin) :param monitor: monitor to track progress of training externally (used by UI) """ super().__init__(normaliser_type, monitor=monitor) use_cuda = use_cuda and (torch.cuda.device_count() > 0) self.device = torch.device( f"cuda:{device_index}" if use_cuda else "cpu") lprint(f"Using device: {self.device}") self.max_epochs = max_epochs self.patience = max_epochs if patience is None else patience self.patience_epsilon = patience_epsilon self.learning_rate = learning_rate self.batch_size = batch_size self.loss = loss self.max_voxels_for_training = max_voxels_for_training self.keep_ratio = keep_ratio self.balance_training_data = balance_training_data self.model_class = model_class self.l1_weight_regularisation = 1e-6 self.l2_weight_regularisation = 1e-6 self.training_noise = 0.1 self.reload_best_model_period = max_epochs # //2 self.reduce_lr_patience = patience // 2 self.reduce_lr_factor = 0.9 self.masking = masking self.masking_density = masking_density self.optimiser_class = ESAdam self.max_tile_size = 1024 # TODO: adjust based on available memory self._stop_training_flag = False
def _additional_losses(self, translated_image, forward_model_image): loss = 0 # non-negativity loss: if self.bounds_loss and self.bounds_loss != 0: epsilon = 0 * 1e-8 bounds_loss = F.relu(-translated_image - epsilon) bounds_loss += F.relu(translated_image - 1 - epsilon) bounds_loss_value = bounds_loss.mean() lprint(f"bounds_loss_value = {bounds_loss_value}") loss += self.bounds_loss * bounds_loss_value**2 # Sharpen loss_deconvolution: if self.sharpening and self.sharpening != 0: image_for_loss = translated_image num_elements = image_for_loss[0, 0].nelement() sharpening_loss = -torch.norm( image_for_loss, dim=(2, 3), keepdim=True, p=2) / ( num_elements**2 ) # /torch.norm(image_for_loss, dim=(2, 3), keepdim=True, p=1) lprint(f"sharpening loss = {sharpening_loss}") loss += self.sharpening * sharpening_loss.mean() # Max entropy loss: if self.entropy and self.entropy != 0: entropy_value = entropy(translated_image) lprint(f"entropy_value = {entropy_value}") loss += -self.entropy * entropy_value return loss
def download_from_gdrive(id, name, dest_folder=datasets_folder, overwrite=False, unzip=False): try: os.makedirs(dest_folder) except Exception: pass url = f'https://drive.google.com/uc?id={id}' output_path = join(dest_folder, name) if overwrite or not exists(output_path): lprint(f"Downloading file {output_path} as it does not exist yet.") gdown.download(url, output_path, quiet=False) if unzip: lprint(f"Unzipping file {output_path}...") zip_ref = zipfile.ZipFile(output_path, 'r') zip_ref.extractall(dest_folder) zip_ref.close() # os.remove(output_path) return output_path else: lprint(f"Not downloading file {output_path} as it already exists.") return None
def offcore_array( shape: Union[Tuple[int, ...], Generator[int, None, None]], dtype: numpy.dtype, force_memmap: bool = False, zarr_allowed: bool = False, no_memmap_limit: bool = True, max_memory_usage_ratio: float = 0.9, ): """ Instanciates an array of given shape and dtype in 'off-core' fashion i.e. not in main memory. Right now it simply uses memory mapping on temp file that is deleted after the file is closed Parameters ---------- shape dtype force_memmap zarr_allowed no_memmap_limit max_memory_usage_ratio """ with lsection(f"Array of shape: {shape} and dtype: {dtype} requested"): size_in_bytes = numpy.prod(shape) * numpy.dtype(dtype).itemsize lprint(f'Array requested will be {(size_in_bytes / 1E6)} MB.') total_physical_memory_in_bytes = psutil.virtual_memory().total total_swap_memory_in_bytes = psutil.swap_memory().total total_mem_in_bytes = total_physical_memory_in_bytes + total_swap_memory_in_bytes lprint( f'There is {int(psutil.virtual_memory().total / 1E6)} MB of physical memory' ) lprint( f'There is {int(psutil.swap_memory().total / 1E6)} MB of swap memory' ) lprint(f'There is {int(total_mem_in_bytes / 1E6)} MB of total memory') is_enough_physical_memory = (size_in_bytes < max_memory_usage_ratio * total_physical_memory_in_bytes) is_enough_total_memory = (size_in_bytes < max_memory_usage_ratio * total_mem_in_bytes) if not force_memmap and is_enough_total_memory: lprint( f'There is enough physical+swap memory -- we do not need to use a mem mapped array or zarr-backed array.' ) array = numpy.zeros(shape, dtype=dtype) elif no_memmap_limit: lprint( f'There is not enough physical+swap memory -- we will use a mem mapped array.' ) temp_file = tempfile.NamedTemporaryFile( dir=OffCore.memmap_directory) lprint( f'The temporary memory mapped file is at: {temp_file.name} (but you might not be able to see it!)' ) array = numpy.memmap(temp_file, dtype=dtype, mode='w+', shape=shape) elif zarr_allowed: lprint( f'There is not enough physical+swap memory -- we will use a zarr-backed array.' ) import zarr array = zarr.create(shape=shape, dtype=dtype, store=zarr.TempStore("output.zarr")) # from numcodecs import Blosc # compressor = Blosc(cname = 'zstd', clevel = 3, shuffle = Blosc.BITSHUFFLE) # array = zarr.zeros((102_0, 200, 210), chunks = (100, 200, 210), compressor = compressor return array
def _train( self, input_image, target_image, train_valid_ratio=0.1, callback_period=3, jinv=False, ): self._stop_training_flag = False if jinv is not None and not jinv: self.masking = False shape = input_image.shape num_batches = shape[0] num_input_channels = input_image.shape[1] num_output_channels = target_image.shape[1] num_spatiotemp_dim = input_image.ndim - 2 # tile size: tile_size = min(self.max_tile_size, min(shape[2:])) # Decide on how many voxels to be used for validation: num_val_voxels = int(train_valid_ratio * input_image.size) lprint( f"Number of voxels used for validation: {num_val_voxels} (train_valid_ratio={train_valid_ratio})" ) # Generate random coordinates for these voxels: val_voxels = tuple( numpy.random.randint(d, size=num_val_voxels) for d in shape) lprint(f"Validation voxel coordinates: {val_voxels}") # Training Tile size: lprint(f"Train Tile dimensions: {tile_size}") # Prepare Training Dataset: dataset = self._get_dataset(input_image, target_image, self.self_supervised, tilesize=tile_size, mode='grid', validation_voxels=val_voxels, batch_size=self.batch_size) lprint(f"Number tiles for training: {len(dataset)}") # Training Data Loader: # num_workers = max(3, os.cpu_count() // 2) num_workers = 0 # faster if data is already in memory... lprint( f"Number of workers for loading training/validation data: {num_workers}" ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=1, # self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) # Model self.model = self.model_class(num_input_channels, num_output_channels).to(self.device) number_of_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) lprint( f"Number of trainable parameters in {self.model_class} model: {number_of_parameters}" ) if self.masking: self.masked_model = Masking(self.model, density=0.5).to(self.device) lprint(f"Optimiser class: {self.optimiser_class}") lprint(f"Learning rate : {self.learning_rate}") # Optimiser: optimizer = self.optimiser_class( chain(self.model.parameters()), lr=self.learning_rate, start_noise_level=self.training_noise, weight_decay=self.l2_weight_regularisation, ) lprint(f"Optimiser: {optimizer}") # Denoise loss functon: loss_function = nn.L1Loss() if self.loss.lower() == 'l2': lprint(f"Training/Validation loss: L2") if self.masking: loss_function = (lambda u, v, m: (u - v)**2 if m is None else ((u - v) * m)**2) else: loss_function = lambda u, v: (u - v)**2 elif self.loss.lower() == 'l1': if self.masking: loss_function = (lambda u, v, m: torch.abs(u - v) if m is None else torch.abs((u - v) * m)) else: loss_function = lambda u, v: torch.abs(u - v) lprint(f"Training/Validation loss: L1") # Start training: self._train_loop(data_loader, optimizer, loss_function)
def _train_loop(self, data_loader, optimizer, loss_function): # Scheduler: scheduler = ReduceLROnPlateau( optimizer, 'min', factor=self.reduce_lr_factor, verbose=True, patience=self.reduce_lr_patience, ) best_val_loss_value = math.inf best_model_state_dict = None patience_counter = 0 with lsection(f"Training loop:"): lprint(f"Maximum number of epochs: {self.max_epochs}") lprint( f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}" ) for epoch in range(self.max_epochs): with lsection(f"Epoch {epoch}:"): if hasattr(self, 'masked_model'): self.masked_model.density = 0.005 * self.masking_density + 0.995 * self.masked_model.density lprint(f"masking density: {self.masked_model.density}") train_loss_value = 0 val_loss_value = 0 iteration = 0 for i, (input_images, target_images, val_mask_images) in enumerate(data_loader): lprint(f"index: {i}, shape:{input_images.shape}") input_images_gpu = input_images.to(self.device, non_blocking=True) target_images_gpu = target_images.to(self.device, non_blocking=True) validation_mask_images_gpu = val_mask_images.to( self.device, non_blocking=True) # Adding training noise to input: if self.training_noise > 0: with torch.no_grad(): alpha = self.training_noise / ( 1 + (10000 * epoch / self.max_epochs)) lprint(f"Training noise level: {alpha}") training_noise = alpha * torch.randn_like( input_images) input_images_gpu += training_noise.to( input_images_gpu.device) # Clear gradients w.r.t. parameters optimizer.zero_grad() # Forward pass: self.model.train() if self.masking: translated_images_gpu = self.masked_model( input_images_gpu) else: translated_images_gpu = self.model( input_images_gpu) # apply forward model: forward_model_images_gpu = self._forward_model( translated_images_gpu) # validation masking: u = forward_model_images_gpu * ( 1 - validation_mask_images_gpu) v = target_images_gpu * (1 - validation_mask_images_gpu) # with napari.gui_qt(): # viewer = napari.Viewer() # viewer.add_image(to_numpy(validation_mask_images_gpu), name='validation_mask_images_gpu') # viewer.add_image(to_numpy(forward_model_images_gpu), name='forward_model_images_gpu') # viewer.add_image(to_numpy(target_images_gpu), name='target_images_gpu') # translation loss (per voxel): if self.masking: mask = self.masked_model.get_mask() translation_loss = loss_function(u, v, mask) else: translation_loss = loss_function(u, v) # loss value (for all voxels): translation_loss_value = translation_loss.mean() # Additional losses: additional_loss_value = self._additional_losses( translated_images_gpu, forward_model_images_gpu) if additional_loss_value is not None: translation_loss_value += additional_loss_value # backpropagation: translation_loss_value.backward() # Updating parameters optimizer.step() # post optimisation -- if needed: self.model.post_optimisation() # update training loss_deconvolution for whole image: train_loss_value += translation_loss_value.item() iteration += 1 # Validation: with torch.no_grad(): # Forward pass: self.model.eval() if self.masking: translated_images_gpu = self.masked_model( input_images_gpu) else: translated_images_gpu = self.model( input_images_gpu) # apply forward model: forward_model_images_gpu = self._forward_model( translated_images_gpu) # validation masking: u = forward_model_images_gpu * validation_mask_images_gpu v = target_images_gpu * validation_mask_images_gpu # translation loss (per voxel): if self.masking: translation_loss = loss_function(u, v, None) else: translation_loss = loss_function(u, v) # loss values: translation_loss_value = ( translation_loss.mean().cpu().item()) # update validation loss_deconvolution for whole image: val_loss_value += translation_loss_value iteration += 1 train_loss_value /= iteration lprint(f"Training loss value: {train_loss_value}") val_loss_value /= iteration lprint(f"Validation loss value: {val_loss_value}") # Learning rate schedule: scheduler.step(val_loss_value) if val_loss_value < best_val_loss_value: lprint(f"## New best val loss!") if val_loss_value < best_val_loss_value - self.patience_epsilon: lprint(f"## Good enough to reset patience!") patience_counter = 0 # Update best val loss value: best_val_loss_value = val_loss_value # Save model: best_model_state_dict = OrderedDict({ k: v.to('cpu') for k, v in self.model.state_dict().items() }) else: if (epoch % max(1, self.reload_best_model_period) == 0 and best_model_state_dict): lprint(f"Reloading best models to date!") self.model.load_state_dict(best_model_state_dict) if patience_counter > self.patience: lprint(f"Early stopping!") break # No improvement: lprint( f"No improvement of validation losses, patience = {patience_counter}/{self.patience} " ) patience_counter += 1 lprint(f"## Best val loss: {best_val_loss_value}") if self._stop_training_flag: lprint(f"Training interupted!") break lprint(f"Reloading best models to date!") self.model.load_state_dict(best_model_state_dict)
def train( self, input_image, target_image=None, batch_dims=None, channel_dims=None, train_valid_ratio=0.1, callback_period=3, jinv=None, ): """Train to translate a given input image to a given output image. This has a lot of the machinery for batching and more... """ if target_image is None: target_image = input_image with lsection( f"Learning to translate from image of dimensions {str(input_image.shape)} to {str(target_image.shape)} ." ): lprint('Running garbage collector...') gc.collect() # If we use the same image for input and output then we are in a self-supervised setting: self.self_supervised = input_image is target_image if self.self_supervised: lprint('Training is self-supervised.') else: lprint('Training is supervised.') if batch_dims is None: # set default batch_dim value: batch_dims = (False, ) * len(input_image.shape) self.input_normaliser = self.normalizer_class( transform=self.normaliser_transform, clip=self.normaliser_clip) self.target_normaliser = ( self.input_normaliser if self.self_supervised else self.normalizer_class(transform=self.normaliser_transform, clip=self.normaliser_clip)) # Calibrates normaliser(s): self.input_normaliser.calibrate(input_image) if not self.self_supervised: self.target_normaliser.calibrate(target_image) # Intensity values normalisation: normalised_input_image = self.input_normaliser.normalise( input_image, batch_dims=batch_dims, channel_dims=channel_dims) normalised_target_image = ( normalised_input_image if self.self_supervised else self.target_normaliser.normalise(target_image, batch_dims=batch_dims, channel_dims=channel_dims)) # Let's pad the images to avoid border effects: # If we do it for translation we also have to do it for training because of # location-aware features such as large-scale features or spatial-features. normalised_input_image = self._pad_norm_image( normalised_input_image) normalised_target_image = self._pad_norm_image( normalised_target_image) self._train( normalised_input_image, normalised_target_image, train_valid_ratio=train_valid_ratio, callback_period=callback_period, jinv=jinv, )
def _get_tilling_strategy_and_margins( self, image, max_voxels_per_tile, min_margin, max_margin, suggested_tile_size=None, ): # We will store the batch strategy as a list of integers representing the number of chunks per dimension: with lsection(f"Determine tilling strategy:"): suggested_tile_size = (math.inf if suggested_tile_size is None else suggested_tile_size) # image shape: shape = image.shape num_spatio_temp_dim = num_spatiotemp_dim = len(shape) - 2 lprint(f"image shape = {shape}") lprint(f"max_voxels_per_tile = {max_voxels_per_tile}") # Estimated amount of memory needed for storing all features: ( estimated_memory_needed, total_memory_available, ) = self._estimate_memory_needed_and_available(image) lprint( f"Estimated amount of memory needed: {estimated_memory_needed}" ) # Available physical memory : total_memory_available *= self.max_memory_usage_ratio lprint( f"Available memory (we reserve 10% for 'comfort'): {total_memory_available}" ) # How much do we need to tile because of memory, if at all? split_factor_mem = estimated_memory_needed / total_memory_available lprint( f"How much do we need to tile because of memory? : {split_factor_mem} times." ) # how much do we have to tile because of the limit on the number of voxels per tile? split_factor_max_voxels = image.size / max_voxels_per_tile lprint( f"How much do we need to tile because of the limit on the number of voxels per tile? : {split_factor_max_voxels} times." ) # how much do we have to tile because of the suggested tile size? split_factor_suggested_tile_size = image.size / ( suggested_tile_size**num_spatio_temp_dim) lprint( f"How much do we need to tile because of the suggested tile size? : {split_factor_suggested_tile_size} times." ) # we keep the max: desired_split_factor = max( split_factor_mem, split_factor_max_voxels, split_factor_suggested_tile_size, ) # We cannot split less than 1 time: desired_split_factor = max(1, int(math.ceil(desired_split_factor))) lprint(f"Desired split factor: {desired_split_factor}") # Number of batches: num_batches = shape[0] # Does the number of batches split the data enough? if num_batches < desired_split_factor: # Not enough splitting happening along the batch dimension, we need to split further: # how much? rest_split_factor = desired_split_factor / num_batches lprint( f"Not enough splitting happening along the batch dimension, we need to split spatio-temp dims by: {rest_split_factor}" ) # let's split the dimensions in a way proportional to their lengths: split_per_dim = (rest_split_factor / numpy.prod(shape[2:]))**(1 / num_spatio_temp_dim) spatiotemp_tilling_strategy = tuple( max(1, int(math.ceil(split_per_dim * s))) for s in shape[2:]) # correction_factor = numpy.prod(tuple(s for s in spatiotemp_tilling_strategy if s<1)) tilling_strategy = (num_batches, 1) + spatiotemp_tilling_strategy lprint(f"Preliminary tilling strategy is: {tilling_strategy}") # We correct for eventual oversplitting by favouring splitting over the front dimensions: current_splitting_factor = 1 corrected_tilling_strategy = [] split_factor_reached = False for i, s in enumerate(tilling_strategy): if split_factor_reached: corrected_tilling_strategy.append(1) else: corrected_tilling_strategy.append(s) current_splitting_factor *= s if current_splitting_factor >= desired_split_factor: split_factor_reached = True tilling_strategy = tuple(corrected_tilling_strategy) else: tilling_strategy = (desired_split_factor, 1) + tuple( 1 for s in shape[2:]) lprint(f"Tilling strategy is: {tilling_strategy}") # Handles defaults: if max_margin is None: max_margin = math.inf if min_margin is None: min_margin = 0 # First we estimate the shape of a tile: estimated_tile_shape = tuple( int(round(s / ts)) for s, ts in zip(shape[2:], tilling_strategy[2:])) lprint(f"The estimated tile shape is: {estimated_tile_shape}") # Limit margins: # We automatically set the margin of the tile size: # the max-margin factor guarantees that tilling will incur no more than a given max tiling overhead: margin_factor = 0.5 * (( (1 + self.max_tilling_overhead)**(1 / num_spatiotemp_dim)) - 1) margins = tuple( int(s * margin_factor) for s in estimated_tile_shape) # Limit the margin to something reasonable (provided or automatically computed): margins = tuple(min(max_margin, m) for m in margins) margins = tuple(max(min_margin, m) for m in margins) # We add the batch and channel dimensions: margins = (0, 0) + margins # We only need margins if we split a dimension: margins = tuple( (0 if split == 1 else margin) for margin, split in zip(margins, tilling_strategy)) return tilling_strategy, margins
def translate( self, input_image, translated_image=None, batch_dims=None, channel_dims=None, tile_size=None, denormalise_values=True, leave_as_float=False, clip=True, ): """ Translates an input image into an output image according to the learned function. :param input_image: :type input_image: :param clip: :type clip: :return: :rtype: """ with lsection( f"Predicting output image from input image of dimension {input_image.shape}" ): # set default batch_dim and channel_dim values: if batch_dims is None: batch_dims = (False, ) * len(input_image.shape) if channel_dims is None: channel_dims = (False, ) * len(input_image.shape) # Number of spatio-temporal dimensions: num_spatiotemp_dim = sum(0 if b or c else 1 for b, c in zip(batch_dims, channel_dims)) # First we normalise the input values: normalised_input_image = self.input_normaliser.normalise( input_image, batch_dims=batch_dims, channel_dims=channel_dims) # When we trained supervised we need to update permutated image shape of target_normaliser # This way we can accommodate different sizes of batch dimensions than batch dimensions used for training if not self.self_supervised: ( _, _, self.target_normaliser.permutated_image_shape, ) = self.target_normaliser.shape_normalize( input_image, batch_dims=batch_dims, channel_dims=channel_dims) # Let's pad the input array so we avoid annoying border-effects: normalised_input_image = self._pad_norm_image( normalised_input_image) # Spatio-temporal shape: spatiotemp_shape = normalised_input_image.shape[ -num_spatiotemp_dim:] normalised_translated_image = None if tile_size == 0: # we _force_ no tilling, this is _not_ the default. # We translate: normalised_translated_image = self._translate( normalised_input_image, whole_image_shape=normalised_input_image.shape, ) else: # We do need to do tiled inference because of a lack of memory # or because a small batch size was requested: normalised_input_shape = normalised_input_image.shape # We get the tilling strategy: # tile_size, shape, min_margin, max_margin tilling_strategy, margins = self._get_tilling_strategy_and_margins( normalised_input_image, self.max_voxels_per_tile, self.tile_min_margin, self.tile_max_margin, suggested_tile_size=tile_size, ) lprint(f"Tilling strategy: {tilling_strategy}") lprint(f"Margins for tiles: {margins} .") # tile slice objects (with and without margins): tile_slices_margins = list( nd_split_slices(normalised_input_shape, tilling_strategy, margins=margins)) tile_slices = list( nd_split_slices(normalised_input_shape, tilling_strategy)) # Number of tiles: number_of_tiles = len(tile_slices) lprint(f"Number of tiles (slices): {number_of_tiles}") # We create slice list: slicezip = zip(tile_slices_margins, tile_slices) counter = 1 for slice_margin_tuple, slice_tuple in slicezip: with lsection( f"Current tile: {counter}/{number_of_tiles}, slice: {slice_tuple} " ): # We first extract the tile image: input_image_tile = normalised_input_image[ slice_margin_tuple].copy() # We do the actual translation: lprint(f"Translating...") translated_image_tile = self._translate( input_image_tile, image_slice=slice_margin_tuple, whole_image_shape=normalised_input_image.shape, ) # We compute the slice needed to cut out the margins: lprint(f"Removing margins...") remove_margin_slice_tuple = remove_margin_slice( normalised_input_shape, slice_margin_tuple, slice_tuple) # We allocate -just in time- the translated array if needed: # if the array is already provided, it must of course have the right dimensions... if normalised_translated_image is None: translated_image_shape = ( normalised_input_image.shape[:2] + spatiotemp_shape) normalised_translated_image = offcore_array( shape=translated_image_shape, dtype=translated_image_tile.dtype, max_memory_usage_ratio=self. max_memory_usage_ratio, ) # We plug in the batch without margins into the destination image: lprint( f"Inserting translated batch into result image...") normalised_translated_image[ slice_tuple] = translated_image_tile[ remove_margin_slice_tuple] counter += 1 # Let's crop the padding: normalised_translated_image = self._crop_norm_image( normalised_translated_image) # Then we denormalise: denormalised_translated_image = self.target_normaliser.denormalise( normalised_translated_image, # denormalise_values=denormalise_values, leave_as_float=leave_as_float, clip=clip, ) if translated_image is None: translated_image = denormalised_translated_image else: translated_image[...] = denormalised_translated_image return translated_image
def test_log(): # This is required for this test to pass! Log.override_test_exclusion = True lprint('Test') with lsection('a section'): lprint('a line') lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 3 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 5 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 7 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') lprint('test is finished...') assert Log.depth == 0
def _translate(self, input_image, image_slice=None, whole_image_shape=None): """Internal method that translates an input image on the basis of the trained model. :param input_image: input image :param batch_dims: batch dimensions :return: """ import numpy convolve_method = self._get_convolution_method( input_image, self.psf_kernel_numpy ) pad_method = self._get_pad_method(input_image) self.psf_kernel = self._convert_array_format_in( self.psf_kernel_numpy.astype(numpy.float32) ) self.psf_kernel_mirror = self._convert_array_format_in( self.psf_kernel[::-1, ::-1] ) input_image = input_image.astype(numpy.float32, copy=False) deconvolved_image = offcore_array( shape=input_image.shape, dtype=input_image.dtype ) lprint(f"Number of Lucy-Richardson iterations: {self.max_num_iterations}") for batch_index, batch_image in enumerate(input_image): for channel_index, channel_image in enumerate(batch_image): channel_image = channel_image.clip(0, math.inf) channel_image = self._convert_array_format_in(channel_image) candidate_deconvolved_image = numpy.full( channel_image.shape, float(numpy.mean(channel_image)) ) candidate_deconvolved_image = self._convert_array_format_in( candidate_deconvolved_image ) kernel_shape = self.psf_kernel.shape pad_width = tuple( (max(self.padding, (s - 1) // 2), max(self.padding, (s - 1) // 2)) for s in kernel_shape ) for i in range(self.max_num_iterations): if self.padding > 0: padded_candidate_deconvolved_image = pad_method( candidate_deconvolved_image, pad_width=pad_width, mode=self.padding_mode, ) else: padded_candidate_deconvolved_image = candidate_deconvolved_image convolved = convolve_method( padded_candidate_deconvolved_image, self.psf_kernel, mode='valid' if self.padding else 'same', ) convolved[convolved == 0] = 1 relative_blur = channel_image / convolved self._debug_allocation(f"after division") if self.padding: relative_blur = numpy.pad( relative_blur, pad_width=pad_width, mode=self.padding_mode ) multiplicative_correction = convolve_method( relative_blur, self.psf_kernel_mirror, mode='valid' if self.padding else 'same', ) self._debug_allocation(f"after second convolution") candidate_deconvolved_image *= multiplicative_correction if self.clip: candidate_deconvolved_image[candidate_deconvolved_image > 1] = 1 candidate_deconvolved_image[candidate_deconvolved_image < -1] = -1 candidate_deconvolved_image = self._convert_array_format_out( candidate_deconvolved_image ) deconvolved_image[ batch_index, channel_index ] = candidate_deconvolved_image return deconvolved_image
def _get_convolution_method(self, input_image, psf_kernel): if self.backend == 'scipy': lprint("Using scipy backend.") from scipy.signal import convolve return convolve elif self.backend == "scipy-cupy": try: lprint("Attempting to use scipy-cupy backend.") import scipy import cupy scipy.fft.set_backend(cupy.fft) self.backend = 'scipy' lprint("Succeeded to use scipy-cupy backend.") return self._get_convolution_method(input_image, psf_kernel) except Exception: track = traceback.format_exc() lprint(track) lprint("Failed to use scipy-cupy backend.") self.backend = 'cupy' return self._get_convolution_method(input_image, psf_kernel) elif self.backend == 'gputools': try: lprint("Attempting to use gputools backend.") # testing if gputools works: import gputools import numpy # try something simple and see if it crashes... data = numpy.ones((30, 40, 50)) h = numpy.ones((10, 11, 12)) out = gputools.convolve(data, h) # noqa: F841 def gputools_convolve(in1, in2, mode=None, method=None): return gputools.convolve(in1, in2) # gputools backend does not need extra padding: self.padding = False lprint("Succeeded to use cupy backend.") return gputools_convolve except Exception: track = traceback.format_exc() lprint(track) lprint("Failed to use gputools backend.") pass elif self.backend == 'cupy': try: lprint("Attempting to use cupy backend.") # try: # testing if gputools works: import cupyx.scipy.ndimage # try something simple and see if it crashes... import cupy data = cupy.ones((30, 40, 50)) h = cupy.ones((10, 11, 12)) cupyx.scipy.ndimage.convolve(data, h) # gputools backend does not need extra padding: self.padding = False def cupy_convolve(in1, in2, mode=None, method=None): return cupyx.scipy.ndimage.convolve(in1, in2, mode='reflect') lprint("Succeeded to use cupy backend.") if psf_kernel.size > 500: return self._cupy_convolve_fft else: return cupy_convolve except Exception: track = traceback.format_exc() lprint(track) lprint("Failed to use cupy backend, trying gputools") self.backend = 'gputools' return self._get_convolution_method(input_image, psf_kernel) lprint("Faling back to scipy backend.") # this is scipy's convolve: from scipy.signal import convolve return convolve