def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor]) -> torch.Tensor: """ Perform forward pass on input module with given list of torch tensors. The function preserves the random state of the backend libraries to avoid reproducibility issues. Additionally, it temporarily sets the model in evaluation mode for inference and then restores its previous state. :param module: Callable torch module :param inputs: List of input torch tensors :return output: Output torch tensors """ if not isinstance(inputs, list): raise RuntimeError("Inputs object has to be a list of torch tensors") if module.is_model_on_gpu(): inputs = [input_tensor.cuda() for input_tensor in inputs] # collect the current state of the model is_train = module.training module_state = RandomStateSnapshot.snapshot_random_state() # set the model in evaluation mode and perform a forward pass module.eval() with torch.no_grad(): output = module.forward(*inputs) if is_train: module.train() # restore the seed for torch and numpy module_state.restore_random_state() return output
def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAwareModule) -> None: """ Writes a human readable summary of the present model to logging.info, and logs the number of trainable parameters to AzureML. :param config: The configuration for the model. :param model: The instantiated Pytorch model. """ random_state = RandomStateSnapshot.snapshot_random_state() # There appears to be a bug in apex, where previous use (in training for example) causes problems # when another model is later built on the CPU (for example, before loading from a checkpoint) # https://github.com/NVIDIA/apex/issues/694 # Hence, move the model to the GPU before doing model summary. if config.use_gpu: model = model.cuda() if isinstance(config, ScalarModelBase): # To generate the model summary, read the first item of the dataset. Then use the model's own # get_model_input function to convert the dataset item to input tensors, and feed them through the model. train_dataset = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN) train_item_0 = next(iter(train_dataset.as_data_loader(shuffle=False, batch_size=1, num_dataload_workers=0))) model_inputs = get_scalar_model_inputs_and_labels(config, model, train_item_0).model_inputs # The model inputs may already be converted to float16, assuming that we would do mixed precision. # However, the model is not yet converted to float16 when this function is called, hence convert back to float32 summary = ModelSummary(model) summary.generate_summary(input_tensors=model_inputs, log_summaries_to_files=config.log_summaries_to_files) elif config.is_segmentation_model: summary_for_segmentation_models(config, model) assert model.summarizer summary = model.summarizer # type: ignore else: raise ValueError("Don't know how to generate a summary for this type of model?") RUN_CONTEXT.log(LoggingColumns.NumTrainableParameters, summary.n_trainable_params) random_state.restore_random_state()
def on_validation_epoch_start(self) -> None: """ Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure that any randomization when loading validation data is consistent during training. In particular, this ensures that drawing random patches for segmentation model training is giving a validation set that does not fluctuate. """ # Store the random number generator state, so that the next training epoch starts from here. self.random_state = RandomStateSnapshot.snapshot_random_state() # reset the random state for validation, so that we get consistent behaviour when drawing random patches # when validating segmentation models. seed = self.effective_random_seed set_random_seed(seed, "Validation")
def test_random_state_snapshot() -> None: """ Test get and reset all random states via RandomStateSnapshot classes. """ def _get_random_ints_from_libs( ) -> Tuple[List[int], np.ndarray, torch.Tensor]: _python_random = [random.randint(0, 100) for _ in range(0, 20)] _numpy_random = np.random.randint(0, 100, 20) _torch_random = torch.randint(0, 100, (20, 1)) return _python_random, _numpy_random, _torch_random # set the random state ml_util.set_random_seed(0) # take snapshot of the random state at it's original state random_state = RandomStateSnapshot.snapshot_random_state() # create random numbers using python, numpy, and torch original_python_random, original_numpy_random, original_torch_random = _get_random_ints_from_libs( ) # re-set the random state ml_util.set_random_seed(0) # validate that the current random state is accurately captured assert random.getstate() == random_state.random_state for i, x in enumerate(np.random.get_state()): assert np.array_equal(x, random_state.numpy_random_state[i]) assert torch.equal(torch.random.get_rng_state(), random_state.torch_random_state) assert random_state.torch_cuda_random_state is None # change the random state ml_util.set_random_seed(10) # create random numbers using python, numpy, and torch new_python_random, new_numpy_random, new_torch_random = _get_random_ints_from_libs( ) # check that a new state was used to create these random numbers assert not new_python_random == original_python_random assert not np.array_equal(new_numpy_random, original_numpy_random) assert not torch.equal(new_torch_random, original_torch_random) # restore the original random stage random_state.restore_random_state() # get restored random variables restored_python_random, restored_numpy_random, restored_torch_random = _get_random_ints_from_libs( ) # check restored variables match the original assert restored_python_random == original_python_random assert np.array_equal(restored_numpy_random, original_numpy_random) assert torch.equal(restored_torch_random, original_torch_random)
def on_validation_epoch_start(self) -> None: """ Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure that any randomization when loading validation data is consistent during training. In particular, this ensures that drawing random patches for segmentation model training is giving a validation set that does not fluctuate. """ self.val_timers.reset() # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training # is done for this epoch, even though the on_training_epoch hook has not yet been called. self.train_timers.epoch_end() # Store the random number generator state, so that the next training epoch starts from here. self.random_state = RandomStateSnapshot.snapshot_random_state() # reset the random state for validation, so that we get consistent behaviour when drawing random patches # when validating segmentation models. seed = self.effective_random_seed set_random_seed(seed, "Validation")
def train_or_validate_epoch( training_steps: ModelTrainingStepsBase ) -> ModelOutputsAndMetricsForEpoch: """ Trains or validates the model for one epoch. :param training_steps: Training pipeline to use. :returns: The results for training or validation. Result type depends on the type of model that is trained. """ epoch_start_time = time() training_random_state = None train_val_params = training_steps.train_val_params config = training_steps.model_config if not train_val_params.in_training_mode: # take the snapshot of the existing random state training_random_state = RandomStateSnapshot.snapshot_random_state() # reset the random state for validation ml_util.set_random_seed(config.get_effective_random_seed(), "Model validation") status_string = "training" if train_val_params.in_training_mode else "validation" item_start_time = time() num_load_time_warnings = 0 num_load_time_exceeded = 0 num_batches = 0 total_extra_load_time = 0.0 total_load_time = 0.0 model_outputs_epoch = [] for batch_index, sample in enumerate(train_val_params.data_loader): item_finish_time = time() item_load_time = item_finish_time - item_start_time # Having slow minibatch loading is OK in the very first batch of the every epoch, where processes # are spawned. Later, the load time should be zero. if batch_index == 0: logging.info( f"Loaded the first minibatch of {status_string} data in {item_load_time:0.2f} sec." ) elif item_load_time > MAX_ITEM_LOAD_TIME_SEC: num_load_time_exceeded += 1 total_extra_load_time += item_load_time if num_load_time_warnings < MAX_LOAD_TIME_WARNINGS: logging.warning( f"Loading {status_string} minibatch {batch_index} took {item_load_time:0.2f} sec. " f"This can mean that there are not enough data loader worker processes, or that there " f"is a " f"performance problem in loading. This warning will be printed at most " f"{MAX_LOAD_TIME_WARNINGS} times.") num_load_time_warnings += 1 model_outputs_minibatch = training_steps.forward_and_backward_minibatch( sample, batch_index, train_val_params.epoch) model_outputs_epoch.append(model_outputs_minibatch) train_finish_time = time() logging.debug( f"Epoch {train_val_params.epoch} {status_string} batch {batch_index}: " f"Loaded in {item_load_time:0.2f}sec, " f"{status_string} in {(train_finish_time - item_finish_time):0.2f}sec. " f"Loss = {model_outputs_minibatch.loss}") total_load_time += item_finish_time - item_start_time num_batches += 1 item_start_time = time() # restore the training random state when validation has finished if training_random_state is not None: training_random_state.restore_random_state() epoch_time_seconds = time() - epoch_start_time logging.info( f"Epoch {train_val_params.epoch} {status_string} took {epoch_time_seconds:0.2f} sec, " f"of which waiting for next minibatch took {total_load_time:0.2f} sec total. {num_batches} " "minibatches in total.") if num_load_time_exceeded > 0: logging.warning( "The dataloaders were not fast enough to always supply the next batch in less than " f"{MAX_ITEM_LOAD_TIME_SEC}sec.") logging.warning( f"In this epoch, {num_load_time_exceeded} out of {num_batches} batches exceeded the load time " f"threshold. The total loading time for the slow batches was {total_extra_load_time:0.2f}sec." ) _metrics = training_steps.get_epoch_results_and_store(epoch_time_seconds) \ if train_val_params.save_metrics else MetricsDict() return ModelOutputsAndMetricsForEpoch( metrics=_metrics, model_outputs=model_outputs_epoch, is_train=train_val_params.in_training_mode)