def __init__(self, config: DeepLearningConfig, *args: Any,
              **kwargs: Any) -> None:
     super().__init__(*args, **kwargs)
     self.outputs_folder = config.outputs_folder
     self.checkpoint_folder = config.checkpoint_folder
     self.model: DeviceAwareModule = DeviceAwareModule()
     # These two will be set later in set_optimizer_and_scheduler.
     # The ddp_spawn accelerator only works if the model configuration object is
     # not stored in here. Hence, need to do operations that require a full config
     # in a way that does not require storing the config.
     self.optimizer: Optional[Optimizer] = None
     self.l_rate_scheduler: Optional[_LRScheduler] = None
     self.cross_validation_split_index = config.cross_validation_split_index
     self.effective_random_seed = config.get_effective_random_seed()
     # This should be re-assigned on the outside, to a logger that is hooked up with the Trainer object.
     self.storing_logger = StoringLogger()
     # This will be initialized correctly in epoch_start
     self.random_state: Optional[RandomStateSnapshot] = None
     # training loggers
     self.train_metrics_folder = self.outputs_folder / ModelExecutionMode.TRAIN.value
     self.val_metrics_folder = self.outputs_folder / ModelExecutionMode.VAL.value
     fixed_logger_columns = {
         LoggingColumns.CrossValidationSplitIndex.value:
         config.cross_validation_split_index
     }
     self.train_epoch_metrics_logger = DataframeLogger(
         self.train_metrics_folder / EPOCH_METRICS_FILE_NAME,
         fixed_columns=fixed_logger_columns)
     self.val_epoch_metrics_logger = DataframeLogger(
         self.val_metrics_folder / EPOCH_METRICS_FILE_NAME,
         fixed_columns=fixed_logger_columns)
     # Stores information the checkpoint that created this model, if any.
     self.checkpoint_loading_message = ""
