def test_diff_from_file(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())

        # this json file makes the same changes as _make_some_changes():
        with open(
                Path('forecast_app/tests/project_diff/docs-project-edited.json'
                     )) as fp:
            edited_config_dict = json.load(fp)
        changes = project_config_diff(out_config_dict, edited_config_dict)

        # # print a little report
        # print(f"* Analyzed {len(changes)} changes. Results:")
        # for change, num_points, num_named, num_bins, num_samples, num_quantiles, num_truth in \
        #         database_changes_for_project_config_diff(project, changes):
        #     print(f"- {change.change_type.name} on {change.object_type.name} {change.object_pk!r} will delete:\n"
        #           f"  = {num_points} point predictions\n"
        #           f"  = {num_named} named predictions\n"
        #           f"  = {num_bins} bin predictions\n"
        #           f"  = {num_samples} samples\n"
        #           f"  = {num_quantiles} quantiles\n"
        #           f"  = {num_truth} truth rows")

        # same tests as test_execute_project_config_diff():
        execute_project_config_diff(project, changes)
        self._do_make_some_changes_tests(project)
    def test_duplicate_abbreviation(self):
        # duplicate names are OK, but not abbreviations
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)

        # case: new name, duplicate abbreviation
        with self.assertRaises(ValidationError) as context:
            ForecastModel.objects.create(
                project=project,
                name=forecast_model.name + '2',
                abbreviation=forecast_model.abbreviation)
        self.assertIn('abbreviation must be unique', str(context.exception))

        # case: duplicate name, new abbreviation
        try:
            forecast_model2 = ForecastModel.objects.create(
                project=project,
                name=forecast_model.name,
                abbreviation=forecast_model.abbreviation + '2')
            forecast_model2.delete()
        except Exception as ex:
            self.fail(f"unexpected exception: {ex}")

        # case: new name, duplicate abbreviation, but saving existing forecast_model
        try:
            forecast_model.name = forecast_model.name + '2'  # new name, duplicate abbreviation
            forecast_model.save()

            forecast_model.abbreviation = forecast_model.abbreviation + '2'  # duplicate name, new abbreviation
            forecast_model.save()
        except Exception as ex:
            self.fail(f"unexpected exception: {ex}")
Exemple #3
0
    def test_last_update(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)

        # one truth and one forecast (yes truth, yes forecasts)
        self.assertEqual(forecast.created_at, project.last_update())

        # add a second forecast for a newer timezero (yes truth, yes forecasts)
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(
            forecast_model=forecast_model,
            source='docs-predictions-non-dup.json',
            time_zero=time_zero2,
            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions-non-dup.json'
        ) as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2,
                                               json_io_dict_in,
                                               is_validate_cats=False)
        self.assertEqual(forecast2.created_at, project.last_update())
Exemple #4
0
    def test_load_truth_data_versions(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # loads docs-ground-truth.csv

        oracle_model = oracle_model_for_project(project)
        self.assertEqual(3, oracle_model.forecasts.count(
        ))  # for 3 timezeros: 2011-10-02, 2011-10-09, 2011-10-16
        self.assertEqual(14, truth_data_qs(project).count())
        self.assertTrue(is_truth_data_loaded(project))

        with self.assertRaisesRegex(RuntimeError,
                                    'cannot load 100% duplicate data'):
            load_truth_data(
                project,
                Path('forecast_app/tests/truth_data/docs-ground-truth.csv'),
                file_name='docs-ground-truth.csv')

        load_truth_data(
            project,
            Path(
                'forecast_app/tests/truth_data/docs-ground-truth-non-dup.csv'),
            file_name='docs-ground-truth-non-dup.csv')
        self.assertEqual(3 * 2, oracle_model.forecasts.count())
        self.assertEqual(14 * 2, truth_data_qs(project).count())
Exemple #5
0
    def test__upload_forecast_worker_blue_sky(self):
        # blue sky to verify load_predictions_from_json_io_dict() and cache_forecast_metadata() are called. also tests
        # that _upload_forecast_worker() correctly sets job.output_json. this test is complicated by that function's use
        # of the `job_cloud_file` context manager. solution is per https://stackoverflow.com/questions/60198229/python-patch-context-manager-to-return-object
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)
        forecast.issued_at -= datetime.timedelta(
            days=1)  # older version avoids unique constraint errors
        forecast.save()

        with patch('forecast_app.models.job.job_cloud_file') as job_cloud_file_mock, \
                patch('utils.forecast.load_predictions_from_json_io_dict') as load_preds_mock, \
                patch('utils.forecast.cache_forecast_metadata') as cache_metatdata_mock, \
                open('forecast_app/tests/predictions/docs-predictions.json') as cloud_file_fp:
            job = Job.objects.create()
            job.input_json = {
                'forecast_pk': forecast.pk,
                'filename': 'a name!'
            }
            job.save()
            job_cloud_file_mock.return_value.__enter__.return_value = (
                job, cloud_file_fp)
            _upload_forecast_worker(job.pk)
            job.refresh_from_db()
            load_preds_mock.assert_called_once()
            cache_metatdata_mock.assert_called_once()
            self.assertEqual(Job.SUCCESS, job.status)
            self.assertEqual(job.input_json['forecast_pk'],
                             job.output_json['forecast_pk'])
Exemple #6
0
 def setUpTestData(cls):
     # recall that _make_docs_project() calls cache_forecast_metadata(), but the below tests assume it doesn't, so
     # we clear here
     _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(is_create_super=True)
     cls.project, cls.time_zero, cls.forecast_model, cls.forecast = _make_docs_project(po_user)
     clear_forecast_metadata(cls.forecast)
     cls.forecast.issued_at -= datetime.timedelta(days=1)  # older version avoids unique constraint errors
     cls.forecast.save()
