Пример #1
0
 def visualise_weights(self):
     try:
         self.model.visualise_weights()
     except AttributeError:
         lprint(
             f"Method 'visualise_weights()' unavailable, cannot visualise weights. "
         )
Пример #2
0
    def _train_loop(self, data_loader, optimizer, loss_function):
        try:
            self.model.kernel_continuity_regularisation = False
        except AttributeError:
            lprint("Cannot deactivate kernel continuity regularisation")

        super()._train_loop(data_loader, optimizer, loss_function)
Пример #3
0
    def _debug_allocation(self, info):
        if self.__debug_allocation:
            if self.backend == "cupy":
                import cupy

                lprint(
                    f"CUDA memory usage {info}: {cupy.get_default_memory_pool().used_bytes() / 1e6} MB"
                )
Пример #4
0
    def __init__(
        self,
        max_epochs=2048,
        patience=None,
        patience_epsilon=0.0,
        learning_rate=0.01,
        batch_size=8,
        model_class=UNet,
        masking=True,
        masking_density=0.01,
        loss='l1',
        normaliser_type='percentile',
        balance_training_data=None,
        keep_ratio=1,
        max_voxels_for_training=4e6,
        monitor=None,
        use_cuda=True,
        device_index=0,
    ):
        """
        Constructs an image translator using the pytorch deep learning library.

        :param normaliser_type: normaliser type
        :param balance_training_data: balance data ? (limits number training entries per target value histogram bin)
        :param monitor: monitor to track progress of training externally (used by UI)
        """
        super().__init__(normaliser_type, monitor=monitor)

        use_cuda = use_cuda and (torch.cuda.device_count() > 0)
        self.device = torch.device(
            f"cuda:{device_index}" if use_cuda else "cpu")
        lprint(f"Using device: {self.device}")

        self.max_epochs = max_epochs
        self.patience = max_epochs if patience is None else patience
        self.patience_epsilon = patience_epsilon
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.loss = loss
        self.max_voxels_for_training = max_voxels_for_training
        self.keep_ratio = keep_ratio
        self.balance_training_data = balance_training_data

        self.model_class = model_class

        self.l1_weight_regularisation = 1e-6
        self.l2_weight_regularisation = 1e-6
        self.training_noise = 0.1
        self.reload_best_model_period = max_epochs  # //2
        self.reduce_lr_patience = patience // 2
        self.reduce_lr_factor = 0.9
        self.masking = masking
        self.masking_density = masking_density
        self.optimiser_class = ESAdam
        self.max_tile_size = 1024  # TODO: adjust based on available memory

        self._stop_training_flag = False
Пример #5
0
    def _additional_losses(self, translated_image, forward_model_image):

        loss = 0

        # Bounds loss:
        if self.bounds_loss and self.bounds_loss != 0:
            epsilon = 0 * 1e-8
            bounds_loss = F.relu(-translated_image - epsilon)
            bounds_loss += F.relu(translated_image - 1 - epsilon)
            bounds_loss_value = bounds_loss.mean()
            lprint(f"bounds_loss_value = {bounds_loss_value}")
            loss += self.bounds_loss * bounds_loss_value**2

        # Sharpen loss_deconvolution:
        if self.sharpening and self.sharpening != 0:
            image_for_loss = translated_image
            num_elements = image_for_loss[0, 0].nelement()
            sharpening_loss = -torch.norm(
                image_for_loss, dim=(2, 3), keepdim=True, p=2) / (
                    num_elements**2
                )  # /torch.norm(image_for_loss, dim=(2, 3), keepdim=True, p=1)
            lprint(f"sharpening loss = {sharpening_loss}")
            loss += self.sharpening * sharpening_loss.mean()

        # Max entropy loss:
        if self.entropy and self.entropy != 0:
            entropy_value = entropy(translated_image)
            lprint(f"entropy_value = {entropy_value}")
            loss += -self.entropy * entropy_value

        return loss
Пример #6
0
def download_from_gdrive(id,
                         name,
                         dest_folder=datasets_folder,
                         overwrite=False,
                         unzip=False):
    try:
        os.makedirs(dest_folder)
    except Exception:
        pass

    url = f'https://drive.google.com/uc?id={id}'
    output_path = join(dest_folder, name)
    if overwrite or not exists(output_path):
        lprint(f"Downloading file {output_path} as it does not exist yet.")
        gdown.download(url, output_path, quiet=False)

        if unzip:
            lprint(f"Unzipping file {output_path}...")
            zip_ref = zipfile.ZipFile(output_path, 'r')
            zip_ref.extractall(dest_folder)
            zip_ref.close()
            # os.remove(output_path)

        return output_path
    else:
        lprint(f"Not downloading file {output_path} as it already exists.")
        return None
Пример #7
0
    def __init__(
        self,
        psf_kernel: numpy.ndarray = None,
        broaden_psf: int = 1,
        sharpening: float = 0.0,
        bounds_loss: float = 0.1,
        entropy: float = 0.0,
        clip_before_psf: bool = True,
        fft_psf: Union[str, bool] = "auto",
        **kwargs,
    ):
        """
        Constructs a CNN image translator using the pytorch deep learning library.

        :param normaliser_type: normaliser type
        :param balance_training_data: balance data ? (limits number training entries per target value histogram bin)
        :param monitor: monitor to track progress of training externally (used by UI)

        :param clip_before_psf: torch.clamp(x, 0, 1) before PSF convolution
        :param fft_psf: "auto" or True or False
        """
        super().__init__(**kwargs)

        if self.standardize_image and clip_before_psf:
            lprint(
                "Clipping before PSF convolution is not supported when standardizing image"
            )
            clip_before_psf = False

        self.provided_psf_kernel = psf_kernel
        self.broaden_psf = broaden_psf
        self.sharpening = sharpening
        self.bounds_loss = bounds_loss
        self.entropy = entropy
        self.clip_before_psf = clip_before_psf
        self.fft_psf = fft_psf
