def visualise_weights(self): try: self.model.visualise_weights() except AttributeError: lprint( f"Method 'visualise_weights()' unavailable, cannot visualise weights. " )
def _train_loop(self, data_loader, optimizer, loss_function): try: self.model.kernel_continuity_regularisation = False except AttributeError: lprint("Cannot deactivate kernel continuity regularisation") super()._train_loop(data_loader, optimizer, loss_function)
def _debug_allocation(self, info): if self.__debug_allocation: if self.backend == "cupy": import cupy lprint( f"CUDA memory usage {info}: {cupy.get_default_memory_pool().used_bytes() / 1e6} MB" )
def __init__( self, max_epochs=2048, patience=None, patience_epsilon=0.0, learning_rate=0.01, batch_size=8, model_class=UNet, masking=True, masking_density=0.01, loss='l1', normaliser_type='percentile', balance_training_data=None, keep_ratio=1, max_voxels_for_training=4e6, monitor=None, use_cuda=True, device_index=0, ): """ Constructs an image translator using the pytorch deep learning library. :param normaliser_type: normaliser type :param balance_training_data: balance data ? (limits number training entries per target value histogram bin) :param monitor: monitor to track progress of training externally (used by UI) """ super().__init__(normaliser_type, monitor=monitor) use_cuda = use_cuda and (torch.cuda.device_count() > 0) self.device = torch.device( f"cuda:{device_index}" if use_cuda else "cpu") lprint(f"Using device: {self.device}") self.max_epochs = max_epochs self.patience = max_epochs if patience is None else patience self.patience_epsilon = patience_epsilon self.learning_rate = learning_rate self.batch_size = batch_size self.loss = loss self.max_voxels_for_training = max_voxels_for_training self.keep_ratio = keep_ratio self.balance_training_data = balance_training_data self.model_class = model_class self.l1_weight_regularisation = 1e-6 self.l2_weight_regularisation = 1e-6 self.training_noise = 0.1 self.reload_best_model_period = max_epochs # //2 self.reduce_lr_patience = patience // 2 self.reduce_lr_factor = 0.9 self.masking = masking self.masking_density = masking_density self.optimiser_class = ESAdam self.max_tile_size = 1024 # TODO: adjust based on available memory self._stop_training_flag = False
def _additional_losses(self, translated_image, forward_model_image): loss = 0 # Bounds loss: if self.bounds_loss and self.bounds_loss != 0: epsilon = 0 * 1e-8 bounds_loss = F.relu(-translated_image - epsilon) bounds_loss += F.relu(translated_image - 1 - epsilon) bounds_loss_value = bounds_loss.mean() lprint(f"bounds_loss_value = {bounds_loss_value}") loss += self.bounds_loss * bounds_loss_value**2 # Sharpen loss_deconvolution: if self.sharpening and self.sharpening != 0: image_for_loss = translated_image num_elements = image_for_loss[0, 0].nelement() sharpening_loss = -torch.norm( image_for_loss, dim=(2, 3), keepdim=True, p=2) / ( num_elements**2 ) # /torch.norm(image_for_loss, dim=(2, 3), keepdim=True, p=1) lprint(f"sharpening loss = {sharpening_loss}") loss += self.sharpening * sharpening_loss.mean() # Max entropy loss: if self.entropy and self.entropy != 0: entropy_value = entropy(translated_image) lprint(f"entropy_value = {entropy_value}") loss += -self.entropy * entropy_value return loss
def download_from_gdrive(id, name, dest_folder=datasets_folder, overwrite=False, unzip=False): try: os.makedirs(dest_folder) except Exception: pass url = f'https://drive.google.com/uc?id={id}' output_path = join(dest_folder, name) if overwrite or not exists(output_path): lprint(f"Downloading file {output_path} as it does not exist yet.") gdown.download(url, output_path, quiet=False) if unzip: lprint(f"Unzipping file {output_path}...") zip_ref = zipfile.ZipFile(output_path, 'r') zip_ref.extractall(dest_folder) zip_ref.close() # os.remove(output_path) return output_path else: lprint(f"Not downloading file {output_path} as it already exists.") return None
def __init__( self, psf_kernel: numpy.ndarray = None, broaden_psf: int = 1, sharpening: float = 0.0, bounds_loss: float = 0.1, entropy: float = 0.0, clip_before_psf: bool = True, fft_psf: Union[str, bool] = "auto", **kwargs, ): """ Constructs a CNN image translator using the pytorch deep learning library. :param normaliser_type: normaliser type :param balance_training_data: balance data ? (limits number training entries per target value histogram bin) :param monitor: monitor to track progress of training externally (used by UI) :param clip_before_psf: torch.clamp(x, 0, 1) before PSF convolution :param fft_psf: "auto" or True or False """ super().__init__(**kwargs) if self.standardize_image and clip_before_psf: lprint( "Clipping before PSF convolution is not supported when standardizing image" ) clip_before_psf = False self.provided_psf_kernel = psf_kernel self.broaden_psf = broaden_psf self.sharpening = sharpening self.bounds_loss = bounds_loss self.entropy = entropy self.clip_before_psf = clip_before_psf self.fft_psf = fft_psf
def __init__( self, kernel_psf: ndarray, in_channels: int = 1, pad_mode: str = "reflect", trainable: bool = False, fft: Union[str, bool] = "auto", auto_padding: bool = False, ): """ Parametrized trainable version of PSF :param kernel_psf: ndarray, PSF kernel (should be of required layer dimensionality) :param in_channels: number of input channels :param pad_mode: "reflect" for 2D, "replicate" for 3D :param trainable: if True, the kernel is trainable :param fft: ["auto", True, False] - if "auto" - use FFT if PSF has > 100 elements :param auto_padding: (bool) If True, automatically computes padding based on the signal size, kernel size and stride. """ super().__init__() self.kernel_size = kernel_psf.shape self.n_dim = len(kernel_psf.shape) self.in_channels = in_channels assert self.n_dim in (2, 3) if pad_mode is None: lprint("Padding mode is None, using default pad_mode='reflect'") pad_mode = "reflect" if self.n_dim == 3 and pad_mode == "reflect": # Not supported yet lprint( "Padding mode 'reflect' is not supported for 3D convolution, use 'replicate' instead" ) pad_mode = "replicate" self.pad_mode = pad_mode self.pad = [k // 2 for k in self.kernel_size] self.fft = fft if self.fft == "auto": # Use FFT Conv if kernel has > 100 elements self.fft = np.product(self.kernel_size) > 100 if isinstance(self.fft, str): raise ValueError(f"Invalid fft value {self.fft}") if not self.fft: auto_padding = False self.auto_padding = auto_padding lprint(f"Use FFT for PSF: {self.fft}, auto padding: {auto_padding}") self.psf = torch.from_numpy(kernel_psf.squeeze()[(None, ) * 2]).float() self.psf = nn.Parameter(self.psf, requires_grad=trainable)
def _translate(self, input_image, image_slice=None, whole_image_shape=None): """Internal method that translates an input image on the basis of the trained model. :param input_image: input image :param batch_dims: batch dimensions :return: """ import numpy convolve_method = self._get_convolution_method( input_image, self.psf_kernel_numpy ) pad_method = self._get_pad_method(input_image) self.psf_kernel = self._convert_array_format_in( self.psf_kernel_numpy.astype(numpy.float32) ) self.psf_kernel_mirror = self._convert_array_format_in( self.psf_kernel[::-1, ::-1] ) input_image = input_image.astype(numpy.float32, copy=False) deconvolved_image = offcore_array( shape=input_image.shape, dtype=input_image.dtype ) lprint(f"Number of Lucy-Richardson iterations: {self.max_num_iterations}") for batch_index, batch_image in enumerate(input_image): for channel_index, channel_image in enumerate(batch_image): channel_image = channel_image.clip(0, math.inf) channel_image = self._convert_array_format_in(channel_image) candidate_deconvolved_image = numpy.full( channel_image.shape, float(numpy.mean(channel_image)) ) candidate_deconvolved_image = self._convert_array_format_in( candidate_deconvolved_image ) kernel_shape = self.psf_kernel.shape pad_width = tuple( (max(self.padding, (s - 1) // 2), max(self.padding, (s - 1) // 2)) for s in kernel_shape ) for i in range(self.max_num_iterations): if self.padding > 0: padded_candidate_deconvolved_image = pad_method( candidate_deconvolved_image, pad_width=pad_width, mode=self.padding_mode, ) else: padded_candidate_deconvolved_image = candidate_deconvolved_image convolved = convolve_method( padded_candidate_deconvolved_image, self.psf_kernel, mode="valid" if self.padding else "same", ) convolved[convolved == 0] = 1 relative_blur = channel_image / convolved self._debug_allocation(f"after division") if self.padding: relative_blur = numpy.pad( relative_blur, pad_width=pad_width, mode=self.padding_mode ) multiplicative_correction = convolve_method( relative_blur, self.psf_kernel_mirror, mode="valid" if self.padding else "same", ) self._debug_allocation(f"after second convolution") candidate_deconvolved_image *= multiplicative_correction if self.clip: candidate_deconvolved_image[candidate_deconvolved_image > 1] = 1 candidate_deconvolved_image[candidate_deconvolved_image < -1] = -1 candidate_deconvolved_image = self._convert_array_format_out( candidate_deconvolved_image ) deconvolved_image[ batch_index, channel_index ] = candidate_deconvolved_image return deconvolved_image
def _train_step( self, input_images: T, target_images: T, valid_mask_images: T, epoch: int = 0, ) -> Tuple[T, Dict[str, float]]: self.model.train() loss_log = {} # Adding training noise to input: if self.training_noise > 0: with torch.no_grad(): alpha = self.training_noise / (1 + (10000 * epoch / self.max_epochs)) lprint(f"Training noise level: {alpha}") loss_log["training_noise"] = alpha training_noise = alpha * torch.randn_like( input_images, device=input_images.device ) input_images += training_noise # Forward pass: if self.masking: translated_images = self.masked_model(input_images) # pass with masking else: translated_images = self.model(input_images) if self.two_pass: # pass without masking translated_images_full = self.model(input_images) forward_model_images_full = self._forward_model(translated_images_full) # no masking for reconstruction reconstruction_loss = self.loss_function( forward_model_images_full, target_images, None ).mean() loss_log["reconstruction_loss"] = reconstruction_loss.item() if self.inv_mse_before_forward_model: u = translated_images_full * (1 - valid_mask_images) v = translated_images * (1 - valid_mask_images) else: forward_model_images = self._forward_model(translated_images) u = forward_model_images_full * (1 - valid_mask_images) v = forward_model_images * (1 - valid_mask_images) mask = self.masked_model.get_mask() invariance_loss = self.loss_function(u, v, mask).mean() loss_log["invariance_loss"] = invariance_loss.item() translation_loss_value = ( reconstruction_loss + self.inv_mse_lambda * torch.sqrt(invariance_loss) ) else: # validation masking: forward_model_images = self._forward_model(translated_images) u = forward_model_images * (1 - valid_mask_images) v = target_images * (1 - valid_mask_images) # translation loss (per voxel): if self.masking: mask = self.masked_model.get_mask() translation_loss = self.loss_function(u, v, mask) else: translation_loss = self.loss_function(u, v) # translation loss all voxels translation_loss_value = translation_loss.mean() loss_log["translation_loss"] = translation_loss_value.item() # Additional losses: (additional_loss_value, additional_loss_log,) = self._additional_losses( translated_images, forward_model_images_full if self.two_pass else forward_model_images, ) if additional_loss_value is not None: translation_loss_value += additional_loss_value loss_log.update(additional_loss_log) return translation_loss_value, loss_log
def _get_convolution_method(self, input_image, psf_kernel): if self.backend == "scipy": lprint("Using scipy backend.") from scipy.signal import convolve return convolve elif self.backend == "scipy-cupy": try: lprint("Attempting to use scipy-cupy backend.") import cupy import scipy scipy.fft.set_backend(cupy.fft) self.backend = "scipy" lprint("Succeeded to use scipy-cupy backend.") return self._get_convolution_method(input_image, psf_kernel) except Exception: track = traceback.format_exc() lprint(track) lprint("Failed to use scipy-cupy backend.") self.backend = "cupy" return self._get_convolution_method(input_image, psf_kernel) elif self.backend == "gputools": try: lprint("Attempting to use gputools backend.") # testing if gputools works: import gputools import numpy # try something simple and see if it crashes... data = numpy.ones((30, 40, 50)) h = numpy.ones((10, 11, 12)) out = gputools.convolve(data, h) # noqa: F841 def gputools_convolve(in1, in2, mode=None, method=None): return gputools.convolve(in1, in2) # gputools backend does not need extra padding: self.padding = False lprint("Succeeded to use cupy backend.") return gputools_convolve except Exception: track = traceback.format_exc() lprint(track) lprint("Failed to use gputools backend.") pass elif self.backend == "cupy": try: lprint("Attempting to use cupy backend.") # try: # testing if gputools works: # try something simple and see if it crashes... import cupy import cupyx.scipy.ndimage data = cupy.ones((30, 40, 50)) h = cupy.ones((10, 11, 12)) cupyx.scipy.ndimage.convolve(data, h) # gputools backend does not need extra padding: self.padding = False def cupy_convolve(in1, in2, mode=None, method=None): return cupyx.scipy.ndimage.convolve(in1, in2, mode="reflect") lprint("Succeeded to use cupy backend.") if psf_kernel.size > 500: return self._cupy_convolve_fft else: return cupy_convolve except Exception: track = traceback.format_exc() lprint(track) lprint("Failed to use cupy backend, trying gputools") self.backend = "gputools" return self._get_convolution_method(input_image, psf_kernel) lprint("Faling back to scipy backend.") # this is scipy's convolve: from scipy.signal import convolve return convolve
def _epoch( self, optimizer, data_loader, epoch: int = 0 ) -> Tuple[float, float, Dict[str, float]]: train_loss_value = 0 val_loss_value = 0 loss_log_epoch = {} if hasattr(self, "masked_model"): self.masked_model.density = ( 0.005 * self.masking_density + 0.995 * self.masked_model.density ) lprint(f"masking density: {self.masked_model.density}") for i, (input_images, target_images, val_mask_images) in enumerate(data_loader): # Clear gradients w.r.t. parameters optimizer.zero_grad() input_images_gpu = input_images.to(self.device, non_blocking=True) target_images_gpu = target_images.to(self.device, non_blocking=True) validation_mask_images_gpu = val_mask_images.to( self.device, non_blocking=True ) if self.standardize_image: # Standardize input and target images with the same statistics input_images_gpu, mean, std = standardize(input_images_gpu) target_images_gpu, _, _ = standardize(target_images_gpu, mean, std) # Training step with autocast(enabled=self.amp): translation_loss_value, loss_log = self._train_step( input_images_gpu, target_images_gpu, validation_mask_images_gpu, epoch, ) # Backpropagation translation_loss_value.backward() # Updating parameters optimizer.step() # post optimisation -- if needed: self.model.post_optimisation() # update training loss_deconvolution for whole image: train_loss_value += translation_loss_value.item() # Validation: with autocast(enabled=self.amp): translation_loss_value = self._valid_step( input_images_gpu, target_images_gpu, validation_mask_images_gpu, ) # update validation loss_deconvolution for whole image: loss_log["val_translation_loss"] = translation_loss_value val_loss_value += translation_loss_value if not loss_log_epoch: loss_log_epoch = deepcopy(loss_log) else: loss_log_epoch = {k: v + loss_log[k] for k, v in loss_log_epoch.items()} # Aggregate losses: iteration = len(data_loader) train_loss_value /= iteration lprint(f"Training loss value: {train_loss_value}") val_loss_value /= iteration lprint(f"Validation loss value: {val_loss_value}") loss_log_epoch = {k: v / iteration for k, v in loss_log_epoch.items()} return train_loss_value, val_loss_value, loss_log_epoch
def _train_loop(self, data_loader, optimizer): # Scheduler: if self.scheduler == "plateau": scheduler = ReduceLROnPlateau( optimizer, "min", factor=self.reduce_lr_factor, verbose=True, patience=self.reduce_lr_patience, ) elif self.scheduler == "cosine": scheduler = CosineAnnealingLR( optimizer, T_max=self.max_epochs, eta_min=1e-6 ) else: raise ValueError( f"Unknown scheduler: {self.scheduler}, supported: plateau, cosine" ) self.best_val_loss_value = math.inf self.best_model_state_dict = None self.patience_counter = 0 with lsection(f"Training loop:"): lprint(f"Maximum number of epochs: {self.max_epochs}") lprint( f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}" ) for epoch in range(self.max_epochs): with lsection(f"Epoch {epoch}:"): # One epoch of training train_loss_value, val_loss_value, loss_log_epoch = self._epoch( optimizer, data_loader, epoch ) # Learning rate schedule: if self.scheduler == "plateau": scheduler.step(val_loss_value) else: scheduler.step() # Logging: loss_log_epoch["masking_density"] = self.masked_model.density loss_log_epoch["lr"] = scheduler._last_lr[0] if not self.check: wandb.log(loss_log_epoch) # Monitoring and saving: if val_loss_value < self.best_val_loss_value: lprint(f"## New best val loss!") if ( val_loss_value < self.best_val_loss_value - self.patience_epsilon ): lprint(f"## Good enough to reset patience!") self.patience_counter = 0 # Update best val loss value: self.best_val_loss_value = val_loss_value # Save model: self.best_model_state_dict = OrderedDict( {k: v.to("cpu") for k, v in self.model.state_dict().items()} ) else: if ( epoch % max(1, self.reload_best_model_period) == 0 and self.best_model_state_dict ): lprint(f"Reloading best models to date!") self.model.load_state_dict(self.best_model_state_dict) if self.patience_counter > self.patience: lprint(f"Early stopping!") break # No improvement: lprint( f"No improvement of validation losses, patience = {self.patience_counter}/{self.patience} " ) self.patience_counter += 1 lprint(f"## Best val loss: {self.best_val_loss_value}") if self._stop_training_flag: lprint(f"Training interupted!") break lprint(f"Reloading best models to date!") self.model.load_state_dict(self.best_model_state_dict)
def _train( self, input_image, target_image, tile_size=None, train_valid_ratio=0.1, callback_period=3, j_inv=False, ): self._stop_training_flag = False if j_inv is not None and not j_inv: self.masking = False shape = input_image.shape num_input_channels = input_image.shape[1] num_output_channels = target_image.shape[1] num_spatiotemp_dim = input_image.ndim - 2 # tile size: if tile_size is None: # tile_size = min(self.max_tile_size, min(shape[2:])) tile_size = tuple(min(self.max_tile_size, s) for s in shape[2:]) lprint(f"Estimated max tile size {tile_size}") # Decide on how many voxels to be used for validation: num_val_voxels = int(train_valid_ratio * input_image.size) lprint( f"Number of voxels used for validation: {num_val_voxels} (train_valid_ratio={train_valid_ratio})" ) # Generate random coordinates for these voxels: val_voxels = tuple(numpy.random.randint(d, size=num_val_voxels) for d in shape) lprint(f"Validation voxel coordinates: {val_voxels}") # Training Tile size: lprint(f"Train Tile dimensions: {tile_size}") # Prepare Training Dataset: dataset = self._get_dataset( input_image, target_image, self.self_supervised, tile_size=tile_size, mode="grid", validation_voxels=val_voxels, batch_size=self.batch_size, ) lprint(f"Number tiles for training: {len(dataset)}") # Training Data Loader: # num_workers = max(3, os.cpu_count() // 2) num_workers = 0 # faster if data is already in memory... lprint(f"Number of workers for loading training/validation data: {num_workers}") data_loader = torch.utils.data.DataLoader( dataset, batch_size=1, # self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) # Model self.model = self.model_class( num_input_channels, num_output_channels, ndim=num_spatiotemp_dim ).to(self.device) number_of_parameters = sum( p.numel() for p in self.model.parameters() if p.requires_grad ) lprint( f"Number of trainable parameters in {self.model_class} model: {number_of_parameters}" ) if self.masking: self.masked_model = Masking(self.model, density=0.5).to(self.device) lprint(f"Optimiser class: {self.optimizer_class}") lprint(f"Learning rate : {self.learning_rate}") # Optimiser: if isinstance(self.optimizer_class, ESAdam): optimizer = partial( self.optimizer_class, start_noise_level=self.training_noise ) else: optimizer = self.optimizer_class optimizer = optimizer( chain(self.model.parameters()), lr=self.learning_rate, weight_decay=self.l2_weight_regularisation, ) lprint(f"Optimiser: {optimizer}") # Start training: try: self._train_loop(data_loader, optimizer) except KeyboardInterrupt: lprint("Training interrupted by user.") self._stop_training_flag = True
def __init__( self, max_epochs=2048, patience=None, patience_epsilon=0.0, learning_rate=0.01, batch_size=8, model_class=UNet, masking=True, two_pass=False, # two-pass Noise2Same loss inv_mse_lambda: float = 2.0, inv_mse_before_forward_model=False, masking_density=0.01, training_noise=0.1, loss="l1", normaliser_type="percentile", balance_training_data=None, keep_ratio=1, max_voxels_for_training=4e6, monitor=None, use_cuda=True, device_index=0, max_tile_size: int = 1024, # TODO: adjust based on available memory check: bool = True, optimizer: str = "esadam", scheduler: str = "step", standardize_image: bool = False, amp: bool = False, ): """ Constructs an image translator using the pytorch deep learning library. :param normaliser_type: normaliser type :param balance_training_data: balance data ? (limits number training entries per target value histogram bin) :param monitor: monitor to track progress of training externally (used by UI) :param two_pass: bool, adopt Noise2Same two forward pass strategy (one masked, one unmasked) :param inv_mse_before_forward_model: bool, use invariance MSE before forward (PSF) model for Noise2Same :param check: bool, run smoke test :param optimizer: str, optimiser to use ["adam", "esadam"] :param standardize_image: bool, standardize input images to zero mean and unit variance """ super().__init__(normaliser_type, monitor=monitor) if two_pass and not masking: lprint("Force masking=True, it is needed in two-pass") masking = True use_cuda = use_cuda and (torch.cuda.device_count() > 0) self.device = torch.device(f"cuda:{device_index}" if use_cuda else "cpu") lprint(f"Using device: {self.device}") self.max_epochs = max_epochs self.patience = max_epochs if patience is None else patience self.patience_epsilon = patience_epsilon self.learning_rate = learning_rate self.batch_size = batch_size self.loss = loss self.max_voxels_for_training = max_voxels_for_training self.keep_ratio = keep_ratio self.balance_training_data = balance_training_data self.model_class = model_class self.l1_weight_regularisation = 1e-6 self.l2_weight_regularisation = 1e-6 self.training_noise = training_noise self.reload_best_model_period = max_epochs # //2 self.reduce_lr_patience = patience // 2 self.reduce_lr_factor = 0.9 self.masking = masking self.two_pass = two_pass self.inv_mse_before_forward_model = inv_mse_before_forward_model self.inv_mse_lambda = inv_mse_lambda self.masking_density = masking_density self.optimizer_class = ESAdam if optimizer == "esadam" else torch.optim.Adam self.scheduler = scheduler self.max_tile_size = max_tile_size self._stop_training_flag = False self.check = check self.standardize_image = standardize_image self.amp = amp # Denoise loss function: loss_function = nn.L1Loss() if self.loss.lower() == "l2": lprint(f"Training/Validation loss: L2") if self.masking: loss_function = ( lambda u, v, m: (u - v) ** 2 if m is None else ((u - v) * m) ** 2 ) else: loss_function = lambda u, v: (u - v) ** 2 elif self.loss.lower() == "l1": lprint(f"Training/Validation loss: L1") if self.masking: loss_function = ( lambda u, v, m: torch.abs(u - v) if m is None else torch.abs((u - v) * m) ) else: loss_function = lambda u, v: torch.abs(u - v) lprint(f"Training/Validation loss: L1") self.loss_function = loss_function # Monitor self.best_val_loss_value = math.inf self.best_model_state_dict = None self.patience_counter = 0
def offcore_array( shape: Union[Tuple[int, ...], Generator[int, None, None]], dtype: numpy.dtype, force_memmap: bool = False, zarr_allowed: bool = False, no_memmap_limit: bool = True, max_memory_usage_ratio: float = 0.9, ): """ Instanciates an array of given shape and dtype in 'off-core' fashion i.e. not in main memory. Right now it simply uses memory mapping on temp file that is deleted after the file is closed Parameters ---------- shape dtype force_memmap zarr_allowed no_memmap_limit max_memory_usage_ratio """ with lsection(f"Array of shape: {shape} and dtype: {dtype} requested"): size_in_bytes = numpy.prod(shape) * numpy.dtype(dtype).itemsize lprint(f'Array requested will be {(size_in_bytes / 1E6)} MB.') total_physical_memory_in_bytes = psutil.virtual_memory().total total_swap_memory_in_bytes = psutil.swap_memory().total total_mem_in_bytes = total_physical_memory_in_bytes + total_swap_memory_in_bytes lprint( f'There is {int(psutil.virtual_memory().total / 1E6)} MB of physical memory' ) lprint( f'There is {int(psutil.swap_memory().total / 1E6)} MB of swap memory' ) lprint(f'There is {int(total_mem_in_bytes / 1E6)} MB of total memory') is_enough_physical_memory = (size_in_bytes < max_memory_usage_ratio * total_physical_memory_in_bytes) is_enough_total_memory = (size_in_bytes < max_memory_usage_ratio * total_mem_in_bytes) if not force_memmap and is_enough_total_memory: lprint( f'There is enough physical+swap memory -- we do not need to use a mem mapped array or zarr-backed array.' ) array = numpy.zeros(shape, dtype=dtype) elif no_memmap_limit: lprint( f'There is not enough physical+swap memory -- we will use a mem mapped array.' ) temp_file = tempfile.NamedTemporaryFile( dir=OffCore.memmap_directory) lprint( f'The temporary memory mapped file is at: {temp_file.name} (but you might not be able to see it!)' ) array = numpy.memmap(temp_file, dtype=dtype, mode='w+', shape=shape) elif zarr_allowed: lprint( f'There is not enough physical+swap memory -- we will use a zarr-backed array.' ) import zarr array = zarr.create(shape=shape, dtype=dtype, store=zarr.TempStore("output.zarr")) # from numcodecs import Blosc # compressor = Blosc(cname = 'zstd', clevel = 3, shuffle = Blosc.BITSHUFFLE) # array = zarr.zeros((102_0, 200, 210), chunks = (100, 200, 210), compressor = compressor return array
def test_log(): # This is required for this test to pass! Log.override_test_exclusion = True lprint('Test') with lsection('a section'): lprint('a line') lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 3 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 5 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') assert Log.depth == 7 with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') with lsection('a subsection'): lprint('another line') lprint('we are done') lprint('test is finished...') assert Log.depth == 0
def _train_loop(self, data_loader, optimizer, loss_function): # Scheduler: scheduler = ReduceLROnPlateau( optimizer, 'min', factor=self.reduce_lr_factor, verbose=True, patience=self.reduce_lr_patience, ) best_val_loss_value = math.inf best_model_state_dict = None patience_counter = 0 with lsection(f"Training loop:"): lprint(f"Maximum number of epochs: {self.max_epochs}") lprint( f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}" ) for epoch in range(self.max_epochs): with lsection(f"Epoch {epoch}:"): if hasattr(self, 'masked_model'): self.masked_model.density = 0.005 * self.masking_density + 0.995 * self.masked_model.density lprint(f"masking density: {self.masked_model.density}") train_loss_value = 0 val_loss_value = 0 iteration = 0 for i, (input_images, target_images, val_mask_images) in enumerate(data_loader): lprint(f"index: {i}, shape:{input_images.shape}") input_images_gpu = input_images.to(self.device, non_blocking=True) target_images_gpu = target_images.to(self.device, non_blocking=True) validation_mask_images_gpu = val_mask_images.to( self.device, non_blocking=True) # Adding training noise to input: if self.training_noise > 0: with torch.no_grad(): alpha = self.training_noise / ( 1 + (10000 * epoch / self.max_epochs)) lprint(f"Training noise level: {alpha}") training_noise = alpha * torch.randn_like( input_images) input_images_gpu += training_noise.to( input_images_gpu.device) # Clear gradients w.r.t. parameters optimizer.zero_grad() # Forward pass: self.model.train() if self.masking: translated_images_gpu = self.masked_model( input_images_gpu) else: translated_images_gpu = self.model( input_images_gpu) # apply forward model: forward_model_images_gpu = self._forward_model( translated_images_gpu) # validation masking: u = forward_model_images_gpu * ( 1 - validation_mask_images_gpu) v = target_images_gpu * (1 - validation_mask_images_gpu) # with napari.gui_qt(): # viewer = napari.Viewer() # viewer.add_image(to_numpy(validation_mask_images_gpu), name='validation_mask_images_gpu') # viewer.add_image(to_numpy(forward_model_images_gpu), name='forward_model_images_gpu') # viewer.add_image(to_numpy(target_images_gpu), name='target_images_gpu') # translation loss (per voxel): if self.masking: mask = self.masked_model.get_mask() translation_loss = loss_function(u, v, mask) else: translation_loss = loss_function(u, v) # loss value (for all voxels): translation_loss_value = translation_loss.mean() # Additional losses: additional_loss_value = self._additional_losses( translated_images_gpu, forward_model_images_gpu) if additional_loss_value is not None: translation_loss_value += additional_loss_value # backpropagation: translation_loss_value.backward() # Updating parameters optimizer.step() # post optimisation -- if needed: self.model.post_optimisation() # update training loss_deconvolution for whole image: train_loss_value += translation_loss_value.item() iteration += 1 # Validation: with torch.no_grad(): # Forward pass: self.model.eval() if self.masking: translated_images_gpu = self.masked_model( input_images_gpu) else: translated_images_gpu = self.model( input_images_gpu) # apply forward model: forward_model_images_gpu = self._forward_model( translated_images_gpu) # validation masking: u = forward_model_images_gpu * validation_mask_images_gpu v = target_images_gpu * validation_mask_images_gpu # translation loss (per voxel): if self.masking: translation_loss = loss_function(u, v, None) else: translation_loss = loss_function(u, v) # loss values: translation_loss_value = ( translation_loss.mean().cpu().item()) # update validation loss_deconvolution for whole image: val_loss_value += translation_loss_value iteration += 1 train_loss_value /= iteration lprint(f"Training loss value: {train_loss_value}") val_loss_value /= iteration lprint(f"Validation loss value: {val_loss_value}") # Learning rate schedule: scheduler.step(val_loss_value) if val_loss_value < best_val_loss_value: lprint(f"## New best val loss!") if val_loss_value < best_val_loss_value - self.patience_epsilon: lprint(f"## Good enough to reset patience!") patience_counter = 0 # Update best val loss value: best_val_loss_value = val_loss_value # Save model: best_model_state_dict = OrderedDict({ k: v.to('cpu') for k, v in self.model.state_dict().items() }) else: if (epoch % max(1, self.reload_best_model_period) == 0 and best_model_state_dict): lprint(f"Reloading best models to date!") self.model.load_state_dict(best_model_state_dict) if patience_counter > self.patience: lprint(f"Early stopping!") break # No improvement: lprint( f"No improvement of validation losses, patience = {patience_counter}/{self.patience} " ) patience_counter += 1 lprint(f"## Best val loss: {best_val_loss_value}") if self._stop_training_flag: lprint(f"Training interupted!") break lprint(f"Reloading best models to date!") self.model.load_state_dict(best_model_state_dict)
def translate( self, input_image, translated_image=None, batch_dims=None, channel_dims=None, tile_size=None, denormalise_values=True, leave_as_float=False, clip=True, ): """ Translates an input image into an output image according to the learned function. :param input_image: :type input_image: :param clip: :type clip: :return: :rtype: """ with lsection( f"Predicting output image from input image of dimension {input_image.shape}" ): # set default batch_dim and channel_dim values: if batch_dims is None: batch_dims = (False,) * len(input_image.shape) if channel_dims is None: channel_dims = (False,) * len(input_image.shape) # Number of spatio-temporal dimensions: num_spatiotemp_dim = sum( 0 if b or c else 1 for b, c in zip(batch_dims, channel_dims) ) # First we normalise the input values: normalised_input_image = self.input_normaliser.normalise( input_image, batch_dims=batch_dims, channel_dims=channel_dims ) # When we trained supervised we need to update permutated image shape of target_normaliser # This way we can accommodate different sizes of batch dimensions than batch dimensions used for training if not self.self_supervised: ( _, _, self.target_normaliser.permutated_image_shape, ) = self.target_normaliser.shape_normalize( input_image, batch_dims=batch_dims, channel_dims=channel_dims ) # Let's pad the input array so we avoid annoying border-effects: normalised_input_image = self._pad_norm_image(normalised_input_image) # Spatio-temporal shape: spatiotemp_shape = normalised_input_image.shape[-num_spatiotemp_dim:] normalised_translated_image = None if tile_size == 0: # we _force_ no tilling, this is _not_ the default. # We translate: normalised_translated_image = self._translate( normalised_input_image, whole_image_shape=normalised_input_image.shape, ) else: # We do need to do tiled inference because of a lack of memory # or because a small batch size was requested: normalised_input_shape = normalised_input_image.shape # We get the tilling strategy: # tile_size, shape, min_margin, max_margin tilling_strategy, margins = self._get_tilling_strategy_and_margins( normalised_input_image, self.max_voxels_per_tile, self.tile_min_margin, self.tile_max_margin, suggested_tile_size=tile_size, ) lprint(f"Tilling strategy: {tilling_strategy}") lprint(f"Margins for tiles: {margins} .") # tile slice objects (with and without margins): tile_slices_margins = list( nd_split_slices( normalised_input_shape, tilling_strategy, margins=margins ) ) tile_slices = list( nd_split_slices(normalised_input_shape, tilling_strategy) ) # Number of tiles: number_of_tiles = len(tile_slices) lprint(f"Number of tiles (slices): {number_of_tiles}") # We create slice list: slicezip = zip(tile_slices_margins, tile_slices) counter = 1 for slice_margin_tuple, slice_tuple in slicezip: with lsection( f"Current tile: {counter}/{number_of_tiles}, slice: {slice_tuple} " ): # We first extract the tile image: input_image_tile = normalised_input_image[ slice_margin_tuple ].copy() # We do the actual translation: lprint(f"Translating...") translated_image_tile = self._translate( input_image_tile, image_slice=slice_margin_tuple, whole_image_shape=normalised_input_image.shape, ) # We compute the slice needed to cut out the margins: lprint(f"Removing margins...") remove_margin_slice_tuple = remove_margin_slice( normalised_input_shape, slice_margin_tuple, slice_tuple ) # We allocate -just in time- the translated array if needed: # if the array is already provided, it must of course have the right dimensions... if normalised_translated_image is None: translated_image_shape = ( normalised_input_image.shape[:2] + spatiotemp_shape ) normalised_translated_image = offcore_array( shape=translated_image_shape, dtype=translated_image_tile.dtype, max_memory_usage_ratio=self.max_memory_usage_ratio, ) # We plug in the batch without margins into the destination image: lprint(f"Inserting translated batch into result image...") normalised_translated_image[ slice_tuple ] = translated_image_tile[remove_margin_slice_tuple] counter += 1 # Let's crop the padding: normalised_translated_image = self._crop_norm_image( normalised_translated_image ) # Then we denormalise: denormalised_translated_image = self.target_normaliser.denormalise( normalised_translated_image, # denormalise_values=denormalise_values, leave_as_float=leave_as_float, clip=clip, ) if translated_image is None: translated_image = denormalised_translated_image else: translated_image[...] = denormalised_translated_image return translated_image
def _train( self, input_image, target_image, train_valid_ratio=0.1, callback_period=3, jinv=False, ): self._stop_training_flag = False if jinv is not None and not jinv: self.masking = False shape = input_image.shape num_batches = shape[0] num_input_channels = input_image.shape[1] num_output_channels = target_image.shape[1] num_spatiotemp_dim = input_image.ndim - 2 # tile size: tile_size = min(self.max_tile_size, min(shape[2:])) # Decide on how many voxels to be used for validation: num_val_voxels = int(train_valid_ratio * input_image.size) lprint( f"Number of voxels used for validation: {num_val_voxels} (train_valid_ratio={train_valid_ratio})" ) # Generate random coordinates for these voxels: val_voxels = tuple( numpy.random.randint(d, size=num_val_voxels) for d in shape) lprint(f"Validation voxel coordinates: {val_voxels}") # Training Tile size: lprint(f"Train Tile dimensions: {tile_size}") # Prepare Training Dataset: dataset = self._get_dataset(input_image, target_image, self.self_supervised, tilesize=tile_size, mode='grid', validation_voxels=val_voxels, batch_size=self.batch_size) lprint(f"Number tiles for training: {len(dataset)}") # Training Data Loader: # num_workers = max(3, os.cpu_count() // 2) num_workers = 0 # faster if data is already in memory... lprint( f"Number of workers for loading training/validation data: {num_workers}" ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=1, # self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, ) # Model self.model = self.model_class(num_input_channels, num_output_channels, ndim=num_spatiotemp_dim).to(self.device) number_of_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) lprint( f"Number of trainable parameters in {self.model_class} model: {number_of_parameters}" ) if self.masking: self.masked_model = Masking(self.model, density=0.5).to(self.device) lprint(f"Optimiser class: {self.optimiser_class}") lprint(f"Learning rate : {self.learning_rate}") # Optimiser: optimizer = self.optimiser_class( chain(self.model.parameters()), lr=self.learning_rate, start_noise_level=self.training_noise, weight_decay=self.l2_weight_regularisation, ) lprint(f"Optimiser: {optimizer}") # Denoise loss functon: loss_function = nn.L1Loss() if self.loss.lower() == 'l2': lprint(f"Training/Validation loss: L2") if self.masking: loss_function = (lambda u, v, m: (u - v)**2 if m is None else ((u - v) * m)**2) else: loss_function = lambda u, v: (u - v)**2 elif self.loss.lower() == 'l1': if self.masking: loss_function = (lambda u, v, m: torch.abs(u - v) if m is None else torch.abs((u - v) * m)) else: loss_function = lambda u, v: torch.abs(u - v) lprint(f"Training/Validation loss: L1") # Start training: self._train_loop(data_loader, optimizer, loss_function)
def _get_tilling_strategy_and_margins( self, image, max_voxels_per_tile, min_margin, max_margin, suggested_tile_size=None, ): # We will store the batch strategy as a list of integers representing the number of chunks per dimension: with lsection(f"Determine tilling strategy:"): suggested_tile_size = ( math.inf if suggested_tile_size is None else suggested_tile_size ) # image shape: shape = image.shape num_spatio_temp_dim = num_spatiotemp_dim = len(shape) - 2 lprint(f"image shape = {shape}") lprint(f"max_voxels_per_tile = {max_voxels_per_tile}") # Estimated amount of memory needed for storing all features: ( estimated_memory_needed, total_memory_available, ) = self._estimate_memory_needed_and_available(image) lprint(f"Estimated amount of memory needed: {estimated_memory_needed}") # Available physical memory : total_memory_available *= self.max_memory_usage_ratio lprint( f"Available memory (we reserve 10% for 'comfort'): {total_memory_available}" ) # How much do we need to tile because of memory, if at all? split_factor_mem = estimated_memory_needed / total_memory_available lprint( f"How much do we need to tile because of memory? : {split_factor_mem} times." ) # how much do we have to tile because of the limit on the number of voxels per tile? split_factor_max_voxels = image.size / max_voxels_per_tile lprint( f"How much do we need to tile because of the limit on the number of voxels per tile? : {split_factor_max_voxels} times." ) # how much do we have to tile because of the suggested tile size? split_factor_suggested_tile_size = image.size / ( suggested_tile_size ** num_spatio_temp_dim ) lprint( f"How much do we need to tile because of the suggested tile size? : {split_factor_suggested_tile_size} times." ) # we keep the max: desired_split_factor = max( split_factor_mem, split_factor_max_voxels, split_factor_suggested_tile_size, ) # We cannot split less than 1 time: desired_split_factor = max(1, int(math.ceil(desired_split_factor))) lprint(f"Desired split factor: {desired_split_factor}") # Number of batches: num_batches = shape[0] # Does the number of batches split the data enough? if num_batches < desired_split_factor: # Not enough splitting happening along the batch dimension, we need to split further: # how much? rest_split_factor = desired_split_factor / num_batches lprint( f"Not enough splitting happening along the batch dimension, we need to split spatio-temp dims by: {rest_split_factor}" ) # let's split the dimensions in a way proportional to their lengths: split_per_dim = (rest_split_factor / numpy.prod(shape[2:])) ** ( 1 / num_spatio_temp_dim ) spatiotemp_tilling_strategy = tuple( max(1, int(math.ceil(split_per_dim * s))) for s in shape[2:] ) # correction_factor = numpy.prod(tuple(s for s in spatiotemp_tilling_strategy if s<1)) tilling_strategy = (num_batches, 1) + spatiotemp_tilling_strategy lprint(f"Preliminary tilling strategy is: {tilling_strategy}") # We correct for eventual oversplitting by favouring splitting over the front dimensions: current_splitting_factor = 1 corrected_tilling_strategy = [] split_factor_reached = False for i, s in enumerate(tilling_strategy): if split_factor_reached: corrected_tilling_strategy.append(1) else: corrected_tilling_strategy.append(s) current_splitting_factor *= s if current_splitting_factor >= desired_split_factor: split_factor_reached = True tilling_strategy = tuple(corrected_tilling_strategy) else: tilling_strategy = (desired_split_factor, 1) + tuple( 1 for s in shape[2:] ) lprint(f"Tilling strategy is: {tilling_strategy}") # Handles defaults: if max_margin is None: max_margin = math.inf if min_margin is None: min_margin = 0 # First we estimate the shape of a tile: estimated_tile_shape = tuple( int(round(s / ts)) for s, ts in zip(shape[2:], tilling_strategy[2:]) ) lprint(f"The estimated tile shape is: {estimated_tile_shape}") # Limit margins: # We automatically set the margin of the tile size: # the max-margin factor guarantees that tilling will incur no more than a given max tiling overhead: margin_factor = 0.5 * ( ((1 + self.max_tilling_overhead) ** (1 / num_spatiotemp_dim)) - 1 ) margins = tuple(int(s * margin_factor) for s in estimated_tile_shape) # Limit the margin to something reasonable (provided or automatically computed): margins = tuple(min(max_margin, m) for m in margins) margins = tuple(max(min_margin, m) for m in margins) # We add the batch and channel dimensions: margins = (0, 0) + margins # We only need margins if we split a dimension: margins = tuple( (0 if split == 1 else margin) for margin, split in zip(margins, tilling_strategy) ) return tilling_strategy, margins
def train( self, input_image, target_image=None, batch_dims=None, channel_dims=None, tile_size=None, train_valid_ratio=0.1, callback_period=3, jinv=None, ): """Train to translate a given input image to a given output image. This has a lot of the machinery for batching and more... """ if target_image is None: target_image = input_image with lsection( f"Learning to translate from image of dimensions {str(input_image.shape)} to {str(target_image.shape)} ." ): lprint("Running garbage collector...") gc.collect() # If we use the same image for input and output then we are in a self-supervised setting: self.self_supervised = input_image is target_image if self.self_supervised: lprint("Training is self-supervised.") else: lprint("Training is supervised.") if batch_dims is None: # set default batch_dim value: batch_dims = (False,) * len(input_image.shape) self.input_normaliser = self.normalizer_class( transform=self.normaliser_transform, clip=self.normaliser_clip ) self.target_normaliser = ( self.input_normaliser if self.self_supervised else self.normalizer_class( transform=self.normaliser_transform, clip=self.normaliser_clip ) ) # Calibrates normaliser(s): self.input_normaliser.calibrate(input_image) if not self.self_supervised: self.target_normaliser.calibrate(target_image) # Intensity values normalisation: normalised_input_image = self.input_normaliser.normalise( input_image, batch_dims=batch_dims, channel_dims=channel_dims ) normalised_target_image = ( normalised_input_image if self.self_supervised else self.target_normaliser.normalise( target_image, batch_dims=batch_dims, channel_dims=channel_dims ) ) # Let's pad the images to avoid border effects: # If we do it for translation we also have to do it for training because of # location-aware features such as large-scale features or spatial-features. normalised_input_image = self._pad_norm_image(normalised_input_image) normalised_target_image = self._pad_norm_image(normalised_target_image) self._train( normalised_input_image, normalised_target_image, tile_size=tile_size, train_valid_ratio=train_valid_ratio, callback_period=callback_period, jinv=jinv, )
def test_log(): # This is required for this test to pass! Log.override_test_exclusion = True lprint("Test") with lsection("a section"): lprint("a line") lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") assert Log.depth == 3 with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") assert Log.depth == 5 with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") assert Log.depth == 7 with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") with lsection("a subsection"): lprint("another line") lprint("we are done") lprint("test is finished...") assert Log.depth == 0