def test_dedup_holiday_dict():
    """Tests dedup_holiday_dict"""
    countries = ["UnitedStates", "UnitedKingdom", "India",
                 "France", "Ireland"]
    year_start = 2019
    year_end = 2020
    # retrieves separate DataFrame for each country, with list of holidays
    holidays_dict = get_holidays(
        countries,
        year_start=year_start,
        year_end=year_end)
    # merges country DataFrames, removes duplicate holidays
    holiday_df = dedup_holiday_dict(holidays_dict)
    assert not holiday_df.duplicated().any()

    # ensure all country holidays are included
    for country, country_df in holidays_dict.items():
        joined = country_df.merge(holiday_df, on=[EVENT_DF_DATE_COL, EVENT_DF_LABEL_COL])
        assert joined.shape[0] == country_df.shape[0]  # checks if all values are contained in holiday_df
def test_generate_holiday_events2():
    """Tests proper handling of pre_num = 0 and post_num = 0"""
    countries = ["UnitedStates", "UnitedKingdom", "India", "France"]
    year_start = 2019
    year_end = 2020
    holidays_to_model_separately = [
        "New Year's Day",
        "Christmas Day",
        "Independence Day",
        "Thanksgiving",
        "Labor Day",
        "Good Friday",
        "Easter Monday [England, Wales, Northern Ireland]",
        "Memorial Day",
        "Veterans Day"]

    daily_event_df_dict1 = generate_holiday_events(
        countries=countries,
        holidays_to_model_separately=holidays_to_model_separately,
        year_start=year_start,
        year_end=year_end,
        pre_num=0,
        post_num=0)

    holidays_dict = get_holidays(
        countries,
        year_start=year_start,
        year_end=year_end)
    # merges country DataFrames, removes duplicate holidays
    holiday_df = dedup_holiday_dict(holidays_dict)
    # creates separate DataFrame for each holiday
    daily_event_df_dict2 = split_events_into_dictionaries(
        holiday_df,
        holidays_to_model_separately)

    assert daily_event_df_dict1.keys() == daily_event_df_dict2.keys()