Пример #8
0
    def __init__(
        self,
        kernel_psf: ndarray,
        in_channels: int = 1,
        pad_mode: str = "reflect",
        trainable: bool = False,
        fft: Union[str, bool] = "auto",
        auto_padding: bool = False,
    ):
        """
        Parametrized trainable version of PSF
        :param kernel_psf: ndarray, PSF kernel (should be of required layer dimensionality)
        :param in_channels: number of input channels
        :param pad_mode: "reflect" for 2D, "replicate" for 3D
        :param trainable: if True, the kernel is trainable
        :param fft: ["auto", True, False] - if "auto" - use FFT if PSF has > 100 elements
        :param auto_padding: (bool) If True, automatically computes padding based on the
                             signal size, kernel size and stride.
        """
        super().__init__()
        self.kernel_size = kernel_psf.shape
        self.n_dim = len(kernel_psf.shape)
        self.in_channels = in_channels
        assert self.n_dim in (2, 3)

        if pad_mode is None:
            lprint("Padding mode is None, using default pad_mode='reflect'")
            pad_mode = "reflect"

        if self.n_dim == 3 and pad_mode == "reflect":
            # Not supported yet
            lprint(
                "Padding mode 'reflect' is not supported for 3D convolution, use 'replicate' instead"
            )
            pad_mode = "replicate"

        self.pad_mode = pad_mode
        self.pad = [k // 2 for k in self.kernel_size]

        self.fft = fft
        if self.fft == "auto":
            # Use FFT Conv if kernel has > 100 elements
            self.fft = np.product(self.kernel_size) > 100
        if isinstance(self.fft, str):
            raise ValueError(f"Invalid fft value {self.fft}")

        if not self.fft:
            auto_padding = False
        self.auto_padding = auto_padding
        lprint(f"Use FFT for PSF: {self.fft}, auto padding: {auto_padding}")

        self.psf = torch.from_numpy(kernel_psf.squeeze()[(None, ) * 2]).float()
        self.psf = nn.Parameter(self.psf, requires_grad=trainable)
Пример #9
0
    def _translate(self, input_image, image_slice=None, whole_image_shape=None):
        """Internal method that translates an input image on the basis of the trained model.

        :param input_image: input image
        :param batch_dims: batch dimensions
        :return:
        """
        import numpy

        convolve_method = self._get_convolution_method(
            input_image, self.psf_kernel_numpy
        )
        pad_method = self._get_pad_method(input_image)

        self.psf_kernel = self._convert_array_format_in(
            self.psf_kernel_numpy.astype(numpy.float32)
        )
        self.psf_kernel_mirror = self._convert_array_format_in(
            self.psf_kernel[::-1, ::-1]
        )

        input_image = input_image.astype(numpy.float32, copy=False)

        deconvolved_image = offcore_array(
            shape=input_image.shape, dtype=input_image.dtype
        )

        lprint(f"Number of Lucy-Richardson iterations: {self.max_num_iterations}")

        for batch_index, batch_image in enumerate(input_image):

            for channel_index, channel_image in enumerate(batch_image):

                channel_image = channel_image.clip(0, math.inf)
                channel_image = self._convert_array_format_in(channel_image)

                candidate_deconvolved_image = numpy.full(
                    channel_image.shape, float(numpy.mean(channel_image))
                )

                candidate_deconvolved_image = self._convert_array_format_in(
                    candidate_deconvolved_image
                )

                kernel_shape = self.psf_kernel.shape
                pad_width = tuple(
                    (max(self.padding, (s - 1) // 2), max(self.padding, (s - 1) // 2))
                    for s in kernel_shape
                )

                for i in range(self.max_num_iterations):

                    if self.padding > 0:
                        padded_candidate_deconvolved_image = pad_method(
                            candidate_deconvolved_image,
                            pad_width=pad_width,
                            mode=self.padding_mode,
                        )
                    else:
                        padded_candidate_deconvolved_image = candidate_deconvolved_image

                    convolved = convolve_method(
                        padded_candidate_deconvolved_image,
                        self.psf_kernel,
                        mode="valid" if self.padding else "same",
                    )

                    convolved[convolved == 0] = 1

                    relative_blur = channel_image / convolved

                    self._debug_allocation(f"after division")

                    if self.padding:
                        relative_blur = numpy.pad(
                            relative_blur, pad_width=pad_width, mode=self.padding_mode
                        )

                    multiplicative_correction = convolve_method(
                        relative_blur,
                        self.psf_kernel_mirror,
                        mode="valid" if self.padding else "same",
                    )

                    self._debug_allocation(f"after second convolution")

                    candidate_deconvolved_image *= multiplicative_correction

                if self.clip:
                    candidate_deconvolved_image[candidate_deconvolved_image > 1] = 1
                    candidate_deconvolved_image[candidate_deconvolved_image < -1] = -1

                candidate_deconvolved_image = self._convert_array_format_out(
                    candidate_deconvolved_image
                )

                deconvolved_image[
                    batch_index, channel_index
                ] = candidate_deconvolved_image

        return deconvolved_image
Пример #10
0
    def _train_step(
        self,
        input_images: T,
        target_images: T,
        valid_mask_images: T,
        epoch: int = 0,
    ) -> Tuple[T, Dict[str, float]]:
        self.model.train()
        loss_log = {}

        # Adding training noise to input:
        if self.training_noise > 0:
            with torch.no_grad():
                alpha = self.training_noise / (1 + (10000 * epoch / self.max_epochs))
                lprint(f"Training noise level: {alpha}")
                loss_log["training_noise"] = alpha
                training_noise = alpha * torch.randn_like(
                    input_images, device=input_images.device
                )
                input_images += training_noise

        # Forward pass:
        if self.masking:
            translated_images = self.masked_model(input_images)  # pass with masking
        else:
            translated_images = self.model(input_images)

        if self.two_pass:
            # pass without masking
            translated_images_full = self.model(input_images)
            forward_model_images_full = self._forward_model(translated_images_full)

            # no masking for reconstruction
            reconstruction_loss = self.loss_function(
                forward_model_images_full, target_images, None
            ).mean()
            loss_log["reconstruction_loss"] = reconstruction_loss.item()

            if self.inv_mse_before_forward_model:
                u = translated_images_full * (1 - valid_mask_images)
                v = translated_images * (1 - valid_mask_images)
            else:
                forward_model_images = self._forward_model(translated_images)
                u = forward_model_images_full * (1 - valid_mask_images)
                v = forward_model_images * (1 - valid_mask_images)

            mask = self.masked_model.get_mask()
            invariance_loss = self.loss_function(u, v, mask).mean()
            loss_log["invariance_loss"] = invariance_loss.item()

            translation_loss_value = (
                reconstruction_loss + self.inv_mse_lambda * torch.sqrt(invariance_loss)
            )
        else:
            # validation masking:
            forward_model_images = self._forward_model(translated_images)
            u = forward_model_images * (1 - valid_mask_images)
            v = target_images * (1 - valid_mask_images)

            # translation loss (per voxel):
            if self.masking:
                mask = self.masked_model.get_mask()
                translation_loss = self.loss_function(u, v, mask)
            else:
                translation_loss = self.loss_function(u, v)

            # translation loss all voxels
            translation_loss_value = translation_loss.mean()

        loss_log["translation_loss"] = translation_loss_value.item()

        # Additional losses:
        (additional_loss_value, additional_loss_log,) = self._additional_losses(
            translated_images,
            forward_model_images_full if self.two_pass else forward_model_images,
        )
        if additional_loss_value is not None:
            translation_loss_value += additional_loss_value
            loss_log.update(additional_loss_log)

        return translation_loss_value, loss_log
Пример #11
0
    def _get_convolution_method(self, input_image, psf_kernel):

        if self.backend == "scipy":
            lprint("Using scipy backend.")
            from scipy.signal import convolve

            return convolve

        elif self.backend == "scipy-cupy":
            try:
                lprint("Attempting to use scipy-cupy backend.")
                import cupy
                import scipy

                scipy.fft.set_backend(cupy.fft)
                self.backend = "scipy"
                lprint("Succeeded to use scipy-cupy backend.")
                return self._get_convolution_method(input_image, psf_kernel)
            except Exception:
                track = traceback.format_exc()
                lprint(track)
                lprint("Failed to use scipy-cupy backend.")
                self.backend = "cupy"
                return self._get_convolution_method(input_image, psf_kernel)

        elif self.backend == "gputools":
            try:
                lprint("Attempting to use gputools backend.")
                # testing if gputools works:
                import gputools
                import numpy

                # try something simple and see if it crashes...
                data = numpy.ones((30, 40, 50))
                h = numpy.ones((10, 11, 12))
                out = gputools.convolve(data, h)  # noqa: F841

                def gputools_convolve(in1, in2, mode=None, method=None):
                    return gputools.convolve(in1, in2)

                # gputools backend does not need extra padding:
                self.padding = False

                lprint("Succeeded to use cupy backend.")
                return gputools_convolve

            except Exception:
                track = traceback.format_exc()
                lprint(track)
                lprint("Failed to use gputools backend.")
                pass

        elif self.backend == "cupy":
            try:
                lprint("Attempting to use cupy backend.")
                # try:
                # testing if gputools works:
                # try something simple and see if it crashes...
                import cupy
                import cupyx.scipy.ndimage

                data = cupy.ones((30, 40, 50))
                h = cupy.ones((10, 11, 12))
                cupyx.scipy.ndimage.convolve(data, h)

                # gputools backend does not need extra padding:
                self.padding = False

                def cupy_convolve(in1, in2, mode=None, method=None):
                    return cupyx.scipy.ndimage.convolve(in1, in2, mode="reflect")

                lprint("Succeeded to use cupy backend.")
                if psf_kernel.size > 500:
                    return self._cupy_convolve_fft
                else:
                    return cupy_convolve

            except Exception:
                track = traceback.format_exc()
                lprint(track)
                lprint("Failed to use cupy backend, trying gputools")
                self.backend = "gputools"
                return self._get_convolution_method(input_image, psf_kernel)

        lprint("Faling back to scipy backend.")

        # this is scipy's convolve:
        from scipy.signal import convolve

        return convolve
Пример #12
0
    def _epoch(
        self, optimizer, data_loader, epoch: int = 0
    ) -> Tuple[float, float, Dict[str, float]]:
        train_loss_value = 0
        val_loss_value = 0
        loss_log_epoch = {}

        if hasattr(self, "masked_model"):
            self.masked_model.density = (
                0.005 * self.masking_density + 0.995 * self.masked_model.density
            )
            lprint(f"masking density: {self.masked_model.density}")

        for i, (input_images, target_images, val_mask_images) in enumerate(data_loader):
            # Clear gradients w.r.t. parameters
            optimizer.zero_grad()

            input_images_gpu = input_images.to(self.device, non_blocking=True)
            target_images_gpu = target_images.to(self.device, non_blocking=True)
            validation_mask_images_gpu = val_mask_images.to(
                self.device, non_blocking=True
            )

            if self.standardize_image:
                # Standardize input and target images with the same statistics
                input_images_gpu, mean, std = standardize(input_images_gpu)
                target_images_gpu, _, _ = standardize(target_images_gpu, mean, std)

            # Training step
            with autocast(enabled=self.amp):
                translation_loss_value, loss_log = self._train_step(
                    input_images_gpu,
                    target_images_gpu,
                    validation_mask_images_gpu,
                    epoch,
                )

            # Backpropagation
            translation_loss_value.backward()

            # Updating parameters
            optimizer.step()

            # post optimisation -- if needed:
            self.model.post_optimisation()

            # update training loss_deconvolution for whole image:
            train_loss_value += translation_loss_value.item()

            # Validation:
            with autocast(enabled=self.amp):
                translation_loss_value = self._valid_step(
                    input_images_gpu,
                    target_images_gpu,
                    validation_mask_images_gpu,
                )
            # update validation loss_deconvolution for whole image:
            loss_log["val_translation_loss"] = translation_loss_value
            val_loss_value += translation_loss_value

            if not loss_log_epoch:
                loss_log_epoch = deepcopy(loss_log)
            else:
                loss_log_epoch = {k: v + loss_log[k] for k, v in loss_log_epoch.items()}

        # Aggregate losses:
        iteration = len(data_loader)
        train_loss_value /= iteration
        lprint(f"Training loss value: {train_loss_value}")

        val_loss_value /= iteration
        lprint(f"Validation loss value: {val_loss_value}")

        loss_log_epoch = {k: v / iteration for k, v in loss_log_epoch.items()}

        return train_loss_value, val_loss_value, loss_log_epoch
Пример #13
0
    def _train_loop(self, data_loader, optimizer):

        # Scheduler:
        if self.scheduler == "plateau":
            scheduler = ReduceLROnPlateau(
                optimizer,
                "min",
                factor=self.reduce_lr_factor,
                verbose=True,
                patience=self.reduce_lr_patience,
            )
        elif self.scheduler == "cosine":
            scheduler = CosineAnnealingLR(
                optimizer, T_max=self.max_epochs, eta_min=1e-6
            )
        else:
            raise ValueError(
                f"Unknown scheduler: {self.scheduler}, supported: plateau, cosine"
            )

        self.best_val_loss_value = math.inf
        self.best_model_state_dict = None
        self.patience_counter = 0

        with lsection(f"Training loop:"):
            lprint(f"Maximum number of epochs: {self.max_epochs}")
            lprint(
                f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}"
            )

            for epoch in range(self.max_epochs):
                with lsection(f"Epoch {epoch}:"):
                    # One epoch of training
                    train_loss_value, val_loss_value, loss_log_epoch = self._epoch(
                        optimizer, data_loader, epoch
                    )

                    # Learning rate schedule:
                    if self.scheduler == "plateau":
                        scheduler.step(val_loss_value)
                    else:
                        scheduler.step()

                    # Logging:
                    loss_log_epoch["masking_density"] = self.masked_model.density
                    loss_log_epoch["lr"] = scheduler._last_lr[0]

                    if not self.check:
                        wandb.log(loss_log_epoch)

                    # Monitoring and saving:
                    if val_loss_value < self.best_val_loss_value:
                        lprint(f"## New best val loss!")
                        if (
                            val_loss_value
                            < self.best_val_loss_value - self.patience_epsilon
                        ):
                            lprint(f"## Good enough to reset patience!")
                            self.patience_counter = 0

                        # Update best val loss value:
                        self.best_val_loss_value = val_loss_value

                        # Save model:
                        self.best_model_state_dict = OrderedDict(
                            {k: v.to("cpu") for k, v in self.model.state_dict().items()}
                        )

                    else:
                        if (
                            epoch % max(1, self.reload_best_model_period) == 0
                            and self.best_model_state_dict
                        ):
                            lprint(f"Reloading best models to date!")
                            self.model.load_state_dict(self.best_model_state_dict)

                        if self.patience_counter > self.patience:
                            lprint(f"Early stopping!")
                            break

                        # No improvement:
                        lprint(
                            f"No improvement of validation losses, patience = {self.patience_counter}/{self.patience} "
                        )
                        self.patience_counter += 1

                    lprint(f"## Best val loss: {self.best_val_loss_value}")

                    if self._stop_training_flag:
                        lprint(f"Training interupted!")
                        break

        lprint(f"Reloading best models to date!")
        self.model.load_state_dict(self.best_model_state_dict)
Пример #14
0
    def _train(
        self,
        input_image,
        target_image,
        tile_size=None,
        train_valid_ratio=0.1,
        callback_period=3,
        j_inv=False,
    ):
        self._stop_training_flag = False

        if j_inv is not None and not j_inv:
            self.masking = False

        shape = input_image.shape
        num_input_channels = input_image.shape[1]
        num_output_channels = target_image.shape[1]
        num_spatiotemp_dim = input_image.ndim - 2

        # tile size:
        if tile_size is None:
            # tile_size = min(self.max_tile_size, min(shape[2:]))
            tile_size = tuple(min(self.max_tile_size, s) for s in shape[2:])
            lprint(f"Estimated max tile size {tile_size}")

        # Decide on how many voxels to be used for validation:
        num_val_voxels = int(train_valid_ratio * input_image.size)
        lprint(
            f"Number of voxels used for validation: {num_val_voxels} (train_valid_ratio={train_valid_ratio})"
        )

        # Generate random coordinates for these voxels:
        val_voxels = tuple(numpy.random.randint(d, size=num_val_voxels) for d in shape)
        lprint(f"Validation voxel coordinates: {val_voxels}")

        # Training Tile size:
        lprint(f"Train Tile dimensions: {tile_size}")

        # Prepare Training Dataset:
        dataset = self._get_dataset(
            input_image,
            target_image,
            self.self_supervised,
            tile_size=tile_size,
            mode="grid",
            validation_voxels=val_voxels,
            batch_size=self.batch_size,
        )
        lprint(f"Number tiles for training: {len(dataset)}")

        # Training Data Loader:
        # num_workers = max(3, os.cpu_count() // 2)
        num_workers = 0  # faster if data is already in memory...
        lprint(f"Number of workers for loading training/validation data: {num_workers}")
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,  # self.batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
        )

        # Model
        self.model = self.model_class(
            num_input_channels, num_output_channels, ndim=num_spatiotemp_dim
        ).to(self.device)

        number_of_parameters = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        lprint(
            f"Number of trainable parameters in {self.model_class} model: {number_of_parameters}"
        )

        if self.masking:
            self.masked_model = Masking(self.model, density=0.5).to(self.device)

        lprint(f"Optimiser class: {self.optimizer_class}")
        lprint(f"Learning rate : {self.learning_rate}")

        # Optimiser:
        if isinstance(self.optimizer_class, ESAdam):
            optimizer = partial(
                self.optimizer_class, start_noise_level=self.training_noise
            )
        else:
            optimizer = self.optimizer_class

        optimizer = optimizer(
            chain(self.model.parameters()),
            lr=self.learning_rate,
            weight_decay=self.l2_weight_regularisation,
        )

        lprint(f"Optimiser: {optimizer}")

        # Start training:
        try:
            self._train_loop(data_loader, optimizer)
        except KeyboardInterrupt:
            lprint("Training interrupted by user.")
            self._stop_training_flag = True
Пример #15
0
    def __init__(
        self,
        max_epochs=2048,
        patience=None,
        patience_epsilon=0.0,
        learning_rate=0.01,
        batch_size=8,
        model_class=UNet,
        masking=True,
        two_pass=False,  # two-pass Noise2Same loss
        inv_mse_lambda: float = 2.0,
        inv_mse_before_forward_model=False,
        masking_density=0.01,
        training_noise=0.1,
        loss="l1",
        normaliser_type="percentile",
        balance_training_data=None,
        keep_ratio=1,
        max_voxels_for_training=4e6,
        monitor=None,
        use_cuda=True,
        device_index=0,
        max_tile_size: int = 1024,  # TODO: adjust based on available memory
        check: bool = True,
        optimizer: str = "esadam",
        scheduler: str = "step",
        standardize_image: bool = False,
        amp: bool = False,
    ):
        """
        Constructs an image translator using the pytorch deep learning library.

        :param normaliser_type: normaliser type
        :param balance_training_data: balance data ? (limits number training entries per target value histogram bin)
        :param monitor: monitor to track progress of training externally (used by UI)
        :param two_pass: bool, adopt Noise2Same two forward pass strategy (one masked, one unmasked)
        :param inv_mse_before_forward_model: bool, use invariance MSE before forward (PSF) model for Noise2Same
        :param check: bool, run smoke test
        :param optimizer: str, optimiser to use ["adam", "esadam"]
        :param standardize_image: bool, standardize input images to zero mean and unit variance
        """
        super().__init__(normaliser_type, monitor=monitor)
        if two_pass and not masking:
            lprint("Force masking=True, it is needed in two-pass")
            masking = True

        use_cuda = use_cuda and (torch.cuda.device_count() > 0)
        self.device = torch.device(f"cuda:{device_index}" if use_cuda else "cpu")
        lprint(f"Using device: {self.device}")

        self.max_epochs = max_epochs
        self.patience = max_epochs if patience is None else patience
        self.patience_epsilon = patience_epsilon
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.loss = loss
        self.max_voxels_for_training = max_voxels_for_training
        self.keep_ratio = keep_ratio
        self.balance_training_data = balance_training_data

        self.model_class = model_class

        self.l1_weight_regularisation = 1e-6
        self.l2_weight_regularisation = 1e-6
        self.training_noise = training_noise
        self.reload_best_model_period = max_epochs  # //2
        self.reduce_lr_patience = patience // 2
        self.reduce_lr_factor = 0.9
        self.masking = masking
        self.two_pass = two_pass
        self.inv_mse_before_forward_model = inv_mse_before_forward_model
        self.inv_mse_lambda = inv_mse_lambda
        self.masking_density = masking_density
        self.optimizer_class = ESAdam if optimizer == "esadam" else torch.optim.Adam
        self.scheduler = scheduler
        self.max_tile_size = max_tile_size

        self._stop_training_flag = False
        self.check = check
        self.standardize_image = standardize_image
        self.amp = amp

        # Denoise loss function:
        loss_function = nn.L1Loss()
        if self.loss.lower() == "l2":
            lprint(f"Training/Validation loss: L2")
            if self.masking:
                loss_function = (
                    lambda u, v, m: (u - v) ** 2 if m is None else ((u - v) * m) ** 2
                )
            else:
                loss_function = lambda u, v: (u - v) ** 2

        elif self.loss.lower() == "l1":
            lprint(f"Training/Validation loss: L1")
            if self.masking:
                loss_function = (
                    lambda u, v, m: torch.abs(u - v)
                    if m is None
                    else torch.abs((u - v) * m)
                )
            else:
                loss_function = lambda u, v: torch.abs(u - v)
            lprint(f"Training/Validation loss: L1")

        self.loss_function = loss_function

        # Monitor
        self.best_val_loss_value = math.inf
        self.best_model_state_dict = None
        self.patience_counter = 0
Пример #16
0
def offcore_array(
    shape: Union[Tuple[int, ...], Generator[int, None, None]],
    dtype: numpy.dtype,
    force_memmap: bool = False,
    zarr_allowed: bool = False,
    no_memmap_limit: bool = True,
    max_memory_usage_ratio: float = 0.9,
):
    """
    Instanciates an array of given shape and dtype in  'off-core' fashion i.e. not in main memory.
    Right now it simply uses memory mapping on temp file that is deleted after the file is closed

    Parameters
    ----------
    shape
    dtype
    force_memmap
    zarr_allowed
    no_memmap_limit
    max_memory_usage_ratio
    """

    with lsection(f"Array of shape: {shape} and dtype: {dtype} requested"):
        size_in_bytes = numpy.prod(shape) * numpy.dtype(dtype).itemsize
        lprint(f'Array requested will be {(size_in_bytes / 1E6)} MB.')

        total_physical_memory_in_bytes = psutil.virtual_memory().total
        total_swap_memory_in_bytes = psutil.swap_memory().total

        total_mem_in_bytes = total_physical_memory_in_bytes + total_swap_memory_in_bytes
        lprint(
            f'There is {int(psutil.virtual_memory().total / 1E6)} MB of physical memory'
        )
        lprint(
            f'There is {int(psutil.swap_memory().total / 1E6)} MB of swap memory'
        )
        lprint(f'There is {int(total_mem_in_bytes / 1E6)} MB of total memory')

        is_enough_physical_memory = (size_in_bytes < max_memory_usage_ratio *
                                     total_physical_memory_in_bytes)

        is_enough_total_memory = (size_in_bytes <
                                  max_memory_usage_ratio * total_mem_in_bytes)

        if not force_memmap and is_enough_total_memory:
            lprint(
                f'There is enough physical+swap memory -- we do not need to use a mem mapped array or zarr-backed array.'
            )
            array = numpy.zeros(shape, dtype=dtype)

        elif no_memmap_limit:
            lprint(
                f'There is not enough physical+swap memory -- we will use a mem mapped array.'
            )
            temp_file = tempfile.NamedTemporaryFile(
                dir=OffCore.memmap_directory)
            lprint(
                f'The temporary memory mapped file is at: {temp_file.name} (but you might not be able to see it!)'
            )
            array = numpy.memmap(temp_file,
                                 dtype=dtype,
                                 mode='w+',
                                 shape=shape)

        elif zarr_allowed:
            lprint(
                f'There is not enough physical+swap memory -- we will use a zarr-backed array.'
            )
            import zarr

            array = zarr.create(shape=shape,
                                dtype=dtype,
                                store=zarr.TempStore("output.zarr"))
            # from numcodecs import Blosc
            # compressor = Blosc(cname = 'zstd', clevel = 3, shuffle = Blosc.BITSHUFFLE)
            # array = zarr.zeros((102_0, 200, 210), chunks = (100, 200, 210), compressor = compressor

        return array
Пример #17
0
def test_log():
    # This is required for this test to pass!
    Log.override_test_exclusion = True

    lprint('Test')

    with lsection('a section'):
        lprint('a line')
        lprint('another line')
        lprint('we are done')

        with lsection('a subsection'):
            lprint('another line')
            lprint('we are done')

            with lsection('a subsection'):
                lprint('another line')
                lprint('we are done')

                assert Log.depth == 3

                with lsection('a subsection'):
                    lprint('another line')
                    lprint('we are done')

                    with lsection('a subsection'):
                        lprint('another line')
                        lprint('we are done')

                        assert Log.depth == 5

                        with lsection('a subsection'):
                            lprint('another line')
                            lprint('we are done')

                            with lsection('a subsection'):
                                lprint('another line')
                                lprint('we are done')

                                assert Log.depth == 7

                        with lsection('a subsection'):
                            lprint('another line')
                            lprint('we are done')

                    with lsection('a subsection'):
                        lprint('another line')
                        lprint('we are done')

                with lsection('a subsection'):
                    lprint('another line')
                    lprint('we are done')

    lprint('test is finished...')

    assert Log.depth == 0
Пример #18
0
    def _train_loop(self, data_loader, optimizer, loss_function):

        # Scheduler:
        scheduler = ReduceLROnPlateau(
            optimizer,
            'min',
            factor=self.reduce_lr_factor,
            verbose=True,
            patience=self.reduce_lr_patience,
        )

        best_val_loss_value = math.inf
        best_model_state_dict = None
        patience_counter = 0

        with lsection(f"Training loop:"):
            lprint(f"Maximum number of epochs: {self.max_epochs}")
            lprint(
                f"Training type: {'self-supervised' if self.self_supervised else 'supervised'}"
            )

            for epoch in range(self.max_epochs):
                with lsection(f"Epoch {epoch}:"):

                    if hasattr(self, 'masked_model'):
                        self.masked_model.density = 0.005 * self.masking_density + 0.995 * self.masked_model.density
                        lprint(f"masking density: {self.masked_model.density}")

                    train_loss_value = 0
                    val_loss_value = 0
                    iteration = 0
                    for i, (input_images, target_images,
                            val_mask_images) in enumerate(data_loader):

                        lprint(f"index: {i}, shape:{input_images.shape}")

                        input_images_gpu = input_images.to(self.device,
                                                           non_blocking=True)
                        target_images_gpu = target_images.to(self.device,
                                                             non_blocking=True)
                        validation_mask_images_gpu = val_mask_images.to(
                            self.device, non_blocking=True)

                        # Adding training noise to input:
                        if self.training_noise > 0:
                            with torch.no_grad():
                                alpha = self.training_noise / (
                                    1 + (10000 * epoch / self.max_epochs))
                                lprint(f"Training noise level: {alpha}")
                                training_noise = alpha * torch.randn_like(
                                    input_images)
                                input_images_gpu += training_noise.to(
                                    input_images_gpu.device)

                        # Clear gradients w.r.t. parameters
                        optimizer.zero_grad()

                        # Forward pass:
                        self.model.train()
                        if self.masking:
                            translated_images_gpu = self.masked_model(
                                input_images_gpu)
                        else:
                            translated_images_gpu = self.model(
                                input_images_gpu)

                        # apply forward model:
                        forward_model_images_gpu = self._forward_model(
                            translated_images_gpu)

                        # validation masking:
                        u = forward_model_images_gpu * (
                            1 - validation_mask_images_gpu)
                        v = target_images_gpu * (1 -
                                                 validation_mask_images_gpu)

                        # with napari.gui_qt():
                        #     viewer = napari.Viewer()
                        #     viewer.add_image(to_numpy(validation_mask_images_gpu), name='validation_mask_images_gpu')
                        #     viewer.add_image(to_numpy(forward_model_images_gpu), name='forward_model_images_gpu')
                        #     viewer.add_image(to_numpy(target_images_gpu), name='target_images_gpu')

                        # translation loss (per voxel):
                        if self.masking:
                            mask = self.masked_model.get_mask()
                            translation_loss = loss_function(u, v, mask)
                        else:
                            translation_loss = loss_function(u, v)

                        # loss value (for all voxels):
                        translation_loss_value = translation_loss.mean()

                        # Additional losses:
                        additional_loss_value = self._additional_losses(
                            translated_images_gpu, forward_model_images_gpu)
                        if additional_loss_value is not None:
                            translation_loss_value += additional_loss_value

                        # backpropagation:
                        translation_loss_value.backward()

                        # Updating parameters
                        optimizer.step()

                        # post optimisation -- if needed:
                        self.model.post_optimisation()

                        # update training loss_deconvolution for whole image:
                        train_loss_value += translation_loss_value.item()
                        iteration += 1

                        # Validation:
                        with torch.no_grad():
                            # Forward pass:
                            self.model.eval()
                            if self.masking:
                                translated_images_gpu = self.masked_model(
                                    input_images_gpu)
                            else:
                                translated_images_gpu = self.model(
                                    input_images_gpu)

                            # apply forward model:
                            forward_model_images_gpu = self._forward_model(
                                translated_images_gpu)

                            # validation masking:
                            u = forward_model_images_gpu * validation_mask_images_gpu
                            v = target_images_gpu * validation_mask_images_gpu

                            # translation loss (per voxel):
                            if self.masking:
                                translation_loss = loss_function(u, v, None)
                            else:
                                translation_loss = loss_function(u, v)

                            # loss values:
                            translation_loss_value = (
                                translation_loss.mean().cpu().item())

                            # update validation loss_deconvolution for whole image:
                            val_loss_value += translation_loss_value
                            iteration += 1

                    train_loss_value /= iteration
                    lprint(f"Training loss value: {train_loss_value}")

                    val_loss_value /= iteration
                    lprint(f"Validation loss value: {val_loss_value}")

                    # Learning rate schedule:
                    scheduler.step(val_loss_value)

                    if val_loss_value < best_val_loss_value:
                        lprint(f"## New best val loss!")
                        if val_loss_value < best_val_loss_value - self.patience_epsilon:
                            lprint(f"## Good enough to reset patience!")
                            patience_counter = 0

                        # Update best val loss value:
                        best_val_loss_value = val_loss_value

                        # Save model:
                        best_model_state_dict = OrderedDict({
                            k: v.to('cpu')
                            for k, v in self.model.state_dict().items()
                        })

                    else:
                        if (epoch % max(1, self.reload_best_model_period) == 0
                                and best_model_state_dict):
                            lprint(f"Reloading best models to date!")
                            self.model.load_state_dict(best_model_state_dict)

                        if patience_counter > self.patience:
                            lprint(f"Early stopping!")
                            break

                        # No improvement:
                        lprint(
                            f"No improvement of validation losses, patience = {patience_counter}/{self.patience} "
                        )
                        patience_counter += 1

                    lprint(f"## Best val loss: {best_val_loss_value}")

                    if self._stop_training_flag:
                        lprint(f"Training interupted!")
                        break

        lprint(f"Reloading best models to date!")
        self.model.load_state_dict(best_model_state_dict)
Пример #19
0
    def translate(
        self,
        input_image,
        translated_image=None,
        batch_dims=None,
        channel_dims=None,
        tile_size=None,
        denormalise_values=True,
        leave_as_float=False,
        clip=True,
    ):
        """
        Translates an input image into an output image according to the learned function.
        :param input_image:
        :type input_image:
        :param clip:
        :type clip:
        :return:
        :rtype:
        """

        with lsection(
            f"Predicting output image from input image of dimension {input_image.shape}"
        ):

            # set default batch_dim and channel_dim values:
            if batch_dims is None:
                batch_dims = (False,) * len(input_image.shape)
            if channel_dims is None:
                channel_dims = (False,) * len(input_image.shape)

            # Number of spatio-temporal dimensions:
            num_spatiotemp_dim = sum(
                0 if b or c else 1 for b, c in zip(batch_dims, channel_dims)
            )

            # First we normalise the input values:
            normalised_input_image = self.input_normaliser.normalise(
                input_image, batch_dims=batch_dims, channel_dims=channel_dims
            )

            # When we trained supervised we need to update permutated image shape of target_normaliser
            # This way we can accommodate different sizes of batch dimensions than batch dimensions used for training
            if not self.self_supervised:
                (
                    _,
                    _,
                    self.target_normaliser.permutated_image_shape,
                ) = self.target_normaliser.shape_normalize(
                    input_image, batch_dims=batch_dims, channel_dims=channel_dims
                )

            # Let's pad the input array so we avoid annoying border-effects:
            normalised_input_image = self._pad_norm_image(normalised_input_image)

            # Spatio-temporal shape:
            spatiotemp_shape = normalised_input_image.shape[-num_spatiotemp_dim:]

            normalised_translated_image = None

            if tile_size == 0:
                # we _force_ no tilling, this is _not_ the default.

                # We translate:
                normalised_translated_image = self._translate(
                    normalised_input_image,
                    whole_image_shape=normalised_input_image.shape,
                )

            else:

                # We do need to do tiled inference because of a lack of memory
                # or because a small batch size was requested:

                normalised_input_shape = normalised_input_image.shape

                # We get the tilling strategy:
                # tile_size, shape, min_margin, max_margin
                tilling_strategy, margins = self._get_tilling_strategy_and_margins(
                    normalised_input_image,
                    self.max_voxels_per_tile,
                    self.tile_min_margin,
                    self.tile_max_margin,
                    suggested_tile_size=tile_size,
                )
                lprint(f"Tilling strategy: {tilling_strategy}")
                lprint(f"Margins for tiles: {margins} .")

                # tile slice objects (with and without margins):
                tile_slices_margins = list(
                    nd_split_slices(
                        normalised_input_shape, tilling_strategy, margins=margins
                    )
                )
                tile_slices = list(
                    nd_split_slices(normalised_input_shape, tilling_strategy)
                )

                # Number of tiles:
                number_of_tiles = len(tile_slices)
                lprint(f"Number of tiles (slices): {number_of_tiles}")

                # We create slice list:
                slicezip = zip(tile_slices_margins, tile_slices)

                counter = 1
                for slice_margin_tuple, slice_tuple in slicezip:
                    with lsection(
                        f"Current tile: {counter}/{number_of_tiles}, slice: {slice_tuple} "
                    ):

                        # We first extract the tile image:
                        input_image_tile = normalised_input_image[
                            slice_margin_tuple
                        ].copy()

                        # We do the actual translation:
                        lprint(f"Translating...")
                        translated_image_tile = self._translate(
                            input_image_tile,
                            image_slice=slice_margin_tuple,
                            whole_image_shape=normalised_input_image.shape,
                        )

                        # We compute the slice needed to cut out the margins:
                        lprint(f"Removing margins...")
                        remove_margin_slice_tuple = remove_margin_slice(
                            normalised_input_shape, slice_margin_tuple, slice_tuple
                        )

                        # We allocate -just in time- the translated array if needed:
                        # if the array is already provided, it must of course have the right dimensions...
                        if normalised_translated_image is None:
                            translated_image_shape = (
                                normalised_input_image.shape[:2] + spatiotemp_shape
                            )
                            normalised_translated_image = offcore_array(
                                shape=translated_image_shape,
                                dtype=translated_image_tile.dtype,
                                max_memory_usage_ratio=self.max_memory_usage_ratio,
                            )

                        # We plug in the batch without margins into the destination image:
                        lprint(f"Inserting translated batch into result image...")
                        normalised_translated_image[
                            slice_tuple
                        ] = translated_image_tile[remove_margin_slice_tuple]

                        counter += 1

            # Let's crop the padding:
            normalised_translated_image = self._crop_norm_image(
                normalised_translated_image
            )

            # Then we denormalise:
            denormalised_translated_image = self.target_normaliser.denormalise(
                normalised_translated_image,
                # denormalise_values=denormalise_values,
                leave_as_float=leave_as_float,
                clip=clip,
            )

            if translated_image is None:
                translated_image = denormalised_translated_image
            else:
                translated_image[...] = denormalised_translated_image

            return translated_image
Пример #20
0
    def _train(
        self,
        input_image,
        target_image,
        train_valid_ratio=0.1,
        callback_period=3,
        jinv=False,
    ):
        self._stop_training_flag = False

        if jinv is not None and not jinv:
            self.masking = False

        shape = input_image.shape
        num_batches = shape[0]
        num_input_channels = input_image.shape[1]
        num_output_channels = target_image.shape[1]
        num_spatiotemp_dim = input_image.ndim - 2

        # tile size:
        tile_size = min(self.max_tile_size, min(shape[2:]))

        # Decide on how many voxels to be used for validation:
        num_val_voxels = int(train_valid_ratio * input_image.size)
        lprint(
            f"Number of voxels used for validation: {num_val_voxels} (train_valid_ratio={train_valid_ratio})"
        )

        # Generate random coordinates for these voxels:
        val_voxels = tuple(
            numpy.random.randint(d, size=num_val_voxels) for d in shape)
        lprint(f"Validation voxel coordinates: {val_voxels}")

        # Training Tile size:
        lprint(f"Train Tile dimensions: {tile_size}")

        # Prepare Training Dataset:
        dataset = self._get_dataset(input_image,
                                    target_image,
                                    self.self_supervised,
                                    tilesize=tile_size,
                                    mode='grid',
                                    validation_voxels=val_voxels,
                                    batch_size=self.batch_size)
        lprint(f"Number tiles for training: {len(dataset)}")

        # Training Data Loader:
        # num_workers = max(3, os.cpu_count() // 2)
        num_workers = 0  # faster if data is already in memory...
        lprint(
            f"Number of workers for loading training/validation data: {num_workers}"
        )
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,  # self.batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
        )

        # Model
        self.model = self.model_class(num_input_channels,
                                      num_output_channels,
                                      ndim=num_spatiotemp_dim).to(self.device)

        number_of_parameters = sum(p.numel() for p in self.model.parameters()
                                   if p.requires_grad)
        lprint(
            f"Number of trainable parameters in {self.model_class} model: {number_of_parameters}"
        )

        if self.masking:
            self.masked_model = Masking(self.model,
                                        density=0.5).to(self.device)

        lprint(f"Optimiser class: {self.optimiser_class}")
        lprint(f"Learning rate : {self.learning_rate}")

        # Optimiser:
        optimizer = self.optimiser_class(
            chain(self.model.parameters()),
            lr=self.learning_rate,
            start_noise_level=self.training_noise,
            weight_decay=self.l2_weight_regularisation,
        )

        lprint(f"Optimiser: {optimizer}")

        # Denoise loss functon:
        loss_function = nn.L1Loss()
        if self.loss.lower() == 'l2':
            lprint(f"Training/Validation loss: L2")
            if self.masking:
                loss_function = (lambda u, v, m: (u - v)**2
                                 if m is None else ((u - v) * m)**2)
            else:
                loss_function = lambda u, v: (u - v)**2

        elif self.loss.lower() == 'l1':
            if self.masking:
                loss_function = (lambda u, v, m: torch.abs(u - v)
                                 if m is None else torch.abs((u - v) * m))
            else:
                loss_function = lambda u, v: torch.abs(u - v)
            lprint(f"Training/Validation loss: L1")

        # Start training:
        self._train_loop(data_loader, optimizer, loss_function)
