Example #1
0
def forward_preserve_state(module: DeviceAwareModule,
                           inputs: List[torch.Tensor]) -> torch.Tensor:
    """
    Perform forward pass on input module with given list of torch tensors. The function preserves the random state
    of the backend libraries to avoid reproducibility issues. Additionally, it temporarily sets the model in
    evaluation mode for inference and then restores its previous state.
    :param module: Callable torch module
    :param inputs: List of input torch tensors
    :return output: Output torch tensors
    """
    if not isinstance(inputs, list):
        raise RuntimeError("Inputs object has to be a list of torch tensors")

    if module.is_model_on_gpu():
        inputs = [input_tensor.cuda() for input_tensor in inputs]

    # collect the current state of the model
    is_train = module.training
    module_state = RandomStateSnapshot.snapshot_random_state()

    # set the model in evaluation mode and perform a forward pass
    module.eval()
    with torch.no_grad():
        output = module.forward(*inputs)
    if is_train:
        module.train()

    # restore the seed for torch and numpy
    module_state.restore_random_state()

    return output
def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAwareModule) -> None:
    """
    Writes a human readable summary of the present model to logging.info, and logs the number of trainable
    parameters to AzureML.

    :param config: The configuration for the model.
    :param model: The instantiated Pytorch model.
    """
    random_state = RandomStateSnapshot.snapshot_random_state()
    # There appears to be a bug in apex, where previous use (in training for example) causes problems
    # when another model is later built on the CPU (for example, before loading from a checkpoint)
    # https://github.com/NVIDIA/apex/issues/694
    # Hence, move the model to the GPU before doing model summary.
    if config.use_gpu:
        model = model.cuda()
    if isinstance(config, ScalarModelBase):
        # To generate the model summary, read the first item of the dataset. Then use the model's own
        # get_model_input function to convert the dataset item to input tensors, and feed them through the model.
        train_dataset = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN)
        train_item_0 = next(iter(train_dataset.as_data_loader(shuffle=False, batch_size=1, num_dataload_workers=0)))
        model_inputs = get_scalar_model_inputs_and_labels(config, model, train_item_0).model_inputs
        # The model inputs may already be converted to float16, assuming that we would do mixed precision.
        # However, the model is not yet converted to float16 when this function is called, hence convert back to float32
        summary = ModelSummary(model)
        summary.generate_summary(input_tensors=model_inputs, log_summaries_to_files=config.log_summaries_to_files)
    elif config.is_segmentation_model:
        summary_for_segmentation_models(config, model)
        assert model.summarizer
        summary = model.summarizer  # type: ignore
    else:
        raise ValueError("Don't know how to generate a summary for this type of model?")
    RUN_CONTEXT.log(LoggingColumns.NumTrainableParameters, summary.n_trainable_params)
    random_state.restore_random_state()
 def on_validation_epoch_start(self) -> None:
     """
     Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure
     that any randomization when loading validation data is consistent during training. In particular, this ensures
     that drawing random patches for segmentation model training is giving a validation set that does not fluctuate.
     """
     # Store the random number generator state, so that the next training epoch starts from here.
     self.random_state = RandomStateSnapshot.snapshot_random_state()
     # reset the random state for validation, so that we get consistent behaviour when drawing random patches
     # when validating segmentation models.
     seed = self.effective_random_seed
     set_random_seed(seed, "Validation")
def test_random_state_snapshot() -> None:
    """
    Test get and reset all random states via RandomStateSnapshot classes.
    """
    def _get_random_ints_from_libs(
    ) -> Tuple[List[int], np.ndarray, torch.Tensor]:
        _python_random = [random.randint(0, 100) for _ in range(0, 20)]
        _numpy_random = np.random.randint(0, 100, 20)
        _torch_random = torch.randint(0, 100, (20, 1))
        return _python_random, _numpy_random, _torch_random

    # set the random state
    ml_util.set_random_seed(0)
    # take snapshot of the random state at it's original state
    random_state = RandomStateSnapshot.snapshot_random_state()
    # create random numbers using python, numpy, and torch
    original_python_random, original_numpy_random, original_torch_random = _get_random_ints_from_libs(
    )
    # re-set the random state
    ml_util.set_random_seed(0)

    # validate that the current random state is accurately captured
    assert random.getstate() == random_state.random_state
    for i, x in enumerate(np.random.get_state()):
        assert np.array_equal(x, random_state.numpy_random_state[i])
    assert torch.equal(torch.random.get_rng_state(),
                       random_state.torch_random_state)
    assert random_state.torch_cuda_random_state is None

    # change the random state
    ml_util.set_random_seed(10)
    # create random numbers using python, numpy, and torch
    new_python_random, new_numpy_random, new_torch_random = _get_random_ints_from_libs(
    )
    # check that a new state was used to create these random numbers
    assert not new_python_random == original_python_random
    assert not np.array_equal(new_numpy_random, original_numpy_random)
    assert not torch.equal(new_torch_random, original_torch_random)

    # restore the original random stage
    random_state.restore_random_state()
    # get restored random variables
    restored_python_random, restored_numpy_random, restored_torch_random = _get_random_ints_from_libs(
    )
    # check restored variables match the original
    assert restored_python_random == original_python_random
    assert np.array_equal(restored_numpy_random, original_numpy_random)
    assert torch.equal(restored_torch_random, original_torch_random)
 def on_validation_epoch_start(self) -> None:
     """
     Stores the state of all random number generators, and resets them all to a fixed seed. This is done to ensure
     that any randomization when loading validation data is consistent during training. In particular, this ensures
     that drawing random patches for segmentation model training is giving a validation set that does not fluctuate.
     """
     self.val_timers.reset()
     # In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training
     # is done for this epoch, even though the on_training_epoch hook has not yet been called.
     self.train_timers.epoch_end()
     # Store the random number generator state, so that the next training epoch starts from here.
     self.random_state = RandomStateSnapshot.snapshot_random_state()
     # reset the random state for validation, so that we get consistent behaviour when drawing random patches
     # when validating segmentation models.
     seed = self.effective_random_seed
     set_random_seed(seed, "Validation")
