コード例 #1
0
 def test_get_ts_features_to_preprocess(self):
     expected_ts_features = {
         constants.GOOGLE_MOBILITY_PARKS,
         constants.GOOGLE_MOBILITY_WORK,
         constants.GOOGLE_MOBILITY_RES,
         constants.GOOGLE_MOBILITY_TRANSIT,
         constants.GOOGLE_MOBILITY_GROCERY,
         constants.GOOGLE_MOBILITY_RETAIL,
         constants.TOTAL_TESTS,
         constants.JAPAN_PREFECTURE_STATE_OF_EMERGENCY_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_DISCHARGED_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_HOSPITALIZED_FEATURE_KEY,
         constants.
         JAPAN_PREFECTURE_EFFECTIVE_REPRODUCTIVE_NUMBER_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_SURVEY_FEATURE_KEY,
         constants.
         JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_UNWEIGHTED_SURVEY_FEATURE_KEY,
         constants.
         JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_PERCENT_SURVEY_FEATURE_KEY,
         constants.
         JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_PERCENT_UNWEIGHTED_SURVEY_FEATURE_KEY,
         constants.DOW_WINDOW,
     }
     japan_model = japan_model_definitions.PrefectureModelDefinition()
     actual_ts_features = japan_model.get_ts_features_to_preprocess()
     np.testing.assert_equal(expected_ts_features, actual_ts_features)
コード例 #2
0
    def test_extract_japan_prefecture_ts_prefectures(self):
        model = japan_model_definitions.PrefectureModelDefinition()
        feature_name_map = model.get_ts_features()

        def _dummy_dict(feat):
            return {
                "feature_name": feat,
                "feature_value": 100,
                "dt": np.datetime64("2020-01-22"),
                "geo_id": "4059"
            }

        ts_data = pd.DataFrame(
            [_dummy_dict(feat) for feat in feature_name_map.values()])

        static_data = pd.DataFrame([{
            "feature_name": constants.POPULATION,
            "feature_value": 120,
            "geo_id": "4059"
        }])

        static_features, _ = model._extract_static_features(
            static_data=static_data, locations=["4059"])

        actual, _ = model._extract_ts_features(ts_data=ts_data,
                                               static_features=static_features,
                                               locations=["4059"],
                                               training_window_size=2)
        self.assertIsNotNone(actual)
コード例 #3
0
 def test_get_ts_features(self):
     expected_ts_features = {
         constants.DEATH:
         constants.JAPAN_PREFECTURE_DEATH_FEATURE_KEY,
         constants.CONFIRMED:
         constants.JAPAN_PREFECTURE_CONFIRMED_FEATURE_KEY,
         constants.RECOVERED_DOC:
         constants.JAPAN_PREFECTURE_DISCHARGED_FEATURE_KEY,
         constants.HOSPITALIZED:
         constants.JAPAN_PREFECTURE_HOSPITALIZED_FEATURE_KEY,
         constants.GOOGLE_MOBILITY_PARKS:
         constants.
         JAPAN_PREFECTURE_MOBILITY_PARKS_PERCENT_FROM_BASELINE_FEATURE_KEY,
         constants.GOOGLE_MOBILITY_WORK:
         constants.
         JAPAN_PREFECTURE_MOBILITY_WORKPLACE_PERCENT_FROM_BASELINE_FEATURE_KEY,
         constants.GOOGLE_MOBILITY_RES:
         constants.
         JAPAN_PREFECTURE_MOBILITY_RESIDENTIAL_PERCENT_FROM_BASELINE_FEATURE_KEY,
         constants.GOOGLE_MOBILITY_TRANSIT:
         constants.
         JAPAN_PREFECTURE_MOBILITY_TRAIN_STATION_PERCENT_FROM_BASELINE_FEATURE_KEY,
         constants.GOOGLE_MOBILITY_GROCERY:
         constants.
         JAPAN_PREFECTURE_MOBILITY_GROCERY_AND_PHARMACY_PERCENT_FROM_BASELINE_FEATURE_KEY,
         constants.GOOGLE_MOBILITY_RETAIL:
         constants.
         JAPAN_PREFECTURE_MOBILITY_RETAIL_AND_RECREATION_PERCENT_FROM_BASELINE_FEATURE_KEY,
         constants.TOTAL_TESTS:
         constants.JAPAN_PREFECTURE_TESTED_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_STATE_OF_EMERGENCY_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_STATE_OF_EMERGENCY_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_DISCHARGED_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_DISCHARGED_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_HOSPITALIZED_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_HOSPITALIZED_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_EFFECTIVE_REPRODUCTIVE_NUMBER_FEATURE_KEY:
         constants.
         JAPAN_PREFECTURE_EFFECTIVE_REPRODUCTIVE_NUMBER_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_SURVEY_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_SURVEY_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_UNWEIGHTED_SURVEY_FEATURE_KEY:
         constants.
         JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_UNWEIGHTED_SURVEY_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_PERCENT_SURVEY_FEATURE_KEY:
         constants.
         JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_PERCENT_SURVEY_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_PERCENT_UNWEIGHTED_SURVEY_FEATURE_KEY:
         constants.
         JAPAN_PREFECTURE_COVID_LIKE_ILLNESS_PERCENT_UNWEIGHTED_SURVEY_FEATURE_KEY,
         constants.DOW_WINDOW:
         constants.DOW_WINDOW,
     }
     japan_model = japan_model_definitions.PrefectureModelDefinition()
     actual_ts_features = japan_model.get_ts_features()
     self.assertDictEqual(expected_ts_features, actual_ts_features)