Пример #21
0
    def _get_tilling_strategy_and_margins(
        self,
        image,
        max_voxels_per_tile,
        min_margin,
        max_margin,
        suggested_tile_size=None,
    ):

        # We will store the batch strategy as a list of integers representing the number of chunks per dimension:
        with lsection(f"Determine tilling strategy:"):

            suggested_tile_size = (
                math.inf if suggested_tile_size is None else suggested_tile_size
            )

            # image shape:
            shape = image.shape
            num_spatio_temp_dim = num_spatiotemp_dim = len(shape) - 2

            lprint(f"image shape             = {shape}")
            lprint(f"max_voxels_per_tile     = {max_voxels_per_tile}")

            # Estimated amount of memory needed for storing all features:
            (
                estimated_memory_needed,
                total_memory_available,
            ) = self._estimate_memory_needed_and_available(image)
            lprint(f"Estimated amount of memory needed: {estimated_memory_needed}")

            # Available physical memory :
            total_memory_available *= self.max_memory_usage_ratio

            lprint(
                f"Available memory (we reserve 10% for 'comfort'): {total_memory_available}"
            )

            # How much do we need to tile because of memory, if at all?
            split_factor_mem = estimated_memory_needed / total_memory_available
            lprint(
                f"How much do we need to tile because of memory? : {split_factor_mem} times."
            )

            # how much do we have to tile because of the limit on the number of voxels per tile?
            split_factor_max_voxels = image.size / max_voxels_per_tile
            lprint(
                f"How much do we need to tile because of the limit on the number of voxels per tile? : {split_factor_max_voxels} times."
            )

            # how much do we have to tile because of the suggested tile size?
            split_factor_suggested_tile_size = image.size / (
                suggested_tile_size ** num_spatio_temp_dim
            )
            lprint(
                f"How much do we need to tile because of the suggested tile size? : {split_factor_suggested_tile_size} times."
            )

            # we keep the max:
            desired_split_factor = max(
                split_factor_mem,
                split_factor_max_voxels,
                split_factor_suggested_tile_size,
            )
            # We cannot split less than 1 time:
            desired_split_factor = max(1, int(math.ceil(desired_split_factor)))
            lprint(f"Desired split factor: {desired_split_factor}")

            # Number of batches:
            num_batches = shape[0]

            # Does the number of batches split the data enough?
            if num_batches < desired_split_factor:
                # Not enough splitting happening along the batch dimension, we need to split further:
                # how much?
                rest_split_factor = desired_split_factor / num_batches
                lprint(
                    f"Not enough splitting happening along the batch dimension, we need to split spatio-temp dims by: {rest_split_factor}"
                )

                # let's split the dimensions in a way proportional to their lengths:
                split_per_dim = (rest_split_factor / numpy.prod(shape[2:])) ** (
                    1 / num_spatio_temp_dim
                )

                spatiotemp_tilling_strategy = tuple(
                    max(1, int(math.ceil(split_per_dim * s))) for s in shape[2:]
                )

                # correction_factor = numpy.prod(tuple(s for s in spatiotemp_tilling_strategy if s<1))

                tilling_strategy = (num_batches, 1) + spatiotemp_tilling_strategy
                lprint(f"Preliminary tilling strategy is: {tilling_strategy}")

                # We correct for eventual oversplitting by favouring splitting over the front dimensions:
                current_splitting_factor = 1
                corrected_tilling_strategy = []
                split_factor_reached = False
                for i, s in enumerate(tilling_strategy):

                    if split_factor_reached:
                        corrected_tilling_strategy.append(1)
                    else:
                        corrected_tilling_strategy.append(s)
                        current_splitting_factor *= s

                    if current_splitting_factor >= desired_split_factor:
                        split_factor_reached = True

                tilling_strategy = tuple(corrected_tilling_strategy)

            else:
                tilling_strategy = (desired_split_factor, 1) + tuple(
                    1 for s in shape[2:]
                )

            lprint(f"Tilling strategy is: {tilling_strategy}")

            # Handles defaults:
            if max_margin is None:
                max_margin = math.inf
            if min_margin is None:
                min_margin = 0

            # First we estimate the shape of a tile:

            estimated_tile_shape = tuple(
                int(round(s / ts)) for s, ts in zip(shape[2:], tilling_strategy[2:])
            )
            lprint(f"The estimated tile shape is: {estimated_tile_shape}")

            # Limit margins:
            # We automatically set the margin of the tile size:
            # the max-margin factor guarantees that tilling will incur no more than a given max tiling overhead:
            margin_factor = 0.5 * (
                ((1 + self.max_tilling_overhead) ** (1 / num_spatiotemp_dim)) - 1
            )
            margins = tuple(int(s * margin_factor) for s in estimated_tile_shape)

            # Limit the margin to something reasonable (provided or automatically computed):
            margins = tuple(min(max_margin, m) for m in margins)
            margins = tuple(max(min_margin, m) for m in margins)

            # We add the batch and channel dimensions:
            margins = (0, 0) + margins

            # We only need margins if we split a dimension:
            margins = tuple(
                (0 if split == 1 else margin)
                for margin, split in zip(margins, tilling_strategy)
            )

            return tilling_strategy, margins
