Example #1
0
def _make_docs_project(user):
    """
    Creates a project based on docs-project.json with forecasts from docs-predictions.json.
    """
    found_project = Project.objects.filter(name=DOCS_PROJECT_NAME).first()
    if found_project:
        click.echo("* deleting previous project: {}".format(found_project))
        found_project.delete()

    project = create_project_from_json(
        Path('forecast_app/tests/projects/docs-project.json'), user)  # atomic
    project.name = DOCS_PROJECT_NAME
    project.save()

    load_truth_data(
        project, Path('forecast_app/tests/truth_data/docs-ground-truth.csv'))

    forecast_model = ForecastModel.objects.create(project=project,
                                                  name='docs forecast model',
                                                  abbreviation='docs_mod')
    time_zero = project.timezeros.filter(
        timezero_date=datetime.date(2011, 10, 2)).first()
    forecast = Forecast.objects.create(forecast_model=forecast_model,
                                       source='docs-predictions.json',
                                       time_zero=time_zero,
                                       notes="a small prediction file")
    with open('forecast_app/tests/predictions/docs-predictions.json') as fp:
        json_io_dict_in = json.load(fp)
        load_predictions_from_json_io_dict(forecast, json_io_dict_in,
                                           False)  # atomic
        cache_forecast_metadata(forecast)  # atomic

    return project, time_zero, forecast_model, forecast
    def test_load_predictions_from_cdc_csv_file(self):
        # sanity-check that the predictions get converted and then loaded into the database
        project = Project.objects.create()
        make_cdc_units_and_targets(project)

        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='model',
                                                      abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project,
                                            timezero_date=datetime.date(
                                                2017, 1, 1))
        cdc_csv_path = Path(
            'forecast_app/tests/EW1-KoTsarima-2017-01-17-small.csv'
        )  # EW01 2017
        forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           source=cdc_csv_path.name,
                                           time_zero=time_zero)

        with open(self.cdc_csv_path) as cdc_csv_fp:
            json_io_dict = json_io_dict_from_cdc_csv_file(2011, cdc_csv_fp)
            load_predictions_from_json_io_dict(forecast, json_io_dict, False)
        self.assertEqual(729, forecast.get_num_rows())
        self.assertEqual(722,
                         forecast.bin_distribution_qs().count())  # 729 - 7
        self.assertEqual(0, forecast.named_distribution_qs().count())
        self.assertEqual(7, forecast.point_prediction_qs().count())
        self.assertEqual(0, forecast.sample_distribution_qs().count())
        self.assertEqual(0, forecast.quantile_prediction_qs().count())
Example #3
0
    def test_last_update(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)

        # one truth and one forecast (yes truth, yes forecasts)
        self.assertEqual(forecast.created_at, project.last_update())

        # add a second forecast for a newer timezero (yes truth, yes forecasts)
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(
            forecast_model=forecast_model,
            source='docs-predictions-non-dup.json',
            time_zero=time_zero2,
            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions-non-dup.json'
        ) as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2,
                                               json_io_dict_in,
                                               is_validate_cats=False)
        self.assertEqual(forecast2.created_at, project.last_update())
    def test_load_predictions_from_json_io_dict(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(is_create_super=True)
        project = create_project_from_json(Path('forecast_app/tests/projects/docs-project.json'), po_user)
        forecast_model = ForecastModel.objects.create(project=project, name='name', abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project, timezero_date=datetime.date(2017, 1, 1))
        forecast = Forecast.objects.create(forecast_model=forecast_model, source='docs-predictions.json',
                                           time_zero=time_zero)

        # test json with no 'predictions'
        with self.assertRaises(RuntimeError) as context:
            load_predictions_from_json_io_dict(forecast, {}, False)
        self.assertIn("json_io_dict had no 'predictions' key", str(context.exception))

        # load all four types of Predictions, call Forecast.*_qs() functions. see docs-predictionsexp-rows.xlsx.

        # counts from docs-predictionsexp-rows.xlsx: point: 11, named: 3, bin: 30 (3 zero prob), sample: 23
        # = total rows: 67
        #
        # counts based on .json file:
        # - 'pct next week':    point: 3, named: 1 , bin: 3, sample: 5, quantile: 5 = 17
        # - 'cases next week':  point: 2, named: 1 , bin: 3, sample: 3, quantile: 2 = 12
        # - 'season severity':  point: 2, named: 0 , bin: 3, sample: 5, quantile: 0 = 10
        # - 'above baseline':   point: 1, named: 0 , bin: 2, sample: 6, quantile: 0 =  9
        # - 'Season peak week': point: 3, named: 0 , bin: 7, sample: 4, quantile: 3 = 16
        # = total rows: 64 - 2 zero prob = 62

        with open('forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict = json.load(fp)
            load_predictions_from_json_io_dict(forecast, json_io_dict, False)
        self.assertEqual(62, forecast.get_num_rows())
        self.assertEqual(16, forecast.bin_distribution_qs().count())  # 18 - 2 zero prob
        self.assertEqual(2, forecast.named_distribution_qs().count())
        self.assertEqual(11, forecast.point_prediction_qs().count())
        self.assertEqual(23, forecast.sample_distribution_qs().count())
        self.assertEqual(10, forecast.quantile_prediction_qs().count())
Example #5
0
def load_cdc_csv_forecast_file(season_start_year, forecast_model,
                               cdc_csv_file_path, time_zero):
    """
    Loads the passed cdc csv file into a new forecast_model Forecast for time_zero. NB: does not check if a Forecast
    already exists for time_zero and file_name. Is atomic so that an invalid forecast's data is not saved.

    :param season_start_year
    :param forecast_model: the ForecastModel to create the new Forecast in
    :param cdc_csv_file_path: string or Path to a CDC CSV forecast file. the CDC CSV file format is documented at
        https://predict.cdc.gov/api/v1/attachments/flusight/flu_challenge_2016-17_update.docx
    :param time_zero: the TimeZero this forecast applies to
    :return returns a new Forecast for it
    :raises RuntimeError: if the data could not be loaded
    """
    if time_zero not in forecast_model.project.timezeros.all():
        raise RuntimeError(
            f"time_zero was not in project. time_zero={time_zero}, "
            f"project timezeros={forecast_model.project.timezeros.all()}")

    cdc_csv_file_path = Path(cdc_csv_file_path)
    file_name = cdc_csv_file_path.name
    new_forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           time_zero=time_zero,
                                           source=file_name)
    with open(cdc_csv_file_path) as cdc_csv_file_fp:
        json_io_dict = json_io_dict_from_cdc_csv_file(season_start_year,
                                                      cdc_csv_file_fp)
        load_predictions_from_json_io_dict(new_forecast,
                                           json_io_dict,
                                           is_validate_cats=False)  # atomic
        cache_forecast_metadata(new_forecast)  # atomic
    return new_forecast
Example #6
0
    def test_load_forecast_skips_zero_values(self):
        forecast2 = Forecast.objects.create(forecast_model=self.forecast_model,
                                            time_zero=self.time_zero)
        with open('forecast_app/tests/predictions/cdc_zero_probabilities.json'
                  ) as fp:
            json_io_dict = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict, False)

        # test points: both should be there (points are not skipped)
        self.assertEqual(2, forecast2.point_prediction_qs().count())

        # test bins: 2 out of 6 have zero probabilities and should be skipped
        exp_bins = [
            ('HHS Region 1', '1 wk ahead', 0.2, None, 0.1, None, None,
             None),  # _i, _f, _t, _d, _b
            ('HHS Region 1', '1 wk ahead', 0.8, None, 0.2, None, None, None),
            ('US National', 'Season onset', 0.1, None, None, 'cat2', None,
             None),
            ('US National', 'Season onset', 0.9, None, None, 'cat3', None,
             None)
        ]
        bin_distribution_qs = forecast2.bin_distribution_qs() \
            .order_by('unit__name', 'target__name', 'prob') \
            .values_list('unit__name', 'target__name', 'prob', 'cat_i', 'cat_f', 'cat_t', 'cat_d', 'cat_b')
        self.assertEqual(4, bin_distribution_qs.count())
        self.assertEqual(exp_bins, list(bin_distribution_qs))
