Ejemplo n.º 1
0
    def test_load_truth_data_versions(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # loads docs-ground-truth.csv

        oracle_model = oracle_model_for_project(project)
        self.assertEqual(3, oracle_model.forecasts.count(
        ))  # for 3 timezeros: 2011-10-02, 2011-10-09, 2011-10-16
        self.assertEqual(14, truth_data_qs(project).count())
        self.assertTrue(is_truth_data_loaded(project))

        with self.assertRaisesRegex(RuntimeError,
                                    'cannot load 100% duplicate data'):
            load_truth_data(
                project,
                Path('forecast_app/tests/truth_data/docs-ground-truth.csv'),
                file_name='docs-ground-truth.csv')

        load_truth_data(
            project,
            Path(
                'forecast_app/tests/truth_data/docs-ground-truth-non-dup.csv'),
            file_name='docs-ground-truth-non-dup.csv')
        self.assertEqual(3 * 2, oracle_model.forecasts.count())
        self.assertEqual(14 * 2, truth_data_qs(project).count())
Ejemplo n.º 2
0
    def test_load_truth_data(self):
        load_truth_data(self.project,
                        Path('forecast_app/tests/truth_data/truths-ok.csv'),
                        is_convert_na_none=True)
        self.assertEqual(5, truth_data_qs(self.project).count())
        self.assertTrue(is_truth_data_loaded(self.project))

        # csv references non-existent TimeZero in Project: the bad timezero 2017-01-02 is skipped by
        # _read_truth_data_rows(), but the remaining data that's loaded (the three 2017-01-01 rows) is therefore a
        # subset. this raised 'new data is a subset of previous' prior to this issue:
        # [support truth "diff" uploads #319](https://github.com/reichlab/forecast-repository/issues/319), but now
        # subsets are allowed.
        load_truth_data(
            self.project,
            Path('forecast_app/tests/truth_data/truths-bad-timezero.csv'),
            'truths-bad-timezero.csv',
            is_convert_na_none=True)

        # csv references non-existent unit in Project: the bad unit is skipped, again resulting in a subset. again,
        # subsets are now allowed
        load_truth_data(
            self.project,
            Path('forecast_app/tests/truth_data/truths-bad-location.csv'),
            'truths-bad-location.csv',
            is_convert_na_none=True)

        # csv references non-existent target in Project: the bad target is skipped. subset is allowed
        load_truth_data(
            self.project,
            Path('forecast_app/tests/truth_data/truths-bad-target.csv'),
            'truths-bad-target.csv',
            is_convert_na_none=True)

        project2 = Project.objects.create()
        make_cdc_units_and_targets(project2)
        self.assertEqual(0, truth_data_qs(project2).count())
        self.assertFalse(is_truth_data_loaded(project2))

        TimeZero.objects.create(project=project2,
                                timezero_date=datetime.date(2017, 1, 1))
        load_truth_data(project2,
                        Path('forecast_app/tests/truth_data/truths-ok.csv'),
                        is_convert_na_none=True)
        self.assertEqual(5, truth_data_qs(project2).count())

        # test get_truth_data_preview()
        exp_truth_preview = [
            (datetime.date(2017, 1, 1), 'US National', '1 wk ahead', 0.73102),
            (datetime.date(2017, 1, 1), 'US National', '2 wk ahead', 0.688338),
            (datetime.date(2017, 1, 1), 'US National', '3 wk ahead', 0.732049),
            (datetime.date(2017, 1, 1), 'US National', '4 wk ahead', 0.911641),
            (datetime.date(2017, 1,
                           1), 'US National', 'Season onset', '2017-11-20')
        ]
        self.assertEqual(sorted(exp_truth_preview),
                         sorted(get_truth_data_preview(project2)))
Ejemplo n.º 3
0
    def test_load_truth_data_dups(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project = create_project_from_json(
            Path('forecast_app/tests/projects/docs-project.json'), po_user)
        load_truth_data(
            project,
            Path(
                'forecast_app/tests/truth_data/docs-ground-truth-null-value.csv'
            ),
            is_convert_na_none=True)
        self.assertEqual(-1, truth_data_qs(project).count())

        load_truth_data(
            project,
            Path(
                'forecast_app/tests/truth_data/docs-ground-truth-null-value.csv'
            ),
            is_convert_na_none=True)
        self.assertEqual(-1, truth_data_qs(project).count())
Ejemplo n.º 4
0
    def test_load_truth_data_diff(self):
        """
        Tests the relaxing of this forecast version rule when loading truth (issue
        [support truth "diff" uploads #319](https://github.com/reichlab/forecast-repository/issues/319) ):
            3. New forecast versions cannot imply any retracted prediction elements in existing versions, i.e., you
            cannot load data that's a subset of the previous forecast's data.
        """
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # loads docs-ground-truth.csv

        oracle_model = oracle_model_for_project(project)
        self.assertEqual(3, oracle_model.forecasts.count(
        ))  # for 3 timezeros: 2011-10-02, 2011-10-09, 2011-10-16
        self.assertEqual(14, truth_data_qs(project).count())

        # updates only the five location2 rows:
        load_truth_data(
            project,
            Path('forecast_app/tests/truth_data/docs-ground-truth-diff.csv'),
            file_name='docs-ground-truth-diff.csv')
        self.assertEqual(3 + 1, oracle_model.forecasts.count())
        self.assertEqual(14 + 5, truth_data_qs(project).count())
Ejemplo n.º 5
0
 def test_load_truth_data_null_rows(self):
     _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
         is_create_super=True)
     project = create_project_from_json(
         Path('forecast_app/tests/projects/docs-project.json'), po_user)
     load_truth_data(
         project,
         Path(
             'forecast_app/tests/truth_data/docs-ground-truth-null-value.csv'
         ),
         is_convert_na_none=True)
     exp_rows = [
         (datetime.date(2011, 10, 2), 'location1', 'Season peak week', None,
          None, None, datetime.date(2019, 12, 15), None),
         (datetime.date(2011, 10, 2), 'location1', 'above baseline', None,
          None, None, None, True),
         (datetime.date(2011, 10, 2), 'location1', 'season severity', None,
          None, 'moderate', None, None),
         (datetime.date(2011, 10, 2), 'location1', 'cases next week', None,
          None, None, None, None),  # all None
         (datetime.date(2011, 10, 2), 'location1', 'pct next week', None,
          None, None, None, None),  # all None
         (datetime.date(2011, 10, 9), 'location2', 'Season peak week', None,
          None, None, datetime.date(2019, 12, 29), None),
         (datetime.date(2011, 10, 9), 'location2', 'above baseline', None,
          None, None, None, True),
         (datetime.date(2011, 10, 9), 'location2', 'season severity', None,
          None, 'severe', None, None),
         (datetime.date(2011, 10, 9), 'location2', 'cases next week', 3,
          None, None, None, None),
         (datetime.date(2011, 10, 9), 'location2', 'pct next week', None,
          99.9, None, None, None),
         (datetime.date(2011, 10, 16), 'location1', 'Season peak week',
          None, None, None, datetime.date(2019, 12, 22), None),
         (datetime.date(2011, 10, 16), 'location1', 'above baseline', None,
          None, None, None, False),
         (datetime.date(2011, 10, 16), 'location1', 'cases next week', 0,
          None, None, None, None),
         (datetime.date(2011, 10, 16), 'location1', 'pct next week', None,
          0.0, None, None, None)
     ]
     act_rows = truth_data_qs(project) \
         .values_list('pred_ele__forecast__time_zero__timezero_date',
                      'pred_ele__unit__name', 'pred_ele__target__name',
                      'value_i', 'value_f', 'value_t', 'value_d', 'value_b')
     self.assertEqual(sorted(exp_rows), sorted(act_rows))
Ejemplo n.º 6
0
def database_changes_for_project_config_diff(project, changes):
    """
    Analyzes impact of `changes` on project with respect to deleted rows. The only impactful one is
    ChangeType.OBJ_REMOVED.

    :param project: a Project whose data is being analyzed for changes
    :param changes: list of Changes as returned by project_config_diff()
    :return: a list of 3-tuples: (change, num_pred_eles, num_truth)
    """
    pred_ele_qs = PredictionElement.objects \
        .filter(forecast__forecast_model__project=project,
                forecast__forecast_model__is_oracle=False)
    pred_ele_truth_qs = truth_data_qs(project)
    database_changes = []  # return value. filled next
    for change in order_project_config_diff(changes):
        if (change.object_type == ObjectType.PROJECT) or (
                change.change_type != ChangeType.OBJ_REMOVED):
            continue

        if change.object_type == ObjectType.UNIT:  # removing a Unit
            unit = object_for_change(project, change, [])  # raises
            num_points = pred_ele_qs.filter(unit=unit).count()
            num_truth = pred_ele_truth_qs.filter(unit=unit).count()
        elif change.object_type == ObjectType.TARGET:  # removing a Target
            target = object_for_change(project, change, [])  # raises
            num_points = pred_ele_qs.filter(target=target).count()
            num_truth = pred_ele_truth_qs.filter(target=target).count()
        else:  # change.object_type == ObjectType.TIMEZERO:  # removing a TimeZero
            timezero = object_for_change(project, change, [])  # raises
            num_points = pred_ele_qs.filter(
                forecast__time_zero=timezero).count()
            num_truth = pred_ele_truth_qs.filter(
                forecast__time_zero=timezero).count()
        if num_points:
            database_changes.append((change, num_points, num_truth))
    return database_changes