def test_log(): # This is required for this test to pass! Log.override_test_exclusion = True lprint('Test') with lsection('a section'): lprint('a line') lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 3 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 5 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 7 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') lprint('test is finished...') assert Log.depth == 0
def test_log(): # This is required for this test to pass! Log.override_test_exclusion = True lprint("Test") with lsection("a section"): lprint("a line") lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") assert Log.depth == 3 with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") assert Log.depth == 5 with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") assert Log.depth == 7 with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") lprint("test is finished...") assert Log.depth == 0
def _train_loop(self, data_loader, optimizer, loss_function): # Scheduler: scheduler = ReduceLROnPlateau( optimizer, 'min', factor=self.reduce_lr_factor, verbose=True, patience=self.reduce_lr_patience, ) best_val_loss_value = math.inf best_model_state_dict = None patience_counter = 0 with lsection(f"Training loop:"): lprint(f"Maximum number of epochs: {self.max_epochs}") lprint( f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}" ) for epoch in range(self.max_epochs): with lsection(f"Epoch {epoch}:"): if hasattr(self, 'masked_model'): self.masked_model.density = 0.005 * self.masking_density + 0.995 * self.masked_model.density lprint(f"masking density: {self.masked_model.density}") train_loss_value = 0 val_loss_value = 0 iteration = 0 for i, (input_images, target_images, val_mask_images) in enumerate(data_loader): lprint(f"index: {i}, shape:{input_images.shape}") input_images_gpu = input_images.to(self.device, non_blocking=True) target_images_gpu = target_images.to(self.device, non_blocking=True) validation_mask_images_gpu = val_mask_images.to( self.device, non_blocking=True) # Adding training noise to input: if self.training_noise > 0: with torch.no_grad(): alpha = self.training_noise / ( 1 + (10000 * epoch / self.max_epochs)) lprint(f"Training noise level: {alpha}") training_noise = alpha * torch.randn_like( input_images) input_images_gpu += training_noise.to( input_images_gpu.device) # Clear gradients w.r.t. parameters optimizer.zero_grad() # Forward pass: self.model.train() if self.masking: translated_images_gpu = self.masked_model( input_images_gpu) else: translated_images_gpu = self.model( input_images_gpu) # apply forward model: forward_model_images_gpu = self._forward_model( translated_images_gpu) # validation masking: u = forward_model_images_gpu * ( 1 - validation_mask_images_gpu) v = target_images_gpu * (1 - validation_mask_images_gpu) # with napari.gui_qt(): # viewer = napari.Viewer() # viewer.add_image(to_numpy(validation_mask_images_gpu), name='validation_mask_images_gpu') # viewer.add_image(to_numpy(forward_model_images_gpu), name='forward_model_images_gpu') # viewer.add_image(to_numpy(target_images_gpu), name='target_images_gpu') # translation loss (per voxel): if self.masking: mask = self.masked_model.get_mask() translation_loss = loss_function(u, v, mask) else: translation_loss = loss_function(u, v) # loss value (for all voxels): translation_loss_value = translation_loss.mean() # Additional losses: additional_loss_value = self._additional_losses( translated_images_gpu, forward_model_images_gpu) if additional_loss_value is not None: translation_loss_value += additional_loss_value # backpropagation: translation_loss_value.backward() # Updating parameters optimizer.step() # post optimisation -- if needed: self.model.post_optimisation() # update training loss_deconvolution for whole image: train_loss_value += translation_loss_value.item() iteration += 1 # Validation: with torch.no_grad(): # Forward pass: self.model.eval() if self.masking: translated_images_gpu = self.masked_model( input_images_gpu) else: translated_images_gpu = self.model( input_images_gpu) # apply forward model: forward_model_images_gpu = self._forward_model( translated_images_gpu) # validation masking: u = forward_model_images_gpu * validation_mask_images_gpu v = target_images_gpu * validation_mask_images_gpu # translation loss (per voxel): if self.masking: translation_loss = loss_function(u, v, None) else: translation_loss = loss_function(u, v) # loss values: translation_loss_value = ( translation_loss.mean().cpu().item()) # update validation loss_deconvolution for whole image: val_loss_value += translation_loss_value iteration += 1 train_loss_value /= iteration lprint(f"Training loss value: {train_loss_value}") val_loss_value /= iteration lprint(f"Validation loss value: {val_loss_value}") # Learning rate schedule: scheduler.step(val_loss_value) if val_loss_value < best_val_loss_value: lprint(f"## New best val loss!") if val_loss_value < best_val_loss_value - self.patience_epsilon: lprint(f"## Good enough to reset patience!") patience_counter = 0 # Update best val loss value: best_val_loss_value = val_loss_value # Save model: best_model_state_dict = OrderedDict({ k: v.to('cpu') for k, v in self.model.state_dict().items() }) else: if (epoch % max(1, self.reload_best_model_period) == 0 and best_model_state_dict): lprint(f"Reloading best models to date!") self.model.load_state_dict(best_model_state_dict) if patience_counter > self.patience: lprint(f"Early stopping!") break # No improvement: lprint( f"No improvement of validation losses, patience = {patience_counter}/{self.patience} " ) patience_counter += 1 lprint(f"## Best val loss: {best_val_loss_value}") if self._stop_training_flag: lprint(f"Training interupted!") break lprint(f"Reloading best models to date!") self.model.load_state_dict(best_model_state_dict)
def train( self, input_image, target_image=None, batch_dims=None, channel_dims=None, tile_size=None, train_valid_ratio=0.1, callback_period=3, jinv=None, ): """Train to translate a given input image to a given output image. This has a lot of the machinery for batching and more... """ if target_image is None: target_image = input_image with lsection( f"Learning to translate from image of dimensions {str(input_image.shape)} to {str(target_image.shape)} ." ): lprint("Running garbage collector...") gc.collect() # If we use the same image for input and output then we are in a self-supervised setting: self.self_supervised = input_image is target_image if self.self_supervised: lprint("Training is self-supervised.") else: lprint("Training is supervised.") if batch_dims is None: # set default batch_dim value: batch_dims = (False,) * len(input_image.shape) self.input_normaliser = self.normalizer_class( transform=self.normaliser_transform, clip=self.normaliser_clip ) self.target_normaliser = ( self.input_normaliser if self.self_supervised else self.normalizer_class( transform=self.normaliser_transform, clip=self.normaliser_clip ) ) # Calibrates normaliser(s): self.input_normaliser.calibrate(input_image) if not self.self_supervised: self.target_normaliser.calibrate(target_image) # Intensity values normalisation: normalised_input_image = self.input_normaliser.normalise( input_image, batch_dims=batch_dims, channel_dims=channel_dims ) normalised_target_image = ( normalised_input_image if self.self_supervised else self.target_normaliser.normalise( target_image, batch_dims=batch_dims, channel_dims=channel_dims ) ) # Let's pad the images to avoid border effects: # If we do it for translation we also have to do it for training because of # location-aware features such as large-scale features or spatial-features. normalised_input_image = self._pad_norm_image(normalised_input_image) normalised_target_image = self._pad_norm_image(normalised_target_image) self._train( normalised_input_image, normalised_target_image, tile_size=tile_size, train_valid_ratio=train_valid_ratio, callback_period=callback_period, jinv=jinv, )
def _get_tilling_strategy_and_margins( self, image, max_voxels_per_tile, min_margin, max_margin, suggested_tile_size=None, ): # We will store the batch strategy as a list of integers representing the number of chunks per dimension: with lsection(f"Determine tilling strategy:"): suggested_tile_size = ( math.inf if suggested_tile_size is None else suggested_tile_size ) # image shape: shape = image.shape num_spatio_temp_dim = num_spatiotemp_dim = len(shape) - 2 lprint(f"image shape = {shape}") lprint(f"max_voxels_per_tile = {max_voxels_per_tile}") # Estimated amount of memory needed for storing all features: ( estimated_memory_needed, total_memory_available, ) = self._estimate_memory_needed_and_available(image) lprint(f"Estimated amount of memory needed: {estimated_memory_needed}") # Available physical memory : total_memory_available *= self.max_memory_usage_ratio lprint( f"Available memory (we reserve 10% for 'comfort'): {total_memory_available}" ) # How much do we need to tile because of memory, if at all? split_factor_mem = estimated_memory_needed / total_memory_available lprint( f"How much do we need to tile because of memory? : {split_factor_mem} times." ) # how much do we have to tile because of the limit on the number of voxels per tile? split_factor_max_voxels = image.size / max_voxels_per_tile lprint( f"How much do we need to tile because of the limit on the number of voxels per tile? : {split_factor_max_voxels} times." ) # how much do we have to tile because of the suggested tile size? split_factor_suggested_tile_size = image.size / ( suggested_tile_size ** num_spatio_temp_dim ) lprint( f"How much do we need to tile because of the suggested tile size? : {split_factor_suggested_tile_size} times." ) # we keep the max: desired_split_factor = max( split_factor_mem, split_factor_max_voxels, split_factor_suggested_tile_size, ) # We cannot split less than 1 time: desired_split_factor = max(1, int(math.ceil(desired_split_factor))) lprint(f"Desired split factor: {desired_split_factor}") # Number of batches: num_batches = shape[0] # Does the number of batches split the data enough? if num_batches < desired_split_factor: # Not enough splitting happening along the batch dimension, we need to split further: # how much? rest_split_factor = desired_split_factor / num_batches lprint( f"Not enough splitting happening along the batch dimension, we need to split spatio-temp dims by: {rest_split_factor}" ) # let's split the dimensions in a way proportional to their lengths: split_per_dim = (rest_split_factor / numpy.prod(shape[2:])) ** ( 1 / num_spatio_temp_dim ) spatiotemp_tilling_strategy = tuple( max(1, int(math.ceil(split_per_dim * s))) for s in shape[2:] ) # correction_factor = numpy.prod(tuple(s for s in spatiotemp_tilling_strategy if s<1)) tilling_strategy = (num_batches, 1) + spatiotemp_tilling_strategy lprint(f"Preliminary tilling strategy is: {tilling_strategy}") # We correct for eventual oversplitting by favouring splitting over the front dimensions: current_splitting_factor = 1 corrected_tilling_strategy = [] split_factor_reached = False for i, s in enumerate(tilling_strategy): if split_factor_reached: corrected_tilling_strategy.append(1) else: corrected_tilling_strategy.append(s) current_splitting_factor *= s if current_splitting_factor >= desired_split_factor: split_factor_reached = True tilling_strategy = tuple(corrected_tilling_strategy) else: tilling_strategy = (desired_split_factor, 1) + tuple( 1 for s in shape[2:] ) lprint(f"Tilling strategy is: {tilling_strategy}") # Handles defaults: if max_margin is None: max_margin = math.inf if min_margin is None: min_margin = 0 # First we estimate the shape of a tile: estimated_tile_shape = tuple( int(round(s / ts)) for s, ts in zip(shape[2:], tilling_strategy[2:]) ) lprint(f"The estimated tile shape is: {estimated_tile_shape}") # Limit margins: # We automatically set the margin of the tile size: # the max-margin factor guarantees that tilling will incur no more than a given max tiling overhead: margin_factor = 0.5 * ( ((1 + self.max_tilling_overhead) ** (1 / num_spatiotemp_dim)) - 1 ) margins = tuple(int(s * margin_factor) for s in estimated_tile_shape) # Limit the margin to something reasonable (provided or automatically computed): margins = tuple(min(max_margin, m) for m in margins) margins = tuple(max(min_margin, m) for m in margins) # We add the batch and channel dimensions: margins = (0, 0) + margins # We only need margins if we split a dimension: margins = tuple( (0 if split == 1 else margin) for margin, split in zip(margins, tilling_strategy) ) return tilling_strategy, margins
def translate( self, input_image, translated_image=None, batch_dims=None, channel_dims=None, tile_size=None, denormalise_values=True, leave_as_float=False, clip=True, ): """ Translates an input image into an output image according to the learned function. :param input_image: :type input_image: :param clip: :type clip: :return: :rtype: """ with lsection( f"Predicting output image from input image of dimension {input_image.shape}" ): # set default batch_dim and channel_dim values: if batch_dims is None: batch_dims = (False,) * len(input_image.shape) if channel_dims is None: channel_dims = (False,) * len(input_image.shape) # Number of spatio-temporal dimensions: num_spatiotemp_dim = sum( 0 if b or c else 1 for b, c in zip(batch_dims, channel_dims) ) # First we normalise the input values: normalised_input_image = self.input_normaliser.normalise( input_image, batch_dims=batch_dims, channel_dims=channel_dims ) # When we trained supervised we need to update permutated image shape of target_normaliser # This way we can accommodate different sizes of batch dimensions than batch dimensions used for training if not self.self_supervised: ( _, _, self.target_normaliser.permutated_image_shape, ) = self.target_normaliser.shape_normalize( input_image, batch_dims=batch_dims, channel_dims=channel_dims ) # Let's pad the input array so we avoid annoying border-effects: normalised_input_image = self._pad_norm_image(normalised_input_image) # Spatio-temporal shape: spatiotemp_shape = normalised_input_image.shape[-num_spatiotemp_dim:] normalised_translated_image = None if tile_size == 0: # we _force_ no tilling, this is _not_ the default. # We translate: normalised_translated_image = self._translate( normalised_input_image, whole_image_shape=normalised_input_image.shape, ) else: # We do need to do tiled inference because of a lack of memory # or because a small batch size was requested: normalised_input_shape = normalised_input_image.shape # We get the tilling strategy: # tile_size, shape, min_margin, max_margin tilling_strategy, margins = self._get_tilling_strategy_and_margins( normalised_input_image, self.max_voxels_per_tile, self.tile_min_margin, self.tile_max_margin, suggested_tile_size=tile_size, ) lprint(f"Tilling strategy: {tilling_strategy}") lprint(f"Margins for tiles: {margins} .") # tile slice objects (with and without margins): tile_slices_margins = list( nd_split_slices( normalised_input_shape, tilling_strategy, margins=margins ) ) tile_slices = list( nd_split_slices(normalised_input_shape, tilling_strategy) ) # Number of tiles: number_of_tiles = len(tile_slices) lprint(f"Number of tiles (slices): {number_of_tiles}") # We create slice list: slicezip = zip(tile_slices_margins, tile_slices) counter = 1 for slice_margin_tuple, slice_tuple in slicezip: with lsection( f"Current tile: {counter}/{number_of_tiles}, slice: {slice_tuple} " ): # We first extract the tile image: input_image_tile = normalised_input_image[ slice_margin_tuple ].copy() # We do the actual translation: lprint(f"Translating...") translated_image_tile = self._translate( input_image_tile, image_slice=slice_margin_tuple, whole_image_shape=normalised_input_image.shape, ) # We compute the slice needed to cut out the margins: lprint(f"Removing margins...") remove_margin_slice_tuple = remove_margin_slice( normalised_input_shape, slice_margin_tuple, slice_tuple ) # We allocate -just in time- the translated array if needed: # if the array is already provided, it must of course have the right dimensions... if normalised_translated_image is None: translated_image_shape = ( normalised_input_image.shape[:2] + spatiotemp_shape ) normalised_translated_image = offcore_array( shape=translated_image_shape, dtype=translated_image_tile.dtype, max_memory_usage_ratio=self.max_memory_usage_ratio, ) # We plug in the batch without margins into the destination image: lprint(f"Inserting translated batch into result image...") normalised_translated_image[ slice_tuple ] = translated_image_tile[remove_margin_slice_tuple] counter += 1 # Let's crop the padding: normalised_translated_image = self._crop_norm_image( normalised_translated_image ) # Then we denormalise: denormalised_translated_image = self.target_normaliser.denormalise( normalised_translated_image, # denormalise_values=denormalise_values, leave_as_float=leave_as_float, clip=clip, ) if translated_image is None: translated_image = denormalised_translated_image else: translated_image[...] = denormalised_translated_image return translated_image
def offcore_array( shape: Union[Tuple[int, ...], Generator[int, None, None]], dtype: numpy.dtype, force_memmap: bool = False, zarr_allowed: bool = False, no_memmap_limit: bool = True, max_memory_usage_ratio: float = 0.9, ): """ Instanciates an array of given shape and dtype in 'off-core' fashion i.e. not in main memory. Right now it simply uses memory mapping on temp file that is deleted after the file is closed Parameters ---------- shape dtype force_memmap zarr_allowed no_memmap_limit max_memory_usage_ratio """ with lsection(f"Array of shape: {shape} and dtype: {dtype} requested"): size_in_bytes = numpy.prod(shape) * numpy.dtype(dtype).itemsize lprint(f'Array requested will be {(size_in_bytes / 1E6)} MB.') total_physical_memory_in_bytes = psutil.virtual_memory().total total_swap_memory_in_bytes = psutil.swap_memory().total total_mem_in_bytes = total_physical_memory_in_bytes + total_swap_memory_in_bytes lprint( f'There is {int(psutil.virtual_memory().total / 1E6)} MB of physical memory' ) lprint( f'There is {int(psutil.swap_memory().total / 1E6)} MB of swap memory' ) lprint(f'There is {int(total_mem_in_bytes / 1E6)} MB of total memory') is_enough_physical_memory = (size_in_bytes < max_memory_usage_ratio * total_physical_memory_in_bytes) is_enough_total_memory = (size_in_bytes < max_memory_usage_ratio * total_mem_in_bytes) if not force_memmap and is_enough_total_memory: lprint( f'There is enough physical+swap memory -- we do not need to use a mem mapped array or zarr-backed array.' ) array = numpy.zeros(shape, dtype=dtype) elif no_memmap_limit: lprint( f'There is not enough physical+swap memory -- we will use a mem mapped array.' ) temp_file = tempfile.NamedTemporaryFile( dir=OffCore.memmap_directory) lprint( f'The temporary memory mapped file is at: {temp_file.name} (but you might not be able to see it!)' ) array = numpy.memmap(temp_file, dtype=dtype, mode='w+', shape=shape) elif zarr_allowed: lprint( f'There is not enough physical+swap memory -- we will use a zarr-backed array.' ) import zarr array = zarr.create(shape=shape, dtype=dtype, store=zarr.TempStore("output.zarr")) # from numcodecs import Blosc # compressor = Blosc(cname = 'zstd', clevel = 3, shuffle = Blosc.BITSHUFFLE) # array = zarr.zeros((102_0, 200, 210), chunks = (100, 200, 210), compressor = compressor return array
def _train_loop(self, data_loader, optimizer): # Scheduler: if self.scheduler == "plateau": scheduler = ReduceLROnPlateau( optimizer, "min", factor=self.reduce_lr_factor, verbose=True, patience=self.reduce_lr_patience, ) elif self.scheduler == "cosine": scheduler = CosineAnnealingLR( optimizer, T_max=self.max_epochs, eta_min=1e-6 ) else: raise ValueError( f"Unknown scheduler: {self.scheduler}, supported: plateau, cosine" ) self.best_val_loss_value = math.inf self.best_model_state_dict = None self.patience_counter = 0 with lsection(f"Training loop:"): lprint(f"Maximum number of epochs: {self.max_epochs}") lprint( f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}" ) for epoch in range(self.max_epochs): with lsection(f"Epoch {epoch}:"): # One epoch of training train_loss_value, val_loss_value, loss_log_epoch = self._epoch( optimizer, data_loader, epoch ) # Learning rate schedule: if self.scheduler == "plateau": scheduler.step(val_loss_value) else: scheduler.step() # Logging: loss_log_epoch["masking_density"] = self.masked_model.density loss_log_epoch["lr"] = scheduler._last_lr[0] if not self.check: wandb.log(loss_log_epoch) # Monitoring and saving: if val_loss_value < self.best_val_loss_value: lprint(f"## New best val loss!") if ( val_loss_value < self.best_val_loss_value - self.patience_epsilon ): lprint(f"## Good enough to reset patience!") self.patience_counter = 0 # Update best val loss value: self.best_val_loss_value = val_loss_value # Save model: self.best_model_state_dict = OrderedDict( {k: v.to("cpu") for k, v in self.model.state_dict().items()} ) else: if ( epoch % max(1, self.reload_best_model_period) == 0 and self.best_model_state_dict ): lprint(f"Reloading best models to date!") self.model.load_state_dict(self.best_model_state_dict) if self.patience_counter > self.patience: lprint(f"Early stopping!") break # No improvement: lprint( f"No improvement of validation losses, patience = {self.patience_counter}/{self.patience} " ) self.patience_counter += 1 lprint(f"## Best val loss: {self.best_val_loss_value}") if self._stop_training_flag: lprint(f"Training interupted!") break lprint(f"Reloading best models to date!") self.model.load_state_dict(self.best_model_state_dict)