Example #7
0
    def test_load_predictions_from_json_io_dict(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'), po_user)
        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='name',
                                                      abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project,
                                            timezero_date=datetime.date(
                                                2017, 1, 1))
        forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           source='docs-predictions.json',
                                           time_zero=time_zero)

        # test json with no 'predictions'
        with self.assertRaises(RuntimeError) as context:
            load_predictions_from_json_io_dict(forecast, {},
                                               is_validate_cats=False)
        self.assertIn("json_io_dict had no 'predictions' key",
                      str(context.exception))

        # test loading all five types of Predictions
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict = json.load(fp)
            load_predictions_from_json_io_dict(forecast,
                                               json_io_dict,
                                               is_validate_cats=False)

        # test prediction element counts match number in .json file
        pred_ele_qs = forecast.pred_eles.all()
        pred_data_qs = PredictionData.objects.filter(
            pred_ele__forecast=forecast)
        self.assertEqual(29, len(pred_ele_qs))
        self.assertEqual(29, len(pred_data_qs))

        # test there's a prediction element for every .json item
        unit_name_to_obj = {unit.name: unit for unit in project.units.all()}
        target_name_to_obj = {
            target.name: target
            for target in project.targets.all()
        }
        for pred_ele_dict in json_io_dict['predictions']:
            unit = unit_name_to_obj[pred_ele_dict['unit']]
            target = target_name_to_obj[pred_ele_dict['target']]
            pred_class_int = PRED_CLASS_NAME_TO_INT[pred_ele_dict['class']]
            data_hash = PredictionElement.hash_for_prediction_data_dict(
                pred_ele_dict['prediction'])
            pred_ele = pred_ele_qs.filter(pred_class=pred_class_int,
                                          unit=unit,
                                          target=target,
                                          is_retract=False,
                                          data_hash=data_hash).first()
            self.assertIsNotNone(pred_ele)
            self.assertIsNotNone(
                pred_data_qs.filter(pred_ele=pred_ele).first())
Example #8
0
    def test_load_predictions_from_json_io_dict_none_prediction(self):
        # tests load_predictions_from_json_io_dict() where `"prediction": None`
        project, forecast_model, f1, tz1, tz2, u1, u2, u3, t1 = ProjectQueriesTestCase._set_up_as_of_case(
        )

        f2 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f2',
                                     time_zero=tz1)
        predictions = [{
            "unit": u1.name,
            "target": t1.name,
            "class": "named",
            "prediction": None
        }, {
            "unit": u2.name,
            "target": t1.name,
            "class": "point",
            "prediction": None
        }, {
            "unit": u2.name,
            "target": t1.name,
            "class": "sample",
            "prediction": None
        }, {
            "unit": u3.name,
            "target": t1.name,
            "class": "bin",
            "prediction": None
        }, {
            "unit": u3.name,
            "target": t1.name,
            "class": "quantile",
            "prediction": None
        }]
        load_predictions_from_json_io_dict(f2, {'predictions': predictions},
                                           is_validate_cats=False)
        self.assertEqual(5, f2.pred_eles.count())
        self.assertEqual('', f2.pred_eles.first().data_hash)

        # test loading an initial version that includes retractions (we are sure what this means, but it is valid and
        # should not fail :-)
        f3 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f3',
                                     time_zero=tz2)
        load_predictions_from_json_io_dict(f3, {'predictions': predictions},
                                           is_validate_cats=False)
        self.assertEqual(5, f3.pred_eles.count())
        self.assertEqual('', f3.pred_eles.first().data_hash)

        # test querying same
        try:
            rows = list(query_forecasts_for_project(project,
                                                    {}))  # list for generator
            self.assertEqual(1, len(rows))  # header
        except Exception as ex:
            self.fail(f"unexpected exception: {ex}")
    def test_cache_forecast_metadata_second_forecast(self):
        # make sure only the passed forecast is cached
        forecast2 = Forecast.objects.create(forecast_model=self.forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=self.time_zero,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)

        self.assertEqual(
            0,
            ForecastMetaPrediction.objects.filter(
                forecast=self.forecast).count())
        self.assertEqual(
            0,
            ForecastMetaUnit.objects.filter(forecast=self.forecast).count())
        self.assertEqual(
            0,
            ForecastMetaTarget.objects.filter(forecast=self.forecast).count())

        self.assertEqual(
            0,
            ForecastMetaPrediction.objects.filter(forecast=forecast2).count())
        self.assertEqual(
            0,
            ForecastMetaUnit.objects.filter(forecast=forecast2).count())
        self.assertEqual(
            0,
            ForecastMetaTarget.objects.filter(forecast=forecast2).count())

        cache_forecast_metadata(self.forecast)
        self.assertEqual(
            1,
            ForecastMetaPrediction.objects.filter(
                forecast=self.forecast).count())
        self.assertEqual(
            3,
            ForecastMetaUnit.objects.filter(forecast=self.forecast).count())
        self.assertEqual(
            5,
            ForecastMetaTarget.objects.filter(forecast=self.forecast).count())

        self.assertEqual(
            0,
            ForecastMetaPrediction.objects.filter(forecast=forecast2).count())
        self.assertEqual(
            0,
            ForecastMetaUnit.objects.filter(forecast=forecast2).count())
        self.assertEqual(
            0,
            ForecastMetaTarget.objects.filter(forecast=forecast2).count())
Example #10
0
    def test_forecast_metadata_counts_for_f_ids(self):
        forecast2 = Forecast.objects.create(forecast_model=self.forecast_model, source='docs-predictions-non-dup.json',
                                            time_zero=self.time_zero, notes="a small prediction file")
        with open('forecast_app/tests/predictions/docs-predictions-non-dup.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in, is_validate_cats=False)
        cache_forecast_metadata(self.forecast)
        cache_forecast_metadata(forecast2)

        forecasts_qs = self.forecast_model.forecasts.all()
        forecast_id_to_counts = forecast_metadata_counts_for_f_ids(forecasts_qs)
        #  f_id:  [(point_count, named_count, bin_count, sample_count, quantile_count), num_names, num_targets]
        # {   4:  [(11,          2,           6,         7,            3),              3,         5          ],
        #     5:  [(11,          2,           6,         7,            3),              3,         5          ]}
        self.assertEqual(sorted([self.forecast.id, forecast2.id]), sorted(forecast_id_to_counts.keys()))
        self.assertEqual([(11, 2, 6, 7, 3), 3, 5], forecast_id_to_counts[self.forecast.id])
        self.assertEqual([(11, 2, 6, 7, 3), 3, 5], forecast_id_to_counts[forecast2.id])
Example #11
0
    def test_query_scores_for_project_max_num_rows(self):
        # add some more predictions to get more to work with
        time_zero_2 = self.project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 9)).first()
        forecast2 = Forecast.objects.create(forecast_model=self.forecast_model,
                                            source='docs-predictions.json 2',
                                            time_zero=time_zero_2,
                                            notes="f2")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)

        forecast_model_2 = ForecastModel.objects.create(
            project=self.project,
            name='docs forecast model 2',
            abbreviation='docs_mod_2')
        time_zero_3 = self.project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 16)).first()
        forecast_3 = Forecast.objects.create(forecast_model=forecast_model_2,
                                             source='docs-predictions.json 3',
                                             time_zero=time_zero_3,
                                             notes="f3")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast_3, json_io_dict_in,
                                               False)

        Score.ensure_all_scores_exist()
        _update_scores_for_all_projects()

        try:
            list(query_scores_for_project(
                self.project, {},
                max_num_rows=14))  # actual number of rows = 14
        except Exception as ex:
            self.fail(f"unexpected exception: {ex}")

        with self.assertRaises(RuntimeError) as context:
            list(query_scores_for_project(self.project, {}, max_num_rows=13))
        self.assertIn("number of rows exceeded maximum",
                      str(context.exception))
Example #12
0
 def test_load_predictions_from_json_io_dict_existing_pred_eles(self):
     _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
         is_create_super=True)
     project, time_zero, forecast_model, forecast = _make_docs_project(
         po_user)
     json_io_dict = {
         "predictions": [{
             "unit": "location1",
             "target": "pct next week",
             "class": "point",
             "prediction": {
                 "value": 2.1
             }
         }]
     }
     with self.assertRaises(RuntimeError) as context:
         load_predictions_from_json_io_dict(forecast, json_io_dict)
     self.assertIn("cannot load data into a non-empty forecast",
                   str(context.exception))
Example #13
0
    def test_load_predictions_from_cdc_csv_file(self):
        # sanity-check that the predictions get converted and then loaded into the database
        project = Project.objects.create()
        make_cdc_units_and_targets(project)

        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='model',
                                                      abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project,
                                            timezero_date=datetime.date(
                                                2017, 1, 1))
        forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           time_zero=time_zero)

        with open(self.cdc_csv_path) as cdc_csv_fp:
            json_io_dict = json_io_dict_from_cdc_csv_file(2011, cdc_csv_fp)
            load_predictions_from_json_io_dict(forecast,
                                               json_io_dict,
                                               is_validate_cats=False)

        self.assertEqual(
            1 * 7 * 2,
            forecast.pred_eles.count())  # locations * targets * points/bins
