Beispiel #1
0
def model_train(checkpoint_handler: CheckpointHandler,
                container: LightningContainer,
                num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
    """
    The main training loop. It creates the Pytorch model based on the configuration options passed in,
    creates a Pytorch Lightning trainer, and trains the model.
    If a checkpoint was specified, then it loads the checkpoint before resuming training.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param num_nodes: The number of nodes to use in distributed training.
    :param container: A container object that holds the training data in PyTorch Lightning format
    and the model to train.
    :return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
    the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
    fitting other models.
    """
    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()
    lightning_model = container.model

    resource_monitor: Optional[ResourceMonitor] = None
    # Execute some bookkeeping tasks only once if running distributed:
    if is_global_rank_zero():
        logging.info(
            f"Model checkpoints are saved at {container.checkpoint_folder}")
        write_args_file(container.config if isinstance(
            container, InnerEyeContainer) else container,
                        outputs_folder=container.outputs_folder)
        if container.monitoring_interval_seconds > 0:
            resource_monitor = start_resource_monitor(container)

    # Run all of the container-related operations consistently with changed outputs folder, even ones that
    # should not rely on the current working directory, like get_data_module.
    with change_working_directory(container.outputs_folder):
        data_module = container.get_data_module()
        if is_global_rank_zero():
            container.before_training_on_global_rank_zero()
        if is_local_rank_zero():
            container.before_training_on_local_rank_zero()
        container.before_training_on_all_ranks()

    # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
    # training in the unit tests.d
    old_environ = dict(os.environ)
    # Set random seeds just before training. For segmentation models, we have
    # something that changes the random seed in the before_training_on_rank_zero hook.
    seed_everything(container.get_effective_random_seed())
    trainer, storing_logger = create_lightning_trainer(
        container,
        checkpoint_path,
        num_nodes=num_nodes,
        **container.get_trainer_arguments())
    rank_info = ", ".join(
        f"{env}: {os.getenv(env)}"
        for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK])
    logging.info(
        f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}"
    )
    # InnerEye models use this logger for diagnostics
    if isinstance(lightning_model, InnerEyeLightning):
        if storing_logger is None:
            raise ValueError(
                "InnerEye models require the storing_logger for diagnostics")
        lightning_model.storing_logger = storing_logger

    logging.info("Starting training")
    # When training models that are not built-in InnerEye models, we have no guarantee that they write
    # files to the right folder. Best guess is to change the current working directory to where files should go.
    with change_working_directory(container.outputs_folder):
        trainer.fit(lightning_model, datamodule=data_module)
        trainer.logger.close()  # type: ignore
    world_size = getattr(trainer, "world_size", 0)
    is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
    # Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
    # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
    if is_azureml_run and world_size > 1 and isinstance(
            lightning_model, ScalarLightning):
        upload_output_file_as_temp(
            lightning_model.train_subject_outputs_logger.csv_path,
            container.outputs_folder)
        upload_output_file_as_temp(
            lightning_model.val_subject_outputs_logger.csv_path,
            container.outputs_folder)
    # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
    # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
    # all necessary properties.
    if lightning_model.global_rank != 0:
        logging.info(
            f"Terminating training thread with rank {lightning_model.global_rank}."
        )
        sys.exit()

    logging.info("Choosing the best checkpoint and removing redundant files.")
    create_best_checkpoint(container.checkpoint_folder)
    # Lightning modifies a ton of environment variables. If we first run training and then the test suite,
    # those environment variables will mislead the training runs in the test suite, and make them crash.
    # Hence, restore the original environment after training.
    os.environ.clear()
    os.environ.update(old_environ)

    if world_size and isinstance(lightning_model, ScalarLightning):
        if is_azureml_run and world_size > 1:
            # In a DDP run on the local box, all ranks will write to local disk, hence no download needed.
            # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and
            # concatenate
            for rank in range(world_size):
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
                    file = mode.value + "/" + get_subject_output_file_per_rank(
                        rank)
                    RUN_CONTEXT.download_file(
                        name=TEMP_PREFIX + file,
                        output_file_path=container.outputs_folder / file)
        # Concatenate all temporary file per execution mode
        aggregate_and_create_subject_metrics_file(container.outputs_folder)

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it in the Azure UI.
    if isinstance(container, InnerEyeContainer):
        if container.config.max_batch_grad_cam > 0 and container.visualization_folder.exists(
        ):
            RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                      path=str(container.visualization_folder))

    if resource_monitor:
        logging.info("Shutting down the resource monitor process.")
        if is_azureml_run:
            for gpu_name, metrics_per_gpu in resource_monitor.read_aggregate_metrics(
            ).items():
                # Log as a table, with GPU being the first column
                RUN_CONTEXT.log_row("GPU utilization",
                                    GPU=gpu_name,
                                    **metrics_per_gpu)
        resource_monitor.kill()

    return trainer, storing_logger
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler,
                num_nodes: int = 1) -> ModelTrainingResults:
    """
    The main training loop. It creates the Pytorch model based on the configuration options passed in,
    creates a Pytorch Lightning trainer, and trains the model.
    If a checkpoint was specified, then it loads the checkpoint before resuming training.
    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param num_nodes: The number of nodes to use in distributed training.
    """
    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()
    # This reads the dataset file, and possibly sets required pre-processing objects, like one-hot encoder
    # for categorical features, that need to be available before creating the model.
    config.read_dataset_if_needed()

    # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
    # training in the unit tests.d
    old_environ = dict(os.environ)
    seed_everything(config.get_effective_random_seed())
    trainer, storing_logger = create_lightning_trainer(config,
                                                       checkpoint_path,
                                                       num_nodes=num_nodes)

    logging.info(
        f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. "
        f"trainer.global_rank: {trainer.global_rank}")
    logging.debug("Creating the PyTorch model.")
    lightning_model = create_lightning_model(config)
    lightning_model.storing_logger = storing_logger

    resource_monitor = None
    # Execute some bookkeeping tasks only once if running distributed:
    if is_rank_zero():
        config.write_args_file()
        logging.info(str(config))
        # Save the dataset files for later use in cross validation analysis
        config.write_dataset_files()
        logging.info(
            f"Model checkpoints are saved at {config.checkpoint_folder}")

        # set the random seed for all libraries
        ml_util.set_random_seed(config.get_effective_random_seed(),
                                "Patch visualization")
        # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
        # want training to depend on how many patients we visualized, and hence set the random seed again right after.
        with logging_section(
                "Visualizing the effect of sampling random crops for training"
        ):
            visualize_random_crops_for_dataset(config)

        # Print out a detailed breakdown of layers, memory consumption and time.
        generate_and_print_model_summary(config, lightning_model.model)

        if config.monitoring_interval_seconds > 0:
            # initialize and start GPU monitoring
            gpu_tensorboard = config.logs_folder / "gpu_utilization"
            # Result file in CSV format should NOT live in the logs folder, the streaming upload that is
            # used for this folder might corrupt the file.
            gpu_csv = config.outputs_folder / "gpu_utilization"
            gpu_csv.mkdir(parents=True, exist_ok=True)
            logging.info(
                f"Starting resource monitor. GPU utilization will be written to Tensorboard in "
                f"{gpu_tensorboard}, aggregate metrics to {gpu_csv}")
            resource_monitor = ResourceMonitor(
                interval_seconds=config.monitoring_interval_seconds,
                tensorboard_folder=gpu_tensorboard,
                csv_results_folder=gpu_csv)
            resource_monitor.start()

    # Training loop
    logging.info("Starting training")

    lightning_data = TrainingAndValidationDataLightning(config)  # type: ignore
    # When trying to store the config object in the constructor, it does not appear to get stored at all, later
    # reference of the object simply fail. Hence, have to set explicitly here.
    lightning_data.config = config
    trainer.fit(lightning_model, datamodule=lightning_data)
    trainer.logger.close()  # type: ignore
    lightning_model.close_all_loggers()
    world_size = getattr(trainer, "world_size", 0)
    is_azureml_run = not config.is_offline_run
    # Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
    # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
    if is_azureml_run and world_size > 1 and isinstance(
            lightning_model, ScalarLightning):
        upload_output_file_as_temp(
            lightning_model.train_subject_outputs_logger.csv_path,
            config.outputs_folder)
        upload_output_file_as_temp(
            lightning_model.val_subject_outputs_logger.csv_path,
            config.outputs_folder)
    # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
    # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
    # all necessary properties.
    if lightning_model.global_rank != 0:
        logging.info(
            f"Terminating training thread with rank {lightning_model.global_rank}."
        )
        sys.exit()

    logging.info("Choosing the best checkpoint and removing redundant files.")
    cleanup_checkpoint_folder(config.checkpoint_folder)
    # Lightning modifies a ton of environment variables. If we first run training and then the test suite,
    # those environment variables will mislead the training runs in the test suite, and make them crash.
    # Hence, restore the original environment after training.
    os.environ.clear()
    os.environ.update(old_environ)

    if world_size and isinstance(lightning_model, ScalarLightning):
        if is_azureml_run and world_size > 1:
            # In a DDP run on the local box, all ranks will write to local disk, hence no download needed.
            # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and
            # concatenate
            for rank in range(world_size):
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
                    file = mode.value + "/" + get_subject_output_file_per_rank(
                        rank)
                    RUN_CONTEXT.download_file(
                        name=TEMP_PREFIX + file,
                        output_file_path=config.outputs_folder / file)
        # Concatenate all temporary file per execution mode
        for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
            temp_files = (config.outputs_folder /
                          mode.value).rglob(SUBJECT_OUTPUT_PER_RANK_PREFIX +
                                            "*")
            result_file = config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME
            for i, file in enumerate(temp_files):
                temp_file_contents = file.read_text()
                if i == 0:
                    # Copy the first file as-is, including the first line with the column headers
                    result_file.write_text(temp_file_contents)
                else:
                    # For all files but the first one, cut off the header line.
                    result_file.write_text(
                        os.linesep.join(temp_file_contents.splitlines()[1:]))

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=list(
            storing_logger.to_metrics_dicts(
                prefix_filter=TRAIN_PREFIX).values()),
        val_results_per_epoch=list(
            storing_logger.to_metrics_dicts(
                prefix_filter=VALIDATION_PREFIX).values()),
        train_diagnostics=lightning_model.train_diagnostics,
        val_diagnostics=lightning_model.val_diagnostics,
        optimal_temperature_scale_values_per_checkpoint_epoch=[])

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                  path=str(config.visualization_folder))

    if resource_monitor:
        logging.info("Shutting down the resource monitor process.")
        if not config.is_offline_run:
            for gpu_name, metrics_per_gpu in resource_monitor.read_aggregate_metrics(
            ).items():
                # Log as a table, with GPU being the first column
                RUN_CONTEXT.log_row("GPU utilization",
                                    GPU=gpu_name,
                                    **metrics_per_gpu)
        resource_monitor.kill()

    return model_training_results