class InnerEyeLightning(LightningModule):
    """
    The base class for all InnerEye models for training in PyTorch Lightning. The base class handles all shared
    operations like choosing the optimizer and learning rate schedule, keeping track of IO performance (loading times),
    and IO to files.
    """
    def __init__(self, config: DeepLearningConfig, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.outputs_folder = config.outputs_folder
        self.model: DeviceAwareModule = DeviceAwareModule()
        # These two will be set later in set_optimizer_and_scheduler.
        # The ddp_spawn accelerator only works if the model configuration object is
        # not stored in here. Hence, need to do operations that require a full config
        # in a way that does not require storing the config.
        self.optimizer: Optional[Optimizer] = None
        self.l_rate_scheduler: Optional[_LRScheduler] = None
        self.cross_validation_split_index = config.cross_validation_split_index
        self.effective_random_seed = config.get_effective_random_seed()
        # Timers for monitoring data loading time
        self.train_timers = EpochTimers()
        self.val_timers = EpochTimers()
        # This should be re-assigned on the outside, to a logger that is hooked up with the Trainer object.
        self.storing_logger = StoringLogger()
        # This will be initialized correctly in epoch_start
        self.random_state: Optional[RandomStateSnapshot] = None
        # training loggers
        self.train_metrics_folder = self.outputs_folder / ModelExecutionMode.TRAIN.value
        self.val_metrics_folder = self.outputs_folder / ModelExecutionMode.VAL.value
        fixed_logger_columns = {
            LoggingColumns.CrossValidationSplitIndex.value:
            config.cross_validation_split_index
        }
        self.train_epoch_metrics_logger = DataframeLogger(
            self.train_metrics_folder / EPOCH_METRICS_FILE_NAME,
            fixed_columns=fixed_logger_columns)
        self.val_epoch_metrics_logger = DataframeLogger(
            self.val_metrics_folder / EPOCH_METRICS_FILE_NAME,
            fixed_columns=fixed_logger_columns)
        # Fields to store diagnostics for unit testing
        self.train_diagnostics: List[Any] = []
        self.val_diagnostics: List[Any] = []
        # Stores information the checkpoint that created this model, if any.
        self.checkpoint_loading_message = ""

    def set_optimizer_and_scheduler(self, config: DeepLearningConfig) -> None:
        self.optimizer = model_util.create_optimizer(config,
                                                     self.model.parameters())
        self.l_rate_scheduler = SchedulerWithWarmUp(config, self.optimizer)

    def configure_optimizers(
            self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
        return [self.optimizer], [self.l_rate_scheduler]  # type: ignore

    def close_all_loggers(self) -> None:
        """
        Flushes all logger objects that the present object holds.
        """
        self.train_epoch_metrics_logger.flush()
        self.val_epoch_metrics_logger.flush()

    def on_train_epoch_start(self) -> None:
        self.train_timers.reset()

    def training_epoch_end(self, outputs: List[Any]) -> None:
        self.training_or_validation_epoch_end(is_training=True)

    def on_validation_epoch_start(self) -> None:
        """
        Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure
        that any randomization when loading validation data is consistent during training. In particular, this ensures
        that drawing random patches for segmentation model training is giving a validation set that does not fluctuate.
        """
        self.val_timers.reset()
        # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training
        # is done for this epoch, even though the on_training_epoch hook has not yet been called.
        self.train_timers.epoch_end()
        # Store the random number generator state, so that the next training epoch starts from here.
        self.random_state = RandomStateSnapshot.snapshot_random_state()
        # reset the random state for validation, so that we get consistent behaviour when drawing random patches
        # when validating segmentation models.
        seed = self.effective_random_seed
        set_random_seed(seed, "Validation")

    def on_validation_epoch_end(self) -> None:
        self.val_timers.epoch_end()

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        """
        Resets the random number generator state to what it was before the current validation epoch started.
        :param outputs: The list of outputs from the individual validation minibatches.
        """
        # reset the random state for training, so that we get continue from where we were before the validation step.
        assert self.random_state is not None
        self.random_state.restore_random_state()
        self.training_or_validation_epoch_end(is_training=False)

    @rank_zero_only
    def on_epoch_end(self) -> None:
        """
        This hook is called once per epoch, before on_train_epoch_end. Use it to write out all the metrics
        that have been accumulated in the StoringLogger in the previous epoch.
        """
        self.read_epoch_results_from_logger_and_store(
            epoch=self.current_epoch - 1)

    @rank_zero_only
    def on_train_end(self) -> None:
        """
        This hook is called at the very end of training. Use that to write the very last set of training and
        validation metrics from the StoringLogger to disk.
        """
        self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch)

    def read_epoch_results_from_logger_and_store(self, epoch: int) -> None:
        """
        Reads the metrics for the previous epoch from the StoringLogger, and writes them to disk, broken down by
        Training and Validation metrics.
        """
        if epoch >= 0:
            if epoch in self.storing_logger.results:
                for is_training, prefix in [(True, TRAIN_PREFIX),
                                            (False, VALIDATION_PREFIX)]:
                    metrics = self.storing_logger.extract_by_prefix(
                        epoch, prefix)
                    self.store_epoch_results(metrics, epoch, is_training)
            else:
                print(f"Skipping, no results for {epoch}")

    @rank_zero_only
    def training_or_validation_epoch_end(self, is_training: bool) -> None:
        """
        This is a hook called at the end of a training or validation epoch. In here, we can still write
        metrics to a logger.
        :param is_training: If True, this is called at the end of a training epoch. If False, this is at the
        end of a validation epoch.
        """
        if not is_training:
            # In validation epochs, mark that it has been completed. Training epochs are marked completed already
            # at the start of the validation epoch.
            self.val_timers.epoch_end()
            # Write all IO stats here, so that the order on the console is Train start, train end, val start, val end.
            self.write_and_log_epoch_time(is_training=True)
            self.write_and_log_epoch_time(is_training=False)

    def write_and_log_epoch_time(self, is_training: bool) -> None:
        """
        Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the
        time per epoch.
        :param is_training: If True, show and log the data for the training epoch. If False, use the data for the
        validation epoch.
        """
        timers = self.get_timers(is_training=is_training)
        epoch_time_seconds = timers.total_epoch_time
        status = "training" if is_training else "validation"
        logging.info(
            f"Epoch {self.current_epoch} {status} took {epoch_time_seconds:0.2f}sec, of which waiting for "
            f"data took {timers.total_load_time:0.2f} sec total.")
        if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch:
            logging.warning(
                "The dataloaders were not fast enough to always supply the next batch in less than "
                f"{MAX_ITEM_LOAD_TIME_SEC}sec.")
            logging.warning(
                f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load "
                f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec."
            )
        # This metric is only written at rank zero, and hence must no be synchronized across workers. If attempted,
        # training will get stuck.
        self.log_on_epoch(MetricType.SECONDS_PER_EPOCH,
                          epoch_time_seconds,
                          is_training=is_training,
                          sync_dist_override=False)

    def log_on_epoch(self,
                     name: Union[MetricType, str],
                     value: Any,
                     is_training: bool,
                     reduce_fx: Callable = torch.mean,
                     sync_dist_override: Optional[bool] = None,
                     sync_dist_op: Any = "mean") -> None:
        """
        Logs a metrics to Pytorch Lightning with the on_epoch flag set. The metric will get a prefix indicating
        if it is a training or a validation metric. A custom reducer function can be provided.
        The method also ensures that the correct synchronization across nodes is used. If the value to log is a
        floating point, it is converted to a Tensor on the current device to enable synchronization.
        :param sync_dist_override: If not None, use this value for the sync_dist argument to self.log. If None,
        set it automatically depending on the use of DDP.
        :param name: The name of the metric to log
        :param value: The value of the metric. This can be a tensor, floating point value, or a Metric class.
        :param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix.
        :param reduce_fx: The reduce function to apply after synchronizing the tensors across GPUs.
        :param sync_dist_op: The reduce operation to use when synchronizing the tensors across GPUs. This must be
        a value recognized by sync_ddp: Either 'None' to use 'sum' as aggregate, or 'mean' or 'avg'
        """
        metric_name = name if isinstance(name, str) else name.value
        if isinstance(value, numbers.Number):
            value = torch.tensor(value, dtype=torch.float, device=self.device)
        prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
        sync_dist = self.use_ddp if sync_dist_override is None else sync_dist_override
        self.log(prefix + metric_name,
                 value,
                 sync_dist=sync_dist,
                 on_step=False,
                 on_epoch=True,
                 reduce_fx=reduce_fx,
                 sync_dist_op=sync_dist_op)

    def store_epoch_results(self, metrics: DictStrFloat, epoch: int,
                            is_training: bool) -> None:
        """
        Stores a set of metrics (key/value pairs) to a file logger. That file logger is either one that only holds
        training or only holds validation metrics.
        :param metrics: A dictionary with all the metrics to write, as key/value pairs.
        :param epoch: The epoch to which the metrics belong.
        :param is_training: If true, write the metrics to the logger for training metrics, if False, write to the logger
        for validation metrics.
        """
        file_logger = self.train_epoch_metrics_logger if is_training else self.val_epoch_metrics_logger
        store_epoch_metrics(metrics, epoch, file_logger=file_logger)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        """
        This hook is called when loading a model from a checkpoint. It just prints out diagnostics about which epoch
        created the present checkpoint.
        :param checkpoint: The checkpoint dictionary loaded from disk.
        """
        keys = ['epoch', 'global_step']
        present_keys = [
            f"{key} = {checkpoint[key]}" for key in keys if key in checkpoint
        ]
        if present_keys:
            self.checkpoint_loading_message = f"Loading checkpoint that was created at ({', '.join(present_keys)})"
            logging.info(self.checkpoint_loading_message)

    def on_train_batch_start(self, batch: Any, batch_idx: int,
                             dataloader_idx: int) -> None:
        self.batch_start(batch_idx=batch_idx, is_training=True)

    def on_validation_batch_start(self, batch: Any, batch_idx: int,
                                  dataloader_idx: int) -> None:
        self.batch_start(batch_idx=batch_idx, is_training=False)

    def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int,
                           dataloader_idx: int) -> None:
        self.batch_end(is_training=True)

    def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int,
                                dataloader_idx: int) -> None:
        self.batch_end(is_training=False)

    def training_step(
            self,  # type: ignore
            sample: Dict[str, Any],
            batch_index: int) -> Any:
        return self.training_or_validation_step(sample,
                                                batch_index,
                                                is_training=True)

    def validation_step(
            self,  # type: ignore
            sample: Dict[str, Any],
            batch_index: int) -> Any:
        return self.training_or_validation_step(sample,
                                                batch_index,
                                                is_training=False)

    def training_or_validation_step(self, sample: Dict[str,
                                                       Any], batch_index: int,
                                    is_training: bool) -> Any:
        """
        This is the shared method that handles the training (when `is_training==True`) and validation steps
        (when `is_training==False`)
        :param sample: The minibatch of data that should be processed.
        :param batch_index: The index of the current minibatch.
        :param is_training: If true, this has been called from `training_step`, otherwise it has been called from
        `validation_step`.
        """
        raise NotImplementedError(
            "This method must be overwritten in a derived class.")

    @rank_zero_only
    def batch_start(self, batch_idx: int, is_training: bool) -> None:
        """
        Shared code to keep track of IO-related metrics when loading a minibatch. This is only done on rank zero.
        :param batch_idx: The index of the current minibatch.
        :param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from
        `on_validation_batch_start`.
        :return:
        """
        timers = self.get_timers(is_training=is_training)
        message_prefix = f"Epoch {self.current_epoch} {'training' if is_training else 'validation'}"
        timers.batch_start(batch_index=batch_idx,
                           epoch=self.current_epoch,
                           message_prefix=message_prefix)

    @rank_zero_only
    def batch_end(self, is_training: bool) -> None:
        """
        Shared code to keep track of IO-related metrics when loading a minibatch.
        :param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from
        `on_validation_batch_end`.
        """
        timers = self.get_timers(is_training=is_training)
        batch_time = timers.batch_end()
        # This metric is only written at rank 0, and hence must not be synchronized. Trying to synchronize will
        # block training.
        self.log_on_epoch(MetricType.SECONDS_PER_BATCH,
                          batch_time,
                          is_training=is_training,
                          sync_dist_override=False)

    def get_timers(self, is_training: bool) -> EpochTimers:
        """
        Gets the object that holds all IO-related metrics and timers, for either the validation or the training epoch.
        """
        return self.train_timers if is_training else self.val_timers

    def reset_timers(self) -> None:
        """
        Resets all timers and counters for IO-related metrics, for both the validation and the training epoch.
        """
        self.train_timers.reset()
        self.val_timers.reset()

    def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
        """
        Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss".
        If this comes from a training step, then also log the learning rate.
        :param loss: The loss value that should be logged.
        :param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
        be called "val/Loss"
        """
        self.log_on_epoch(MetricType.LOSS, loss, is_training)
        if is_training:
            learning_rate = self.trainer.lr_schedulers[0][
                'scheduler'].get_last_lr()[0]
            self.log_on_epoch(MetricType.LEARNING_RATE, learning_rate,
                              is_training)