Example #14
0
    def test_load_predictions_from_json_io_dict_phase_1(self):
        # tests pass 1/2 of load_predictions_from_json_io_dict(). NB: implicitly covers test_hash_for_prediction_dict()
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'), po_user)
        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='name',
                                                      abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project,
                                            timezero_date=datetime.date(
                                                2017, 1, 1))
        forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           source='docs-predictions.json',
                                           time_zero=time_zero)

        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict = json.load(fp)
            load_predictions_from_json_io_dict(forecast,
                                               json_io_dict,
                                               is_validate_cats=False)

        # test PredictionElement.forecast and is_retract
        self.assertEqual(29, forecast.pred_eles.count())
        self.assertEqual(
            0,
            PredictionElement.objects.filter(is_retract=True).count())

        exp_rows = [
            ('point', 'location1', 'pct next week',
             '2c343e1ea37e8b493c219066a8664276'),
            ('named', 'location1', 'pct next week',
             '58a7f8487958446d57333b262aaa8271'),
            ('point', 'location2', 'pct next week',
             '2b9db448ae1a3b7065ffee67d4857268'),
            ('bin', 'location2', 'pct next week',
             '7d1485af48de540dbcd954ee5cba51cb'),
            ('quantile', 'location2', 'pct next week',
             '0d3698e4e39456b8e36c750d73bb6870'),
            ('point', 'location3', 'pct next week',
             '5d321ea39f0af08cb3f40a58fa7c54d4'),
            ('sample', 'location3', 'pct next week',
             '0b431a76d5ad343981944c4b0792d738'),
            ('named', 'location1', 'cases next week',
             '845e3d041b6be23a381b6afd263fb113'),
            ('point', 'location2', 'cases next week',
             '2ed5d7d59eb10044644ab28a1b292efb'),
            ('sample', 'location2', 'cases next week',
             '74135c30ddfd5427c8b1e86b2989a642'),
            ('point', 'location3', 'cases next week',
             'a6ff82cc0637f67254df41352e1c00f9'),
            ('bin', 'location3', 'cases next week',
             'a74ea3f2472e0aec511eb1f604282220'),
            ('quantile', 'location3', 'cases next week',
             '838e6e3f77075f69eef3bb3d7bcdffdc'),
            ('point', 'location1', 'season severity',
             'bc55989f596fd157ccc6e3279b1f694a'),
            ('bin', 'location1', 'season severity',
             'ac263a19694da72f65e903c2ec2000d1'),
            ('point', 'location2', 'season severity',
             'ec5add7ea7a8abf3e68e9570d0b73898'),
            ('sample', 'location2', 'season severity',
             '51d07bda3e8a39da714f5767d93704ff'),
            ('point', 'location1', 'above baseline',
             '19d0e94bc24114abfa0d07ca41b8b3bf'),
            ('bin', 'location2', 'above baseline',
             '1b98c3c7b5b09d3ba0ea43566d5e9d03'),
            ('sample', 'location2', 'above baseline',
             'ae168d5bfdad1463672120d51787fed2'),
            ('sample', 'location3', 'above baseline',
             '380b79bea27bfa66e8864cfec9e3403a'),
            ('point', 'location1', 'Season peak week',
             'fad04bc4cd443ca7cd7cd53f5de4fa99'),
            ('bin', 'location1', 'Season peak week',
             'c74e3f626224eeb482368d9fb7a387da'),
            ('sample', 'location1', 'Season peak week',
             '2f0fdc8a293046d38eb912601cf0a5cf'),
            ('point', 'location2', 'Season peak week',
             '39c511635eb21cfde3657ab144521b94'),
            ('bin', 'location2', 'Season peak week',
             '4fa62ed754c3fc9b9ede90926efe8f7f'),
            ('quantile', 'location2', 'Season peak week',
             'd06cb30665b099e471c6dd9d50ba2c30'),
            ('point', 'location3', 'Season peak week',
             'f15fab078daf9adb53f464272b31dbf6'),
            ('sample', 'location3', 'Season peak week',
             '213d829834bceaaa4376a79b989161c3'),
        ]
        pred_data_qs = PredictionElement.objects \
            .filter(forecast=forecast) \
            .values_list('pred_class', 'unit__name', 'target__name', 'data_hash') \
            .order_by('id')
        act_rows = [(PredictionElement.prediction_class_int_as_str(row[0]),
                     row[1], row[2], row[3]) for row in pred_data_qs]
        self.assertEqual(sorted(exp_rows), sorted(act_rows))
Example #15
0
    def test_load_predictions_from_json_io_dict_dups(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'), po_user)
        tz1 = TimeZero.objects.create(project=project,
                                      timezero_date=datetime.date(2020, 10, 4))
        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='name',
                                                      abbreviation='abbrev')

        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict = json.load(fp)
            pred_dicts = json_io_dict[
                'predictions']  # get some prediction elements to work with (29)

        # per https://stackoverflow.com/questions/1937622/convert-date-to-datetime-in-python/1937636 :
        f1 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f1',
                                     time_zero=tz1,
                                     issued_at=datetime.datetime.combine(
                                         tz1.timezero_date,
                                         datetime.time(),
                                         tzinfo=datetime.timezone.utc))
        load_predictions_from_json_io_dict(f1, {
            'meta': {},
            'predictions': pred_dicts[:-2]
        })  # all but last 2 PEs

        # case: load the just-loaded file into a separate timezero -> should load all rows (duplicates are only within
        # the same timezero)
        tz2 = project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 9)).first()
        f2 = Forecast.objects.create(forecast_model=forecast_model,
                                     time_zero=tz2)
        load_predictions_from_json_io_dict(
            f2,
            {
                'meta': {},
                'predictions': pred_dicts[:-1]
            },  # all but last PE
            is_validate_cats=False)
        self.assertEqual(27, f1.pred_eles.count())
        self.assertEqual(28, f2.pred_eles.count())
        self.assertEqual(27 + 28,
                         project.num_pred_ele_rows_all_models(is_oracle=False))

        # case: load the same predictions into a different version -> none should load (they're all duplicates)
        f1.issued_at -= datetime.timedelta(days=1)
        f1.save()

        f3 = Forecast.objects.create(forecast_model=forecast_model,
                                     time_zero=tz1)
        load_predictions_from_json_io_dict(f3,
                                           json_io_dict,
                                           is_validate_cats=False)
        self.assertEqual(27, f1.pred_eles.count())
        self.assertEqual(28, f2.pred_eles.count())
        self.assertEqual(2, f3.pred_eles.count())  # 2 were new (non-dup)
        self.assertEqual(27 + 28 + 2,
                         project.num_pred_ele_rows_all_models(is_oracle=False))

        # case: load the same file, but change one multi-row prediction (a sample) to have partial duplication
        f3.issued_at -= datetime.timedelta(days=2)
        f3.save()
        quantile_pred_dict = [
            pred_dict for pred_dict in json_io_dict['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'pct next week') and (
                    pred_dict['class'] == 'quantile')
        ][0]
        # original: {"quantile": [0.025, 0.25, 0.5, 0.75,  0.975 ],
        #            "value":    [1.0,   2.2,  2.2,  5.0, 50.0  ]}
        quantile_pred_dict['prediction']['value'][0] = 2.2  # was 1.0
        f4 = Forecast.objects.create(forecast_model=forecast_model,
                                     time_zero=tz1)
        load_predictions_from_json_io_dict(f4,
                                           json_io_dict,
                                           is_validate_cats=False)
        self.assertEqual(1, f4.pred_eles.count())
        self.assertEqual(27 + 28 + 2 + 1,
                         project.num_pred_ele_rows_all_models(is_oracle=False))
    def test_calc_interval_20_docs_project_additional_version(self):
        Score.ensure_all_scores_exist()
        interval_20_score = Score.objects.filter(
            abbreviation='interval_20').first()
        self.assertIsNotNone(interval_20_score)

        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)

        unit_loc2 = project.units.filter(name='location2').first()
        targ_pct_next_wk = project.targets.filter(
            name='pct next week').first()  # continuous
        unit_loc3 = project.units.filter(name='location3').first()
        targ_cases_next_wk = project.targets.filter(
            name='cases next week').first()  # discrete

        # add two truths that result in two ScoreValues
        project.delete_truth_data()
        TruthData.objects.create(time_zero=time_zero,
                                 unit=unit_loc2,
                                 target=targ_pct_next_wk,
                                 value_f=2.2)  # 2/7)
        TruthData.objects.create(time_zero=time_zero,
                                 unit=unit_loc3,
                                 target=targ_cases_next_wk,
                                 value_i=50)  # 6/7
        ScoreValue.objects \
            .filter(score=interval_20_score, forecast__forecast_model=forecast_model) \
            .delete()  # usually done by update_score_for_model()
        _calculate_interval_score_values(interval_20_score, forecast_model,
                                         0.5)
        self.assertEqual(2, interval_20_score.values.count())
        self.assertEqual([2.8, 50],
                         sorted(interval_20_score.values.all().values_list(
                             'value', flat=True)))

        # add a second forecast for a newer timezero
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)
        TruthData.objects.create(time_zero=time_zero2,
                                 unit=unit_loc2,
                                 target=targ_pct_next_wk,
                                 value_f=2.2)  # 2/7)
        TruthData.objects.create(time_zero=time_zero2,
                                 unit=unit_loc3,
                                 target=targ_cases_next_wk,
                                 value_i=50)  # 6/7
        ScoreValue.objects \
            .filter(score=interval_20_score, forecast__forecast_model=forecast_model) \
            .delete()  # usually done by update_score_for_model()
        _calculate_interval_score_values(interval_20_score, forecast_model,
                                         0.5)
        self.assertEqual(4, interval_20_score.values.count())

        # finally, add a new version to timezero
        forecast.issue_date = forecast.time_zero.timezero_date
        forecast.save()

        forecast2.issue_date = forecast2.time_zero.timezero_date
        forecast2.save()

        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='f2',
                                            time_zero=time_zero)
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)  # atomic
            cache_forecast_metadata(forecast2)  # atomic

        # s/b no change from previous
        ScoreValue.objects \
            .filter(score=interval_20_score, forecast__forecast_model=forecast_model) \
            .delete()  # usually done by update_score_for_model()

        # RuntimeError: >2 lower_upper_interval_values: [2.2, 2.2, 5.0, 5.0]. timezero_id=4, unit_id=5, target_id=6
        _calculate_interval_score_values(interval_20_score, forecast_model,
                                         0.5)

        self.assertEqual(4, interval_20_score.values.count())