Пример #22
0
    def train(
        self,
        input_image,
        target_image=None,
        batch_dims=None,
        channel_dims=None,
        tile_size=None,
        train_valid_ratio=0.1,
        callback_period=3,
        jinv=None,
    ):
        """Train to translate a given input image to a given output image.
        This has a lot of the machinery for batching and more...
        """

        if target_image is None:
            target_image = input_image

        with lsection(
            f"Learning to translate from image of dimensions {str(input_image.shape)} to {str(target_image.shape)} ."
        ):

            lprint("Running garbage collector...")
            gc.collect()

            # If we use the same image for input and output then we are in a self-supervised setting:
            self.self_supervised = input_image is target_image

            if self.self_supervised:
                lprint("Training is self-supervised.")
            else:
                lprint("Training is supervised.")

            if batch_dims is None:  # set default batch_dim value:
                batch_dims = (False,) * len(input_image.shape)

            self.input_normaliser = self.normalizer_class(
                transform=self.normaliser_transform, clip=self.normaliser_clip
            )
            self.target_normaliser = (
                self.input_normaliser
                if self.self_supervised
                else self.normalizer_class(
                    transform=self.normaliser_transform, clip=self.normaliser_clip
                )
            )

            # Calibrates normaliser(s):
            self.input_normaliser.calibrate(input_image)
            if not self.self_supervised:
                self.target_normaliser.calibrate(target_image)

            # Intensity values normalisation:
            normalised_input_image = self.input_normaliser.normalise(
                input_image, batch_dims=batch_dims, channel_dims=channel_dims
            )
            normalised_target_image = (
                normalised_input_image
                if self.self_supervised
                else self.target_normaliser.normalise(
                    target_image, batch_dims=batch_dims, channel_dims=channel_dims
                )
            )

            # Let's pad the images to avoid border effects:
            # If we do it for translation we also have to do it for training because of
            # location-aware features such as large-scale features or spatial-features.
            normalised_input_image = self._pad_norm_image(normalised_input_image)
            normalised_target_image = self._pad_norm_image(normalised_target_image)

            self._train(
                normalised_input_image,
                normalised_target_image,
                tile_size=tile_size,
                train_valid_ratio=train_valid_ratio,
                callback_period=callback_period,
                jinv=jinv,
            )
Пример #23
0
def test_log():
    # This is required for this test to pass!
    Log.override_test_exclusion = True

    lprint("Test")

    with lsection("a section"):
        lprint("a line")
        lprint("another line")
        lprint("we are done")

        with lsection("a subsection"):
            lprint("another line")
            lprint("we are done")

            with lsection("a subsection"):
                lprint("another line")
                lprint("we are done")

                assert Log.depth == 3

                with lsection("a subsection"):
                    lprint("another line")
                    lprint("we are done")

                    with lsection("a subsection"):
                        lprint("another line")
                        lprint("we are done")

                        assert Log.depth == 5

                        with lsection("a subsection"):
                            lprint("another line")
                            lprint("we are done")

                            with lsection("a subsection"):
                                lprint("another line")
                                lprint("we are done")

                                assert Log.depth == 7

                        with lsection("a subsection"):
                            lprint("another line")
                            lprint("we are done")

                    with lsection("a subsection"):
                        lprint("another line")
                        lprint("we are done")

                with lsection("a subsection"):
                    lprint("another line")
                    lprint("we are done")

    lprint("test is finished...")

    assert Log.depth == 0