Exemple #7
0
    def test_models_summary_table_rows_for_project(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)

        # test with just one forecast - oldest and newest forecast is the same. a 7-tuple:
        #   [forecast_model, num_forecasts, oldest_forecast_tz_date, newest_forecast_tz_date, oldest_forecast_id,
        #    newest_forecast_id, newest_forecast_created_at].
        # NB: we have to work around a Django bug where DateField and DateTimeField come out of the database as either
        # datetime.date/datetime.datetime objects (postgres) or strings (sqlite3)
        exp_row = (
            forecast_model,
            forecast_model.forecasts.count(),
            str(time_zero.timezero_date),  # oldest_forecast_tz_date
            str(time_zero.timezero_date),  # newest_forecast_tz_date
            forecast.id,
            forecast.created_at.utctimetuple())  # newest_forecast_created_at
        act_rows = models_summary_table_rows_for_project(project)
        act_rows = [(
            act_rows[0][0],
            act_rows[0][1],
            str(act_rows[0][2]),  # oldest_forecast_tz_date
            str(act_rows[0][3]),  # newest_forecast_tz_date
            act_rows[0][4],
            act_rows[0][5].utctimetuple())]  # newest_forecast_created_at

        sql = f"""SELECT created_at FROM {Forecast._meta.db_table} WHERE id = %s;"""
        with connection.cursor() as cursor:
            cursor.execute(sql, (forecast.pk, ))
            rows = cursor.fetchall()

        self.assertEqual([exp_row], act_rows)

        # test a second forecast
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2017, 1, 1))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2)
        exp_row = (
            forecast_model,
            forecast_model.forecasts.count(),
            str(time_zero.timezero_date),  # oldest_forecast_tz_date
            str(time_zero2.timezero_date),  # newest_forecast_tz_date
            forecast2.id,
            forecast2.created_at.utctimetuple())  # newest_forecast_created_at
        act_rows = models_summary_table_rows_for_project(project)
        act_rows = [(
            act_rows[0][0],
            act_rows[0][1],
            str(act_rows[0][2]),  # oldest_forecast_tz_date
            str(act_rows[0][3]),  # newest_forecast_tz_date
            act_rows[0][4],
            act_rows[0][5].utctimetuple())]  # newest_forecast_created_at
        self.assertEqual([exp_row], act_rows)
Exemple #8
0
    def test_truth_batches(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # loads batch: docs-ground-truth.csv

        # add a second batch
        load_truth_data(
            project,
            Path(
                'forecast_app/tests/truth_data/docs-ground-truth-non-dup.csv'),
            file_name='docs-ground-truth-non-dup.csv')
        oracle_model = oracle_model_for_project(project)
        first_forecast = oracle_model.forecasts.first()
        last_forecast = oracle_model.forecasts.last()

        # test truth_batches() and truth_batch_forecasts() for each batch
        batches = truth_batches(project)
        self.assertEqual(2, len(batches))
        self.assertEqual(first_forecast.source, batches[0][0])
        self.assertEqual(first_forecast.issued_at, batches[0][1])
        self.assertEqual(last_forecast.source, batches[1][0])
        self.assertEqual(last_forecast.issued_at, batches[1][1])

        for source, issued_at in batches:
            forecasts = truth_batch_forecasts(project, source, issued_at)
            self.assertEqual(3, len(forecasts))
            for forecast in forecasts:
                self.assertEqual(source, forecast.source)
                self.assertEqual(issued_at, forecast.issued_at)

        # test truth_batch_summary_table(). NB: utctimetuple() makes sqlite comparisons work
        exp_table = [(source, issued_at.utctimetuple(),
                      len(truth_batch_forecasts(project, source, issued_at)))
                     for source, issued_at in batches]
        act_table = [(source, issued_at.utctimetuple(), num_forecasts)
                     for source, issued_at, num_forecasts in
                     truth_batch_summary_table(project)]
        self.assertEqual(exp_table, act_table)

        # finally, test deleting a batch. try deleting the first, which should fail due to version rules.
        # transaction.atomic() somehow avoids the second `truth_delete_batch()` call getting the error:
        # django.db.transaction.TransactionManagementError: An error occurred in the current transaction. You can't execute queries until the end of the 'atomic' block.
        with transaction.atomic():
            with self.assertRaisesRegex(
                    RuntimeError,
                    'you cannot delete a forecast that has any newer versions'
            ):
                truth_delete_batch(project, batches[0][0], batches[0][1])

        # delete second batch - should not fail
        truth_delete_batch(project, batches[1][0], batches[1][1])
        batches = truth_batches(project)
        self.assertEqual(1, len(batches))
        self.assertEqual(first_forecast.source, batches[0][0])
        self.assertEqual(first_forecast.issued_at, batches[0][1])
 def setUp(
     self
 ):  # runs before every test. done here instead of setUpTestData(cls) b/c below tests modify the db
     _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
         is_create_super=True)
     self.project, self.time_zero, self.forecast_model, self.forecast = _make_docs_project(
         po_user)
     self.tz1 = self.project.timezeros.filter(
         timezero_date=datetime.date(2011, 10, 2)).first()
     self.tz2 = self.project.timezeros.filter(
         timezero_date=datetime.date(2011, 10, 9)).first()
     self.tz3 = self.project.timezeros.filter(
         timezero_date=datetime.date(2011, 10, 16)).first()
    def test_execute_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)

        # make some changes
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)

        changes = project_config_diff(out_config_dict, edit_config_dict)
        execute_project_config_diff(project, changes)
        self._do_make_some_changes_tests(project)
Exemple #11
0
    def test__upload_forecast_worker_atomic(self):
        # test `_upload_forecast_worker()` does not create a Forecast if subsequent calls to
        # `load_predictions_from_json_io_dict()` or `cache_forecast_metadata()` fail. this test is complicated by that
        # function's use of the `job_cloud_file` context manager. solution is per https://stackoverflow.com/questions/60198229/python-patch-context-manager-to-return-object
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)
        forecast.issued_at -= datetime.timedelta(
            days=1)  # older version avoids unique constraint errors
        forecast.save()

        with patch('forecast_app.models.job.job_cloud_file') as job_cloud_file_mock, \
                patch('utils.forecast.load_predictions_from_json_io_dict') as load_preds_mock, \
                patch('utils.forecast.cache_forecast_metadata') as cache_metatdata_mock:
            forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                                time_zero=time_zero)
            job = Job.objects.create()
            job.input_json = {
                'forecast_pk': forecast2.pk,
                'filename': 'a name!'
            }
            job.save()

            job_cloud_file_mock.return_value.__enter__.return_value = (
                job, None)  # 2-tuple: (job, cloud_file_fp)

            # test that no Forecast is created when load_predictions_from_json_io_dict() fails
            load_preds_mock.side_effect = Exception(
                'load_preds_mock Exception')
            num_forecasts_before = forecast_model.forecasts.count()
            _upload_forecast_worker(job.pk)
            job.refresh_from_db()
            self.assertEqual(
                num_forecasts_before - 1,
                forecast_model.forecasts.count())  # -1 b/c forecast2 deleted
            self.assertEqual(Job.FAILED, job.status)

            # test when cache_forecast_metadata() fails
            load_preds_mock.reset_mock(side_effect=True)
            cache_metatdata_mock.side_effect = Exception(
                'cache_metatdata_mock Exception')
            num_forecasts_before = forecast_model.forecasts.count()
            _upload_forecast_worker(job.pk)
            job.refresh_from_db()
            self.assertEqual(num_forecasts_before,
                             forecast_model.forecasts.count())
            self.assertEqual(Job.FAILED, job.status)
    def test_order_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)

        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)
        changes = project_config_diff(out_config_dict, edit_config_dict)
        # removes one wasted activity ('pct next week', ChangeType.FIELD_EDITED) that is wasted b/c that target is being
        # ChangeType.OBJ_REMOVED:
        ordered_changes = order_project_config_diff(changes)
        self.assertEqual(
            13, len(changes))  # contains two duplicate and one wasted change
        self.assertEqual(10, len(ordered_changes))