Example #17
0
    def test_target_rows_for_project(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        # recall that _make_docs_project() calls cache_forecast_metadata():
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # 2011, 10, 2

        # case: one model with one timezero that has five groups of one target each.
        # recall: `group_targets(project.targets.all())` (only one target/group in this case):
        #   {'pct next week':    [(1, 'pct next week', 'continuous', True, 1, 'percent')],
        #    'cases next week':  [(2, 'cases next week', 'discrete', True, 2, 'cases')],
        #    'season severity':  [(3, 'season severity', 'nominal', False, None, None)],
        #    'above baseline':   [(4, 'above baseline', 'binary', False, None, None)],
        #    'Season peak week': [(5, 'Season peak week', 'date', False, None, 'week')]}
        exp_rows = [(forecast_model, str(time_zero.timezero_date), forecast.id,
                     'Season peak week', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'above baseline', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'cases next week', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'pct next week', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'season severity', 1)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))

        # case: add a second forecast for a newer timezero
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)
            cache_forecast_metadata(
                forecast2
            )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = [(forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'Season peak week', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'above baseline', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'cases next week', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'pct next week', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'season severity', 1)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))

        # case: add a second model with only forecasts for one target
        forecast_model2 = ForecastModel.objects.create(
            project=project,
            name=forecast_model.name + '2',
            abbreviation=forecast_model.abbreviation + '2')
        time_zero3 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 4))
        forecast3 = Forecast.objects.create(forecast_model=forecast_model2,
                                            source='docs-predictions.json',
                                            time_zero=time_zero3,
                                            notes="a small prediction file")
        json_io_dict = {
            "meta": {},
            "predictions": [{
                "unit": "location1",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 2.1
                }
            }]
        }
        load_predictions_from_json_io_dict(forecast3, json_io_dict, False)
        cache_forecast_metadata(
            forecast3
        )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = exp_rows + [(forecast_model2, str(
            time_zero3.timezero_date), forecast3.id, 'pct next week', 1)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))

        # case: no forecasts
        forecast.delete()
        forecast2.delete()
        forecast3.delete()
        exp_rows = [(forecast_model, '', '', '', 0),
                    (forecast_model2, '', '', '', 0)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))
Example #18
0
    def test_unit_rows_for_project(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        # recall that _make_docs_project() calls cache_forecast_metadata():
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # 2011, 10, 2

        # case: one model with one timezero. recall rows:
        # (model, newest_forecast_tz_date, newest_forecast_id,
        #  num_present_unit_names, present_unit_names, missing_unit_names):
        exp_rows = [(forecast_model, str(time_zero.timezero_date), forecast.id,
                     3, '(all)', '')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(exp_rows, act_rows)

        # case: add a second forecast for a newer timezero
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)
            cache_forecast_metadata(
                forecast2
            )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = [(forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 3, '(all)', '')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(exp_rows, act_rows)

        # case: add a second model with only forecasts for one unit
        forecast_model2 = ForecastModel.objects.create(
            project=project,
            name=forecast_model.name + '2',
            abbreviation=forecast_model.abbreviation + '2')
        time_zero3 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 4))
        forecast3 = Forecast.objects.create(forecast_model=forecast_model2,
                                            source='docs-predictions.json',
                                            time_zero=time_zero3,
                                            notes="a small prediction file")
        json_io_dict = {
            "meta": {},
            "predictions": [{
                "unit": "location1",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 2.1
                }
            }]
        }
        load_predictions_from_json_io_dict(forecast3, json_io_dict, False)
        cache_forecast_metadata(
            forecast3
        )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = [(forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 3, '(all)', ''),
                    (forecast_model2, str(time_zero3.timezero_date),
                     forecast3.id, 1, 'location1', 'location2, location3')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(exp_rows, act_rows)

        # case: exposes bug: syntax error when no forecasts in project:
        #   psycopg2.errors.SyntaxError: syntax error at or near ")"
        #   LINE 6:             WHERE f.id IN ()
        forecast.delete()
        forecast2.delete()
        forecast3.delete()
        # (model, newest_forecast_tz_date, newest_forecast_id, num_present_unit_names, present_unit_names,
        #  missing_unit_names):
        exp_rows = [(forecast_model, 'None', None, 0, '', '(all)'),
                    (forecast_model2, 'None', None, 0, '', '(all)')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))
