def test_amp_and_parallel_for_scalar_models(
        test_output_dirs: TestOutputDirectories,
        execution_mode: ModelExecutionMode, use_mixed_precision: bool) -> None:
    """
    Tests the mix precision flag and data parallel for scalar models.
    """
    assert machine_has_gpu, "This test must be executed on a GPU machine."
    assert torch.cuda.device_count(
    ) > 1, "This test must be executed on a multi-GPU machine"
    config = ClassificationModelForTesting()
    config.use_mixed_precision = use_mixed_precision
    model = DummyScalarModel(
        expected_image_size_zyx=config.expected_image_size_zyx,
        activation=Identity())
    model.use_mixed_precision = use_mixed_precision
    model_and_info = ModelAndInfo(model=model,
                                  model_execution_mode=execution_mode)
    # This is the same logic spelt out in update_model_for_multiple_gpu
    # execution_mode == ModelExecutionMode.TRAIN or (not use_model_parallel), which is always True in our case
    use_data_parallel = True
    model_and_info = model_util.update_model_for_multiple_gpus(
        model_and_info, config)
    if use_data_parallel:
        assert isinstance(model_and_info.model, DataParallelModel)
    data_loaders = config.create_data_loaders()
    gradient_scaler = GradScaler() if use_mixed_precision else None
    train_val_parameters: TrainValidateParameters = TrainValidateParameters(
        model=model_and_info.model,
        data_loader=data_loaders[execution_mode],
        in_training_mode=execution_mode == ModelExecutionMode.TRAIN,
        gradient_scaler=gradient_scaler,
        dataframe_loggers=MetricsDataframeLoggers(
            Path(test_output_dirs.root_dir)),
        summary_writers=SummaryWriters(train=None, val=None)  # type: ignore
    )
    training_steps = ModelTrainingStepsForScalarModel(config,
                                                      train_val_parameters)
    sample = list(data_loaders[execution_mode])[0]
    model_input = get_scalar_model_inputs_and_labels(config, model, sample)
    logits, posteriors, loss = training_steps._compute_model_output_and_loss(
        model_input)
    # When using DataParallel, we expect to get a list of tensors back, one per GPU.
    if use_data_parallel:
        assert isinstance(logits, list)
        first_logit = logits[0]
    else:
        first_logit = logits
    if use_mixed_precision:
        assert first_logit.dtype == torch.float16
        assert posteriors.dtype == torch.float16
        # BCEWithLogitsLoss outputs float32, even with float16 args
        assert loss.dtype == torch.float32
    else:
        assert first_logit.dtype == torch.float32
        assert posteriors.dtype == torch.float32
        assert loss.dtype == torch.float32
    # Verify that forward pass does not throw. It would for example if it fails to gather tensors or not convert
    # float16 to float32
    _, _, _ = training_steps._compute_model_output_and_loss(model_input)
def create_model_training_steps(model_config: ModelConfigBase,
                                train_val_params: TrainValidateParameters) -> ModelTrainingStepsBase:
    """
    Create model training steps based on the model config and train/val parameters
    :param model_config: Model configs to use
    :param train_val_params: Train/Val parameters to use
    :return:
    """
    if isinstance(model_config, SegmentationModelBase):
        return ModelTrainingStepsForSegmentation(model_config, train_val_params)
    elif isinstance(model_config, ScalarModelBase):
        if isinstance(model_config, SequenceModelBase):
            return ModelTrainingStepsForSequenceModel(model_config, train_val_params)
        else:
            return ModelTrainingStepsForScalarModel(model_config, train_val_params)
    else:
        raise NotImplementedError(f"There are no model training steps defined for config type {type(model_config)}")