Exemple #13
0
    def test_load_truth_data_partial_dup(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # loads batch: docs-ground-truth.csv

        try:
            load_truth_data(
                project,
                Path(
                    'forecast_app/tests/truth_data/docs-ground-truth-partial-dup.csv'
                ),
                file_name='docs-ground-truth-partial-dup.csv')
            batches = truth_batches(project)
            self.assertEqual(2, len(batches))
        except Exception as ex:
            self.fail(f"unexpected exception: {ex}")
    def test_diff_from_file_empty_data_version_date_string(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)
        # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get:
        #   AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context.
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edited_config_dict = copy.deepcopy(out_config_dict)

        # change '2011-10-02': None -> '' (incorrect, but we fix for users)
        edited_config_dict['timezeros'][0]['data_version_date'] = ''

        changes = project_config_diff(out_config_dict, edited_config_dict)
        self.assertEqual(
            0, len(changes)
        )  # is 1 without the fix "this test for `!= ''` matches this one below"
    def test_execute_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)
        _update_scores_for_all_projects()

        # make some changes
        # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get:
        #   AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context.
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)

        changes = project_config_diff(out_config_dict, edit_config_dict)
        execute_project_config_diff(project, changes)
        self._do_make_some_changes_tests(project)
Exemple #16
0
 def test_load_predictions_from_json_io_dict_existing_pred_eles(self):
     _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
         is_create_super=True)
     project, time_zero, forecast_model, forecast = _make_docs_project(
         po_user)
     json_io_dict = {
         "predictions": [{
             "unit": "location1",
             "target": "pct next week",
             "class": "point",
             "prediction": {
                 "value": 2.1
             }
         }]
     }
     with self.assertRaises(RuntimeError) as context:
         load_predictions_from_json_io_dict(forecast, json_io_dict)
     self.assertIn("cannot load data into a non-empty forecast",
                   str(context.exception))
Exemple #17
0
    def test__upload_forecast_worker_deletes_forecast(self):
        # verifies that _upload_forecast_worker() deletes the (presumably empty) Forecast that's passed to it by
        # upload functions if the file is invalid. here we mock load_predictions_from_json_io_dict() to throw the two
        # exceptions that cause deletes: JobTimeoutException and Exception
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)
        forecast.issued_at -= datetime.timedelta(
            days=1)  # older version avoids unique constraint errors
        forecast.save()

        for exception, exp_job_status in [
            (Exception('load_preds_mock Exception'), Job.FAILED),
            (JobTimeoutException('load_preds_mock JobTimeoutException'),
             Job.TIMEOUT)
        ]:
            with patch('forecast_app.models.job.job_cloud_file') as job_cloud_file_mock, \
                    patch('utils.forecast.load_predictions_from_json_io_dict') as load_preds_mock, \
                    patch('utils.forecast.cache_forecast_metadata') as cache_metatdata_mock, \
                    open('forecast_app/tests/predictions/docs-predictions.json') as cloud_file_fp:
                load_preds_mock.side_effect = exception
                forecast2 = Forecast.objects.create(
                    forecast_model=forecast_model, time_zero=time_zero)
                job = Job.objects.create()
                job.input_json = {
                    'forecast_pk': forecast2.pk,
                    'filename': 'a name!'
                }
                job.save()

                job_cloud_file_mock.return_value.__enter__.return_value = (
                    job, cloud_file_fp)
                try:
                    _upload_forecast_worker(job.pk)
                except JobTimeoutException as jte:
                    pass  # expected re-raise of this exception
                job.refresh_from_db()
                self.assertEqual(exp_job_status, job.status)
                self.assertIsNone(
                    Forecast.objects.filter(
                        id=forecast2.id).first())  # deleted
    def test_null_or_empty_name_or_abbreviation(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)
        for empty_name in [None, '']:
            with self.assertRaises(ValidationError) as context:
                ForecastModel.objects.create(project=project,
                                             name=empty_name,
                                             abbreviation='abbrev')
            self.assertIn('both name and abbreviation are required',
                          str(context.exception))

        for empty_abbreviation in [None, '']:
            with self.assertRaises(ValidationError) as context:
                ForecastModel.objects.create(project=project,
                                             name=forecast_model.name + '2',
                                             abbreviation=empty_abbreviation)
            self.assertIn('both name and abbreviation are required',
                          str(context.exception))
    def test_order_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)
        _update_scores_for_all_projects()

        # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get:
        #   AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context.
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)
        changes = project_config_diff(out_config_dict, edit_config_dict)
        # removes one wasted activity ('pct next week', ChangeType.FIELD_EDITED) that is wasted b/c that target is being
        # ChangeType.OBJ_REMOVED:
        ordered_changes = order_project_config_diff(changes)
        self.assertEqual(
            13, len(changes))  # contains two duplicate and one wasted change
        self.assertEqual(10, len(ordered_changes))
    def test_database_changes_for_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)

        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)

        changes = project_config_diff(
            out_config_dict,
            edit_config_dict)  # change, num_pred_eles, num_truth
        exp_changes = [(Change(ObjectType.UNIT, 'location3',
                               ChangeType.OBJ_REMOVED, None, None), 8, 0),
                       (Change(ObjectType.TARGET, 'pct next week',
                               ChangeType.OBJ_REMOVED, None, None), 7, 3),
                       (Change(ObjectType.TIMEZERO, '2011-10-02',
                               ChangeType.OBJ_REMOVED, None, None), 29, 5)]
        act_changes = database_changes_for_project_config_diff(
            project, changes)
        self.assertEqual(exp_changes, act_changes)
    def test_database_changes_for_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)
        _update_scores_for_all_projects()

        # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get:
        #   AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context.
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)

        changes = project_config_diff(out_config_dict, edit_config_dict)
        self.assertEqual(  # change, num_points, num_named, num_bins, num_samples, num_truth
            [(Change(ObjectType.UNIT, 'location3', ChangeType.OBJ_REMOVED,
                     None, None), 3, 0, 2, 10, 0),
             (Change(ObjectType.TARGET, 'pct next week',
                     ChangeType.OBJ_REMOVED, None, None), 3, 1, 3, 5, 3),
             (Change(ObjectType.TIMEZERO, '2011-10-02', ChangeType.OBJ_REMOVED,
                     None, None), 11, 2, 16, 23, 5)],
            database_changes_for_project_config_diff(project, changes))