예제 #3
0
def test_storing_logger() -> None:
    """
    Test if the StoringLogger can correctly handle multiple metrics of the same name logged per epoch.
    """
    logger = StoringLogger()
    key1 = "key"
    key2 = "key2"
    value1 = 3.14
    value2 = 2.71
    value3 = 100.0
    assert value1 != value2
    epoch = 1
    # Add metrics in the same epoch in two calls, so that we test both the cases where the epoch is already present,
    # and where not
    logger.log_metrics({"epoch": 1, key1: value1})
    logger.log_metrics({"epoch": 1, key2: value2})
    # All results for epoch 1 should be collated into a single dictionary
    assert logger.extract_by_prefix(epoch=epoch) == {
        key1: value1,
        key2: value2
    }
    # When updating a metric that already exists, the result should not be a float anymore but a list.
    logger.log_metrics({"epoch": epoch, key1: value3})
    assert logger.extract_by_prefix(epoch=epoch) == {
        key1: [value1, value3],
        key2: value2
    }
    # Add more metrics for key1, so that we also test the case that the results are already a list
    logger.log_metrics({"epoch": epoch, key1: value3})
    assert logger.extract_by_prefix(epoch=epoch) == {
        key1: [value1, value3, value3],
        key2: value2
    }
    # Add metrics that don't have an epoch key: This happens for example during testing with trainer.test
    other_metrics1 = {"foo": 1.0}
    other_metrics2 = {"foo": 2.0}
    logger.log_metrics(other_metrics1)
    logger.log_metrics(other_metrics2)
    assert logger.results_without_epoch == [other_metrics1, other_metrics2]
