def validate_correction_factor(name, df, trainer_handle, model_parameter,
                               model_handle):
    """Validate a correction factor with a given bitrate model and trainer handle"""
    trainer = trainer_handle(df, model_handle, model_parameter)
    params = trainer_handle.params_from_mp(model_parameter)
    expected_cf, predicted_cf = trainer.validate(params)
    print_rmse_and_pc(
        f"{name} validation results",
        get_rmse(expected_cf, predicted_cf),
        get_pc(expected_cf, predicted_cf),
    )
def train_correction_factor(name, df, trainer_handle, model_handle,
                            init_mode_parameter):
    """Train a correction factor with a given bitrate model and trainer handle"""
    trainer = trainer_handle(df, model_handle, init_mode_parameter)
    param_min = trainer.train()
    expected_cf, predicted_cf = trainer.validate(param_min)
    print_rmse_and_pc(
        f"{name} training results",
        get_rmse(expected_cf, predicted_cf),
        get_pc(expected_cf, predicted_cf),
    )
    return param_min
Beispiel #3
0
def validate(model, dfw, input_scaler, output_scaler, data_name="test"):
    """Validate the trained ML model on the test set"""
    in_test = input_scaler.transform(dfw.get_ml_input())
    in_test = tf.convert_to_tensor(in_test, dtype=tf.float32)
    out_test = dfw.get_ml_output()

    print(f"Evaluating on {data_name} data...")
    results = model.evaluate(in_test, output_scaler.transform(out_test))
    print(f"{data_name} loss, accuracy: ", results)

    print(f"Predicting on {data_name} data...")
    output = model.predict(in_test)
    out_prediction = output_scaler.inverse_transform(output)

    max_bitrates = dfw.get_ml_max_output()
    print_rmse_normalized_rmse_and_pc(
        "Bitrate average",
        get_rmse(out_test, out_prediction) * 0.001,
        get_rmse(
            out_test.transpose() / max_bitrates,
            out_prediction.transpose() / max_bitrates,
        ),
        get_pc(out_test.transpose(), out_prediction.transpose()),
    )
Beispiel #4
0
def plot_bitrates(bitrate_measured, bitrate_estimated, key, decimals=-1):
    """
    Plot measured over estimated bitrates
    By default rounded to -1 decimals as plot as larger plot cannot be handled by latex
    Accuracy is still fine as plot is shown in very small size
    """
    rmse = get_rmse(bitrate_measured, bitrate_estimated)
    pc = get_pc(bitrate_measured, bitrate_estimated)
    plt.figure()
    estimated = np.around(bitrate_estimated, decimals=decimals)
    measured = np.around(bitrate_measured, decimals=decimals)
    plt.scatter(estimated, measured, marker="x")
    plt.xlabel("Estimated Bitrate [kBit/s]")
    plt.ylabel("Measured Bitrate [kBit/s]")
    plt.title(
        f"Bitrate model comparison {key} (RMSE: {rmse:.2f}, PC: {pc:.5f})")
def validate(dfw, model_parameter, model_handle, full_evaluation):
    """
    Validate all supported correction factors
    """
    video_keys = dfw.video_keys
    bitrate_model = model_handle(model_parameter)
    rmses = []
    nrmses = []
    pcs = []

    for key in video_keys.tolist():
        bitrate_measured, bitrate_estimated = validate_bitrates_for_video(
            get_df_for_evaluation(dfw, model_handle, full_evaluation),
            bitrate_model,
            key,
        )
        rmses.append(get_rmse(bitrate_measured, bitrate_estimated))
        nrmses.append(rmses[-1] / max(bitrate_measured))
        pcs.append(get_pc(bitrate_measured.T, bitrate_estimated.T))
        print_rmse_normalized_rmse_and_pc(f"Bitrate validation {key}",
                                          rmses[-1], nrmses[-1], pcs[-1])

    validate_correction_factor("Rmax", dfw.include(), RMaxTrainer,
                               model_parameter, model_handle)
    validate_correction_factor(
        "SCF",
        dfw.include(variable_qp=True),
        SCFTrainer,
        model_parameter,
        model_handle,
    )
    validate_correction_factor(
        "TCF",
        dfw.include(variable_rate=True),
        TCFTrainer,
        model_parameter,
        model_handle,
    )
    validate_correction_factor(
        "NCF",
        dfw.include(variable_gop=True),
        NCFTrainer,
        model_parameter,
        model_handle,
    )
    validate_correction_factor(
        "RCF",
        dfw.include(variable_res=True),
        RCFTrainer,
        model_parameter,
        model_handle,
    )
    validate_correction_factor(
        "GCF",
        dfw.include(variable_k_size=True),
        GCFTrainer,
        model_parameter,
        model_handle,
    )
    df_gauss = dfw.include(variable_k_size=True, variable_sigma=True)
    validate_correction_factor(
        "SDCF",
        df_gauss.loc[df_gauss[KEYS.KSIZE] == 3],
        SDCFTrainer,
        model_parameter,
        model_handle,
    )
    print_rmse_normalized_rmse_and_pc(
        "Bitrate validation average",
        statistics.mean(rmses),
        statistics.mean(nrmses),
        statistics.mean(pcs),
    )