Exemple #22
0
    def test_load_truth_data_diff(self):
        """
        Tests the relaxing of this forecast version rule when loading truth (issue
        [support truth "diff" uploads #319](https://github.com/reichlab/forecast-repository/issues/319) ):
            3. New forecast versions cannot imply any retracted prediction elements in existing versions, i.e., you
            cannot load data that's a subset of the previous forecast's data.
        """
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # loads docs-ground-truth.csv

        oracle_model = oracle_model_for_project(project)
        self.assertEqual(3, oracle_model.forecasts.count(
        ))  # for 3 timezeros: 2011-10-02, 2011-10-09, 2011-10-16
        self.assertEqual(14, truth_data_qs(project).count())

        # updates only the five location2 rows:
        load_truth_data(
            project,
            Path('forecast_app/tests/truth_data/docs-ground-truth-diff.csv'),
            file_name='docs-ground-truth-diff.csv')
        self.assertEqual(3 + 1, oracle_model.forecasts.count())
        self.assertEqual(14 + 5, truth_data_qs(project).count())
Exemple #23
0
    def test_target_rows_for_project(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        # recall that _make_docs_project() calls cache_forecast_metadata():
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # 2011, 10, 2

        # case: one model with one timezero that has five groups of one target each.
        # recall: `group_targets(project.targets.all())` (only one target/group in this case):
        #   {'pct next week':    [(1, 'pct next week', 'continuous', True, 1, 'percent')],
        #    'cases next week':  [(2, 'cases next week', 'discrete', True, 2, 'cases')],
        #    'season severity':  [(3, 'season severity', 'nominal', False, None, None)],
        #    'above baseline':   [(4, 'above baseline', 'binary', False, None, None)],
        #    'Season peak week': [(5, 'Season peak week', 'date', False, None, 'week')]}
        exp_rows = [(forecast_model, str(time_zero.timezero_date), forecast.id,
                     'Season peak week', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'above baseline', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'cases next week', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'pct next week', 1),
                    (forecast_model, str(time_zero.timezero_date), forecast.id,
                     'season severity', 1)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))

        # case: add a second forecast for a newer timezero
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)
            cache_forecast_metadata(
                forecast2
            )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = [(forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'Season peak week', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'above baseline', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'cases next week', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'pct next week', 1),
                    (forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 'season severity', 1)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))

        # case: add a second model with only forecasts for one target
        forecast_model2 = ForecastModel.objects.create(
            project=project,
            name=forecast_model.name + '2',
            abbreviation=forecast_model.abbreviation + '2')
        time_zero3 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 4))
        forecast3 = Forecast.objects.create(forecast_model=forecast_model2,
                                            source='docs-predictions.json',
                                            time_zero=time_zero3,
                                            notes="a small prediction file")
        json_io_dict = {
            "meta": {},
            "predictions": [{
                "unit": "location1",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 2.1
                }
            }]
        }
        load_predictions_from_json_io_dict(forecast3, json_io_dict, False)
        cache_forecast_metadata(
            forecast3
        )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = exp_rows + [(forecast_model2, str(
            time_zero3.timezero_date), forecast3.id, 'pct next week', 1)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))

        # case: no forecasts
        forecast.delete()
        forecast2.delete()
        forecast3.delete()
        exp_rows = [(forecast_model, '', '', '', 0),
                    (forecast_model2, '', '', '', 0)]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4])
                    for row in target_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))
Exemple #24
0
    def test_unit_rows_for_project(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        # recall that _make_docs_project() calls cache_forecast_metadata():
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)  # 2011, 10, 2

        # case: one model with one timezero. recall rows:
        # (model, newest_forecast_tz_date, newest_forecast_id,
        #  num_present_unit_names, present_unit_names, missing_unit_names):
        exp_rows = [(forecast_model, str(time_zero.timezero_date), forecast.id,
                     3, '(all)', '')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(exp_rows, act_rows)

        # case: add a second forecast for a newer timezero
        time_zero2 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 3))
        forecast2 = Forecast.objects.create(forecast_model=forecast_model,
                                            source='docs-predictions.json',
                                            time_zero=time_zero2,
                                            notes="a small prediction file")
        with open(
                'forecast_app/tests/predictions/docs-predictions.json') as fp:
            json_io_dict_in = json.load(fp)
            load_predictions_from_json_io_dict(forecast2, json_io_dict_in,
                                               False)
            cache_forecast_metadata(
                forecast2
            )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = [(forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 3, '(all)', '')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(exp_rows, act_rows)

        # case: add a second model with only forecasts for one unit
        forecast_model2 = ForecastModel.objects.create(
            project=project,
            name=forecast_model.name + '2',
            abbreviation=forecast_model.abbreviation + '2')
        time_zero3 = TimeZero.objects.create(project=project,
                                             timezero_date=datetime.date(
                                                 2011, 10, 4))
        forecast3 = Forecast.objects.create(forecast_model=forecast_model2,
                                            source='docs-predictions.json',
                                            time_zero=time_zero3,
                                            notes="a small prediction file")
        json_io_dict = {
            "meta": {},
            "predictions": [{
                "unit": "location1",
                "target": "pct next week",
                "class": "point",
                "prediction": {
                    "value": 2.1
                }
            }]
        }
        load_predictions_from_json_io_dict(forecast3, json_io_dict, False)
        cache_forecast_metadata(
            forecast3
        )  # required by _forecast_ids_to_present_unit_or_target_id_sets()

        exp_rows = [(forecast_model, str(time_zero2.timezero_date),
                     forecast2.id, 3, '(all)', ''),
                    (forecast_model2, str(time_zero3.timezero_date),
                     forecast3.id, 1, 'location1', 'location2, location3')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(exp_rows, act_rows)

        # case: exposes bug: syntax error when no forecasts in project:
        #   psycopg2.errors.SyntaxError: syntax error at or near ")"
        #   LINE 6:             WHERE f.id IN ()
        forecast.delete()
        forecast2.delete()
        forecast3.delete()
        # (model, newest_forecast_tz_date, newest_forecast_id, num_present_unit_names, present_unit_names,
        #  missing_unit_names):
        exp_rows = [(forecast_model, 'None', None, 0, '', '(all)'),
                    (forecast_model2, 'None', None, 0, '', '(all)')]
        act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5])
                    for row in unit_rows_for_project(project)]
        self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id),
                         sorted(act_rows, key=lambda _: _[0].id))
