コード例 #1
0
 def test_get_ts_features_to_preprocess(self):
     expected_ts_features = {
         constants.MOBILITY_INDEX,
         constants.MOBILITY_SAMPLES,
         constants.CSRP_TESTS,
         constants.CONFIRMED_PER_CSRP_TESTS,
         constants.TOTAL_TESTS_PER_CAPITA,
         constants.AMP_RESTAURANTS,
         constants.AMP_NON_ESSENTIAL_BUSINESS,
         constants.AMP_STAY_AT_HOME,
         constants.AMP_SCHOOLS_SECONDARY_EDUCATION,
         constants.AMP_EMERGENCY_DECLARATION,
         constants.AMP_GATHERINGS,
         constants.AMP_FACE_MASKS,
         constants.DEATH_PREPROCESSED,
         constants.CONFIRMED_PREPROCESSED,
         constants.DOW_WINDOW,
         constants.TOTAL_TESTS_PER_CAPITA,
         constants.VACCINATED_RATIO_FIRST_DOSE_PER_DAY_PREPROCESSED,
         constants.VACCINATED_RATIO_SECOND_DOSE_PER_DAY_PREPROCESSED,
     }
     county_model = us_model_definitions.CountyModelDefinition(
         gt_source=constants.GT_SOURCE_JHU)
     actual_ts_features = county_model.get_ts_features_to_preprocess()
     np.testing.assert_equal(expected_ts_features, actual_ts_features)
コード例 #2
0
 def test_get_all_locations(self):
     input_df = pd.DataFrame(
         {constants.GEO_ID_COLUMN: ["4059", "4060", "4061", "4062"]})
     # Exclude FIPS 15005 (Kalawao County, no longer exist)
     expected_locations = {"4059", "4060", "4061", "4062"}
     county_model = us_model_definitions.CountyModelDefinition(
         gt_source=constants.GT_SOURCE_JHU)
     actual_locations = county_model.get_all_locations(input_df)
     np.testing.assert_equal(expected_locations, actual_locations)
コード例 #3
0
    def test_extract_county_static_features(self):
        static_data = pd.DataFrame([{
            "feature_name": constants.AREA,
            "feature_value": 10,
            "geo_id": "4059"
        }, {
            "feature_name": constants.AREA,
            "feature_value": 10,
            "geo_id": "4058"
        }, {
            "feature_name": constants.INCOME_PER_CAPITA,
            "feature_value": 120,
            "geo_id": "4058"
        }, {
            "feature_name": constants.INCOME_PER_CAPITA,
            "feature_value": 100,
            "geo_id": "4059"
        }, {
            "feature_name": constants.COUNTY_POPULATION,
            "feature_value": 70,
            "geo_id": "4059"
        }, {
            "feature_name": constants.COUNTY_POPULATION,
            "feature_value": 50,
            "geo_id": "4058"
        }, {
            "feature_name": constants.COUNTY_POPULATION,
            "feature_value": 10,
            "geo_id": "4057"
        }])

        county_model = us_model_definitions.CountyModelDefinition(
            gt_source="JHU")
        actual, _ = county_model._extract_static_features(
            static_data=static_data, locations=["4059", "4058"])
        expected = {
            constants.INCOME_PER_CAPITA: {
                "4059": 0,
                "4058": 1
            },
            constants.POPULATION: {
                "4059": 70,
                "4058": 50
            }
        }

        for static_feature_name in expected:
            self.assertEqual(
                actual[static_feature_name], expected[static_feature_name],
                "Unexpected value for feature %s" % static_feature_name)
コード例 #4
0
 def test_get_ts_features(self):
     expected_ts_features = {
         constants.DEATH:
         constants.JHU_COUNTY_DEATH_FEATURE_KEY,
         constants.CONFIRMED:
         constants.JHU_COUNTY_CONFIRMED_FEATURE_KEY,
         constants.RECOVERED_DOC:
         constants.CSRP_RECOVERED_FEATURE_KEY,
         constants.HOSPITALIZED:
         constants.CHA_HOSPITALIZED_FEATURE_KEY,
         constants.HOSPITALIZED_CUMULATIVE:
         constants.CHA_HOSPITALIZED_CUMULATIVE_FEATURE_KEY,
         constants.ICU:
         constants.CSRP_ICU_FEATURE_KEY,
         constants.MOBILITY_INDEX:
         constants.MOBILITY_INDEX,
         constants.MOBILITY_SAMPLES:
         constants.MOBILITY_SAMPLES,
         constants.CSRP_TESTS:
         constants.CSRP_TESTS,
         constants.AMP_RESTAURANTS:
         constants.AMP_RESTAURANTS,
         constants.AMP_NON_ESSENTIAL_BUSINESS:
         constants.AMP_NON_ESSENTIAL_BUSINESS,
         constants.AMP_STAY_AT_HOME:
         constants.AMP_STAY_AT_HOME,
         constants.AMP_SCHOOLS_SECONDARY_EDUCATION:
         constants.AMP_SCHOOLS_SECONDARY_EDUCATION,
         constants.AMP_EMERGENCY_DECLARATION:
         constants.AMP_EMERGENCY_DECLARATION,
         constants.AMP_GATHERINGS:
         constants.AMP_GATHERINGS,
         constants.AMP_FACE_MASKS:
         constants.AMP_FACE_MASKS,
         constants.DOW_WINDOW:
         constants.DOW_WINDOW,
         constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL:
         constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL,
         constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL:
         constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL,
     }
     county_model = us_model_definitions.CountyModelDefinition(
         gt_source=constants.GT_SOURCE_JHU)
     actual_ts_features = county_model.get_ts_features()
     np.testing.assert_equal(expected_ts_features, actual_ts_features)
