def model_train(checkpoint_handler: CheckpointHandler, container: LightningContainer, num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param num_nodes: The number of nodes to use in distributed training. :param container: A container object that holds the training data in PyTorch Lightning format and the model to train. :return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when fitting other models. """ # Get the path to the checkpoint to recover from checkpoint_path = checkpoint_handler.get_recovery_path_train() lightning_model = container.model resource_monitor: Optional[ResourceMonitor] = None # Execute some bookkeeping tasks only once if running distributed: if is_global_rank_zero(): logging.info( f"Model checkpoints are saved at {container.checkpoint_folder}") write_args_file(container.config if isinstance( container, InnerEyeContainer) else container, outputs_folder=container.outputs_folder) if container.monitoring_interval_seconds > 0: resource_monitor = start_resource_monitor(container) # Run all of the container-related operations consistently with changed outputs folder, even ones that # should not rely on the current working directory, like get_data_module. with change_working_directory(container.outputs_folder): data_module = container.get_data_module() if is_global_rank_zero(): container.before_training_on_global_rank_zero() if is_local_rank_zero(): container.before_training_on_local_rank_zero() container.before_training_on_all_ranks() # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second # training in the unit tests.d old_environ = dict(os.environ) # Set random seeds just before training. For segmentation models, we have # something that changes the random seed in the before_training_on_rank_zero hook. seed_everything(container.get_effective_random_seed()) trainer, storing_logger = create_lightning_trainer( container, checkpoint_path, num_nodes=num_nodes, **container.get_trainer_arguments()) rank_info = ", ".join( f"{env}: {os.getenv(env)}" for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK]) logging.info( f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}" ) # InnerEye models use this logger for diagnostics if isinstance(lightning_model, InnerEyeLightning): if storing_logger is None: raise ValueError( "InnerEye models require the storing_logger for diagnostics") lightning_model.storing_logger = storing_logger logging.info("Starting training") # When training models that are not built-in InnerEye models, we have no guarantee that they write # files to the right folder. Best guess is to change the current working directory to where files should go. with change_working_directory(container.outputs_folder): trainer.fit(lightning_model, datamodule=data_module) trainer.logger.close() # type: ignore world_size = getattr(trainer, "world_size", 0) is_azureml_run = not is_offline_run_context(RUN_CONTEXT) # Per-subject model outputs for regression models are written per rank, and need to be aggregated here. # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them. if is_azureml_run and world_size > 1 and isinstance( lightning_model, ScalarLightning): upload_output_file_as_temp( lightning_model.train_subject_outputs_logger.csv_path, container.outputs_folder) upload_output_file_as_temp( lightning_model.val_subject_outputs_logger.csv_path, container.outputs_folder) # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training. # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set # all necessary properties. if lightning_model.global_rank != 0: logging.info( f"Terminating training thread with rank {lightning_model.global_rank}." ) sys.exit() logging.info("Choosing the best checkpoint and removing redundant files.") create_best_checkpoint(container.checkpoint_folder) # Lightning modifies a ton of environment variables. If we first run training and then the test suite, # those environment variables will mislead the training runs in the test suite, and make them crash. # Hence, restore the original environment after training. os.environ.clear() os.environ.update(old_environ) if world_size and isinstance(lightning_model, ScalarLightning): if is_azureml_run and world_size > 1: # In a DDP run on the local box, all ranks will write to local disk, hence no download needed. # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and # concatenate for rank in range(world_size): for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]: file = mode.value + "/" + get_subject_output_file_per_rank( rank) RUN_CONTEXT.download_file( name=TEMP_PREFIX + file, output_file_path=container.outputs_folder / file) # Concatenate all temporary file per execution mode aggregate_and_create_subject_metrics_file(container.outputs_folder) logging.info("Finished training") # Since we have trained the model further, let the checkpoint_handler object know so it can handle # checkpoints correctly. checkpoint_handler.additional_training_done() # Upload visualization directory to AML run context to be able to see it in the Azure UI. if isinstance(container, InnerEyeContainer): if container.config.max_batch_grad_cam > 0 and container.visualization_folder.exists( ): RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(container.visualization_folder)) if resource_monitor: logging.info("Shutting down the resource monitor process.") if is_azureml_run: for gpu_name, metrics_per_gpu in resource_monitor.read_aggregate_metrics( ).items(): # Log as a table, with GPU being the first column RUN_CONTEXT.log_row("GPU utilization", GPU=gpu_name, **metrics_per_gpu) resource_monitor.kill() return trainer, storing_logger
def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler, num_nodes: int = 1) -> ModelTrainingResults: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. :param config: The arguments which specify all required information. :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :param num_nodes: The number of nodes to use in distributed training. """ # Get the path to the checkpoint to recover from checkpoint_path = checkpoint_handler.get_recovery_path_train() # This reads the dataset file, and possibly sets required pre-processing objects, like one-hot encoder # for categorical features, that need to be available before creating the model. config.read_dataset_if_needed() # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second # training in the unit tests.d old_environ = dict(os.environ) seed_everything(config.get_effective_random_seed()) trainer, storing_logger = create_lightning_trainer(config, checkpoint_path, num_nodes=num_nodes) logging.info( f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. " f"trainer.global_rank: {trainer.global_rank}") logging.debug("Creating the PyTorch model.") lightning_model = create_lightning_model(config) lightning_model.storing_logger = storing_logger resource_monitor = None # Execute some bookkeeping tasks only once if running distributed: if is_rank_zero(): config.write_args_file() logging.info(str(config)) # Save the dataset files for later use in cross validation analysis config.write_dataset_files() logging.info( f"Model checkpoints are saved at {config.checkpoint_folder}") # set the random seed for all libraries ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization") # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't # want training to depend on how many patients we visualized, and hence set the random seed again right after. with logging_section( "Visualizing the effect of sampling random crops for training" ): visualize_random_crops_for_dataset(config) # Print out a detailed breakdown of layers, memory consumption and time. generate_and_print_model_summary(config, lightning_model.model) if config.monitoring_interval_seconds > 0: # initialize and start GPU monitoring gpu_tensorboard = config.logs_folder / "gpu_utilization" # Result file in CSV format should NOT live in the logs folder, the streaming upload that is # used for this folder might corrupt the file. gpu_csv = config.outputs_folder / "gpu_utilization" gpu_csv.mkdir(parents=True, exist_ok=True) logging.info( f"Starting resource monitor. GPU utilization will be written to Tensorboard in " f"{gpu_tensorboard}, aggregate metrics to {gpu_csv}") resource_monitor = ResourceMonitor( interval_seconds=config.monitoring_interval_seconds, tensorboard_folder=gpu_tensorboard, csv_results_folder=gpu_csv) resource_monitor.start() # Training loop logging.info("Starting training") lightning_data = TrainingAndValidationDataLightning(config) # type: ignore # When trying to store the config object in the constructor, it does not appear to get stored at all, later # reference of the object simply fail. Hence, have to set explicitly here. lightning_data.config = config trainer.fit(lightning_model, datamodule=lightning_data) trainer.logger.close() # type: ignore lightning_model.close_all_loggers() world_size = getattr(trainer, "world_size", 0) is_azureml_run = not config.is_offline_run # Per-subject model outputs for regression models are written per rank, and need to be aggregated here. # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them. if is_azureml_run and world_size > 1 and isinstance( lightning_model, ScalarLightning): upload_output_file_as_temp( lightning_model.train_subject_outputs_logger.csv_path, config.outputs_folder) upload_output_file_as_temp( lightning_model.val_subject_outputs_logger.csv_path, config.outputs_folder) # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training. # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set # all necessary properties. if lightning_model.global_rank != 0: logging.info( f"Terminating training thread with rank {lightning_model.global_rank}." ) sys.exit() logging.info("Choosing the best checkpoint and removing redundant files.") cleanup_checkpoint_folder(config.checkpoint_folder) # Lightning modifies a ton of environment variables. If we first run training and then the test suite, # those environment variables will mislead the training runs in the test suite, and make them crash. # Hence, restore the original environment after training. os.environ.clear() os.environ.update(old_environ) if world_size and isinstance(lightning_model, ScalarLightning): if is_azureml_run and world_size > 1: # In a DDP run on the local box, all ranks will write to local disk, hence no download needed. # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and # concatenate for rank in range(world_size): for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]: file = mode.value + "/" + get_subject_output_file_per_rank( rank) RUN_CONTEXT.download_file( name=TEMP_PREFIX + file, output_file_path=config.outputs_folder / file) # Concatenate all temporary file per execution mode for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]: temp_files = (config.outputs_folder / mode.value).rglob(SUBJECT_OUTPUT_PER_RANK_PREFIX + "*") result_file = config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME for i, file in enumerate(temp_files): temp_file_contents = file.read_text() if i == 0: # Copy the first file as-is, including the first line with the column headers result_file.write_text(temp_file_contents) else: # For all files but the first one, cut off the header line. result_file.write_text( os.linesep.join(temp_file_contents.splitlines()[1:])) model_training_results = ModelTrainingResults( train_results_per_epoch=list( storing_logger.to_metrics_dicts( prefix_filter=TRAIN_PREFIX).values()), val_results_per_epoch=list( storing_logger.to_metrics_dicts( prefix_filter=VALIDATION_PREFIX).values()), train_diagnostics=lightning_model.train_diagnostics, val_diagnostics=lightning_model.val_diagnostics, optimal_temperature_scale_values_per_checkpoint_epoch=[]) logging.info("Finished training") # Since we have trained the model further, let the checkpoint_handler object know so it can handle # checkpoints correctly. checkpoint_handler.additional_training_done() # Upload visualization directory to AML run context to be able to see it # in the Azure UI. if config.max_batch_grad_cam > 0 and config.visualization_folder.exists(): RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder)) if resource_monitor: logging.info("Shutting down the resource monitor process.") if not config.is_offline_run: for gpu_name, metrics_per_gpu in resource_monitor.read_aggregate_metrics( ).items(): # Log as a table, with GPU being the first column RUN_CONTEXT.log_row("GPU utilization", GPU=gpu_name, **metrics_per_gpu) resource_monitor.kill() return model_training_results