Example #19
0
    def test_as_of_versions(self):
        # tests the case in [Add forecast versioning](https://github.com/reichlab/forecast-repository/issues/273):
        #
        # Here's an example database with versions (header is timezeros, rows are forecast `issue_date`s). Each forecast
        # only has one point prediction:
        #
        # +-----+-----+-----+
        # |10/2 |10/9 |10/16|
        # |tz1  |tz2  |tz3  |
        # +=====+=====+=====+
        # |10/2 |     |     |
        # |f1   | -   | -   |  2.1
        # +-----+-----+-----+
        # |     |     |10/17|
        # |-    | -   |f2   |  2.0
        # +-----+-----+-----+
        # |10/20|10/20|     |
        # |f3   | f4  | -   |  3.567 | 10
        # +-----+-----+-----+
        #
        # Here are some `as_of` examples (which forecast version would be used as of that date):
        #
        # +-----+----+----+----+
        # |as_of|tz1 |tz2 |tz3 |
        # +-----+----+----+----+
        # |10/1 | -  | -  | -  |
        # |10/3 | f1 | -  | -  |
        # |10/18| f1 | -  | f2 |
        # |10/20| f3 | f4 | f2 |
        # |10/21| f3 | f4 | f2 |
        # +-----+----+----+----+

        # set up database
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'),
            po_user)  # atomic
        forecast_model = ForecastModel.objects.create(
            project=project,
            name='docs forecast model',
            abbreviation='docs_mod')
        tz1 = project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 2)).first()
        tz2 = project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 9)).first()
        tz3 = project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 16)).first()

        f1 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f1',
                                     time_zero=tz1)
        json_io_dict = {
            "predictions": [{
                "unit": "location1",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 2.1
                }
            }]
        }
        load_predictions_from_json_io_dict(f1, json_io_dict, False)
        f1.issue_date = tz1.timezero_date
        f1.save()

        f2 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f2',
                                     time_zero=tz3)
        json_io_dict = {
            "predictions": [{
                "unit": "location2",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 2.0
                }
            }]
        }
        load_predictions_from_json_io_dict(f2, json_io_dict, False)
        f2.issue_date = tz3.timezero_date + datetime.timedelta(days=1)
        f2.save()

        f3 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f3',
                                     time_zero=tz1)
        json_io_dict = {
            "predictions": [{
                "unit": "location3",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 3.567
                }
            }]
        }
        load_predictions_from_json_io_dict(f3, json_io_dict, False)
        f3.issue_date = tz1.timezero_date + datetime.timedelta(days=18)
        f3.save()

        f4 = Forecast.objects.create(forecast_model=forecast_model,
                                     source='f4',
                                     time_zero=tz2)
        json_io_dict = {
            "predictions": [{
                "unit": "location3",
                "target": "cases next week",
                "class": "point",
                "prediction": {
                    "value": 10
                }
            }]
        }
        load_predictions_from_json_io_dict(f4, json_io_dict, False)
        f4.issue_date = f3.issue_date
        f4.save()

        # case: default (no `as_of`): no f1 (f3 is newer)
        exp_rows = [
            ['2011-10-16', 'location2', 'pct next week', 'point', 2.0],
            ['2011-10-02', 'location3', 'pct next week', 'point', 3.567],
            ['2011-10-09', 'location3', 'cases next week', 'point', 10]
        ]
        act_rows = list(query_forecasts_for_project(project, {}))
        act_rows = [row[1:2] + row[3:7] for row in act_rows[1:]
                    ]  # 'timezero', 'unit', 'target', 'class', 'value'
        self.assertEqual(exp_rows, act_rows)

        # case: 10/20: same as default
        act_rows = list(
            query_forecasts_for_project(project, {'as_of': '2011-10-20'}))
        act_rows = [row[1:2] + row[3:7] for row in act_rows[1:]
                    ]  # 'timezero', 'unit', 'target', 'class', 'value'
        self.assertEqual(exp_rows, act_rows)

        # case: 10/21: same as default
        act_rows = list(
            query_forecasts_for_project(project, {'as_of': '2011-10-21'}))
        act_rows = [row[1:2] + row[3:7] for row in act_rows[1:]
                    ]  # 'timezero', 'unit', 'target', 'class', 'value'
        self.assertEqual(exp_rows, act_rows)

        # case: 10/1: none
        exp_rows = []
        act_rows = list(
            query_forecasts_for_project(project, {'as_of': '2011-10-01'}))
        act_rows = [row[1:2] + row[3:7] for row in act_rows[1:]
                    ]  # 'timezero', 'unit', 'target', 'class', 'value'
        self.assertEqual(exp_rows, act_rows)

        # case: 10/3: just f1
        exp_rows = [['2011-10-02', 'location1', 'pct next week', 'point', 2.1]]
        act_rows = list(
            query_forecasts_for_project(project, {'as_of': '2011-10-03'}))
        act_rows = [row[1:2] + row[3:7] for row in act_rows[1:]
                    ]  # 'timezero', 'unit', 'target', 'class', 'value'
        self.assertEqual(exp_rows, act_rows)

        # case: 10/18: f1 and f2
        exp_rows = [['2011-10-02', 'location1', 'pct next week', 'point', 2.1],
                    ['2011-10-16', 'location2', 'pct next week', 'point', 2.0]]
        act_rows = list(
            query_forecasts_for_project(project, {'as_of': '2011-10-18'}))
        act_rows = [row[1:2] + row[3:7] for row in act_rows[1:]
                    ]  # 'timezero', 'unit', 'target', 'class', 'value'
        self.assertEqual(exp_rows, act_rows)
Example #20
0
    def test_query_scores_for_project(self):
        # add some more predictions to get more to work with
        time_zero_2 = self.project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 9)).first()
        forecast2 = Forecast.objects.create(forecast_model=self.forecast_model,
                                            source='docs-predictions.json 2',
                                            time_zero=time_zero_2,
                                            notes="f2")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)

        forecast_model_2 = ForecastModel.objects.create(
            project=self.project,
            name='docs forecast model 2',
            abbreviation='docs_mod_2')
        time_zero_3 = self.project.timezeros.filter(
            timezero_date=datetime.date(2011, 10, 16)).first()
        forecast_3 = Forecast.objects.create(forecast_model=forecast_model_2,
                                             source='docs-predictions.json 3',
                                             time_zero=time_zero_3,
                                             notes="f3")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast_3, json_io_dict_in,
                                               False)

        Score.ensure_all_scores_exist()
        _update_scores_for_all_projects()

        # ---- case: empty query -> all scores in project ----
        # note: following floating point values are as returned by postgres. sqlite3 rounds differently, so we use
        # assertAlmostEqual() to compare. columns: model, timezero, season, unit, target, truth, error, abs_error,
        # log_single_bin, log_multi_bin, pit, interval_2, interval_5, interval_10, interval_20, interval_30,
        # interval_40, interval_50, interval_60, interval_70, interval_80, interval_90, interval_100:
        exp_rows = [
            SCORE_CSV_HEADER_PREFIX +
            [score.abbreviation for score in Score.objects.all()],
            [
                'docs_mod', '2011-10-02', '2011-2012', 'location1',
                'pct next week', 4.5432, 2.4432, 2.4432, None, None, None,
                None, None, None, None, None, None, None, None, None, None,
                None, None
            ],
            [
                'docs_mod', '2011-10-09', '2011-2012', 'location2',
                'pct next week', 99.9, 97.9, 97.9, -999.0, -0.356674943938732,
                1.0, None, 2045.0, None, None, None, None, 382.4, None, None,
                None, None, 195.4
            ],
            [
                'docs_mod', '2011-10-09', '2011-2012', 'location2',
                'cases next week', 3, -2.0, 2.0, None, None, None, None, None,
                None, None, None, None, None, None, None, None, None, None
            ],
            [
                'docs_mod_2', '2011-10-16', '2011-2012', 'location1',
                'pct next week', 0.0, -2.1, 2.1, None, None, None, None, None,
                None, None, None, None, None, None, None, None, None, None
            ]
        ]
        act_rows = list(query_scores_for_project(self.project,
                                                 {}))  # list for generator
        self._assert_list_of_lists_almost_equal(exp_rows, act_rows)

        # ---- case: only one model ----
        act_rows = list(
            query_scores_for_project(self.project, {'models': ['docs_mod_2']}))
        self._assert_list_of_lists_almost_equal([exp_rows[0], exp_rows[4]],
                                                act_rows)

        # ---- case: only one unit ----
        act_rows = list(
            query_scores_for_project(self.project, {'units': ['location1']}))
        self._assert_list_of_lists_almost_equal(
            [exp_rows[0], exp_rows[1], exp_rows[4]], act_rows)

        # ---- case: only one target ----
        act_rows = list(
            query_scores_for_project(self.project,
                                     {'targets': ['cases next week']}))
        self._assert_list_of_lists_almost_equal([exp_rows[0], exp_rows[3]],
                                                act_rows)

        # ---- case: only one timezero ----
        act_rows = list(
            query_scores_for_project(self.project,
                                     {'timezeros': ['2011-10-02']}))
        self._assert_list_of_lists_almost_equal([exp_rows[0], exp_rows[1]],
                                                act_rows)

        # ---- case: only one score: some score values exist. 10 = pit ----
        exp_rows_pit = [[
            row[0], row[1], row[2], row[3], row[4], row[5], row[10]
        ] for row in [exp_rows[0]] + [exp_rows[2]]]
        act_rows = list(
            query_scores_for_project(self.project,
                                     {'scores': ['pit']}))  # hard-coded abbrev
        self._assert_list_of_lists_almost_equal(exp_rows_pit, act_rows)

        # ---- case: only one score: no score values exist. 11 = interval_2 ----
        exp_rows_interval_2 = [[
            row[0], row[1], row[2], row[3], row[4], row[5], row[11]
        ] for row in [exp_rows[0]]]  # just header
        act_rows = list(
            query_scores_for_project(
                self.project, {'scores': ['interval_2']}))  # hard-coded abbrev
        self._assert_list_of_lists_almost_equal(exp_rows_interval_2, act_rows)
