Пример #1
0
class Engine(ABC):
    def __init__(
        self,
        cfg: BaseConfig,
        model: nn.Module,
        device: int,
        mixed_precision: bool = False,
        **models: Dict[str, nn.Module],
    ):
        self.logger = logging.getLogger(type(self).__name__)

        self.cfg = cfg
        self.model = model
        self.models = models
        self.device = device

        self.mixed_precision = mixed_precision
        self.checkpointer = None

        # TODO: This might not be needed, if these objects are changed in-place
        self.__optimizer = None
        self.__lr_scheduler = None
        self._scaler = GradScaler(enabled=self.mixed_precision)
        self.__writers = None
        self.__bind_sigint_signal()

        self._ndim = None

    @property
    def real_names(self):
        if self.ndim == 2:
            return ["batch", "height", "width"]
        elif self.ndim == 3:
            return ["batch", "slice", "height", "width"]
        else:
            raise NotImplementedError(
                f"{self.ndim}D named data is not yet supported")

    @property
    def complex_names(self):
        if self.ndim == 2:
            return ["batch", "complex", "height", "width"]
        elif self.ndim == 3:
            return ["batch", "complex", "slice", "height", "width"]
        else:
            raise NotImplementedError(
                f"{self.ndim}D named data is not yet supported")

    @property
    def complex_names_complex_last(self):
        if self.ndim == 2:
            return ["batch", "height", "width", "complex"]
        elif self.ndim == 3:
            return ["batch", "slice", "height", "width", "complex"]
        else:
            raise NotImplementedError(
                f"{self.ndim}D named data is not yet supported")

    @property
    def ndim(self):
        if not self._ndim:
            raise ValueError(f"ndim needs to be set before it can be called.")

        return self._ndim

    @ndim.setter
    def ndim(self, ndim):
        if not isinstance(ndim, int) or ndim <= 0:
            raise ValueError(
                f"ndim has to be an integer larger than 0. Got {ndim}. "
                f"You can for instance use `compute_dimensionality_from_sample` "
                f"to determine the dimensionality or set it explicitly.")

        self._ndim = ndim

    @staticmethod
    def compute_dimensionality_from_sample(sample):
        """
        Computes the dimensionality of the data from a single sample by looking at all the names in the tensor dict
        and finding the tensor with the largest dimension in the dict.

        Parameters
        ----------
        sample : dict

        Returns
        -------
        int : the sample dimensionality
        """
        relevant_names = ["slice", "height", "width", "time"]
        ndim = 0
        for k, v in sample.items():
            if isinstance(v, torch.Tensor):
                curr_dimensionality = sum(
                    [_ in v.names for _ in relevant_names])
                ndim = max(curr_dimensionality, ndim)

        return ndim

    @abstractmethod
    def build_loss(self) -> Dict:
        pass

    @staticmethod
    def build_metrics(metrics_list) -> Dict:
        if not metrics_list:
            return {}

        # _metric is added as only keys containining loss or metric are logged.
        metrics_dict = {
            curr_metric + "_metric": str_to_class("direct.functionals",
                                                  curr_metric)
            for curr_metric in metrics_list
        }
        return metrics_dict

    @abstractmethod
    def _do_iteration(self, *args, **kwargs) -> Tuple[torch.Tensor, Dict]:
        """
        This is a placeholder for the iteration function. This needs to perform the backward pass.
        If using mixed-precision you need to implement `autocast` as well in this function.
        It is recommended you raise an error if `self.mixed_precision` is true but mixed precision is not available.
        """
        pass

    @torch.no_grad()
    def predict(
        self,
        dataset: Dataset,
        experiment_directory: pathlib.Path,
        checkpoint_number: int = -1,
        num_workers: int = 6,
    ) -> np.ndarray:
        # TODO: Improve the way of checkpointing
        self.logger.info(f"Predicting...")
        self.checkpointer = Checkpointer(self.model,
                                         experiment_directory,
                                         save_to_disk=False)

        # Do not load again if we already have loaded the checkpoint.
        if self.checkpointer.checkpoint_loaded is not checkpoint_number:
            self.checkpointer.load(iteration=checkpoint_number,
                                   checkpointable_objects=None)

        batch_sampler = self.build_batch_sampler(
            dataset,
            batch_size=self.cfg.validation.batch_size,
            sampler_type="sequential",
            limit_number_of_volumes=None,
        )
        # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation.
        data_loader = self.build_loader(dataset,
                                        batch_sampler=batch_sampler,
                                        num_workers=num_workers)
        loss, output = self.evaluate(data_loader,
                                     loss_fns=None,
                                     crop=None,
                                     is_validation_process=False)

        return output

    @staticmethod
    def build_loader(
        dataset: Dataset,
        batch_sampler: Optional[Sampler] = None,
        num_workers: int = 6,
    ) -> DataLoader:
        # TODO(jt): Custom memory pinning.
        loader = DataLoader(
            dataset=dataset,
            sampler=None,
            batch_size=1,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=named_collate,
            drop_last=False,
            shuffle=False,
            pin_memory=True,
        )

        return loader

    @staticmethod
    def build_batch_sampler(
        dataset: Union[Dataset, List[Dataset]],
        batch_size: int,
        sampler_type: str,
        **kwargs,
    ) -> Sampler:
        if sampler_type == "random":
            if not isinstance(dataset, List) or any(
                [not isinstance(_, Dataset) for _ in dataset]):
                raise ValueError(
                    f"Random sampler requires a list of datasets as input.")
            batch_sampler = ConcatDatasetBatchSampler(datasets=dataset,
                                                      batch_size=batch_size)
        elif sampler_type == "sequential":
            sampler = direct.data.sampler.DistributedSequentialSampler(
                dataset, **kwargs)
            batch_sampler = direct.data.sampler.BatchVolumeSampler(
                sampler,
                batch_size=batch_size,
            )
        else:
            raise ValueError(f"Sampler type {sampler_type} not supported.")

        return batch_sampler

    def training_loop(
        self,
        data_loader: DataLoader,
        start_iter: int,
        validation_data_loaders: Optional[List[DataLoader]] = None,
        experiment_directory: Optional[pathlib.Path] = None,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == start_iter:
                self.ndim = self.compute_dimensionality_from_sample(data)
                self.logger.info(f"Data dimensionality: {self.ndim}.")

            if iter_idx == 0:
                self.log_first_training_example(data)

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(-1)

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.
                                                gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.cfg.training.gradient_clipping)

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once.")
                    parameters = list(
                        filter(lambda p: p.grad is not None,
                               self.model.parameters()))
                    gradient_norm = sum([
                        parameter.grad.data**2 for parameter in parameters
                    ]).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar("lr",
                               self.__optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict)
                if metrics_dict else {})
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if iter_idx > 5 and (
                    iter_idx % self.cfg.training.checkpointer.checkpoint_steps
                    == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f"Checkpointing at iteration {iter_idx}.")
                self.checkpointer.save(iter_idx)

            if (validation_data_loaders is not None and iter_idx > 5
                    and (iter_idx % self.cfg.training.validation_steps == 0 or
                         (iter_idx + 1) == total_iter)):
                for (
                        curr_dataset_name,
                        curr_validation_data_loader,
                ) in validation_data_loaders:
                    self.logger.info(
                        f"Evaluating: {curr_dataset_name}..."
                    )  # TODO(jt): Fix with better names and stuff.
                    (
                        curr_val_loss_dict,
                        curr_val_metric_dict_per_case,
                        visualize_slices,
                        visualize_target,
                    ) = self.evaluate(
                        curr_validation_data_loader,
                        loss_fns,
                        is_validation_process=True,
                    )

                    if experiment_directory:
                        # Make dictionary serializable for logging
                        serializable_val_metric_dict = {
                            k0: {k1: float(v1)
                                 for k1, v1 in v0.items()}
                            for k0, v0 in
                            curr_val_metric_dict_per_case.items()
                        }
                        write_json(
                            experiment_directory /
                            f"metrics_val_{curr_dataset_name}_{iter_idx}.json",
                            serializable_val_metric_dict,
                        )

                    # Metric dict still needs to be reduced as it gives values *per* data
                    curr_val_metric_dict = reduce_list_of_dicts(list(
                        curr_val_metric_dict_per_case.values()),
                                                                mode="average")

                    key_prefix = ("val/" if not curr_dataset_name else
                                  f"val/{curr_dataset_name}/")
                    val_loss_reduced = sum(curr_val_loss_dict.values())
                    storage.add_scalars(
                        **{key_prefix + "loss": val_loss_reduced},
                        **{
                            **prefix_dict_keys(curr_val_metric_dict, key_prefix),
                            **prefix_dict_keys(curr_val_loss_dict, key_prefix),
                        },
                        smoothing_hint=False,
                    )
                    visualize_slices = self.process_slices_for_visualization(
                        visualize_slices, visualize_target)
                    storage.add_image(f"{key_prefix}prediction",
                                      visualize_slices)

                    if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                        visualize_target = make_grid(
                            crop_to_largest(visualize_target, pad_value=0),
                            nrow=self.cfg.tensorboard.num_images,
                            scale_each=True,
                        )
                        storage.add_image(f"{key_prefix}target",
                                          visualize_target)

                self.logger.info(
                    f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
                )
                self.model.train()

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()

    def process_slices_for_visualization(self, visualize_slices,
                                         visualize_target):
        # Log slices.
        # Compute the difference as well, and normalize for visualization
        difference_slices = [
            a - b for a, b in zip(visualize_slices, visualize_target)
        ]
        # Normalize slices
        difference_slices = [(d / np.abs(d)) * 0.5 + 0.5
                             for d in difference_slices]
        visualize_slices = [
            normalize_image(image) for image in visualize_slices
        ]

        # Visualize slices, and crop to the largest volume
        visualize_slices = make_grid(
            crop_to_largest(visualize_slices + difference_slices, pad_value=0),
            nrow=self.cfg.tensorboard.num_images,
            scale_each=False,
        )
        return visualize_slices

    def models_training_mode(self):
        self.model.train()
        for curr_model in self.models:
            self.models[curr_model].train()

    def models_validation_mode(self):
        self.model.eval()
        for curr_model in self.models:
            self.models[curr_model].eval()

    def models_to_device(self):
        self.model = self.model.to(self.device)
        for curr_model_name in self.models:
            self.models[curr_model_name] = self.models[curr_model_name].to(
                self.device)

    def train(
        self,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
        training_datasets: List[Dataset],
        experiment_directory: pathlib.Path,
        validation_data: Optional[Dataset] = None,
        resume: bool = False,
        initialization: Optional[PathOrString] = None,
        num_workers: int = 6,
    ) -> None:
        self.logger.info("Starting training.")
        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler
        training_data = ConcatDataset(training_datasets)
        self.logger.info(f"Concatenated dataset length: {len(training_data)}.")
        self.logger.info(
            f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}."
        )
        training_sampler = self.build_batch_sampler(
            training_datasets, self.cfg.training.batch_size, "random")
        training_loader = self.build_loader(
            training_data,
            batch_sampler=training_sampler,
            num_workers=num_workers,
        )

        if validation_data:
            validation_loaders = []
            for idx, curr_validation_data in enumerate(validation_data):
                text_dataset_description = curr_validation_data.text_description
                self.logger.info(
                    f"Building dataloader for dataset: {text_dataset_description}."
                )
                curr_batch_sampler = self.build_batch_sampler(
                    curr_validation_data,
                    batch_size=self.cfg.validation.batch_size,
                    sampler_type="sequential",
                    limit_number_of_volumes=None,
                )
                validation_loaders.append((
                    text_dataset_description,
                    self.build_loader(
                        curr_validation_data,
                        batch_sampler=curr_batch_sampler,
                        num_workers=
                        0,  # num_workers, # TODO(jt): This seems to choke the validation.
                    ),
                ))
        else:
            validation_loaders = None

        self.models_to_device()

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        git_hash = direct.utils.git_hash()
        extra_checkpointing = {
            "__author__": git_hash if git_hash else "N/A",
            "__version__": direct.__version__,
            "__mixed_precision__": self.mixed_precision,
        }
        if self.mixed_precision:
            # TODO(jt): Check if on GPU
            self.logger.info(f"Using mixed precision training.")

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            scaler=self._scaler,
            **self.models,
            **extra_checkpointing,
        )

        # Load checkpoint
        start_iter = 0
        checkpoint = {}
        if resume:
            self.logger.info("Attempting to resume...")
            # This changes the model inplace
            checkpoint = self.checkpointer.load(iteration="latest")
            if not checkpoint:
                self.logger.info("No checkpoint found. Starting from scratch.")
            else:
                start_iter = checkpoint["iteration"] + 1
                self.logger.info(f"Starting from iteration: {start_iter}.")

        if start_iter > 0 and initialization:
            self.logger.warning(
                f"Initialization checkpoint set to {initialization},"
                f" but model will resume training from previous checkpoint. Initialization ignored."
            )
        elif initialization:
            self.logger.info(f"Initializing from {initialization}...")
            self.checkpointer.load_from_file(initialization)

        if "__version__" in checkpoint:
            self.logger.info(
                f"DIRECT version of checkpoint: {checkpoint['__version__']}.")
            if checkpoint["__version__"] != direct.__version__:
                self.logger.warning(
                    f"Current DIRECT version {direct.__version__} is different from the one "
                    f"this checkpoint is saved with ({checkpoint['__version__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__author__" in checkpoint:
            self.logger.info(
                f"Git hash of checkpoint: {checkpoint['__author__']}.")
            if checkpoint["__author__"] != direct.utils.git_hash():
                self.logger.warning(
                    f"Current git hash {direct.utils.git_hash()} is different from the one "
                    f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__datetime__" in checkpoint:
            self.logger.info(
                f"Checkpoint created at: {checkpoint['__datetime__']}.")

        if "__mixed_precision__" in checkpoint:
            if (not self.mixed_precision
                ) and checkpoint["__mixed_precision__"]:
                self.logger.warning(
                    f"Mixed precision training is not enabled, yet saved checkpoint requests this"
                    f"Will now enable mixed precision.")
                self.mixed_precision = True
            elif not checkpoint["__mixed_precision__"] and self.mixed_precision:
                self.logger.warning(
                    f"Mixed precision levels of training and loading checkpoint do not match. "
                    f"Requested mixed precision but checkpoint is saved without. "
                    f"This will almost surely lead to performance degradation."
                )

        self.logger.info(f"World size: {communication.get_world_size()}.")
        self.logger.info(f"Device count: {torch.cuda.device_count()}.")
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_local_rank()],
                broadcast_buffers=False,
            )

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / "metrics.json"),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / "tensorboard"),
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(
                training_loader,
                start_iter,
                validation_loaders,
                experiment_directory=experiment_directory,
            )

        self.logger.info("Training completed.")

    @abstractmethod
    def evaluate(self, *args, **kwargs):  # noqa
        pass

    @abstractmethod
    def process_output(self, *args, **kwargs):  # noqa
        # Typically use this to scale data back to the original range.
        pass

    def log_process(self, idx, total):
        if idx % (total // 5) == 0 or total == (idx + 1):
            self.logger.info(f"Progress: {(idx + 1) / total * 100:.2f}%.")

    def log_first_training_example(self, data):
        storage = get_event_storage()
        self.logger.info(
            f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
        )

        # TODO(jt): Cleaner, loop over types of images
        first_sampling_mask = data["sampling_mask"][0][0]
        first_acs_mask = data["acs_mask"][0][0] if "acs_mask" in data else None
        first_target = data["target"][0]
        first_masked_image = data["masked_image"][0]

        if self.ndim == 3:
            first_sampling_mask = first_sampling_mask[0]
            if first_acs_mask:
                first_acs_mask = first_acs_mask[0]

            num_slices = first_target.shape[first_target.names.index("slice")]
            first_target = first_target[num_slices // 2]
            first_masked_image = first_masked_image[num_slices // 2]
        elif self.ndim > 3:
            raise NotImplementedError

        storage.add_image(
            "train/mask", first_sampling_mask[...,
                                              0].rename(None).unsqueeze(0))
        if first_acs_mask:
            storage.add_image(
                "train/acs_mask",
                first_acs_mask[..., 0].rename(None).unsqueeze(0),
            )
        storage.add_image(
            "train/target",
            normalize_image(first_target.rename(None).unsqueeze(0)),
        )
        storage.add_image(
            "train/masked_image",
            normalize_image(
                transforms.modulus_if_complex(first_masked_image).rename(
                    None).unsqueeze(0)),
        )
        self.write_to_logs()

    def write_to_logs(self):
        if self.__writers is not None:
            for writer in self.__writers:
                writer.write()

    def __bind_sigint_signal(self):
        """Bind SIGINT signal to handle preemption or other kill of the process."""
        def raise_process_killed_error(signal_id, _):
            """Raise the ProcessKilledError."""
            self.logger.info(
                f"Received {signal.Signals(signal_id).name}. Shutting down...")
            raise ProcessKilledException(signal_id,
                                         signal.Signals(signal_id).name)

        signal.signal(signalnum=signal.SIGINT,
                      handler=raise_process_killed_error)
Пример #2
0
class Engine(ABC):
    def __init__(self,
                 cfg: BaseConfig,
                 model: nn.Module,
                 device: int,
                 mixed_precision: bool = False):
        self.logger = logging.getLogger(type(self).__name__)

        self.cfg = cfg
        self.model = model
        self.device = device

        self.mixed_precision = mixed_precision
        self.checkpointer = None

        # TODO: This might not be needed, if these objects are changed in-place
        self.__optimizer = None
        self.__lr_scheduler = None
        self.__writers = None
        self.__bind_sigint_signal()

    @abstractmethod
    def build_loss(self) -> Dict:
        pass

    @abstractmethod
    def build_metrics(self) -> Dict:
        pass

    @abstractmethod
    def _do_iteration(self, *args, **kwargs) -> Tuple[torch.Tensor, Dict]:
        pass

    @staticmethod
    def build_loader(dataset: Dataset,
                     sampler: Optional[Sampler] = None,
                     batch_size: int = 1,
                     batch_sampler: Optional[Sampler] = None,
                     num_workers: int = 6,
                     drop_last: bool = False) -> DataLoader:
        # TODO(jt): Custom memory pinning.
        loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=named_collate,
            drop_last=drop_last,
            pin_memory=True,
        )
        return loader

    @staticmethod
    def build_sampler(dataset: Dataset, sampler_type: str,
                      **kwargs) -> Sampler:
        if sampler_type == 'random':
            sampler: Sampler = direct.data.sampler.DistributedSampler(
                len(dataset), shuffle=True, seed=None,
                **kwargs)  # TODO: Set seed
        elif sampler_type == 'sequential':
            sampler: Sampler = direct.data.sampler.DistributedSequentialSampler(
                dataset, **kwargs)
        else:
            raise ValueError(f'Sampler type {sampler_type} not supported.')

        return sampler

    def training_loop(self,
                      data_loader: DataLoader,
                      start_iter: int,
                      validation_data_loader: Optional[DataLoader] = None):
        self.logger.info(f'Local rank: {communication.get_local_rank()}.')
        self.model.train()
        loss_fns = self.build_loss()
        metric_fns = self.build_metrics()
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == 0:
                self.logger.info(
                    f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
                )
                storage.add_image('train/mask', data['sampling_mask'][0, ...,
                                                                      0])
                storage.add_image(
                    'train/target',
                    normalize_image(
                        data['target'][0].rename(None).unsqueeze(0)))
                storage.add_image(
                    'train/masked_image',
                    normalize_image(
                        transforms.modulus_if_complex(
                            data['masked_image'][0]).rename(None).unsqueeze(
                                0)))
                self.write_to_logs()

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f'{e}.')
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(f'Exited with exception: {e}')

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                # TODO: Is this slow? This is a generator, so should be cheap.
                pass
                # parameter_list = self.model.parameters() if not self.mixed_precision else \
                #     amp.master_params(self.__optimizer)
                # if self.cfg.training.gradient_steps > 1:  # type: ignore
                #     for parameter in parameter_list:
                #         if parameter.grad is not None:
                #             # In-place division
                #             parameter.grad.div_(self.cfg.training.gradient_steps)  # type: ignore
                # if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                #     torch.nn.utils.clip_grad_norm_(parameter_list, self.cfg.training.gradient_clipping)  # type: ignore
                #
                # # Gradient norm
                # if self.cfg.training.gradient_debug:  # type: ignore
                #     parameters = list(filter(lambda p: p.grad is not None, parameter_list))
                #     gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt()  # typing: ignore
                #     storage.add_scalar('gradient_norm', gradient_norm)

            self.__optimizer.step()  # type: ignore
            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar('lr',
                               self.__optimizer.param_groups[0]['lr'],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data['target'].rename(None).detach().to(self.device),
                reduction='mean')
            metrics_dict_reduced = communication.reduce_tensor_dict(
                metrics_dict) if metrics_dict else {}
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if validation_data_loader is not None \
                    and iter_idx > 5\
                    and (iter_idx % self.cfg.training.validation_steps == 0 or (iter_idx + 1) == total_iter):
                val_loss_dict = self.evaluate(
                    validation_data_loader,
                    loss_fns,
                    evaluation_round=iter_idx //
                    self.cfg.training.validation_steps - 1)
                self.logger.info(f'Done evaluation at iteration {iter_idx}.')
                storage.add_scalars(**prefix_dict_keys(val_loss_dict, 'val_'),
                                    smoothing_hint=False)
                self.model.train()

            if iter_idx > 5 and\
                    (iter_idx % self.cfg.training.checkpointer.checkpoint_steps == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f'Checkpointing at iteration: {iter_idx}.')
                self.checkpointer.save(iter_idx)

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()

    def train(
            self,
            optimizer: torch.optim.Optimizer,
            lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
            training_data: Dataset,
            experiment_directory: pathlib.Path,
            validation_data: Dataset = None,
            resume: bool = False,
            num_workers: int = 0) -> None:

        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler

        training_sampler = self.build_sampler(training_data, 'random')
        # TODO: Configurable
        training_loader = self.build_loader(
            training_data,
            sampler=training_sampler,
            batch_size=self.cfg.training.batch_size,
            num_workers=num_workers,
            drop_last=True)

        if validation_data:
            validation_sampler = self.build_sampler(
                validation_data, 'sequential', limit_number_of_volumes=None)
            batch_sampler = BatchSampler(validation_sampler,
                                         batch_size=8 *
                                         self.cfg.training.batch_size,
                                         drop_last=False)
            # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation.
            validation_loader = self.build_loader(
                validation_data,
                batch_sampler=batch_sampler,
                num_workers=num_workers,
            )

        self.model = self.model.to(self.device)

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        extra_checkpointing = {}
        if self.mixed_precision > 0:
            opt_level = f'O{self.mixed_precision}'
            self.logger.info(f'Using apex level {opt_level}.')
            self.model, self.__optimizer = amp.initialize(self.model,
                                                          self.__optimizer,
                                                          opt_level=opt_level)
            extra_checkpointing['amp'] = amp
            extra_checkpointing['opt_level'] = opt_level
            git_hash = direct.utils.git_hash()
            extra_checkpointing['__author__'] = git_hash if git_hash else 'N/A'

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            **extra_checkpointing)

        # Load checkpoint
        start_iter = 0
        if resume:
            self.logger.info('Attempting to resume...')
            # This changes the model inplace
            checkpoint = self.checkpointer.load(
                iteration='latest',
                checkpointable_objects=['amp']
                if self.mixed_precision > 0 else [])
            if not checkpoint:
                self.logger.info('No checkpoint found. Starting from scratch.')
            else:
                start_iter = checkpoint['iteration'] + 1
                self.logger.info(f'Starting from iteration: {start_iter}.')

            if '__author__' in checkpoint:
                self.logger.info(
                    f"Git hash of checkpoint: {checkpoint['__author__']}")
                if checkpoint['__author__'] != direct.utils.git_hash():
                    self.logger.warning(
                        f"Current git hash {direct.utils.git_hash()} is different from the one "
                        f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                        f"but beware that this can be a source of confusion.")

            if '__datetime__' in checkpoint:
                self.logger.info(
                    f"Checkpoint created at: {checkpoint['__datetime__']}")
            if 'opt_level' in checkpoint:
                if checkpoint['opt_level'] != opt_level:
                    self.logger.warning(
                        f"Mixed precision opt-levels do not match. "
                        f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. "
                        f"This will almost surely lead to performance degradation."
                    )

        self.logger.info(f'World size: {communication.get_world_size()}.')
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_rank()],
                broadcast_buffers=False)

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / 'metrics.json'),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / 'tensorboard')
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(training_loader, start_iter, validation_loader)

        self.logger.info('Training completed.')

    @abstractmethod
    def evaluate(self, *args, **kwargs):  # noqa
        pass

    @abstractmethod
    def process_output(self, *args, **kwargs):  # noqa
        # Typically use this to scale data back to the original range.
        pass

    def log_process(self, idx, total):
        if total % (idx + 1) == 5 or total == (idx + 1):
            self.logger.info(f'Progress: {(idx + 1) / total * 100:.2f}%.')

    def write_to_logs(self):
        if self.__writers is not None:
            for writer in self.__writers:
                writer.write()

    def __bind_sigint_signal(self):
        """Bind SIGINT signal to handle preemption or other kill of the process."""
        def raise_process_killed_error(signal_id, _):
            """Raise the ProcessKilledError."""
            self.logger.info(
                f'Received {signal.Signals(signal_id).name}. Shutting down...')
            raise ProcessKilledException(signal_id,
                                         signal.Signals(signal_id).name)

        signal.signal(signalnum=signal.SIGINT,
                      handler=raise_process_killed_error)