Exemple #25
0
    def test_validate_scores_query(self):
        """
        Nearly identical to test_validate_forecasts_query().
        """
        # case: query not a dict
        error_messages, _ = validate_scores_query(self.project, -1)
        self.assertEqual(1, len(error_messages))
        self.assertIn("query was not a dict", error_messages[0])

        # case: query contains invalid keys
        error_messages, _ = validate_scores_query(self.project, {'foo': -1})
        self.assertEqual(1, len(error_messages))
        self.assertIn("one or more query keys were invalid", error_messages[0])

        # case: query keys are not correct type (lists)
        for key_name in ['models', 'units', 'targets', 'timezeros']:
            error_messages, _ = validate_scores_query(self.project,
                                                      {key_name: -1})
            self.assertEqual(1, len(error_messages))
            self.assertIn(f"'{key_name}' was not a list", error_messages[0])

        # case: bad object id
        for key_name, exp_error_msg in [
            ('models', 'model with abbreviation not found'),
            ('units', 'unit with name not found'),
            ('targets', 'target with name not found'),
            ('timezeros', 'timezero with date not found')
        ]:
            error_messages, _ = validate_scores_query(self.project,
                                                      {key_name: [-1]})
            self.assertEqual(1, len(error_messages))
            self.assertIn(exp_error_msg, error_messages[0])

        # case: bad score
        error_messages, _ = validate_scores_query(self.project,
                                                  {'scores': ['bad score']})
        self.assertEqual(1, len(error_messages))
        self.assertIn("one or more scores were invalid abbreviations",
                      error_messages[0])

        # case: object references from other project (!)
        project2, time_zero2, forecast_model2, forecast2 = _make_docs_project(
            self.po_user)
        for query_dict, exp_error_msg in [
            ({
                'models': [project2.models.first().abbreviation]
            }, 'model with abbreviation not found'),
            ({
                'units': [project2.units.first().name]
            }, 'unit with name not found'),
            ({
                'targets': [project2.targets.first().name]
            }, 'target with name not found'),
            ({
                'timezeros': [
                    project2.timezeros.first().timezero_date.strftime(
                        YYYY_MM_DD_DATE_FORMAT)
                ]
            }, 'timezero with date not found')
        ]:
            error_messages, _ = validate_scores_query(self.project, query_dict)
            self.assertEqual(1, len(error_messages))
            self.assertIn(exp_error_msg, error_messages[0])

        # case: blue sky
        query = {
            'models':
            list(self.project.models.all().values_list('id', flat=True)),
            'units':
            list(self.project.units.all().values_list('id', flat=True)),
            'targets':
            list(self.project.targets.all().values_list('id', flat=True)),
            'timezeros':
            list(self.project.timezeros.all().values_list('id', flat=True)),
            'scores':
            list(SCORE_ABBREV_TO_NAME_AND_DESCR.keys())
        }
        error_messages, _ = validate_scores_query(self.project, query)
        self.assertEqual(0, len(error_messages))
Exemple #26
0
    def test_validate_forecasts_query(self):
        # case: query not a dict
        error_messages, _ = validate_forecasts_query(self.project, -1)
        self.assertEqual(1, len(error_messages))
        self.assertIn("query was not a dict", error_messages[0])

        # case: query contains invalid keys
        error_messages, _ = validate_forecasts_query(self.project, {'foo': -1})
        self.assertEqual(1, len(error_messages))
        self.assertIn("one or more query keys were invalid", error_messages[0])

        # case: query keys are not correct type (lists)
        for key_name in ['models', 'units', 'targets', 'timezeros']:
            error_messages, _ = validate_forecasts_query(
                self.project, {key_name: -1})
            self.assertEqual(1, len(error_messages))
            self.assertIn(f"'{key_name}' was not a list", error_messages[0])

        # case: as_of is not a string, or is not a date in YYYY_MM_DD_DATE_FORMAT
        error_messages, _ = validate_forecasts_query(self.project,
                                                     {'as_of': -1})
        self.assertEqual(1, len(error_messages))
        self.assertIn(f"'as_of' was not a string", error_messages[0])

        error_messages, _ = validate_forecasts_query(self.project,
                                                     {'as_of': '20201011'})
        self.assertEqual(1, len(error_messages))
        self.assertIn(f"'as_of' was not in YYYY-MM-DD format",
                      error_messages[0])

        try:
            validate_forecasts_query(self.project, {'as_of': '2020-10-11'})
        except Exception as ex:
            self.fail(f"unexpected exception: {ex}")

        # case: bad object reference
        for key_name, exp_error_msg in [
            ('models', 'model with abbreviation not found'),
            ('units', 'unit with name not found'),
            ('targets', 'target with name not found'),
            ('timezeros', 'timezero with date not found')
        ]:
            error_messages, _ = validate_forecasts_query(
                self.project, {key_name: [-1]})
            self.assertEqual(1, len(error_messages))
            self.assertIn(exp_error_msg, error_messages[0])

        # case: bad type
        error_messages, _ = validate_forecasts_query(self.project,
                                                     {'types': ['bad type']})
        self.assertEqual(1, len(error_messages))
        self.assertIn("one or more types were invalid prediction types",
                      error_messages[0])

        # case: object references from other project (!)
        project2, time_zero2, forecast_model2, forecast2 = _make_docs_project(
            self.po_user)
        for query_dict, exp_error_msg in [
            ({
                'models': [project2.models.first().abbreviation]
            }, 'model with abbreviation not found'),
            ({
                'units': [project2.units.first().name]
            }, 'unit with name not found'),
            ({
                'targets': [project2.targets.first().name]
            }, 'target with name not found'),
            ({
                'timezeros': [
                    project2.timezeros.first().timezero_date.strftime(
                        YYYY_MM_DD_DATE_FORMAT)
                ]
            }, 'timezero with date not found')
        ]:
            error_messages, _ = validate_forecasts_query(
                self.project, query_dict)
            self.assertEqual(1, len(error_messages))
            self.assertIn(exp_error_msg, error_messages[0])

        # case: blue sky
        query = {
            'models':
            list(self.project.models.all().values_list('id', flat=True)),
            'units':
            list(self.project.units.all().values_list('id', flat=True)),
            'targets':
            list(self.project.targets.all().values_list('id', flat=True)),
            'timezeros':
            list(self.project.timezeros.all().values_list('id', flat=True)),
            'types':
            list(PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS.values())
        }
        error_messages, _ = validate_forecasts_query(self.project, query)
        self.assertEqual(0, len(error_messages))