Example #21
0
    def test_query_forecasts_for_project(self):
        model = self.forecast_model.abbreviation
        tz = self.time_zero.timezero_date.strftime(YYYY_MM_DD_DATE_FORMAT)
        timezero_to_season_name = self.project.timezero_to_season_name()
        seas = timezero_to_season_name[self.time_zero]

        # ---- case: all BinDistributions in project. check cat and prob columns ----
        rows = list(
            query_forecasts_for_project(self.project, {
                'types':
                [PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[BinDistribution]]
            }))  # list for generator
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))

        exp_rows_bin = [
            (model, tz, seas, 'location1', 'Season peak week', 'bin',
             '2019-12-15', 0.01),
            (model, tz, seas, 'location1', 'Season peak week', 'bin',
             '2019-12-22', 0.1),
            (model, tz, seas, 'location1', 'Season peak week', 'bin',
             '2019-12-29', 0.89),
            (model, tz, seas, 'location1', 'season severity', 'bin',
             'moderate', 0.1),
            (model, tz, seas, 'location1', 'season severity', 'bin', 'severe',
             0.9),
            (model, tz, seas, 'location2', 'Season peak week', 'bin',
             '2019-12-15', 0.01),
            (model, tz, seas, 'location2', 'Season peak week', 'bin',
             '2019-12-22', 0.05),
            (model, tz, seas, 'location2', 'Season peak week', 'bin',
             '2019-12-29', 0.05),
            (model, tz, seas, 'location2', 'Season peak week', 'bin',
             '2020-01-05', 0.89),
            (model, tz, seas, 'location2', 'above baseline', 'bin', False,
             0.1),
            (model, tz, seas, 'location2', 'above baseline', 'bin', True, 0.9),
            (model, tz, seas, 'location2', 'pct next week', 'bin', 1.1, 0.3),
            (model, tz, seas, 'location2', 'pct next week', 'bin', 2.2, 0.2),
            (model, tz, seas, 'location2', 'pct next week', 'bin', 3.3, 0.5),
            (model, tz, seas, 'location3', 'cases next week', 'bin', 2, 0.1),
            (model, tz, seas, 'location3', 'cases next week', 'bin', 50, 0.9)
        ]  # sorted
        # model, timezero, season, unit, target, class, value, cat, prob, sample, quantile, family, param1, 2, 3
        act_rows = [(row[0], row[1], row[2], row[3], row[4], row[5], row[7],
                     row[8]) for row in rows]
        self.assertEqual(exp_rows_bin, sorted(act_rows))

        # ----  case: all NamedDistributions in project. check family, and param1, 2, and 3 columns ----
        rows = list(
            query_forecasts_for_project(
                self.project, {
                    'types': [
                        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[
                            NamedDistribution]
                    ]
                }))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))

        exp_rows_named = [
            (model, tz, seas, 'location1', 'cases next week', 'named',
             NamedDistribution.FAMILY_CHOICE_TO_ABBREVIATION[
                 NamedDistribution.POIS_DIST], 1.1, None, None),
            (model, tz, seas, 'location1', 'pct next week', 'named',
             NamedDistribution.FAMILY_CHOICE_TO_ABBREVIATION[
                 NamedDistribution.NORM_DIST], 1.1, 2.2, None)
        ]  # sorted
        # model, timezero, season, unit, target, class, value, cat, prob, sample, quantile, family, param1, 2, 3
        act_rows = [(row[0], row[1], row[2], row[3], row[4], row[5], row[11],
                     row[12], row[13], row[14]) for row in rows]
        self.assertEqual(exp_rows_named, sorted(act_rows))

        # ---- case: all PointPredictions in project. check value column ----
        rows = list(
            query_forecasts_for_project(self.project, {
                'types':
                [PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[PointPrediction]]
            }))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))

        exp_rows_point = [
            (model, tz, seas, 'location1', 'Season peak week', 'point',
             '2019-12-22'),
            (model, tz, seas, 'location1', 'above baseline', 'point', True),
            (model, tz, seas, 'location1', 'pct next week', 'point', 2.1),
            (model, tz, seas, 'location1', 'season severity', 'point', 'mild'),
            (model, tz, seas, 'location2', 'Season peak week', 'point',
             '2020-01-05'),
            (model, tz, seas, 'location2', 'cases next week', 'point', 5),
            (model, tz, seas, 'location2', 'pct next week', 'point', 2.0),
            (model, tz, seas, 'location2', 'season severity', 'point',
             'moderate'),
            (model, tz, seas, 'location3', 'Season peak week', 'point',
             '2019-12-29'),
            (model, tz, seas, 'location3', 'cases next week', 'point', 10),
            (model, tz, seas, 'location3', 'pct next week', 'point', 3.567)
        ]  # sorted
        # model, timezero, season, unit, target, class, value, cat, prob, sample, quantile, family, param1, 2, 3
        act_rows = [(row[0], row[1], row[2], row[3], row[4], row[5], row[6])
                    for row in rows]
        self.assertEqual(exp_rows_point, sorted(act_rows))

        # ---- case: all SampleDistributions in project. check sample column ----
        rows = list(
            query_forecasts_for_project(
                self.project, {
                    'types': [
                        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[
                            SampleDistribution]
                    ]
                }))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))

        exp_rows_sample = [
            (model, tz, seas, 'location1', 'Season peak week', 'sample',
             '2019-12-15'),
            (model, tz, seas, 'location1', 'Season peak week', 'sample',
             '2020-01-05'),
            (model, tz, seas, 'location2', 'above baseline', 'sample', False),
            (model, tz, seas, 'location2', 'above baseline', 'sample', True),
            (model, tz, seas, 'location2', 'above baseline', 'sample', True),
            (model, tz, seas, 'location2', 'cases next week', 'sample', 0),
            (model, tz, seas, 'location2', 'cases next week', 'sample', 2),
            (model, tz, seas, 'location2', 'cases next week', 'sample', 5),
            (model, tz, seas, 'location2', 'season severity', 'sample',
             'high'),
            (model, tz, seas, 'location2', 'season severity', 'sample',
             'mild'),
            (model, tz, seas, 'location2', 'season severity', 'sample',
             'moderate'),
            (model, tz, seas, 'location2', 'season severity', 'sample',
             'moderate'),
            (model, tz, seas, 'location2', 'season severity', 'sample',
             'severe'),
            (model, tz, seas, 'location3', 'Season peak week', 'sample',
             '2019-12-16'),
            (model, tz, seas, 'location3', 'Season peak week', 'sample',
             '2020-01-06'),
            (model, tz, seas, 'location3', 'above baseline', 'sample', False),
            (model, tz, seas, 'location3', 'above baseline', 'sample', True),
            (model, tz, seas, 'location3', 'above baseline', 'sample', True),
            (model, tz, seas, 'location3', 'pct next week', 'sample', 0.0),
            (model, tz, seas, 'location3', 'pct next week', 'sample', 0.0001),
            (model, tz, seas, 'location3', 'pct next week', 'sample', 2.3),
            (model, tz, seas, 'location3', 'pct next week', 'sample', 6.5),
            (model, tz, seas, 'location3', 'pct next week', 'sample', 10.0234)
        ]  # sorted
        # model, timezero, season, unit, target, class, value, cat, prob, sample, quantile, family, param1, 2, 3
        act_rows = [(row[0], row[1], row[2], row[3], row[4], row[5], row[9])
                    for row in rows]
        self.assertEqual(exp_rows_sample, sorted(act_rows))

        # ---- case: all QuantileDistributions in project. check quantile and value columns ----
        rows = list(
            query_forecasts_for_project(
                self.project, {
                    'types': [
                        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[
                            QuantileDistribution]
                    ]
                }))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))

        exp_rows_quantile = [(model, tz, seas, 'location2', 'Season peak week',
                              'quantile', 0.5, '2019-12-22'),
                             (model, tz, seas, 'location2', 'Season peak week',
                              'quantile', 0.75, '2019-12-29'),
                             (model, tz, seas, 'location2', 'Season peak week',
                              'quantile', 0.975, '2020-01-05'),
                             (model, tz, seas, 'location2', 'pct next week',
                              'quantile', 0.025, 1.0),
                             (model, tz, seas, 'location2', 'pct next week',
                              'quantile', 0.25, 2.2),
                             (model, tz, seas, 'location2', 'pct next week',
                              'quantile', 0.5, 2.2),
                             (model, tz, seas, 'location2', 'pct next week',
                              'quantile', 0.75, 5.0),
                             (model, tz, seas, 'location2', 'pct next week',
                              'quantile', 0.975, 50.0),
                             (model, tz, seas, 'location3', 'cases next week',
                              'quantile', 0.25, 0),
                             (model, tz, seas, 'location3', 'cases next week',
                              'quantile', 0.75, 50)]  # sorted
        # model, timezero, season, unit, target, class, value, cat, prob, sample, quantile, family, param1, 2, 3
        act_rows = [(row[0], row[1], row[2], row[3], row[4], row[5], row[10],
                     row[6]) for row in rows]
        self.assertEqual(exp_rows_quantile, sorted(act_rows))

        # ---- case: empty query -> all forecasts in project ----
        rows = list(query_forecasts_for_project(self.project, {}))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))
        self.assertEqual(
            len(exp_rows_quantile + exp_rows_sample + exp_rows_point +
                exp_rows_named + exp_rows_bin), len(rows))

        # ---- case: only one unit ----
        rows = list(
            query_forecasts_for_project(self.project,
                                        {'units': ['location3']}))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))
        self.assertEqual(17, len(rows))

        # ---- case: only one target ----
        rows = list(
            query_forecasts_for_project(self.project,
                                        {'targets': ['above baseline']}))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))
        self.assertEqual(9, len(rows))

        # following two tests require a second model, timezero, and forecast
        forecast_model2 = ForecastModel.objects.create(project=self.project,
                                                       name=model,
                                                       abbreviation='abbrev')
        time_zero2 = TimeZero.objects.create(project=self.project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 22))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model2,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)

        # ---- case: empty query -> all forecasts in project. s/be twice as many now ----
        rows = list(query_forecasts_for_project(self.project, {}))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))
        self.assertEqual(
            len(exp_rows_quantile + exp_rows_sample + exp_rows_point +
                exp_rows_named + exp_rows_bin) * 2, len(rows))

        # ---- case: only one timezero ----
        rows = list(
            query_forecasts_for_project(self.project,
                                        {'timezeros': ['2011-10-22']}))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))
        self.assertEqual(
            len(exp_rows_quantile + exp_rows_sample + exp_rows_point +
                exp_rows_named + exp_rows_bin), len(rows))

        # ---- case: only one model ----
        rows = list(
            query_forecasts_for_project(self.project, {'models': ['abbrev']}))
        self.assertEqual(FORECAST_CSV_HEADER, rows.pop(0))
        self.assertEqual(
            len(exp_rows_quantile + exp_rows_sample + exp_rows_point +
                exp_rows_named + exp_rows_bin), len(rows))