コード例 #5
0
    def test_extract_ts_county_features(self):
        ts_data = pd.DataFrame([
            {
                "feature_name": "confirmed_cases",
                "feature_value": 100,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": "confirmed_cases",
                "feature_value": 200,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": "deaths",
                "feature_value": 10,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": "deaths",
                "feature_value": 13,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.MOBILITY_INDEX,
                "feature_value": 0.0,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.MOBILITY_INDEX,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.MOBILITY_SAMPLES,
                "feature_value": 10,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.MOBILITY_SAMPLES,
                "feature_value": 12,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.CSRP_TESTS,
                "feature_value": 70,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.CSRP_TESTS,
                "feature_value": 140,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_GATHERINGS,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_GATHERINGS,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_EMERGENCY_DECLARATION,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_EMERGENCY_DECLARATION,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_SCHOOLS_SECONDARY_EDUCATION,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_SCHOOLS_SECONDARY_EDUCATION,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_RESTAURANTS,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_RESTAURANTS,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_NON_ESSENTIAL_BUSINESS,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_NON_ESSENTIAL_BUSINESS,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_STAY_AT_HOME,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_STAY_AT_HOME,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_FACE_MASKS,
                "feature_value": 1.0,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.AMP_FACE_MASKS,
                "feature_value": 1.2,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.CSRP_RECOVERED_FEATURE_KEY,
                "feature_value": 12,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059",
            },
            {
                "feature_name": constants.CSRP_RECOVERED_FEATURE_KEY,
                "feature_value": 11,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059",
            },
            {
                "feature_name": constants.CHA_HOSPITALIZED_FEATURE_KEY,
                "feature_value": 100,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059",
            },
            {
                "feature_name": constants.CHA_HOSPITALIZED_FEATURE_KEY,
                "feature_value": 200,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059",
            },
            {
                "feature_name":
                constants.CHA_HOSPITALIZED_CUMULATIVE_FEATURE_KEY,
                "feature_value": 200,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059",
            },
            {
                "feature_name":
                constants.CHA_HOSPITALIZED_CUMULATIVE_FEATURE_KEY,
                "feature_value": 300,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059",
            },
            {
                "feature_name": constants.CSRP_ICU_FEATURE_KEY,
                "feature_value": 20,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059",
            },
            {
                "feature_name": constants.CSRP_ICU_FEATURE_KEY,
                "feature_value": 30,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059",
            },
            {
                "feature_name": constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL,
                "feature_value": 10,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.VACCINES_GOVEX_FIRST_DOSE_TOTAL,
                "feature_value": 20,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL,
                "feature_value": 5,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            },
            {
                "feature_name": constants.VACCINES_GOVEX_SECOND_DOSE_TOTAL,
                "feature_value": 10,
                "dt": np.datetime64("2020-01-23"),
                "geo_id": "4059"
            },
        ])

        static_data = pd.DataFrame([{
            "feature_name": constants.AREA,
            "feature_value": 10,
            "geo_id": "4059"
        }, {
            "feature_name": constants.AREA,
            "feature_value": 10,
            "geo_id": "4058"
        }, {
            "feature_name": constants.INCOME_PER_CAPITA,
            "feature_value": 120,
            "geo_id": "4058"
        }, {
            "feature_name": constants.INCOME_PER_CAPITA,
            "feature_value": 100,
            "geo_id": "4059"
        }, {
            "feature_name": constants.COUNTY_POPULATION,
            "feature_value": 70,
            "geo_id": "4059"
        }, {
            "feature_name": constants.COUNTY_POPULATION,
            "feature_value": 50,
            "geo_id": "4058"
        }, {
            "feature_name": constants.COUNTY_POPULATION,
            "feature_value": 10,
            "geo_id": "4057"
        }])

        state_model = us_model_definitions.CountyModelDefinition(
            gt_source="USAFACTS")

        static_features, _ = state_model._extract_static_features(
            static_data=static_data, locations=["4059"])

        actual, _ = state_model._extract_ts_features(
            ts_data=ts_data,
            static_features=static_features,
            locations=["4059"],
            training_window_size=2)

        expected = {
            constants.DEATH: {
                "4059": np.array([10, 13], dtype="float32")
            },
            constants.CONFIRMED: {
                "4059": np.array([100, 200], dtype="float32")
            },
            constants.MOBILITY_SAMPLES: {
                "4059": np.array([0, 1], dtype="float32")
            },
            constants.MOBILITY_INDEX: {
                "4059": np.array([0, 1], dtype="float32")
            },
            constants.CSRP_TESTS: {
                "4059": np.array([0, 1], dtype="float32")
            },
            constants.RECOVERED_DOC: {
                "4059": np.array([11, 12], dtype="float32"),
            },
            constants.HOSPITALIZED: {
                "4059": np.array([100, 200], dtype="float32"),
            },
            constants.HOSPITALIZED_CUMULATIVE: {
                "4059": np.array([200, 300], dtype="float32"),
            },
            constants.ICU: {
                "4059": np.array([20, 30], dtype="float32"),
            },
            constants.TOTAL_TESTS_PER_CAPITA: {
                "4059": np.array([0, 0], dtype="float32"),
            },
        }

        for ts_feature_name in expected:
            self.assertIn(ts_feature_name, actual)
            np.testing.assert_equal(
                actual[ts_feature_name], expected[ts_feature_name],
                "Unexpected value for feature %s" % ts_feature_name)
コード例 #6
0
 def test_get_static_features(self):
     county_model = us_model_definitions.CountyModelDefinition(
         gt_source=constants.GT_SOURCE_JHU)
     actual_static_features = county_model.get_static_features()
     self.assertEqual(len(actual_static_features), 51)