Exemple #27
0
 def setUpTestData(cls):
     _, _, cls.po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
         is_create_super=True)
     cls.project, cls.time_zero, cls.forecast_model, cls.forecast = _make_docs_project(
         cls.po_user)
    def test_data_rows_from_forecast(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, time_zero, forecast_model, forecast = _make_docs_project(
            po_user)
        unit_loc1 = project.units.filter(name='location1').first()
        unit_loc2 = project.units.filter(name='location2').first()
        unit_loc3 = project.units.filter(name='location3').first()
        target_pct_next_week = project.targets.filter(
            name='pct next week').first()
        target_cases_next_week = project.targets.filter(
            name='cases next week').first()
        target_season_severity = project.targets.filter(
            name='season severity').first()
        target_above_baseline = project.targets.filter(
            name='above baseline').first()
        target_season_peak_week = project.targets.filter(
            name='Season peak week').first()

        # rows: 5-tuple: (data_rows_bin, data_rows_named, data_rows_point, data_rows_quantile, data_rows_sample)
        loc_targ_to_exp_rows = {
            (unit_loc1, target_pct_next_week): (
                [],
                [('location1', 'pct next week', 'norm', 1.1, 2.2, None)
                 ],  # named
                [('location1', 'pct next week', 2.1)],  # point
                [],
                []),
            (unit_loc1, target_cases_next_week): (
                [],
                [('location1', 'cases next week', 'pois', 1.1, None, None)
                 ],  # named
                [],
                [],
                []),
            (unit_loc1, target_season_severity): (
                [
                    ('location1', 'season severity', 'mild', 0.0),  # bin
                    ('location1', 'season severity', 'moderate', 0.1),
                    ('location1', 'season severity', 'severe', 0.9)
                ],
                [],
                [('location1', 'season severity', 'mild')],  # point
                [],
                []),
            (unit_loc1, target_above_baseline): (
                [],
                [],
                [('location1', 'above baseline', True)],  # point
                [],
                []),
            (unit_loc1, target_season_peak_week): (
                [
                    ('location1', 'Season peak week', '2019-12-15',
                     0.01),  # bin
                    ('location1', 'Season peak week', '2019-12-22', 0.1),
                    ('location1', 'Season peak week', '2019-12-29', 0.89)
                ],
                [],
                [('location1', 'Season peak week', '2019-12-22')],  # point
                [],
                [
                    ('location1', 'Season peak week', '2020-01-05'),  # sample
                    ('location1', 'Season peak week', '2019-12-15')
                ]),
            (unit_loc2, target_pct_next_week): (
                [
                    ('location2', 'pct next week', 1.1, 0.3),  # bin
                    ('location2', 'pct next week', 2.2, 0.2),
                    ('location2', 'pct next week', 3.3, 0.5)
                ],
                [],
                [('location2', 'pct next week', 2.0)],  # point
                [
                    ('location2', 'pct next week', 0.025, 1.0),  # quantile
                    ('location2', 'pct next week', 0.25, 2.2),
                    ('location2', 'pct next week', 0.5, 2.2),
                    ('location2', 'pct next week', 0.75, 5.0),
                    ('location2', 'pct next week', 0.975, 50.0)
                ],
                []),
            (unit_loc2, target_cases_next_week): (
                [],
                [],
                [('location2', 'cases next week', 5)],  # point
                [],
                [
                    ('location2', 'cases next week', 0),  # sample
                    ('location2', 'cases next week', 2),
                    ('location2', 'cases next week', 5)
                ]),
            (unit_loc2, target_season_severity): (
                [],
                [],
                [('location2', 'season severity', 'moderate')],  # point
                [],
                [
                    ('location2', 'season severity', 'moderate'),  # sample
                    ('location2', 'season severity', 'severe'),
                    ('location2', 'season severity', 'high'),
                    ('location2', 'season severity', 'moderate'),
                    ('location2', 'season severity', 'mild')
                ]),
            (unit_loc2, target_above_baseline): (
                [('location2', 'above baseline', True, 0.9),
                 ('location2', 'above baseline', False, 0.1)],  # bin
                [],
                [],
                [],
                [
                    ('location2', 'above baseline', True),  # sample
                    ('location2', 'above baseline', False),
                    ('location2', 'above baseline', True)
                ]),
            (unit_loc2, target_season_peak_week): (
                [
                    ('location2', 'Season peak week', '2019-12-15',
                     0.01),  # bin
                    ('location2', 'Season peak week', '2019-12-22', 0.05),
                    ('location2', 'Season peak week', '2019-12-29', 0.05),
                    ('location2', 'Season peak week', '2020-01-05', 0.89)
                ],
                [],
                [('location2', 'Season peak week', '2020-01-05')],  # point
                [
                    ('location2', 'Season peak week', 0.5,
                     '2019-12-22'),  # quantile
                    ('location2', 'Season peak week', 0.75, '2019-12-29'),
                    ('location2', 'Season peak week', 0.975, '2020-01-05')
                ],
                []),
            (unit_loc3, target_pct_next_week): (
                [],
                [],
                [('location3', 'pct next week', 3.567)],  # point
                [],
                [
                    ('location3', 'pct next week', 2.3),  # sample
                    ('location3', 'pct next week', 6.5),
                    ('location3', 'pct next week', 0.0),
                    ('location3', 'pct next week', 10.0234),
                    ('location3', 'pct next week', 0.0001)
                ]),
            (unit_loc3, target_cases_next_week): (
                [
                    ('location3', 'cases next week', 0, 0.0),  # bin
                    ('location3', 'cases next week', 2, 0.1),
                    ('location3', 'cases next week', 50, 0.9)
                ],
                [],
                [('location3', 'cases next week', 10)],  # point
                [
                    ('location3', 'cases next week', 0.25, 0),  # quantile
                    ('location3', 'cases next week', 0.75, 50)
                ],
                []),
            (unit_loc3, target_season_severity): ([], [], [], [], []),
            (unit_loc3, target_above_baseline): (
                [],
                [],
                [],
                [],
                [
                    ('location3', 'above baseline', False),  # sample
                    ('location3', 'above baseline', True),
                    ('location3', 'above baseline', True)
                ]),
            (unit_loc3, target_season_peak_week): (
                [],
                [],
                [('location3', 'Season peak week', '2019-12-29')],  # point
                [],
                [
                    ('location3', 'Season peak week', '2020-01-06'),  # sample
                    ('location3', 'Season peak week', '2019-12-16')
                ]),
        }
        for (unit, target), exp_rows in loc_targ_to_exp_rows.items():
            act_rows = data_rows_from_forecast(forecast, unit, target)
            self.assertEqual(exp_rows, act_rows)
    def test_serialize_change_list(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)

        # make some changes
        out_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        edit_config_dict = copy.deepcopy(out_config_dict)
        _make_some_changes(edit_config_dict)

        # test round-trip for one Change
        changes = sorted(project_config_diff(out_config_dict,
                                             edit_config_dict),
                         key=lambda _:
                         (_.object_type, _.object_pk, _.change_type))
        exp_dict = {
            'object_type': ObjectType.PROJECT,
            'object_pk': None,
            'change_type': ChangeType.FIELD_EDITED,
            'field_name': 'name',
            'object_dict': edit_config_dict
        }
        act_dict = changes[0].serialize_to_dict()
        self.assertEqual(exp_dict, act_dict)
        self.assertEqual(changes[0], Change.deserialize_dict(exp_dict))

        # test serialize_to_dict() for all changes
        exp_dicts = [{
            'object_type': ObjectType.PROJECT,
            'object_pk': None,
            'change_type': ChangeType.FIELD_EDITED,
            'field_name': 'name',
            'object_dict': edit_config_dict
        }, {
            'object_type': ObjectType.UNIT,
            'object_pk': 'location3',
            'change_type': ChangeType.OBJ_REMOVED,
            'field_name': None,
            'object_dict': None
        }, {
            'object_type': ObjectType.UNIT,
            'object_pk': 'location4',
            'change_type': ChangeType.OBJ_ADDED,
            'field_name': None,
            'object_dict': {
                'name': 'location4'
            }
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'cases next week',
            'change_type': ChangeType.FIELD_EDITED,
            'field_name': 'is_step_ahead',
            'object_dict': {
                'name': 'cases next week',
                'type': 'discrete',
                'description':
                'A forecasted integer number of cases for a future week.',
                'is_step_ahead': False,
                'unit': 'cases',
                'range': [0, 100000],
                'cats': [0, 2, 50]
            }
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'cases next week',
            'change_type': ChangeType.FIELD_REMOVED,
            'field_name': 'step_ahead_increment',
            'object_dict': None
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'pct next week',
            'change_type': ChangeType.OBJ_ADDED,
            'field_name': None,
            'object_dict': {
                'name': 'pct next week',
                'type': 'discrete',
                'description': 'new descr',
                'is_step_ahead': True,
                'step_ahead_increment': 1,
                'unit': 'percent',
                'range': [0, 100],
                'cats': [0, 1, 1, 2, 2, 3, 3, 5, 10, 50]
            }
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'pct next week',
            'change_type': ChangeType.OBJ_ADDED,
            'field_name': None,
            'object_dict': {
                'type': 'discrete',
                'name': 'pct next week',
                'description': 'new descr',
                'is_step_ahead': True,
                'step_ahead_increment': 1,
                'unit': 'percent',
                'range': [0, 100],
                'cats': [0, 1, 1, 2, 2, 3, 3, 5, 10, 50]
            }
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'pct next week',
            'change_type': ChangeType.OBJ_REMOVED,
            'field_name': None,
            'object_dict': None
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'pct next week',
            'change_type': ChangeType.OBJ_REMOVED,
            'field_name': None,
            'object_dict': None
        }, {
            'object_type': ObjectType.TARGET,
            'object_pk': 'pct next week',
            'change_type': ChangeType.FIELD_EDITED,
            'field_name': 'description',
            'object_dict': {
                'name': 'pct next week',
                'type': 'discrete',
                'description': 'new descr',
                'is_step_ahead': True,
                'step_ahead_increment': 1,
                'unit': 'percent',
                'range': [0, 100],
                'cats': [0, 1, 1, 2, 2, 3, 3, 5, 10, 50]
            }
        }, {
            'object_type': ObjectType.TIMEZERO,
            'object_pk': '2011-10-02',
            'change_type': ChangeType.OBJ_REMOVED,
            'field_name': None,
            'object_dict': None
        }, {
            'object_type': ObjectType.TIMEZERO,
            'object_pk': '2011-10-09',
            'change_type': ChangeType.FIELD_EDITED,
            'field_name': 'data_version_date',
            'object_dict': {
                'timezero_date': '2011-10-09',
                'data_version_date': '2011-10-19',
                'is_season_start': False
            }
        }, {
            'object_type': ObjectType.TIMEZERO,
            'object_pk': '2011-10-22',
            'change_type': ChangeType.OBJ_ADDED,
            'field_name': None,
            'object_dict': {
                'timezero_date': '2011-10-22',
                'data_version_date': None,
                'is_season_start': True,
                'season_name': '2011-2012'
            }
        }]
        act_dicts = [change.serialize_to_dict() for change in changes]
        for act_dict in act_dicts:  # remove 'id' and 'url' fields from TargetSerializer to ease testing
            if act_dict['object_dict']:
                if 'id' in act_dict[
                        'object_dict']:  # deleted in previous iteration?
                    del act_dict['object_dict']['id']
                    del act_dict['object_dict']['url']
        self.assertEqual(exp_dicts, act_dicts)

        # test round-trip for all changes
        for change in changes:
            serialized_change_dict = change.serialize_to_dict()
            deserialized_change = Change.deserialize_dict(
                serialized_change_dict)
            self.assertEqual(change, deserialized_change)
    def test_project_config_diff(self):
        _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(
            is_create_super=True)
        project, _, _, _ = _make_docs_project(po_user)
        # first we remove 'id' and 'url' fields from serializers to ease testing
        current_config_dict = config_dict_from_project(
            project,
            APIRequestFactory().request())
        for the_dict_list in [
                current_config_dict['units'], current_config_dict['targets'],
                current_config_dict['timezeros']
        ]:
            for the_dict in the_dict_list:
                if 'id' in the_dict:
                    del the_dict['id']
                    del the_dict['url']

        # project fields: edit
        fields_new_values = [('name', 'new name'), ('is_public', False),
                             ('description', 'new descr'),
                             ('home_url', 'new home_url'),
                             ('logo_url', 'new logo_url'),
                             ('core_data', 'new core_data'),
                             ('time_interval_type', 'Biweek'),
                             ('visualization_y_label',
                              'new visualization_y_label')]
        edit_config_dict = copy.deepcopy(current_config_dict)
        for field_name, new_value in fields_new_values:
            edit_config_dict[field_name] = new_value
        exp_changes = [
            Change(ObjectType.PROJECT, None, ChangeType.FIELD_EDITED,
                   field_name, edit_config_dict)
            for field_name, new_value in fields_new_values
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))

        # project units: remove 'location3', add 'location4'
        edit_config_dict = copy.deepcopy(current_config_dict)
        location_3_dict = [
            target_dict for target_dict in edit_config_dict['units']
            if target_dict['name'] == 'location3'
        ][0]
        location_3_dict['name'] = 'location4'  # 'location3'
        exp_changes = [
            Change(ObjectType.UNIT, 'location3', ChangeType.OBJ_REMOVED, None,
                   None),
            Change(ObjectType.UNIT, 'location4', ChangeType.OBJ_ADDED, None,
                   location_3_dict)
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))

        # project timezeros: remove '2011-10-02', add '2011-10-22', edit '2011-10-09' fields
        edit_config_dict = copy.deepcopy(current_config_dict)

        tz_2011_10_02_dict = [
            target_dict for target_dict in edit_config_dict['timezeros']
            if target_dict['timezero_date'] == '2011-10-02'
        ][0]
        tz_2011_10_02_dict['timezero_date'] = '2011-10-22'  # was '2011-10-02'

        tz_2011_10_09_dict = [
            target_dict for target_dict in edit_config_dict['timezeros']
            if target_dict['timezero_date'] == '2011-10-09'
        ][0]
        tz_2011_10_09_dict['data_version_date'] = '2011-10-19'  # '2011-10-09'
        tz_2011_10_09_dict['is_season_start'] = True  # false
        tz_2011_10_09_dict['season_name'] = 'season name'  # null
        exp_changes = [
            Change(ObjectType.TIMEZERO, '2011-10-02', ChangeType.OBJ_REMOVED,
                   None, None),
            Change(ObjectType.TIMEZERO, '2011-10-22', ChangeType.OBJ_ADDED,
                   None, tz_2011_10_02_dict),
            Change(ObjectType.TIMEZERO, '2011-10-09', ChangeType.FIELD_EDITED,
                   'data_version_date', tz_2011_10_09_dict),
            Change(ObjectType.TIMEZERO, '2011-10-09', ChangeType.FIELD_EDITED,
                   'is_season_start', tz_2011_10_09_dict),
            Change(ObjectType.TIMEZERO, '2011-10-09', ChangeType.FIELD_ADDED,
                   'season_name', tz_2011_10_09_dict)
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))

        # project targets: remove 'pct next week', add 'pct next week 2', edit 'cases next week' and 'Season peak week'
        # fields
        edit_config_dict = copy.deepcopy(current_config_dict)
        pct_next_week_target_dict = [
            target_dict for target_dict in edit_config_dict['targets']
            if target_dict['name'] == 'pct next week'
        ][0]
        pct_next_week_target_dict[
            'name'] = 'pct next week 2'  # was 'pct next week'

        cases_next_week_target_dict = [
            target_dict for target_dict in edit_config_dict['targets']
            if target_dict['name'] == 'cases next week'
        ][0]
        cases_next_week_target_dict[
            'description'] = 'new descr'  # 'cases next week'
        cases_next_week_target_dict['is_step_ahead'] = False
        del (cases_next_week_target_dict['step_ahead_increment'])

        season_peak_week_target_dict = [
            target_dict for target_dict in edit_config_dict['targets']
            if target_dict['name'] == 'Season peak week'
        ][0]
        season_peak_week_target_dict[
            'description'] = 'new descr 2'  # 'Season peak week'
        season_peak_week_target_dict['is_step_ahead'] = True
        season_peak_week_target_dict['step_ahead_increment'] = 2
        season_peak_week_target_dict['unit'] = 'biweek'

        exp_changes = [
            Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_REMOVED,
                   None, None),
            Change(ObjectType.TARGET, 'pct next week 2', ChangeType.OBJ_ADDED,
                   None, pct_next_week_target_dict),
            Change(ObjectType.TARGET, 'cases next week',
                   ChangeType.FIELD_REMOVED, 'step_ahead_increment', None),
            Change(ObjectType.TARGET, 'cases next week',
                   ChangeType.FIELD_EDITED, 'description',
                   cases_next_week_target_dict),
            Change(ObjectType.TARGET, 'cases next week',
                   ChangeType.FIELD_EDITED, 'is_step_ahead',
                   cases_next_week_target_dict),
            Change(ObjectType.TARGET, 'Season peak week',
                   ChangeType.FIELD_ADDED, 'step_ahead_increment',
                   season_peak_week_target_dict),
            Change(ObjectType.TARGET, 'Season peak week',
                   ChangeType.FIELD_EDITED, 'description',
                   season_peak_week_target_dict),
            Change(ObjectType.TARGET, 'Season peak week',
                   ChangeType.FIELD_EDITED, 'is_step_ahead',
                   season_peak_week_target_dict),
            Change(ObjectType.TARGET, 'Season peak week',
                   ChangeType.FIELD_EDITED, 'unit',
                   season_peak_week_target_dict)
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))

        # project targets: edit 'pct next week' 'type' (non-editable) and 'description' (editable) fields
        edit_config_dict = copy.deepcopy(current_config_dict)
        pct_next_week_target_dict = [
            target_dict for target_dict in edit_config_dict['targets']
            if target_dict['name'] == 'pct next week'
        ][0]
        pct_next_week_target_dict['type'] = 'discrete'  # 'pct next week'
        pct_next_week_target_dict['description'] = 'new descr'
        exp_changes = [
            Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_REMOVED,
                   None, None),
            Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_ADDED,
                   None, pct_next_week_target_dict),
            Change(ObjectType.TARGET, 'pct next week', ChangeType.FIELD_EDITED,
                   'description', pct_next_week_target_dict)
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))

        # project targets: edit 'cases next week': remove 'range' (non-editable)
        edit_config_dict = copy.deepcopy(current_config_dict)
        cases_next_week_target_dict = [
            target_dict for target_dict in edit_config_dict['targets']
            if target_dict['name'] == 'cases next week'
        ][0]
        del (cases_next_week_target_dict['range'])  # 'cases next week

        exp_changes = [
            Change(ObjectType.TARGET, 'cases next week',
                   ChangeType.OBJ_REMOVED, None, None),
            Change(ObjectType.TARGET, 'cases next week', ChangeType.OBJ_ADDED,
                   None, cases_next_week_target_dict)
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))

        # project targets: edit 'season severity': edit 'cats' (non-editable)
        edit_config_dict = copy.deepcopy(current_config_dict)
        season_severity_target_dict = [
            target_dict for target_dict in edit_config_dict['targets']
            if target_dict['name'] == 'season severity'
        ][0]
        season_severity_target_dict[
            'cats'] = season_severity_target_dict['cats'] + ['cat 2']
        exp_changes = [
            Change(ObjectType.TARGET, 'season severity',
                   ChangeType.OBJ_REMOVED, None, None),
            Change(ObjectType.TARGET, 'season severity', ChangeType.OBJ_ADDED,
                   None, season_severity_target_dict)
        ]
        act_changes = project_config_diff(current_config_dict,
                                          edit_config_dict)
        self.assertEqual(
            sorted(exp_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)),
            sorted(act_changes,
                   key=lambda _: (_.object_type, _.object_pk, _.change_type)))