Example #22
0
    def test_json_io_dict_from_forecast(self):
        # tests that the json_io_dict_from_forecast()'s output order for SampleDistributions is preserved
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'), po_user)
        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='name',
                                                      abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project,
                                            timezero_date=datetime.date(
                                                2017, 1, 1))
        forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           source='docs-predictions.json',
                                           time_zero=time_zero)

        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast, json_io_dict_in,
                                               False)
            # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get:
            #   AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context.
            json_io_dict_out = json_io_dict_from_forecast(
                forecast,
                APIRequestFactory().request())

        # test round trip. ignore meta, but spot-check it first
        out_meta = json_io_dict_out['meta']
        self.assertEqual({'targets', 'forecast', 'units'},
                         set(out_meta.keys()))
        self.assertEqual(
            {
                'cats', 'unit', 'name', 'is_step_ahead', 'type', 'description',
                'id', 'url'
            }, set(out_meta['targets'][0].keys()))
        self.assertEqual(
            {
                'time_zero', 'forecast_model', 'created_at', 'issue_date',
                'notes', 'forecast_data', 'source', 'id', 'url'
            }, set(out_meta['forecast'].keys()))
        self.assertEqual({'id', 'name', 'url'},
                         set(out_meta['units'][0].keys()))
        self.assertIsInstance(out_meta['forecast']['time_zero'],
                              dict)  # test that time_zero is expanded, not URL

        del (json_io_dict_in['meta'])
        del (json_io_dict_out['meta'])

        # delete the two zero probability bins in the input (they are discarded when loading predictions)
        # - [11] "unit": "location3", "target": "cases next week", "class": "bin"
        # - [14] "unit": "location1", "target": "season severity", "class": "bin"
        del (json_io_dict_in['predictions'][11]['prediction']['cat'][0])  # 0
        del (json_io_dict_in['predictions'][11]['prediction']['prob'][0]
             )  # 0.0
        del (json_io_dict_in['predictions'][14]['prediction']['cat'][0]
             )  # 'mild'
        del (json_io_dict_in['predictions'][14]['prediction']['prob'][0]
             )  # 0.0

        json_io_dict_in['predictions'].sort(
            key=lambda _: (_['unit'], _['target'], _['class']))
        json_io_dict_out['predictions'].sort(
            key=lambda _: (_['unit'], _['target'], _['class']))

        self.assertEqual(json_io_dict_out, json_io_dict_in)

        # spot-check some sample predictions
        sample_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location3') and (
                pred_dict['target'] == 'pct next week') and (
                    pred_dict['class'] == 'sample')
        ][0]
        self.assertEqual([2.3, 6.5, 0.0, 10.0234, 0.0001],
                         sample_pred_dict['prediction']['sample'])

        sample_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'season severity') and (
                    pred_dict['class'] == 'sample')
        ][0]
        self.assertEqual(['moderate', 'severe', 'high', 'moderate', 'mild'],
                         sample_pred_dict['prediction']['sample'])

        sample_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location1') and (
                pred_dict['target'] == 'Season peak week') and (
                    pred_dict['class'] == 'sample')
        ][0]
        self.assertEqual(['2020-01-05', '2019-12-15'],
                         sample_pred_dict['prediction']['sample'])

        # spot-check some quantile predictions
        quantile_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'pct next week') and (
                    pred_dict['class'] == 'quantile')
        ][0]
        self.assertEqual([0.025, 0.25, 0.5, 0.75, 0.975],
                         quantile_pred_dict['prediction']['quantile'])
        self.assertEqual([1.0, 2.2, 2.2, 5.0, 50.0],
                         quantile_pred_dict['prediction']['value'])

        quantile_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'Season peak week') and (
                    pred_dict['class'] == 'quantile')
        ][0]
        self.assertEqual([0.5, 0.75, 0.975],
                         quantile_pred_dict['prediction']['quantile'])
        self.assertEqual(["2019-12-22", "2019-12-29", "2020-01-05"],
                         quantile_pred_dict['prediction']['value'])
Example #23
0
def _load_truth_data(project, oracle_model, truth_file_fp, file_name,
                     is_convert_na_none):
    from forecast_app.models import Forecast  # avoid circular imports
    from utils.forecast import load_predictions_from_json_io_dict  # ""

    # load, validate, and replace with objects and parsed values.
    # rows: (timezero, unit, target, parsed_value) (first three are objects)
    logger.debug(
        f"_load_truth_data(): entered. calling _read_truth_data_rows()")
    rows = _read_truth_data_rows(project, truth_file_fp, is_convert_na_none)
    if not rows:
        return 0

    # group rows by timezero and then create and load oracle Forecasts for each group, passing them as
    # json_io_dicts. we leverage _load_truth_data_rows_for_forecast() by creating a json_io_dict for the truth data
    # where each truth row becomes its own 'point' prediction element. notes:
    # - these forecasts are identified as coming from the same truth file (aka "batch") via all forecasts setting the
    #   same source and issued_at at the end
    # - we collect "cannot load 100% duplicate data" RuntimeErrors so that we can count them at the end. the rule is
    #   that there must be at least one oracle forecast that did not get that error
    timezero_groups = defaultdict(list)
    for timezero, unit, target, parsed_value in rows:
        timezero_groups[timezero].append([unit, target, parsed_value])

    source = file_name if file_name else ''
    forecasts = []  # ones created
    forecasts_100pct_dup = [
    ]  # ones that raised RuntimeError "cannot load 100% duplicate data"
    logger.debug(
        f"_load_truth_data(): creating and loading {len(timezero_groups)} forecasts. source={source!r}"
    )
    point_class = PRED_CLASS_INT_TO_NAME[PredictionElement.POINT_CLASS]
    for timezero, timezero_rows in timezero_groups.items():
        forecast = Forecast.objects.create(forecast_model=oracle_model,
                                           source=source,
                                           time_zero=timezero,
                                           notes=f"oracle forecast")
        prediction_dicts = [{
            'unit': unit.name,
            'target': target.name,
            'class': point_class,
            'prediction': {
                'value':
                parsed_value.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance(
                    parsed_value, datetime.date) else parsed_value
            }
        } for unit, target, parsed_value in timezero_rows]
        try:
            load_predictions_from_json_io_dict(
                forecast, {
                    'meta': {},
                    'predictions': prediction_dicts
                },
                is_skip_validation=True,
                is_subset_allowed=True)  # NB: is_subset_allowed
            forecasts.append(forecast)
        except RuntimeError as rte:
            # todo instead of testing for a string, load_predictions_from_json_io_dict() should raise an application-
            # specific RuntimeError subclass
            if rte.args[0].startswith('cannot load 100% duplicate data'):
                forecasts_100pct_dup.append(forecast)
            else:
                raise rte

    # delete duplicate forecasts
    if forecasts_100pct_dup:
        Forecast.objects.filter(id__in=[f.id for f in forecasts_100pct_dup]) \
            .delete()

    # error if all oracle forecasts were 100% duplicate data
    if forecasts_100pct_dup and not forecasts:
        raise RuntimeError(
            f"cannot load 100% duplicate data (all {len(forecasts_100pct_dup)} oracle forecasts were "
            f"100% duplicate data)")

    # set all issued_ats to be the same - this avoids an edge case where midnight is spanned and some are a day later.
    # arbitrarily use the first forecast's issued_at
    if forecasts:
        issued_at = forecasts[0].issued_at
        logger.debug(
            f"_load_truth_data(): setting issued_ats to {issued_at}. # forecasts={len(forecasts)}, "
            f"# 100% dup data forecasts={len(forecasts_100pct_dup)}")
        for forecast in forecasts:
            forecast.issued_at = issued_at
            forecast.save()

    logger.debug(f"_load_truth_data(): done")
    return len(rows)