コード例 #4
0
    def test_extract_japan_prefecture_static_features(self):
        static_data = pd.DataFrame([{
            "feature_name": constants.INCOME_PER_CAPITA,
            "feature_value": 120,
            "geo_id": "4058"
        }])

        model = japan_model_definitions.PrefectureModelDefinition()
        actual, _ = model._extract_static_features(static_data=static_data,
                                                   locations=["US", "IR"])
        # TODO(joelshor): Add actual checks.
        self.assertIsNotNone(actual)
コード例 #5
0
  def test_japan_sanity(self):
    # Create the basic model.
    model_definition = japan_model_definitions.PrefectureModelDefinition()
    tf_seir_model = tf_seir.TfSeir(
        model_type="TIME_VARYING_WITH_COVARIATES",
        location_granularity="JAPAN_PREFECTURE",
        model_definition=model_definition,
        covariate_delay=0,
        random_seed=1,
        fine_tuning_steps=1)

    # Generate some data that is shaped correctly but is nonsense.
    train_window_size = 90
    full_sim_steps = train_window_size + 7
    gt_confirmed = np.array(range(1, full_sim_steps + 1))

    model_spec = model_definition.get_model_spec(
        constants.MODEL_TYPE_TIME_VARYING_WITH_COVARIATES)

    # Format the data as the fit pipeline expects.
    required_ts_constants = [
        constants.CONFIRMED,
        constants.INFECTED,
        constants.RECOVERED_DOC,
        constants.HOSPITALIZED,
        constants.HOSPITALIZED_CUMULATIVE,
        constants.HOSPITALIZED_INCREASE,
        constants.DEATH,
        constants.ICU,
        constants.VENTILATOR,
    ] + model_spec.covariate_names
    required_static_constants = [
        constants.POPULATION,
        constants.DENSITY,
        constants.JAPAN_PREFECTURE_AGE_0_TO_14_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_AGE_15_TO_64_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_AGE_64_PLUS_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_AGE_75_PLUS_FEATURE_KEY,
        constants.INCOME_PER_CAPITA,
        constants.JAPAN_PREFECTURE_NUM_DOCTORS_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_DOCTORS_PER_100K_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_NUM_HOSPITAL_BEDS_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_NUM_HOSPITAL_BEDS_PER_100K_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_NUM_CLINIC_BEDS_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_NUM_CLINIC_BEDS_PER_100K_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_NUM_NEW_ICU_BEDS_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_H1N1_in_2010_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_ALCOHOL_INTAKE_SCORE_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_BMI_MALE_AVERAGE_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_BMI_FEMALE_LOWER_RANGE_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_SMOKERS_MALE_FEATURE_KEY,
        constants.JAPAN_PREFECTURE_SMOKERS_FEMALE_FEATURE_KEY,
    ]
    jp_ = "JAPAN_PREFECTURE"
    ts_features = {c: {jp_: gt_confirmed} for c in required_ts_constants}
    static_features = {c: {jp_: 0.0} for c in required_static_constants}
    # Sanity check the fit forecast function.
    # TODO(joelshor): Consider using `fit_forecast_moving_window`, which is what
    # is actually used in `fit_forecast_pipeline`.
    with patch("tensorflow.function", lambda func: func):
      # Tests with @tf.function turned into no-op to save time.
      tf_seir_model.fit_forecast_fixed(
          train_window_end_index=90,
          train_window_end_date=parser.parse("4/3/2020 00:00 UTC"),
          num_forecast_steps=7,
          num_train_forecast_steps=7,
          static_features=static_features,
          static_overrides=None,
          ts_features=ts_features,
          ts_overrides=None,
          ts_categorical_features=None,
          ts_state_features=None,
          locations=["JAPAN_PREFECTURE"],
          num_iterations=1,  # execute quickly, not converge.
          display_iterations=100,
          optimization="RMSprop",
          training_data_generator=False,
          quantile_regression=False,
          static_scalers=None,
          ts_scalers=None,
          ts_state_scalers=None)
