def test_get_ts_features_to_preprocess(self): expected_ts_features = { constants.MOBILITY_INDEX, constants.MOBILITY_SAMPLES, constants.CSRP_TESTS, constants.CONFIRMED_PER_CSRP_TESTS, constants.TOTAL_TESTS_PER_CAPITA, constants.AMP_RESTAURANTS, constants.AMP_NON_ESSENTIAL_BUSINESS, constants.AMP_STAY_AT_HOME, constants.AMP_SCHOOLS_SECONDARY_EDUCATION, constants.AMP_EMERGENCY_DECLARATION, constants.AMP_GATHERINGS, constants.AMP_FACE_MASKS, constants.DEATH_PREPROCESSED, constants.CONFIRMED_PREPROCESSED, constants.DOW_WINDOW, constants.TOTAL_TESTS_PER_CAPITA, constants.VACCINATED_RATIO_FIRST_DOSE_PER_DAY_PREPROCESSED, constants.VACCINATED_RATIO_SECOND_DOSE_PER_DAY_PREPROCESSED, } county_model = us_model_definitions.CountyModelDefinition( gt_source=constants.GT_SOURCE_JHU) actual_ts_features = county_model.get_ts_features_to_preprocess() np.testing.assert_equal(expected_ts_features, actual_ts_features)
def test_get_all_locations(self): input_df = pd.DataFrame( {constants.GEO_ID_COLUMN: ["4059", "4060", "4061", "4062"]}) # Exclude FIPS 15005 (Kalawao County, no longer exist) expected_locations = {"4059", "4060", "4061", "4062"} county_model = us_model_definitions.CountyModelDefinition( gt_source=constants.GT_SOURCE_JHU) actual_locations = county_model.get_all_locations(input_df) np.testing.assert_equal(expected_locations, actual_locations)
def test_extract_county_static_features(self): static_data = pd.DataFrame([{ "feature_name": constants.AREA, "feature_value": 10, "geo_id": "4059" }, { "feature_name": constants.AREA, "feature_value": 10, "geo_id": "4058" }, { "feature_name": constants.INCOME_PER_CAPITA, "feature_value": 120, "geo_id": "4058" }, { "feature_name": constants.INCOME_PER_CAPITA, "feature_value": 100, "geo_id": "4059" }, { "feature_name": constants.COUNTY_POPULATION, "feature_value": 70, "geo_id": "4059" }, { "feature_name": constants.COUNTY_POPULATION, "feature_value": 50, "geo_id": "4058" }, { "feature_name": constants.COUNTY_POPULATION, "feature_value": 10, "geo_id": "4057" }]) county_model = us_model_definitions.CountyModelDefinition( gt_source="JHU") actual, _ = county_model._extract_static_features( static_data=static_data, locations=["4059", "4058"]) expected = { constants.INCOME_PER_CAPITA: { "4059": 0, "4058": 1 }, constants.POPULATION: { "4059": 70, "4058": 50 } } for static_feature_name in expected: self.assertEqual( actual[static_feature_name], expected[static_feature_name], "Unexpected value for feature %s" % static_feature_name)
def test_get_ts_features(self): expected_ts_features = { constants.DEATH: constants.JHU_COUNTY_DEATH_FEATURE_KEY, constants.CONFIRMED: constants.JHU_COUNTY_CONFIRMED_FEATURE_KEY, constants.RECOVERED_DOC: constants.CSRP_RECOVERED_FEATURE_KEY, constants.HOSPITALIZED: constants.CHA_HOSPITALIZED_FEATURE_KEY, constants.HOSPITALIZED_CUMULATIVE: constants.CHA_HOSPITALIZED_CUMULATIVE_FEATURE_KEY, constants.ICU: constants.CSRP_ICU_FEATURE_KEY, constants.MOBILITY_INDEX: constants.MOBILITY_INDEX, constants.MOBILITY_SAMPLES: constants.MOBILITY_SAMPLES, constants.CSRP_TESTS: constants.CSRP_TESTS, constants.AMP_RESTAURANTS: constants.AMP_RESTAURANTS, constants.AMP_NON_ESSENTIAL_BUSINESS: constants.AMP_NON_ESSENTIAL_BUSINESS, constants.AMP_STAY_AT_HOME: constants.AMP_STAY_AT_HOME, constants.AMP_SCHOOLS_SECONDARY_EDUCATION: constants.AMP_SCHOOLS_SECONDARY_EDUCATION, constants.AMP_EMERGENCY_DECLARATION: constants.AMP_EMERGENCY_DECLARATION, constants.AMP_GATHERINGS: constants.AMP_GATHERINGS, constants.AMP_FACE_MASKS: constants.AMP_FACE_MASKS, constants.DOW_WINDOW: constants.DOW_WINDOW, constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL: constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL, constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL: constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL, } county_model = us_model_definitions.CountyModelDefinition( gt_source=constants.GT_SOURCE_JHU) actual_ts_features = county_model.get_ts_features() np.testing.assert_equal(expected_ts_features, actual_ts_features)
def test_extract_ts_county_features(self): ts_data = pd.DataFrame([ { "feature_name": "confirmed_cases", "feature_value": 100, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": "confirmed_cases", "feature_value": 200, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": "deaths", "feature_value": 10, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": "deaths", "feature_value": 13, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.MOBILITY_INDEX, "feature_value": 0.0, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.MOBILITY_INDEX, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.MOBILITY_SAMPLES, "feature_value": 10, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.MOBILITY_SAMPLES, "feature_value": 12, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.CSRP_TESTS, "feature_value": 70, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.CSRP_TESTS, "feature_value": 140, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_GATHERINGS, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_GATHERINGS, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.AMP_EMERGENCY_DECLARATION, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_EMERGENCY_DECLARATION, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.AMP_SCHOOLS_SECONDARY_EDUCATION, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_SCHOOLS_SECONDARY_EDUCATION, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.AMP_RESTAURANTS, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_RESTAURANTS, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.AMP_NON_ESSENTIAL_BUSINESS, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_NON_ESSENTIAL_BUSINESS, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.AMP_STAY_AT_HOME, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_STAY_AT_HOME, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.AMP_FACE_MASKS, "feature_value": 1.0, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.AMP_FACE_MASKS, "feature_value": 1.2, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.CSRP_RECOVERED_FEATURE_KEY, "feature_value": 12, "dt": np.datetime64("2020-01-23"), "geo_id": "4059", }, { "feature_name": constants.CSRP_RECOVERED_FEATURE_KEY, "feature_value": 11, "dt": np.datetime64("2020-01-22"), "geo_id": "4059", }, { "feature_name": constants.CHA_HOSPITALIZED_FEATURE_KEY, "feature_value": 100, "dt": np.datetime64("2020-01-22"), "geo_id": "4059", }, { "feature_name": constants.CHA_HOSPITALIZED_FEATURE_KEY, "feature_value": 200, "dt": np.datetime64("2020-01-23"), "geo_id": "4059", }, { "feature_name": constants.CHA_HOSPITALIZED_CUMULATIVE_FEATURE_KEY, "feature_value": 200, "dt": np.datetime64("2020-01-22"), "geo_id": "4059", }, { "feature_name": constants.CHA_HOSPITALIZED_CUMULATIVE_FEATURE_KEY, "feature_value": 300, "dt": np.datetime64("2020-01-23"), "geo_id": "4059", }, { "feature_name": constants.CSRP_ICU_FEATURE_KEY, "feature_value": 20, "dt": np.datetime64("2020-01-22"), "geo_id": "4059", }, { "feature_name": constants.CSRP_ICU_FEATURE_KEY, "feature_value": 30, "dt": np.datetime64("2020-01-23"), "geo_id": "4059", }, { "feature_name": constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL, "feature_value": 10, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL, "feature_value": 20, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, { "feature_name": constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL, "feature_value": 5, "dt": np.datetime64("2020-01-22"), "geo_id": "4059" }, { "feature_name": constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL, "feature_value": 10, "dt": np.datetime64("2020-01-23"), "geo_id": "4059" }, ]) static_data = pd.DataFrame([{ "feature_name": constants.AREA, "feature_value": 10, "geo_id": "4059" }, { "feature_name": constants.AREA, "feature_value": 10, "geo_id": "4058" }, { "feature_name": constants.INCOME_PER_CAPITA, "feature_value": 120, "geo_id": "4058" }, { "feature_name": constants.INCOME_PER_CAPITA, "feature_value": 100, "geo_id": "4059" }, { "feature_name": constants.COUNTY_POPULATION, "feature_value": 70, "geo_id": "4059" }, { "feature_name": constants.COUNTY_POPULATION, "feature_value": 50, "geo_id": "4058" }, { "feature_name": constants.COUNTY_POPULATION, "feature_value": 10, "geo_id": "4057" }]) state_model = us_model_definitions.CountyModelDefinition( gt_source="USAFACTS") static_features, _ = state_model._extract_static_features( static_data=static_data, locations=["4059"]) actual, _ = state_model._extract_ts_features( ts_data=ts_data, static_features=static_features, locations=["4059"], training_window_size=2) expected = { constants.DEATH: { "4059": np.array([10, 13], dtype="float32") }, constants.CONFIRMED: { "4059": np.array([100, 200], dtype="float32") }, constants.MOBILITY_SAMPLES: { "4059": np.array([0, 1], dtype="float32") }, constants.MOBILITY_INDEX: { "4059": np.array([0, 1], dtype="float32") }, constants.CSRP_TESTS: { "4059": np.array([0, 1], dtype="float32") }, constants.RECOVERED_DOC: { "4059": np.array([11, 12], dtype="float32"), }, constants.HOSPITALIZED: { "4059": np.array([100, 200], dtype="float32"), }, constants.HOSPITALIZED_CUMULATIVE: { "4059": np.array([200, 300], dtype="float32"), }, constants.ICU: { "4059": np.array([20, 30], dtype="float32"), }, constants.TOTAL_TESTS_PER_CAPITA: { "4059": np.array([0, 0], dtype="float32"), }, } for ts_feature_name in expected: self.assertIn(ts_feature_name, actual) np.testing.assert_equal( actual[ts_feature_name], expected[ts_feature_name], "Unexpected value for feature %s" % ts_feature_name)
def test_get_static_features(self): county_model = us_model_definitions.CountyModelDefinition( gt_source=constants.GT_SOURCE_JHU) actual_static_features = county_model.get_static_features() self.assertEqual(len(actual_static_features), 51)