예제 #4
0
def create_lightning_trainer(container: LightningContainer,
                             resume_from_checkpoint: Optional[Path] = None,
                             num_nodes: int = 1,
                             **kwargs: Dict[str, Any]) -> \
        Tuple[Trainer, Optional[StoringLogger]]:
    """
    Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
    and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
    return value.
    :param container: The container with model and data.
    :param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
    :param num_nodes: The number of nodes to use in distributed training.
    :param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
    :return: A tuple [Trainer object, diagnostic logger]
    """
    # For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
    # models, this still appears to be the best way of choosing them because validation loss on the relatively small
    # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
    # not for the HeadAndNeck model.
    best_checkpoint_callback = ModelCheckpoint(
        dirpath=str(container.checkpoint_folder),
        # filename=BEST_CHECKPOINT_FILE_NAME,
        # monitor=f"{VALIDATION_PREFIX}{MetricType.LOSS.value}",
        # save_top_k=1,
        save_last=True)

    # Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
    # Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
    # recovery_checkpoints_save_last_k.
    recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(
        container)

    num_gpus = container.num_gpus_per_node
    effective_num_gpus = num_gpus * num_nodes
    # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
    # For unit tests, only "ddp_spawn" works
    accelerator = "ddp" if effective_num_gpus > 1 else None
    if effective_num_gpus > 1:
        # Initialize the DDP plugin with find_unused_parameters=False by default. If True (default), it prints out
        # lengthy warnings about the performance impact of find_unused_parameters
        plugins = [
            InnerEyeDDPPlugin(
                num_nodes=num_nodes,
                sync_batchnorm=True,
                find_unused_parameters=container.pl_find_unused_parameters)
        ]
    else:
        plugins = []
    logging.info(
        f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'")
    tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder),
                                           name="Lightning",
                                           version="")
    loggers = [tensorboard_logger, AzureMLLogger()]
    storing_logger: Optional[StoringLogger]
    if isinstance(container, InnerEyeContainer):
        storing_logger = StoringLogger()
        loggers.append(storing_logger)
    else:
        storing_logger = None
    # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
    precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
    # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
    # https://pytorch.org/docs/stable/notes/randomness.html
    # For the classification models, we observed only a small performance deterioration (increase in 10sec on total
    # training time of 22min) when switching to deterministic.
    if container.pl_deterministic:
        deterministic = True
        benchmark = False
    else:
        deterministic = False
        benchmark = True
    # If the users provides additional callbacks via get_trainer_arguments (for custom
    # containers
    callbacks = [best_checkpoint_callback, recovery_checkpoint_callback]
    if "callbacks" in kwargs:
        callbacks.append(kwargs.pop("callbacks"))  # type: ignore
    is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
    progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
    if progress_bar_refresh_rate is None and is_azureml_run:
        # When running in AzureML, the default progress bar clutters the output files with thousands of lines.
        progress_bar_refresh_rate = 50
        logging.info(
            f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
            f"To change, modify the pl_progress_bar_refresh_rate field of the container."
        )
    # Read out additional model-specific args here.
    # We probably want to keep essential ones like numgpu and logging.
    trainer = Trainer(default_root_dir=str(container.outputs_folder),
                      deterministic=deterministic,
                      benchmark=benchmark,
                      accelerator=accelerator,
                      max_epochs=container.num_epochs,
                      num_sanity_val_steps=container.pl_num_sanity_val_steps,
                      callbacks=callbacks,
                      logger=loggers,
                      progress_bar_refresh_rate=progress_bar_refresh_rate,
                      num_nodes=num_nodes,
                      gpus=num_gpus,
                      precision=precision,
                      sync_batchnorm=True,
                      terminate_on_nan=container.detect_anomaly,
                      resume_from_checkpoint=str(resume_from_checkpoint)
                      if resume_from_checkpoint else None,
                      plugins=plugins,
                      **kwargs)
    return trainer, storing_logger