def train_or_validate_epoch(
        training_steps: ModelTrainingStepsBase
) -> ModelOutputsAndMetricsForEpoch:
    """
    Trains or validates the model for one epoch.
    :param training_steps: Training pipeline to use.
    :returns: The results for training or validation. Result type depends on the type of model that is trained.
    """
    epoch_start_time = time()
    training_random_state = None
    train_val_params = training_steps.train_val_params
    config = training_steps.model_config
    if not train_val_params.in_training_mode:
        # take the snapshot of the existing random state
        training_random_state = RandomStateSnapshot.snapshot_random_state()
        # reset the random state for validation
        ml_util.set_random_seed(config.get_effective_random_seed(),
                                "Model validation")

    status_string = "training" if train_val_params.in_training_mode else "validation"
    item_start_time = time()
    num_load_time_warnings = 0
    num_load_time_exceeded = 0
    num_batches = 0
    total_extra_load_time = 0.0
    total_load_time = 0.0
    model_outputs_epoch = []
    for batch_index, sample in enumerate(train_val_params.data_loader):
        item_finish_time = time()
        item_load_time = item_finish_time - item_start_time
        # Having slow minibatch loading is OK in the very first batch of the every epoch, where processes
        # are spawned. Later, the load time should be zero.
        if batch_index == 0:
            logging.info(
                f"Loaded the first minibatch of {status_string} data in {item_load_time:0.2f} sec."
            )
        elif item_load_time > MAX_ITEM_LOAD_TIME_SEC:
            num_load_time_exceeded += 1
            total_extra_load_time += item_load_time
            if num_load_time_warnings < MAX_LOAD_TIME_WARNINGS:
                logging.warning(
                    f"Loading {status_string} minibatch {batch_index} took {item_load_time:0.2f} sec. "
                    f"This can mean that there are not enough data loader worker processes, or that there "
                    f"is a "
                    f"performance problem in loading. This warning will be printed at most "
                    f"{MAX_LOAD_TIME_WARNINGS} times.")
                num_load_time_warnings += 1
        model_outputs_minibatch = training_steps.forward_and_backward_minibatch(
            sample, batch_index, train_val_params.epoch)
        model_outputs_epoch.append(model_outputs_minibatch)
        train_finish_time = time()
        logging.debug(
            f"Epoch {train_val_params.epoch} {status_string} batch {batch_index}: "
            f"Loaded in {item_load_time:0.2f}sec, "
            f"{status_string} in {(train_finish_time - item_finish_time):0.2f}sec. "
            f"Loss = {model_outputs_minibatch.loss}")
        total_load_time += item_finish_time - item_start_time
        num_batches += 1
        item_start_time = time()

    # restore the training random state when validation has finished
    if training_random_state is not None:
        training_random_state.restore_random_state()

    epoch_time_seconds = time() - epoch_start_time
    logging.info(
        f"Epoch {train_val_params.epoch} {status_string} took {epoch_time_seconds:0.2f} sec, "
        f"of which waiting for next minibatch took {total_load_time:0.2f} sec total. {num_batches} "
        "minibatches in total.")
    if num_load_time_exceeded > 0:
        logging.warning(
            "The dataloaders were not fast enough to always supply the next batch in less than "
            f"{MAX_ITEM_LOAD_TIME_SEC}sec.")
        logging.warning(
            f"In this epoch, {num_load_time_exceeded} out of {num_batches} batches exceeded the load time "
            f"threshold. The total loading time for the slow batches was {total_extra_load_time:0.2f}sec."
        )

    _metrics = training_steps.get_epoch_results_and_store(epoch_time_seconds) \
        if train_val_params.save_metrics else MetricsDict()
    return ModelOutputsAndMetricsForEpoch(
        metrics=_metrics,
        model_outputs=model_outputs_epoch,
        is_train=train_val_params.in_training_mode)