Пример #3
0
class Engine(ABC, DataDimensionality):
    def __init__(
        self,
        cfg: BaseConfig,
        model: nn.Module,
        device: int,
        forward_operator: Optional[Callable] = None,
        backward_operator: Optional[Callable] = None,
        mixed_precision: bool = False,
        **models: Dict[str, nn.Module],
    ):
        self.logger = logging.getLogger(type(self).__name__)

        self.cfg = cfg
        self.model = model
        self.models = models
        self.device = device

        # Operators can be useful in some operations
        self.forward_operator = forward_operator
        self.backward_operator = backward_operator

        self.mixed_precision = mixed_precision
        self.checkpointer = None

        self.__optimizer = None
        self.__lr_scheduler = None
        self._scaler = GradScaler(enabled=self.mixed_precision)
        self.__writers = None
        self.__bind_sigint_signal()

        DataDimensionality.__init__(self)

    @abstractmethod
    def build_loss(self) -> Dict:
        pass

    @staticmethod
    def _build_function_class(functions_list, root_module, postfix) -> Dict:
        if not functions_list:
            return {}

        # _postfix is added as only keys containing loss, metric or reg are logged.
        functions_dict = {
            curr_func.split("(")[0]
            + f"_{postfix}": str_to_class(root_module, curr_func)
            for curr_func in functions_list
        }
        return functions_dict

    def build_metrics(self, metrics_list) -> Dict:
        return self._build_function_class(metrics_list, "direct.functionals", "metric")

    def build_regularizers(self, regularizers_list) -> Dict:
        return self._build_function_class(
            regularizers_list, "direct.functionals", "reg"
        )

    @abstractmethod
    def _do_iteration(self, *args, **kwargs) -> namedtuple:
        """
        This is a placeholder for the iteration function. This needs to perform the backward pass.
        If using mixed-precision you need to implement `autocast` as well in this function.
        It is recommended you raise an error if `self.mixed_precision` is true but mixed precision is not available.
        """
        pass

    @torch.no_grad()
    def predict(
        self,
        dataset: Dataset,
        experiment_directory: pathlib.Path,
        checkpoint_number: int = -1,
        num_workers: int = 6,
    ) -> np.ndarray:
        self.logger.info(f"Predicting...")
        torch.cuda.empty_cache()
        self.ndim = dataset.ndim
        self.logger.info(f"Data dimensionality: {self.ndim}.")

        self.checkpointer = Checkpointer(
            self.model, experiment_directory, save_to_disk=False, **self.models
        )

        # Do not load again if we already have loaded the checkpoint.
        if self.checkpointer.checkpoint_loaded is not checkpoint_number:
            self.checkpointer.load(
                iteration=checkpoint_number, checkpointable_objects=None
            )

        batch_sampler = self.build_batch_sampler(
            dataset,
            batch_size=self.cfg.validation.batch_size,
            sampler_type="sequential",
            limit_number_of_volumes=None,
        )
        # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation.
        data_loader = self.build_loader(
            dataset, batch_sampler=batch_sampler, num_workers=num_workers
        )
        loss, output = self.evaluate(
            data_loader, loss_fns=None, crop=None, is_validation_process=False
        )

        return output

    @staticmethod
    def build_loader(
        dataset: Dataset,
        batch_sampler: Optional[Sampler] = None,
        num_workers: int = 6,
    ) -> DataLoader:
        # TODO(jt): Custom memory pinning.
        loader = DataLoader(
            dataset=dataset,
            sampler=None,
            batch_size=1,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=named_collate,
            drop_last=False,
            shuffle=False,
            pin_memory=False,  # This can do strange things, and needs a custom implementation.
        )

        return loader

    def build_validation_loaders(self, validation_data, num_workers=0):
        for idx, curr_validation_data in enumerate(validation_data):
            text_dataset_description = curr_validation_data.text_description
            self.logger.info(
                f"Building dataloader for dataset: {text_dataset_description}."
            )
            curr_batch_sampler = self.build_batch_sampler(
                curr_validation_data,
                batch_size=self.cfg.validation.batch_size,
                sampler_type="sequential",
                limit_number_of_volumes=None,
            )
            yield (
                text_dataset_description,
                self.build_loader(
                    curr_validation_data,
                    batch_sampler=curr_batch_sampler,
                    num_workers=num_workers,
                ),
            )

    @staticmethod
    def build_batch_sampler(
        dataset: Union[Dataset, List[Dataset]],
        batch_size: int,
        sampler_type: str,
        **kwargs,
    ) -> Sampler:
        if sampler_type == "random":
            if not isinstance(dataset, List) or any(
                [not isinstance(_, Dataset) for _ in dataset]
            ):
                raise ValueError(
                    f"Random sampler requires a list of datasets as input."
                )
            batch_sampler = ConcatDatasetBatchSampler(
                datasets=dataset, batch_size=batch_size
            )
        elif sampler_type == "sequential":
            sampler = direct.data.samplers.DistributedSequentialSampler(
                dataset, **kwargs
            )
            batch_sampler = direct.data.samplers.BatchVolumeSampler(
                sampler,
                batch_size=batch_size,
            )
        else:
            raise ValueError(f"Sampler type {sampler_type} not supported.")

        return batch_sampler

    def training_loop(
        self,
        training_datasets: List,  # TODO(jt): Improve typing
        start_iter: int,
        validation_datasets: Optional[List] = None,
        experiment_directory: Optional[pathlib.Path] = None,
        num_workers: int = 6,
        start_with_validation: bool = False,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        regularizer_fns = self.build_regularizers(self.cfg.training.regularizers)
        storage = get_event_storage()

        self.ndim = training_datasets[0].ndim
        self.logger.info(f"Data dimensionality: {self.ndim}.")

        training_data = ConcatDataset(training_datasets)

        self.logger.info(f"Concatenated dataset length: {len(training_data)}.")
        self.logger.info(
            f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}."
        )

        training_sampler = self.build_batch_sampler(
            training_datasets, self.cfg.training.batch_size, "random"
        )
        data_loader = self.build_loader(
            training_data,
            batch_sampler=training_sampler,
            num_workers=num_workers,
        )

        # Convenient shorthand
        validation_func = functools.partial(
            self.validation_loop,
            validation_datasets,
            loss_fns,
            experiment_directory,
            num_workers=num_workers,
        )

        total_iter = self.cfg.training.num_iterations  # noqa
        fail_counter = 0
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == 0:
                self.log_first_training_example_and_model(data)

            if start_with_validation and iter_idx == start_iter:
                self.logger.info(f"Starting with validation at iteration: {iter_idx}.")
                validation_func(iter_idx)
            try:
                iteration_output = self._do_iteration(
                    data, loss_fns, regularizer_fns=regularizer_fns
                )
                output = iteration_output.output_image
                loss_dict = iteration_output.data_dict
            except (ProcessKilledException, TrainingException) as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpoint_and_write_to_logs(iter_idx)
                sys.exit(-1)
            except RuntimeError as e:
                # Maybe string can change
                if "out of memory" in str(e):
                    if fail_counter == 3:
                        self.checkpoint_and_write_to_logs(iter_idx)
                        raise TrainingException(
                            f"OOM, could not recover after 3 tries: {e}."
                        )
                    fail_counter += 1
                    self.logger.info(
                        f"OOM Error: {e}. Skipping batch. Retry {fail_counter}/3."
                    )
                    self.__optimizer.zero_grad()
                    gc.collect()
                    torch.cuda.empty_cache()
                    continue
                self.checkpoint_and_write_to_logs(iter_idx)
                self.logger.info(f"Cannot recover from exception {e}. Exiting.")
                raise RuntimeError(e)

            if fail_counter > 0:
                self.logger.info(f"Recovered from OOM, skipped batch.")
            fail_counter = 0
            # Gradient accumulation
            if (iter_idx + 1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.cfg.training.gradient_clipping
                    )

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once."
                    )
                    parameters = list(
                        filter(lambda p: p.grad is not None, self.model.parameters())
                    )
                    gradient_norm = sum(
                        [parameter.grad.data ** 2 for parameter in parameters]
                    ).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar(
                "lr", self.__optimizer.param_groups[0]["lr"], smoothing_hint=False
            )

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                T.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict) if metrics_dict else {}
            )
            storage.add_scalars(
                loss=loss_reduced, **loss_dict_reduced, **metrics_dict_reduced
            )

            self.checkpoint_model_at_interval(iter_idx, total_iter)
            self.write_to_logs_at_interval(iter_idx, total_iter)
            self.validate_model_at_interval(validation_func, iter_idx, total_iter)

            storage.step()

    def validate_model_at_interval(self, func, iter_idx, total_iter):
        if iter_idx >= 5:  # No validation or anything needed
            if (
                iter_idx % self.cfg.training.validation_steps == 0
                or (iter_idx + 1) == total_iter
            ):
                func(iter_idx)
                self.write_to_logs()

    def checkpoint_model_at_interval(self, iter_idx, total_iter):
        if iter_idx >= 5:
            if (
                iter_idx % self.cfg.training.checkpointer.checkpoint_steps == 0
                or (iter_idx + 1) == total_iter
            ):
                self.logger.info(f"Checkpointing at iteration {iter_idx}.")
                self.checkpointer.save(iter_idx)

    def write_to_logs_at_interval(self, iter_idx, total_iter):
        if iter_idx >= 5:
            # Log every 20 iterations, or at a validation step or at the end of training.
            if (
                iter_idx % 20 == 0
                or iter_idx % self.cfg.training.validation_steps == 0
                or (iter_idx + 1) == total_iter
            ):
                self.write_to_logs()

    def checkpoint_and_write_to_logs(self, iter_idx):
        if iter_idx >= 5:
            self.checkpointer.save(iter_idx)  # Save checkpoint at kill. # noqa
        self.write_to_logs()

    def validation_loop(
        self,
        validation_datasets,
        loss_fns,
        experiment_directory,
        iter_idx,
        num_workers: int = 6,
    ):
        if not validation_datasets:
            return

        storage = get_event_storage()

        data_loaders = self.build_validation_loaders(
            validation_data=validation_datasets,
            num_workers=num_workers,
        )
        for curr_dataset_name, curr_data_loader in data_loaders:
            self.logger.info(f"Evaluating: {curr_dataset_name}...")
            (
                curr_loss_dict,
                curr_metrics_per_case,
                visualize_slices,
                visualize_target,
            ) = self.evaluate(
                curr_data_loader,
                loss_fns,
                is_validation_process=True,
            )

            if experiment_directory:
                json_output_fn = (
                    experiment_directory
                    / f"metrics_val_{curr_dataset_name}_{iter_idx}.json"
                )
                json_output_fn.parent.mkdir(
                    exist_ok=True, parents=True
                )  # A / in the filename can create a folder
                if communication.is_main_process():
                    write_json(
                        json_output_fn,
                        curr_metrics_per_case,
                    )
                self.logger.info(f"Wrote per image logs to: {json_output_fn}.")

            # Metric dict still needs to be reduced as it gives values *per* data
            curr_metric_dict = reduce_list_of_dicts(
                list(curr_metrics_per_case.values()), mode="average"
            )

            key_prefix = (
                "val/" if not curr_dataset_name else f"val/{curr_dataset_name}/"
            )
            loss_reduced = sum(curr_loss_dict.values())
            storage.add_scalars(
                **{key_prefix + "loss": loss_reduced},
                **{
                    **prefix_dict_keys(curr_metric_dict, key_prefix),
                    **prefix_dict_keys(curr_loss_dict, key_prefix),
                },
                smoothing_hint=False,
            )
            visualize_slices = self.process_slices_for_visualization(
                visualize_slices, visualize_target
            )
            storage.add_image(f"{key_prefix}prediction", visualize_slices)

            if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                visualize_target = make_grid(
                    crop_to_largest(visualize_target, pad_value=0),
                    nrow=self.cfg.logging.tensorboard.num_images,
                    scale_each=True,
                )
                storage.add_image(f"{key_prefix}target", visualize_target)

            self.logger.info(
                f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
            )
        self.model.train()

    def process_slices_for_visualization(self, visualize_slices, visualize_target):
        # Log slices.
        # Compute the difference as well, and normalize for visualization
        difference_slices = [a - b for a, b in zip(visualize_slices, visualize_target)]
        # Normalize slices
        difference_slices = [(d / np.abs(d)) * 0.5 + 0.5 for d in difference_slices]
        visualize_slices = [normalize_image(image) for image in visualize_slices]

        # Visualize slices, and crop to the largest volume
        visualize_slices = make_grid(
            crop_to_largest(visualize_slices + difference_slices, pad_value=0),
            nrow=self.cfg.logging.tensorboard.num_images,
            scale_each=True,
        )
        return visualize_slices

    def models_training_mode(self):
        self.model.train()
        for curr_model in self.models:
            self.models[curr_model].train()

    def models_validation_mode(self):
        self.model.eval()
        for curr_model in self.models:
            self.models[curr_model].eval()

    def models_to_device(self):
        self.model = self.model.to(self.device)
        for curr_model_name in self.models:
            self.models[curr_model_name] = self.models[curr_model_name].to(self.device)

    def train(
        self,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
        training_datasets: List[Dataset],
        experiment_directory: pathlib.Path,
        validation_datasets: Optional[Dataset] = None,
        resume: bool = False,
        start_with_validation: bool = False,
        initialization: Optional[PathOrString] = None,
        num_workers: int = 6,
    ) -> None:
        self.logger.info("Starting training.")
        # Can consider not to make this a member of self, but that requires that optimizer is passed to
        # training_loop()
        self.__optimizer = optimizer
        self.__lr_scheduler = lr_scheduler

        self.models_to_device()

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        git_hash = direct.utils.git_hash()
        extra_checkpointing = {
            "__author__": git_hash if git_hash else "N/A",
            "__version__": direct.__version__,
            "__mixed_precision__": self.mixed_precision,
        }
        if self.mixed_precision:
            # TODO(jt): Check if on GPU
            self.logger.info(f"Using mixed precision training.")

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            scaler=self._scaler,
            **self.models,
            **extra_checkpointing,
        )

        # Load checkpoint
        start_iter = 0
        checkpoint = {}

        if resume:
            self.logger.info("Attempting to resume...")
            # This changes the model inplace
            checkpoint = self.checkpointer.load(iteration="latest")
            if not checkpoint:
                self.logger.info("No checkpoint found. Starting from scratch.")
            else:
                start_iter = checkpoint["iteration"] + 1
                self.logger.info(f"Starting from iteration: {start_iter}.")

        if start_iter > 0 and initialization:
            self.logger.warning(
                f"Initialization checkpoint set to {initialization},"
                f" but model will resume training from previous checkpoint. Initialization ignored."
            )
        elif initialization:
            self.logger.info(f"Initializing from {initialization}...")
            self.checkpointer.load_models_from_file(initialization)
            start_with_validation = True
            self.logger.info(f"Setting start_with_validation to True.")

        if "__version__" in checkpoint:
            self.logger.info(
                f"DIRECT version of checkpoint: {checkpoint['__version__']}."
            )
            if checkpoint["__version__"] != direct.__version__:
                self.logger.warning(
                    f"Current DIRECT version {direct.__version__} is different from the one "
                    f"this checkpoint is saved with: {checkpoint['__version__']}. This can be fine, "
                    f"but beware that this can be a source of confusion."
                )

        if "__author__" in checkpoint:
            self.logger.info(f"Git hash of checkpoint: {checkpoint['__author__']}.")
            if checkpoint["__author__"] != direct.utils.git_hash():
                self.logger.warning(
                    f"Current git hash {direct.utils.git_hash()} is different from the one "
                    f"this checkpoint is saved with: {checkpoint['__author__']}. This can be fine, "
                    f"but beware that this can be a source of confusion."
                )

        if "__datetime__" in checkpoint:
            self.logger.info(f"Checkpoint created at: {checkpoint['__datetime__']}.")

        if "__mixed_precision__" in checkpoint:
            if (not self.mixed_precision) and checkpoint["__mixed_precision__"]:
                self.logger.warning(
                    f"Mixed precision training is not enabled, yet saved checkpoint requests this"
                    f"Will now enable mixed precision."
                )
                self.mixed_precision = True
            elif not checkpoint["__mixed_precision__"] and self.mixed_precision:
                self.logger.warning(
                    f"Mixed precision levels of training and loading checkpoint do not match. "
                    f"Requested mixed precision but checkpoint is saved without. "
                    f"This will almost surely lead to performance degradation."
                )

        if start_with_validation:
            self.logger.info(f"Requested to start with validation.")

        self.logger.info(f"World size: {communication.get_world_size()}.")
        self.logger.info(f"Device count: {torch.cuda.device_count()}.")
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_local_rank()],
                broadcast_buffers=False,
            )

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size() == 1:
            self.model = DataParallel(self.model)

        self.__writers = (
            [
                JSONWriter(experiment_directory / "metrics.json"),
                CommonMetricPrinter(self.cfg.training.num_iterations),
                TensorboardWriter(experiment_directory / "tensorboard"),
            ]
            if communication.is_main_process()
            else []
        )

        with EventStorage(start_iter):
            self.training_loop(
                training_datasets,
                start_iter,
                validation_datasets,
                experiment_directory=experiment_directory,
                num_workers=num_workers,
                start_with_validation=start_with_validation,
            )

        self.logger.info("Training completed.")

    @abstractmethod
    def evaluate(self, *args, **kwargs):  # noqa
        pass

    @staticmethod
    def view_as_complex(data):
        """
        View data as complex named tensor.

        Parameters
        ----------
        data : torch.Tensor
            Named tensor with non-complex dtype and final axis named complex and shape 2.

        Returns
        -------
        torch.Tensor: Complex-valued tensor `data`.
        """
        names = data.names
        if not names[-1] == "complex":
            raise ValueError(
                f"Not a complex tensor. Complex axis needs to be last and have name `complex`."
                f" Got {data.names}"
            )

        return torch.view_as_complex(data.rename(None)).refine_names(*names[:-1])

    @staticmethod
    def view_as_real(data):
        """
        View complex named tensor data as real tensor.

        Parameters
        ----------
        data : torch.Tensor
            Named tensor with complex dtype

        Returns
        -------
        torch.Tensor: Real-valued named tensor `data`, where last axis is of shape 2 denoting the complex axis.
        """

        if not data.is_complex():
            raise ValueError(f"Not a complex tensor. Got {data.dtype}.")

        names = list(data.names) + ["complex"]
        return torch.view_as_real(data.rename(None)).refine_names(*names)

    @abstractmethod
    def process_output(self, *args, **kwargs):  # noqa
        # Typically use this to scale data back to the original range.
        pass

    def log_process(self, idx, total):
        if idx % (total // 10) == 0 or total == (idx + 1):
            self.logger.info(f"Progress: {(idx + 1) / total * 100:.2f}%.")

    def log_first_training_example_and_model(self, data):
        storage = get_event_storage()
        self.logger.info(
            f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
        )

        # TODO(jt): Cleaner, loop over types of images
        first_sampling_mask = data["sampling_mask"][0][0]
        first_target = data["target"][0]

        if self.ndim == 3:
            first_sampling_mask = first_sampling_mask[0]
            num_slices = first_target.shape[first_target.names.index("slice")]
            first_target = first_target[num_slices // 2]
        elif self.ndim > 3:
            raise NotImplementedError

        storage.add_image(
            "train/mask", first_sampling_mask[..., 0].rename(None).unsqueeze(0)
        )
        storage.add_image(
            "train/target",
            normalize_image(first_target.rename(None).unsqueeze(0)),
        )

        if "initial_image" in data:
            storage.add_image(
                "train/initial_image",
                normalize_image(
                    T.modulus(data["initial_image"][0]).rename(None).unsqueeze(0)
                ),
            )

        # TODO: Add graph

        self.write_to_logs()

    def write_to_logs(self):
        if self.__writers is not None:
            for writer in self.__writers:
                writer.write()

    def __bind_sigint_signal(self):
        """Bind SIGINT signal to handle preemption or other kill of the process."""

        def raise_process_killed_error(signal_id, _):
            """Raise the ProcessKilledError."""
            self.logger.info(
                f"Received {signal.Signals(signal_id).name}. Shutting down..."
            )
            raise ProcessKilledException(signal_id, signal.Signals(signal_id).name)

        signal.signal(signalnum=signal.SIGINT, handler=raise_process_killed_error)