예제 #5
0
def create_lightning_trainer(config: ModelConfigBase,
                             resume_from_checkpoint: Optional[Path] = None,
                             num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
    """
    Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
    and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
    return value.
    :param config: The model configuration.
    :param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
    :param num_nodes: The number of nodes to use in distributed training.
    :return: A tuple [Trainer object, diagnostic logger]
    """
    # For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
    # models, this still appears to be the best way of choosing them because validation loss on the relatively small
    # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
    # not for the HeadAndNeck model.
    best_checkpoint_callback = ModelCheckpoint(dirpath=str(config.checkpoint_folder),
                                               # filename=BEST_CHECKPOINT_FILE_NAME,
                                               # monitor=f"{VALIDATION_PREFIX}{MetricType.LOSS.value}",
                                               # save_top_k=1,
                                               save_last=True)
    # Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
    # Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs. Due to a bug in Lightning, this
    # will still write alternate files recovery.ckpt and recovery-v0.ckpt, which are cleaned up later in
    # cleanup_checkpoint_folder
    recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(config.checkpoint_folder),
                                                   filename=RECOVERY_CHECKPOINT_FILE_NAME,
                                                   period=config.recovery_checkpoint_save_interval
                                                   )

    num_gpus = torch.cuda.device_count() if config.use_gpu else 0
    logging.info(f"Number of available GPUs: {num_gpus}")
    if config.max_num_gpus >= 0 and config.max_num_gpus < num_gpus:
        num_gpus = config.max_num_gpus
        logging.info(f"Restricting the number of GPUs to {num_gpus}")
    # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
    # For unit tests, only "ddp_spawn" works
    accelerator = "ddp" if num_gpus > 1 else None
    logging.info(f"Using {num_gpus} GPUs with accelerator '{accelerator}'")
    storing_logger = StoringLogger()
    tensorboard_logger = TensorBoardLogger(save_dir=str(config.logs_folder), name="Lightning", version="")
    loggers = [storing_logger, tensorboard_logger, AzureMLLogger()]
    # This leads to problems with run termination.
    # if not is_offline_run_context(RUN_CONTEXT):
    #     mlflow_logger = MLFlowLogger(experiment_name=RUN_CONTEXT.experiment.name,
    #                                  tracking_uri=RUN_CONTEXT.experiment.workspace.get_mlflow_tracking_uri())
    #     # The MLFlow logger needs to get its ID from the AzureML run context, otherwise there will be two sets of
    #     # results for each run, one from native AzureML and one from the MLFlow logger.
    #     mlflow_logger._run_id = RUN_CONTEXT.id
    #     loggers.append(mlflow_logger)
    # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
    precision = 32 if num_gpus == 0 else 16 if config.use_mixed_precision else 32
    # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
    # https://pytorch.org/docs/stable/notes/randomness.html
    # For the classification models, we observed only a small performance deterioration (increase in 10sec on total
    # training time of 22min) when switching to deterministic.
    if config.pl_deterministic:
        deterministic = True
        benchmark = False
    else:
        deterministic = False
        benchmark = True
    trainer = Trainer(default_root_dir=str(config.outputs_folder),
                      deterministic=deterministic,
                      benchmark=benchmark,
                      accelerator=accelerator,
                      max_epochs=config.num_epochs,
                      num_sanity_val_steps=config.pl_num_sanity_val_steps,
                      callbacks=[best_checkpoint_callback, recovery_checkpoint_callback],
                      logger=loggers,
                      progress_bar_refresh_rate=0,  # Disable the progress bar completely
                      num_nodes=num_nodes,
                      gpus=num_gpus,
                      precision=precision,
                      sync_batchnorm=True,
                      terminate_on_nan=config.detect_anomaly,
                      resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None
                      )
    return trainer, storing_logger
