def temperature_scaling_steps(config: SequenceModelBase,
                              train_val_params: TrainValidateParameters,
                              val_results_for_epoch: ModelOutputsAndMetricsForEpoch) -> \
        Tuple[float, ModelOutputsAndMetricsForEpoch]:
    """
    Perform the steps required for temperature scaling:
    1) Learn the temperature parameter on the logits and labels of the provided validation epoch
    2) Re-run the validation after learning the scaling parameter.

    :param config: Config for a sequence model.
    :param train_val_params: Train/Validate parameters to use.
    :param val_results_for_epoch: results from the validation epoch to use in order to perform temperature scaling.
    :return: the optimal temperature value and the validation results after scaling has been performed.
    """
    # re-create the training steps for the repeat pass, but with metrics saving enabled
    train_val_params.save_metrics = True
    training_steps = create_model_training_steps(config, train_val_params)
    assert isinstance(training_steps, ModelTrainingStepsForSequenceModel)
    # make sure results for a validation epoch have been passed in
    assert val_results_for_epoch.is_train is False
    # perform temperature scaling
    logits = val_results_for_epoch.get_logits()
    labels = val_results_for_epoch.get_labels()
    temperature_value = training_steps.learn_temperature_scale_parameter(
        logits, labels)
    # recompute the validation set results for the temperature scaled model
    val_epoch_results = train_or_validate_epoch(training_steps)

    return temperature_value, val_epoch_results
def train_or_validate_epoch(
        training_steps: ModelTrainingStepsBase
) -> ModelOutputsAndMetricsForEpoch:
    """
    Trains or validates the model for one epoch.
    :param training_steps: Training pipeline to use.
    :returns: The results for training or validation. Result type depends on the type of model that is trained.
    """
    epoch_start_time = time()
    training_random_state = None
    train_val_params = training_steps.train_val_params
    config = training_steps.model_config
    if not train_val_params.in_training_mode:
        # take the snapshot of the existing random state
        training_random_state = RandomStateSnapshot.snapshot_random_state()
        # reset the random state for validation
        ml_util.set_random_seed(config.get_effective_random_seed(),
                                "Model validation")

    status_string = "training" if train_val_params.in_training_mode else "validation"
    item_start_time = time()
    num_load_time_warnings = 0
    num_load_time_exceeded = 0
    num_batches = 0
    total_extra_load_time = 0.0
    total_load_time = 0.0
    model_outputs_epoch = []
    for batch_index, sample in enumerate(train_val_params.data_loader):
        item_finish_time = time()
        item_load_time = item_finish_time - item_start_time
        # Having slow minibatch loading is OK in the very first batch of the every epoch, where processes
        # are spawned. Later, the load time should be zero.
        if batch_index == 0:
            logging.info(
                f"Loaded the first minibatch of {status_string} data in {item_load_time:0.2f} sec."
            )
        elif item_load_time > MAX_ITEM_LOAD_TIME_SEC:
            num_load_time_exceeded += 1
            total_extra_load_time += item_load_time
            if num_load_time_warnings < MAX_LOAD_TIME_WARNINGS:
                logging.warning(
                    f"Loading {status_string} minibatch {batch_index} took {item_load_time:0.2f} sec. "
                    f"This can mean that there are not enough data loader worker processes, or that there "
                    f"is a "
                    f"performance problem in loading. This warning will be printed at most "
                    f"{MAX_LOAD_TIME_WARNINGS} times.")
                num_load_time_warnings += 1
        model_outputs_minibatch = training_steps.forward_and_backward_minibatch(
            sample, batch_index, train_val_params.epoch)
        model_outputs_epoch.append(model_outputs_minibatch)
        train_finish_time = time()
        logging.debug(
            f"Epoch {train_val_params.epoch} {status_string} batch {batch_index}: "
            f"Loaded in {item_load_time:0.2f}sec, "
            f"{status_string} in {(train_finish_time - item_finish_time):0.2f}sec. "
            f"Loss = {model_outputs_minibatch.loss}")
        total_load_time += item_finish_time - item_start_time
        num_batches += 1
        item_start_time = time()

    # restore the training random state when validation has finished
    if training_random_state is not None:
        training_random_state.restore_random_state()

    epoch_time_seconds = time() - epoch_start_time
    logging.info(
        f"Epoch {train_val_params.epoch} {status_string} took {epoch_time_seconds:0.2f} sec, "
        f"of which waiting for next minibatch took {total_load_time:0.2f} sec total. {num_batches} "
        "minibatches in total.")
    if num_load_time_exceeded > 0:
        logging.warning(
            "The dataloaders were not fast enough to always supply the next batch in less than "
            f"{MAX_ITEM_LOAD_TIME_SEC}sec.")
        logging.warning(
            f"In this epoch, {num_load_time_exceeded} out of {num_batches} batches exceeded the load time "
            f"threshold. The total loading time for the slow batches was {total_extra_load_time:0.2f}sec."
        )

    _metrics = training_steps.get_epoch_results_and_store(epoch_time_seconds) \
        if train_val_params.save_metrics else MetricsDict()
    return ModelOutputsAndMetricsForEpoch(
        metrics=_metrics,
        model_outputs=model_outputs_epoch,
        is_train=train_val_params.in_training_mode)