def test_transform(create_data):
        data, _ = create_data()

        t = OneHotEncoder().fit(data)
        data, expected_data = create_data(165)
        actual_data = t.transform(data)
        tm.assert_frame_equal(actual_data, expected_data)

        data = pd.concat((data.iloc[:, :2], data.iloc[:, 5:], data.iloc[:, 2:5]), axis=1)
        actual_data = t.transform(data)
        tm.assert_frame_equal(actual_data, expected_data)
def fit_and_prepare(x_train, y_train, test_df):

    # 3.1. Prepare Y-----
    y_train.specific_death = y_train.specific_death.astype(bool)

    # Transform it into a structured array
    y_train = y_train.to_records(index=False)

    # 3.2. Prepare X-----
    # obtain the x variables that are categorical
    categorical_feature_mask = x_train.dtypes == object

    # Filter categorical columns using mask and turn it into a list
    categorical_cols = x_train.columns[categorical_feature_mask].tolist()

    # Ensure categorical columns are category type
    for col in categorical_cols:
        x_train[col] = x_train[col].astype('category')
        test_df[col] = test_df[col].astype('category')

    # 3.3. Fit model-----
    # initiate
    encoder = OneHotEncoder()
    estimator = CoxPHSurvivalAnalysis()

    # fit model
    estimator.fit(encoder.fit_transform(x_train), y_train)

    # transform the test variables to match the train
    x_test = encoder.transform(test_df)

    return (estimator, x_test, x_train, y_train)
    def test_transform_other_columns(create_data):
        data, _ = create_data()

        t = OneHotEncoder().fit(data)
        data, _ = create_data(125)

        data_renamed = data.rename(columns={"binary_1": "renamed_1"})
        with pytest.raises(
                ValueError,
                match=r"1 features are missing from data: \['binary_1'\]"):
            t.transform(data_renamed)

        data_dropped = data.drop('trinary', axis=1)
        with pytest.raises(
                ValueError,
                match=r"1 features are missing from data: \['trinary'\]"):
            t.transform(data_dropped)

        data_renamed = data.rename(columns={
            "binary_1": "renamed_1",
            "many": "too_many"
        })
        with pytest.raises(
                ValueError,
                match=
                r"2 features are missing from data: \['binary_1', 'many'\]"):
            t.transform(data_renamed)