def create_lightning_trainer(container: LightningContainer,
                             resume_from_checkpoint: Optional[Path] = None,
                             num_nodes: int = 1,
                             multiple_trainloader_mode: str = "max_size_cycle") -> \
        Tuple[Trainer, StoringLogger]:
    """
    Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
    and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
    return value.
    :param container: The container with model and data.
    :param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
    :param num_nodes: The number of nodes to use in distributed training.
    :return: A tuple [Trainer object, diagnostic logger]
    """
    logging.debug(f"resume_from_checkpoint: {resume_from_checkpoint}")
    num_gpus = container.num_gpus_per_node()
    effective_num_gpus = num_gpus * num_nodes
    strategy = None
    if effective_num_gpus == 0:
        accelerator = "cpu"
        devices = 1
        message = "CPU"
    else:
        accelerator = "gpu"
        devices = num_gpus
        message = f"{devices} GPU"
        if effective_num_gpus > 1:
            # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of
            # GPU memory).
            # Initialize the DDP plugin. The default for pl_find_unused_parameters is False. If True, the plugin
            # prints out lengthy warnings about the performance impact of find_unused_parameters.
            strategy = DDPPlugin(find_unused_parameters=container.pl_find_unused_parameters)
            message += "s per node with DDP"
    logging.info(f"Using {message}")
    tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
    loggers = [tensorboard_logger, AzureMLLogger(False)]
    storing_logger = StoringLogger()
    loggers.append(storing_logger)
    # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
    precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
    # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
    # https://pytorch.org/docs/stable/notes/randomness.html
    # Note that switching to deterministic models can have large performance downside.
    if container.pl_deterministic:
        deterministic = True
        benchmark = False
    else:
        deterministic = False
        benchmark = True

    # The last checkpoint is considered the "best" checkpoint. For large segmentation
    # models, this still appears to be the best way of choosing them because validation loss on the relatively small
    # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
    # not for the HeadAndNeck model.
    # Note that "last" is somehow a misnomer, it should rather be "latest". There is a "last" checkpoint written in
    # every epoch. We could use that for recovery too, but it could happen that the job gets preempted right during
    # writing that file, and we would end up with an invalid file.
    last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
                                               save_last=True,
                                               save_top_k=0)
    recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
                                                   filename=AUTOSAVE_CHECKPOINT_FILE_NAME,
                                                   every_n_val_epochs=container.autosave_every_n_val_epochs,
                                                   save_last=False)
    callbacks: List[Callback] = [
        last_checkpoint_callback,
        recovery_checkpoint_callback,
    ]
    if container.monitor_loading:
        # TODO antonsc: Remove after fixing the callback.
        raise NotImplementedError("Monitoring batch loading times has been temporarily disabled.")
        # callbacks.append(BatchTimeCallback())
    if num_gpus > 0 and container.monitor_gpu:
        logging.info("Adding monitoring for GPU utilization")
        callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True))
    # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
    additional_args = container.get_trainer_arguments()
    # Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method
    if "callbacks" in additional_args:
        more_callbacks = additional_args.pop("callbacks")
        if isinstance(more_callbacks, list):
            callbacks.extend(more_callbacks)  # type: ignore
        else:
            callbacks.append(more_callbacks)  # type: ignore
    callbacks.extend(container.get_callbacks())
    is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
    progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
    if progress_bar_refresh_rate is None:
        progress_bar_refresh_rate = 50
        logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
                     f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
    if is_azureml_run:
        callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
                                            write_to_logging_info=True,
                                            print_timestamp=False))
    else:
        callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
    # Read out additional model-specific args here.
    # We probably want to keep essential ones like numgpu and logging.
    trainer = Trainer(default_root_dir=str(container.outputs_folder),
                      deterministic=deterministic,
                      benchmark=benchmark,
                      accelerator=accelerator,
                      strategy=strategy,
                      max_epochs=container.num_epochs,
                      # Both these arguments can be integers or floats. If integers, it is the number of batches.
                      # If float, it's the fraction of batches. We default to 1.0 (processing all batches).
                      limit_train_batches=container.pl_limit_train_batches or 1.0,
                      limit_val_batches=container.pl_limit_val_batches or 1.0,
                      num_sanity_val_steps=container.pl_num_sanity_val_steps,
                      check_val_every_n_epoch=container.pl_check_val_every_n_epoch,
                      callbacks=callbacks,
                      logger=loggers,
                      num_nodes=num_nodes,
                      devices=devices,
                      precision=precision,
                      sync_batchnorm=True,
                      detect_anomaly=container.detect_anomaly,
                      profiler=container.pl_profiler,
                      resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
                      multiple_trainloader_mode=multiple_trainloader_mode,
                      **additional_args)
    return trainer, storing_logger