コード例 #6
0
 def test_get_static_features(self):
     expected_static_features = {
         # Population and demographics.
         constants.POPULATION:
         constants.JAPAN_PREFECTURE_NUM_PEOPLE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_MALE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_MALE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_FEMALE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_FEMALE_FEATURE_KEY,
         constants.DENSITY:
         constants.JAPAN_PREFECTURE_POPULATION_DENSITY_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_AGE_0_TO_14_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_AGE_0_TO_14_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_AGE_15_TO_64_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_AGE_15_TO_64_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_AGE_64_PLUS_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_AGE_64_PLUS_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_AGE_75_PLUS_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_AGE_75_PLUS_FEATURE_KEY,
         constants.INCOME_PER_CAPITA:
         constants.JAPAN_PREFECTURE_GDP_PER_CAPITA_FEATURE_KEY,
         # Hospital resources.
         constants.JAPAN_PREFECTURE_NUM_DOCTORS_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_DOCTORS_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_DOCTORS_PER_100K_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_DOCTORS_PER_100K_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_HOSPITAL_BEDS_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_HOSPITAL_BEDS_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_HOSPITAL_BEDS_PER_100K_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_HOSPITAL_BEDS_PER_100K_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_CLINIC_BEDS_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_CLINIC_BEDS_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_CLINIC_BEDS_PER_100K_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_CLINIC_BEDS_PER_100K_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_NUM_NEW_ICU_BEDS_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_NUM_NEW_ICU_BEDS_FEATURE_KEY,
         # Wellness and health.
         constants.JAPAN_PREFECTURE_H1N1_in_2010_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_H1N1_in_2010_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_ALCOHOL_INTAKE_SCORE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_ALCOHOL_INTAKE_SCORE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_BMI_MALE_AVERAGE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_BMI_MALE_AVERAGE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_BMI_MALE_LOWER_RANGE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_BMI_MALE_LOWER_RANGE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_BMI_MALE_UPPER_RANGE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_BMI_MALE_UPPER_RANGE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_BMI_FEMALE_AVERAGE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_BMI_FEMALE_AVERAGE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_BMI_FEMALE_LOWER_RANGE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_BMI_FEMALE_LOWER_RANGE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_BMI_FEMALE_UPPER_RANGE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_BMI_FEMALE_UPPER_RANGE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_SMOKERS_MALE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_SMOKERS_MALE_FEATURE_KEY,
         constants.JAPAN_PREFECTURE_SMOKERS_FEMALE_FEATURE_KEY:
         constants.JAPAN_PREFECTURE_SMOKERS_FEMALE_FEATURE_KEY,
     }
     japan_model = japan_model_definitions.PrefectureModelDefinition()
     actual_static_features = japan_model.get_static_features()
     np.testing.assert_equal(expected_static_features,
                             actual_static_features)