def test_error_if_input_df_contains_na_in_transform(df_vartypes, df_na):
    # test case 1: when dataset contains na, transform method
    age_dict = {"Age": [0, 10, 20, 30, np.Inf]}

    with pytest.raises(ValueError):
        transformer = ArbitraryDiscretiser(binning_dict=age_dict)
        transformer.fit(df_vartypes)
        transformer.transform(df_na[["Name", "City", "Age", "Marks", "dob"]])
def test_error_when_nan_introduced_during_transform():
    # test error when NA are introduced during the discretisation.
    rng = default_rng()

    # create dataframe with 2 variables, 1 normal and 1 skewed
    random = skewnorm.rvs(a=-50, loc=4, size=100)
    random = random - min(
        random)  # Shift so the minimum value is equal to zero.

    train = pd.concat(
        [
            pd.Series(rng.standard_normal(100)),
            pd.Series(random),
        ],
        axis=1,
    )

    train.columns = ["var_a", "var_b"]

    # create a dataframe with 2 variables normally distributed
    test = pd.concat(
        [
            pd.Series(rng.standard_normal(100)),
            pd.Series(rng.standard_normal(100)),
        ],
        axis=1,
    )

    test.columns = ["var_a", "var_b"]

    msg = ("During the discretisation, NaN values were introduced "
           "in the feature(s) var_b.")

    limits_dict = {"var_a": [-5, -2, 0, 2, 5], "var_b": [0, 2, 5]}

    # check for warning when errors equals 'ignore'
    with pytest.warns(UserWarning) as record:
        transformer = ArbitraryDiscretiser(binning_dict=limits_dict,
                                           errors="ignore")
        transformer.fit(train)
        transformer.transform(test)

    # check that only one warning was returned
    assert len(record) == 1
    # check that message matches
    assert record[0].message.args[0] == msg

    # check for error when errors equals 'raise'
    with pytest.raises(ValueError) as record:
        transformer = ArbitraryDiscretiser(binning_dict=limits_dict,
                                           errors="raise")
        transformer.fit(train)
        transformer.transform(test)

    # check that error message matches
    assert str(record.value) == msg