class InnerEyeLightning(LightningModule):
    """
    The base class for all InnerEye models for training in PyTorch Lightning. The base class handles all shared
    operations like choosing the optimizer and learning rate schedule, keeping track of IO performance (loading times),
    and IO to files.
    """
    def __init__(self, config: DeepLearningConfig, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.outputs_folder = config.outputs_folder
        self.checkpoint_folder = config.checkpoint_folder
        self.model: DeviceAwareModule = DeviceAwareModule()
        # These two will be set later in set_optimizer_and_scheduler.
        # The ddp_spawn accelerator only works if the model configuration object is
        # not stored in here. Hence, need to do operations that require a full config
        # in a way that does not require storing the config.
        self.optimizer: Optional[Optimizer] = None
        self.l_rate_scheduler: Optional[_LRScheduler] = None
        self.cross_validation_split_index = config.cross_validation_split_index
        self.effective_random_seed = config.get_effective_random_seed()
        # This should be re-assigned on the outside, to a logger that is hooked up with the Trainer object.
        self.storing_logger = StoringLogger()
        # This will be initialized correctly in epoch_start
        self.random_state: Optional[RandomStateSnapshot] = None
        # training loggers
        self.train_metrics_folder = self.outputs_folder / ModelExecutionMode.TRAIN.value
        self.val_metrics_folder = self.outputs_folder / ModelExecutionMode.VAL.value
        fixed_logger_columns = {
            LoggingColumns.CrossValidationSplitIndex.value:
            config.cross_validation_split_index
        }
        self.train_epoch_metrics_logger = DataframeLogger(
            self.train_metrics_folder / EPOCH_METRICS_FILE_NAME,
            fixed_columns=fixed_logger_columns)
        self.val_epoch_metrics_logger = DataframeLogger(
            self.val_metrics_folder / EPOCH_METRICS_FILE_NAME,
            fixed_columns=fixed_logger_columns)
        # Stores information the checkpoint that created this model, if any.
        self.checkpoint_loading_message = ""

    def set_optimizer_and_scheduler(self, config: DeepLearningConfig) -> None:
        self.optimizer = model_util.create_optimizer(config,
                                                     self.model.parameters())
        self.l_rate_scheduler = SchedulerWithWarmUp(
            config, self.optimizer, num_epochs=config.num_epochs)

    def configure_optimizers(
            self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
        return [self.optimizer], [self.l_rate_scheduler]  # type: ignore

    @rank_zero_only
    def on_fit_end(self) -> None:
        """
        Flushes all logger objects that the present object holds. This should only be run on rank zero, because
        otherwise ranks != 0 will create empty log files that can clash with the non-empty log files written on
        rank 0.
        """
        self.train_epoch_metrics_logger.flush()
        self.val_epoch_metrics_logger.flush()

    def training_epoch_end(self, outputs: List[Any]) -> None:
        # Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch.
        # Metrics for the very last epoch are written in on_train_end
        self.read_epoch_results_from_logger_and_store(
            epoch=self.current_epoch - 1)
        self.training_or_validation_epoch_end(is_training=True)  # type: ignore

    def on_validation_epoch_start(self) -> None:
        """
        Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure
        that any randomization when loading validation data is consistent during training. In particular, this ensures
        that drawing random patches for segmentation model training is giving a validation set that does not fluctuate.
        """
        # Store the random number generator state, so that the next training epoch starts from here.
        self.random_state = RandomStateSnapshot.snapshot_random_state()
        # reset the random state for validation, so that we get consistent behaviour when drawing random patches
        # when validating segmentation models.
        seed = self.effective_random_seed
        set_random_seed(seed, "Validation")

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        """
        Resets the random number generator state to what it was before the current validation epoch started.
        :param outputs: The list of outputs from the individual validation minibatches.
        """
        # reset the random state for training, so that we get continue from where we were before the validation step.
        assert self.random_state is not None
        self.random_state.restore_random_state()
        self.training_or_validation_epoch_end(
            is_training=False)  # type: ignore

    @rank_zero_only
    def on_train_end(self) -> None:
        """
        This hook is called at the very end of training. Use that to write the very last set of training and
        validation metrics from the StoringLogger to disk.
        """
        self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch)

    @rank_zero_only
    def read_epoch_results_from_logger_and_store(self, epoch: int) -> None:
        """
        Reads the metrics for the previous epoch from the StoringLogger, and writes them to disk, broken down by
        Training and Validation metrics.
        """
        if epoch >= 0:
            if epoch in self.storing_logger.results_per_epoch:
                for is_training, prefix in [(True, TRAIN_PREFIX),
                                            (False, VALIDATION_PREFIX)]:
                    metrics = self.storing_logger.extract_by_prefix(
                        epoch, prefix)
                    self.store_epoch_results(metrics, epoch, is_training)

    def log_on_epoch(self,
                     name: Union[MetricType, str],
                     value: Any,
                     is_training: bool,
                     reduce_fx: Union[str, Callable] = "mean",
                     sync_dist_override: Optional[bool] = None) -> None:
        """
        Logs a metrics to Pytorch Lightning with the on_epoch flag set. The metric will get a prefix indicating
        if it is a training or a validation metric. A custom reducer function can be provided.
        The method also ensures that the correct synchronization across nodes is used. If the value to log is a
        floating point, it is converted to a Tensor on the current device to enable synchronization.

        :param sync_dist_override: If not None, use this value for the sync_dist argument to self.log. If None,
        set it automatically depending on the use of DDP.
        :param name: The name of the metric to log
        :param value: The value of the metric. This can be a tensor, floating point value, or a Metric class.
        :param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix.
        :param reduce_fx: The reduce function to use when synchronizing the tensors across GPUs. This must be
        a value recognized by sync_ddp: "sum", "mean"
        """
        metric_name = name if isinstance(name, str) else name.value
        prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
        log_on_epoch(self,
                     name=prefix + metric_name,
                     value=value,
                     sync_dist=sync_dist_override,
                     reduce_fx=reduce_fx)

    def store_epoch_results(self, metrics: DictStrFloat, epoch: int,
                            is_training: bool) -> None:
        """
        Stores a set of metrics (key/value pairs) to a file logger. That file logger is either one that only holds
        training or only holds validation metrics.
        :param metrics: A dictionary with all the metrics to write, as key/value pairs.
        :param epoch: The epoch to which the metrics belong.
        :param is_training: If true, write the metrics to the logger for training metrics, if False, write to the logger
        for validation metrics.
        """
        file_logger = self.train_epoch_metrics_logger if is_training else self.val_epoch_metrics_logger
        store_epoch_metrics(metrics, epoch, file_logger=file_logger)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        """
        This hook is called when loading a model from a checkpoint. It just prints out diagnostics about which epoch
        created the present checkpoint.
        :param checkpoint: The checkpoint dictionary loaded from disk.
        """
        keys = ['epoch', 'global_step']
        present_keys = [
            f"{key} = {checkpoint[key]}" for key in keys if key in checkpoint
        ]
        if present_keys:
            self.checkpoint_loading_message = f"Loading checkpoint that was created at ({', '.join(present_keys)})"
            logging.info(self.checkpoint_loading_message)

    def training_step(
            self,  # type: ignore
            sample: Dict[str, Any],
            batch_index: int) -> Any:
        return self.training_or_validation_step(sample,
                                                batch_index,
                                                is_training=True)

    def validation_step(
            self,  # type: ignore
            sample: Dict[str, Any],
            batch_index: int) -> Any:
        return self.training_or_validation_step(sample,
                                                batch_index,
                                                is_training=False)

    def training_or_validation_step(self, sample: Dict[str,
                                                       Any], batch_index: int,
                                    is_training: bool) -> Any:
        """
        This is the shared method that handles the training (when `is_training==True`) and validation steps
        (when `is_training==False`)
        :param sample: The minibatch of data that should be processed.
        :param batch_index: The index of the current minibatch.
        :param is_training: If true, this has been called from `training_step`, otherwise it has been called from
        `validation_step`.
        """
        raise NotImplementedError(
            "This method must be overwritten in a derived class.")

    def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
        """
        Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss".
        If this comes from a training step, then also log the learning rate.
        :param loss: The loss value that should be logged.
        :param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
        be called "val/Loss"
        """
        assert isinstance(self.trainer, Trainer)
        self.log_on_epoch(MetricType.LOSS, loss, is_training)
        if is_training:
            learning_rate = self.trainer.lr_schedulers[0][
                'scheduler'].get_last_lr()[0]
            self.log_on_epoch(MetricType.LEARNING_RATE, learning_rate,
                              is_training)