Example #24
0
    def test_json_io_dict_from_forecast(self):
        # tests that the json_io_dict_from_forecast()'s output order for SampleDistributions is preserved
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'), po_user)
        forecast_model = ForecastModel.objects.create(project=project,
                                                      name='name',
                                                      abbreviation='abbrev')
        time_zero = TimeZero.objects.create(project=project,
                                            timezero_date=datetime.date(
                                                2017, 1, 1))
        forecast = Forecast.objects.create(forecast_model=forecast_model,
                                           source='docs-predictions.json',
                                           time_zero=time_zero)

        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast,
                                               json_io_dict_in,
                                               is_validate_cats=False)
            json_io_dict_out = json_io_dict_from_forecast(
                forecast,
                APIRequestFactory().request())

        # test round trip. ignore meta, but spot-check it first
        out_meta = json_io_dict_out['meta']
        self.assertEqual({'targets', 'forecast', 'units'},
                         set(out_meta.keys()))
        self.assertEqual(
            {
                'cats', 'unit', 'name', 'is_step_ahead', 'type', 'description',
                'id', 'url'
            }, set(out_meta['targets'][0].keys()))
        self.assertEqual(
            {
                'time_zero', 'forecast_model', 'created_at', 'issued_at',
                'notes', 'forecast_data', 'source', 'id', 'url'
            }, set(out_meta['forecast'].keys()))
        self.assertEqual({'id', 'name', 'url'},
                         set(out_meta['units'][0].keys()))
        self.assertIsInstance(out_meta['forecast']['time_zero'],
                              dict)  # test that time_zero is expanded, not URL

        del (json_io_dict_in['meta'])
        del (json_io_dict_out['meta'])

        json_io_dict_in['predictions'].sort(
            key=lambda _: (_['unit'], _['target'], _['class']))
        json_io_dict_out['predictions'].sort(
            key=lambda _: (_['unit'], _['target'], _['class']))

        self.assertEqual(json_io_dict_in, json_io_dict_out)

        # spot-check some sample predictions
        sample_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location3') and (
                pred_dict['target'] == 'pct next week') and (
                    pred_dict['class'] == 'sample')
        ][0]
        self.assertEqual([2.3, 6.5, 0.0, 10.0234, 0.0001],
                         sample_pred_dict['prediction']['sample'])

        sample_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'season severity') and (
                    pred_dict['class'] == 'sample')
        ][0]
        self.assertEqual(['moderate', 'severe', 'high', 'moderate', 'mild'],
                         sample_pred_dict['prediction']['sample'])

        sample_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location1') and (
                pred_dict['target'] == 'Season peak week') and (
                    pred_dict['class'] == 'sample')
        ][0]
        self.assertEqual(['2020-01-05', '2019-12-15'],
                         sample_pred_dict['prediction']['sample'])

        # spot-check some quantile predictions
        quantile_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'pct next week') and (
                    pred_dict['class'] == 'quantile')
        ][0]
        self.assertEqual([0.025, 0.25, 0.5, 0.75, 0.975],
                         quantile_pred_dict['prediction']['quantile'])
        self.assertEqual([1.0, 2.2, 2.2, 5.0, 50.0],
                         quantile_pred_dict['prediction']['value'])

        quantile_pred_dict = [
            pred_dict for pred_dict in json_io_dict_out['predictions']
            if (pred_dict['unit'] == 'location2') and (
                pred_dict['target'] == 'Season peak week') and (
                    pred_dict['class'] == 'quantile')
        ][0]
        self.assertEqual([0.5, 0.75, 0.975],
                         quantile_pred_dict['prediction']['quantile'])
        self.assertEqual(["2019-12-22", "2019-12-29", "2020-01-05"],
                         quantile_pred_dict['prediction']['value'])
    def test_valid_prediction_types_by_target_type(self):
        # test invalid combinations of prediction types by target type (valid combos are tested elsewhere). see table at
        # https://docs.zoltardata.com/targets/#valid-prediction-types-by-target-type . -> invalid combos:
        #
        # PredictionElement.NAMED_CLASS    + Target.CONTINUOUS_TARGET_TYPE + (pois, nbinom, nbinom2)
        # PredictionElement.NAMED_CLASS    + Target.DISCRETE_TARGET_TYPE   + (norm, lnorm, gamma, beta)
        # PredictionElement.NAMED_CLASS    + Target.NOMINAL_TARGET_TYPE    + any family (norm, lnorm, gamma, beta, pois, nbinom, nbinom2)
        # PredictionElement.NAMED_CLASS    + Target.BINARY_TARGET_TYPE     + any family
        # PredictionElement.NAMED_CLASS    + Target.DATE_TARGET_TYPE       + any family
        # PredictionElement.QUANTILE_CLASS + Target.NOMINAL_TARGET_TYPE
        # PredictionElement.QUANTILE_CLASS + Target.BINARY_TARGET_TYPE

        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)
        forecast.issued_at -= datetime.timedelta(
            days=1)  # older version avoids unique constraint errors
        forecast.save()
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            time_zero=time_zero)

        # test PredictionElement.NAMED_CLASS
        all_families = (NamedData.NORM_DIST, NamedData.LNORM_DIST,
                        NamedData.GAMMA_DIST, NamedData.BETA_DIST,
                        NamedData.POIS_DIST, NamedData.NBINOM_DIST,
                        NamedData.NBINOM2_DIST)
        bad_target_families = [
            (
                'pct next week',
                (
                    NamedData.POIS_DIST,  # Target.CONTINUOUS_TARGET_TYPE
                    NamedData.NBINOM_DIST,
                    NamedData.NBINOM2_DIST)),
            (
                'cases next week',
                (
                    NamedData.NORM_DIST,  # Target.DISCRETE_TARGET_TYPE
                    NamedData.LNORM_DIST,
                    NamedData.GAMMA_DIST,
                    NamedData.BETA_DIST)),
            ('season severity', all_families),  # Target.NOMINAL_TARGET_TYPE
            ('above baseline', all_families),  # Target.BINARY_TARGET_TYPE
            ('Season peak week', all_families)
        ]  # Target.DATE_TARGET_TYPE
        for target, families in bad_target_families:
            for family_abbrev in families:
                prediction_dict = {
                    'unit': 'location1',
                    'target': target,
                    'class': 'named',
                    'prediction': {
                        'family': family_abbrev,
                        'param1': 1.1,
                        'param2': 2.2,
                        'param3': 3.3
                    }
                }
                with self.assertRaises(RuntimeError) as context:
                    load_predictions_from_json_io_dict(
                        forecast2, {'predictions': [prediction_dict]})
                self.assertIn('is not valid for', str(context.exception))

        # test PredictionElement.QUANTILE_CLASS
        bad_target_pred_data = [
            (
                'season severity',
                {
                    "quantile": [0.25, 0.75],  # Target.NOMINAL_TARGET_TYPE
                    "value": ["mild", "moderate"]
                }),
            (
                'above baseline',
                {
                    "quantile": [0.25, 0.75],  # Target.BINARY_TARGET_TYPE
                    "value": [True, False]
                })
        ]
        for target, pred_data in bad_target_pred_data:
            prediction_dict = {
                'unit': 'location1',
                'target': target,
                'class': 'quantile',
                'prediction': pred_data
            }
            with self.assertRaises(RuntimeError) as context:
                load_predictions_from_json_io_dict(
                    forecast2, {'predictions': [prediction_dict]})
            self.assertIn('is not valid for', str(context.exception))