def test_diff_from_file(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) # this json file makes the same changes as _make_some_changes(): with open( Path('forecast_app/tests/project_diff/docs-project-edited.json' )) as fp: edited_config_dict = json.load(fp) changes = project_config_diff(out_config_dict, edited_config_dict) # # print a little report # print(f"* Analyzed {len(changes)} changes. Results:") # for change, num_points, num_named, num_bins, num_samples, num_quantiles, num_truth in \ # database_changes_for_project_config_diff(project, changes): # print(f"- {change.change_type.name} on {change.object_type.name} {change.object_pk!r} will delete:\n" # f" = {num_points} point predictions\n" # f" = {num_named} named predictions\n" # f" = {num_bins} bin predictions\n" # f" = {num_samples} samples\n" # f" = {num_quantiles} quantiles\n" # f" = {num_truth} truth rows") # same tests as test_execute_project_config_diff(): execute_project_config_diff(project, changes) self._do_make_some_changes_tests(project)
def test_duplicate_abbreviation(self): # duplicate names are OK, but not abbreviations _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # case: new name, duplicate abbreviation with self.assertRaises(ValidationError) as context: ForecastModel.objects.create( project=project, name=forecast_model.name + '2', abbreviation=forecast_model.abbreviation) self.assertIn('abbreviation must be unique', str(context.exception)) # case: duplicate name, new abbreviation try: forecast_model2 = ForecastModel.objects.create( project=project, name=forecast_model.name, abbreviation=forecast_model.abbreviation + '2') forecast_model2.delete() except Exception as ex: self.fail(f"unexpected exception: {ex}") # case: new name, duplicate abbreviation, but saving existing forecast_model try: forecast_model.name = forecast_model.name + '2' # new name, duplicate abbreviation forecast_model.save() forecast_model.abbreviation = forecast_model.abbreviation + '2' # duplicate name, new abbreviation forecast_model.save() except Exception as ex: self.fail(f"unexpected exception: {ex}")
def test_last_update(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # one truth and one forecast (yes truth, yes forecasts) self.assertEqual(forecast.created_at, project.last_update()) # add a second forecast for a newer timezero (yes truth, yes forecasts) time_zero2 = TimeZero.objects.create(project=project, timezero_date=datetime.date( 2011, 10, 3)) forecast2 = Forecast.objects.create( forecast_model=forecast_model, source='docs-predictions-non-dup.json', time_zero=time_zero2, notes="a small prediction file") with open( 'forecast_app/tests/predictions/docs-predictions-non-dup.json' ) as fp: json_io_dict_in = json.load(fp) load_predictions_from_json_io_dict(forecast2, json_io_dict_in, is_validate_cats=False) self.assertEqual(forecast2.created_at, project.last_update())
def test_load_truth_data_versions(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # loads docs-ground-truth.csv oracle_model = oracle_model_for_project(project) self.assertEqual(3, oracle_model.forecasts.count( )) # for 3 timezeros: 2011-10-02, 2011-10-09, 2011-10-16 self.assertEqual(14, truth_data_qs(project).count()) self.assertTrue(is_truth_data_loaded(project)) with self.assertRaisesRegex(RuntimeError, 'cannot load 100% duplicate data'): load_truth_data( project, Path('forecast_app/tests/truth_data/docs-ground-truth.csv'), file_name='docs-ground-truth.csv') load_truth_data( project, Path( 'forecast_app/tests/truth_data/docs-ground-truth-non-dup.csv'), file_name='docs-ground-truth-non-dup.csv') self.assertEqual(3 * 2, oracle_model.forecasts.count()) self.assertEqual(14 * 2, truth_data_qs(project).count())
def test__upload_forecast_worker_blue_sky(self): # blue sky to verify load_predictions_from_json_io_dict() and cache_forecast_metadata() are called. also tests # that _upload_forecast_worker() correctly sets job.output_json. this test is complicated by that function's use # of the `job_cloud_file` context manager. solution is per https://stackoverflow.com/questions/60198229/python-patch-context-manager-to-return-object _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) forecast.issued_at -= datetime.timedelta( days=1) # older version avoids unique constraint errors forecast.save() with patch('forecast_app.models.job.job_cloud_file') as job_cloud_file_mock, \ patch('utils.forecast.load_predictions_from_json_io_dict') as load_preds_mock, \ patch('utils.forecast.cache_forecast_metadata') as cache_metatdata_mock, \ open('forecast_app/tests/predictions/docs-predictions.json') as cloud_file_fp: job = Job.objects.create() job.input_json = { 'forecast_pk': forecast.pk, 'filename': 'a name!' } job.save() job_cloud_file_mock.return_value.__enter__.return_value = ( job, cloud_file_fp) _upload_forecast_worker(job.pk) job.refresh_from_db() load_preds_mock.assert_called_once() cache_metatdata_mock.assert_called_once() self.assertEqual(Job.SUCCESS, job.status) self.assertEqual(job.input_json['forecast_pk'], job.output_json['forecast_pk'])
def setUpTestData(cls): # recall that _make_docs_project() calls cache_forecast_metadata(), but the below tests assume it doesn't, so # we clear here _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users(is_create_super=True) cls.project, cls.time_zero, cls.forecast_model, cls.forecast = _make_docs_project(po_user) clear_forecast_metadata(cls.forecast) cls.forecast.issued_at -= datetime.timedelta(days=1) # older version avoids unique constraint errors cls.forecast.save()
def test_models_summary_table_rows_for_project(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # test with just one forecast - oldest and newest forecast is the same. a 7-tuple: # [forecast_model, num_forecasts, oldest_forecast_tz_date, newest_forecast_tz_date, oldest_forecast_id, # newest_forecast_id, newest_forecast_created_at]. # NB: we have to work around a Django bug where DateField and DateTimeField come out of the database as either # datetime.date/datetime.datetime objects (postgres) or strings (sqlite3) exp_row = ( forecast_model, forecast_model.forecasts.count(), str(time_zero.timezero_date), # oldest_forecast_tz_date str(time_zero.timezero_date), # newest_forecast_tz_date forecast.id, forecast.created_at.utctimetuple()) # newest_forecast_created_at act_rows = models_summary_table_rows_for_project(project) act_rows = [( act_rows[0][0], act_rows[0][1], str(act_rows[0][2]), # oldest_forecast_tz_date str(act_rows[0][3]), # newest_forecast_tz_date act_rows[0][4], act_rows[0][5].utctimetuple())] # newest_forecast_created_at sql = f"""SELECT created_at FROM {Forecast._meta.db_table} WHERE id = %s;""" with connection.cursor() as cursor: cursor.execute(sql, (forecast.pk, )) rows = cursor.fetchall() self.assertEqual([exp_row], act_rows) # test a second forecast time_zero2 = TimeZero.objects.create(project=project, timezero_date=datetime.date( 2017, 1, 1)) forecast2 = Forecast.objects.create(forecast_model=forecast_model, source='docs-predictions.json', time_zero=time_zero2) exp_row = ( forecast_model, forecast_model.forecasts.count(), str(time_zero.timezero_date), # oldest_forecast_tz_date str(time_zero2.timezero_date), # newest_forecast_tz_date forecast2.id, forecast2.created_at.utctimetuple()) # newest_forecast_created_at act_rows = models_summary_table_rows_for_project(project) act_rows = [( act_rows[0][0], act_rows[0][1], str(act_rows[0][2]), # oldest_forecast_tz_date str(act_rows[0][3]), # newest_forecast_tz_date act_rows[0][4], act_rows[0][5].utctimetuple())] # newest_forecast_created_at self.assertEqual([exp_row], act_rows)
def test_truth_batches(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # loads batch: docs-ground-truth.csv # add a second batch load_truth_data( project, Path( 'forecast_app/tests/truth_data/docs-ground-truth-non-dup.csv'), file_name='docs-ground-truth-non-dup.csv') oracle_model = oracle_model_for_project(project) first_forecast = oracle_model.forecasts.first() last_forecast = oracle_model.forecasts.last() # test truth_batches() and truth_batch_forecasts() for each batch batches = truth_batches(project) self.assertEqual(2, len(batches)) self.assertEqual(first_forecast.source, batches[0][0]) self.assertEqual(first_forecast.issued_at, batches[0][1]) self.assertEqual(last_forecast.source, batches[1][0]) self.assertEqual(last_forecast.issued_at, batches[1][1]) for source, issued_at in batches: forecasts = truth_batch_forecasts(project, source, issued_at) self.assertEqual(3, len(forecasts)) for forecast in forecasts: self.assertEqual(source, forecast.source) self.assertEqual(issued_at, forecast.issued_at) # test truth_batch_summary_table(). NB: utctimetuple() makes sqlite comparisons work exp_table = [(source, issued_at.utctimetuple(), len(truth_batch_forecasts(project, source, issued_at))) for source, issued_at in batches] act_table = [(source, issued_at.utctimetuple(), num_forecasts) for source, issued_at, num_forecasts in truth_batch_summary_table(project)] self.assertEqual(exp_table, act_table) # finally, test deleting a batch. try deleting the first, which should fail due to version rules. # transaction.atomic() somehow avoids the second `truth_delete_batch()` call getting the error: # django.db.transaction.TransactionManagementError: An error occurred in the current transaction. You can't execute queries until the end of the 'atomic' block. with transaction.atomic(): with self.assertRaisesRegex( RuntimeError, 'you cannot delete a forecast that has any newer versions' ): truth_delete_batch(project, batches[0][0], batches[0][1]) # delete second batch - should not fail truth_delete_batch(project, batches[1][0], batches[1][1]) batches = truth_batches(project) self.assertEqual(1, len(batches)) self.assertEqual(first_forecast.source, batches[0][0]) self.assertEqual(first_forecast.issued_at, batches[0][1])
def setUp( self ): # runs before every test. done here instead of setUpTestData(cls) b/c below tests modify the db _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) self.project, self.time_zero, self.forecast_model, self.forecast = _make_docs_project( po_user) self.tz1 = self.project.timezeros.filter( timezero_date=datetime.date(2011, 10, 2)).first() self.tz2 = self.project.timezeros.filter( timezero_date=datetime.date(2011, 10, 9)).first() self.tz3 = self.project.timezeros.filter( timezero_date=datetime.date(2011, 10, 16)).first()
def test_execute_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) # make some changes out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) changes = project_config_diff(out_config_dict, edit_config_dict) execute_project_config_diff(project, changes) self._do_make_some_changes_tests(project)
def test__upload_forecast_worker_atomic(self): # test `_upload_forecast_worker()` does not create a Forecast if subsequent calls to # `load_predictions_from_json_io_dict()` or `cache_forecast_metadata()` fail. this test is complicated by that # function's use of the `job_cloud_file` context manager. solution is per https://stackoverflow.com/questions/60198229/python-patch-context-manager-to-return-object _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) forecast.issued_at -= datetime.timedelta( days=1) # older version avoids unique constraint errors forecast.save() with patch('forecast_app.models.job.job_cloud_file') as job_cloud_file_mock, \ patch('utils.forecast.load_predictions_from_json_io_dict') as load_preds_mock, \ patch('utils.forecast.cache_forecast_metadata') as cache_metatdata_mock: forecast2 = Forecast.objects.create(forecast_model=forecast_model, time_zero=time_zero) job = Job.objects.create() job.input_json = { 'forecast_pk': forecast2.pk, 'filename': 'a name!' } job.save() job_cloud_file_mock.return_value.__enter__.return_value = ( job, None) # 2-tuple: (job, cloud_file_fp) # test that no Forecast is created when load_predictions_from_json_io_dict() fails load_preds_mock.side_effect = Exception( 'load_preds_mock Exception') num_forecasts_before = forecast_model.forecasts.count() _upload_forecast_worker(job.pk) job.refresh_from_db() self.assertEqual( num_forecasts_before - 1, forecast_model.forecasts.count()) # -1 b/c forecast2 deleted self.assertEqual(Job.FAILED, job.status) # test when cache_forecast_metadata() fails load_preds_mock.reset_mock(side_effect=True) cache_metatdata_mock.side_effect = Exception( 'cache_metatdata_mock Exception') num_forecasts_before = forecast_model.forecasts.count() _upload_forecast_worker(job.pk) job.refresh_from_db() self.assertEqual(num_forecasts_before, forecast_model.forecasts.count()) self.assertEqual(Job.FAILED, job.status)
def test_order_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) changes = project_config_diff(out_config_dict, edit_config_dict) # removes one wasted activity ('pct next week', ChangeType.FIELD_EDITED) that is wasted b/c that target is being # ChangeType.OBJ_REMOVED: ordered_changes = order_project_config_diff(changes) self.assertEqual( 13, len(changes)) # contains two duplicate and one wasted change self.assertEqual(10, len(ordered_changes))
def test_load_truth_data_partial_dup(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # loads batch: docs-ground-truth.csv try: load_truth_data( project, Path( 'forecast_app/tests/truth_data/docs-ground-truth-partial-dup.csv' ), file_name='docs-ground-truth-partial-dup.csv') batches = truth_batches(project) self.assertEqual(2, len(batches)) except Exception as ex: self.fail(f"unexpected exception: {ex}")
def test_diff_from_file_empty_data_version_date_string(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get: # AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context. out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edited_config_dict = copy.deepcopy(out_config_dict) # change '2011-10-02': None -> '' (incorrect, but we fix for users) edited_config_dict['timezeros'][0]['data_version_date'] = '' changes = project_config_diff(out_config_dict, edited_config_dict) self.assertEqual( 0, len(changes) ) # is 1 without the fix "this test for `!= ''` matches this one below"
def test_execute_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) _update_scores_for_all_projects() # make some changes # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get: # AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context. out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) changes = project_config_diff(out_config_dict, edit_config_dict) execute_project_config_diff(project, changes) self._do_make_some_changes_tests(project)
def test_load_predictions_from_json_io_dict_existing_pred_eles(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) json_io_dict = { "predictions": [{ "unit": "location1", "target": "pct next week", "class": "point", "prediction": { "value": 2.1 } }] } with self.assertRaises(RuntimeError) as context: load_predictions_from_json_io_dict(forecast, json_io_dict) self.assertIn("cannot load data into a non-empty forecast", str(context.exception))
def test__upload_forecast_worker_deletes_forecast(self): # verifies that _upload_forecast_worker() deletes the (presumably empty) Forecast that's passed to it by # upload functions if the file is invalid. here we mock load_predictions_from_json_io_dict() to throw the two # exceptions that cause deletes: JobTimeoutException and Exception _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) forecast.issued_at -= datetime.timedelta( days=1) # older version avoids unique constraint errors forecast.save() for exception, exp_job_status in [ (Exception('load_preds_mock Exception'), Job.FAILED), (JobTimeoutException('load_preds_mock JobTimeoutException'), Job.TIMEOUT) ]: with patch('forecast_app.models.job.job_cloud_file') as job_cloud_file_mock, \ patch('utils.forecast.load_predictions_from_json_io_dict') as load_preds_mock, \ patch('utils.forecast.cache_forecast_metadata') as cache_metatdata_mock, \ open('forecast_app/tests/predictions/docs-predictions.json') as cloud_file_fp: load_preds_mock.side_effect = exception forecast2 = Forecast.objects.create( forecast_model=forecast_model, time_zero=time_zero) job = Job.objects.create() job.input_json = { 'forecast_pk': forecast2.pk, 'filename': 'a name!' } job.save() job_cloud_file_mock.return_value.__enter__.return_value = ( job, cloud_file_fp) try: _upload_forecast_worker(job.pk) except JobTimeoutException as jte: pass # expected re-raise of this exception job.refresh_from_db() self.assertEqual(exp_job_status, job.status) self.assertIsNone( Forecast.objects.filter( id=forecast2.id).first()) # deleted
def test_null_or_empty_name_or_abbreviation(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) for empty_name in [None, '']: with self.assertRaises(ValidationError) as context: ForecastModel.objects.create(project=project, name=empty_name, abbreviation='abbrev') self.assertIn('both name and abbreviation are required', str(context.exception)) for empty_abbreviation in [None, '']: with self.assertRaises(ValidationError) as context: ForecastModel.objects.create(project=project, name=forecast_model.name + '2', abbreviation=empty_abbreviation) self.assertIn('both name and abbreviation are required', str(context.exception))
def test_order_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) _update_scores_for_all_projects() # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get: # AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context. out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) changes = project_config_diff(out_config_dict, edit_config_dict) # removes one wasted activity ('pct next week', ChangeType.FIELD_EDITED) that is wasted b/c that target is being # ChangeType.OBJ_REMOVED: ordered_changes = order_project_config_diff(changes) self.assertEqual( 13, len(changes)) # contains two duplicate and one wasted change self.assertEqual(10, len(ordered_changes))
def test_database_changes_for_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) changes = project_config_diff( out_config_dict, edit_config_dict) # change, num_pred_eles, num_truth exp_changes = [(Change(ObjectType.UNIT, 'location3', ChangeType.OBJ_REMOVED, None, None), 8, 0), (Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_REMOVED, None, None), 7, 3), (Change(ObjectType.TIMEZERO, '2011-10-02', ChangeType.OBJ_REMOVED, None, None), 29, 5)] act_changes = database_changes_for_project_config_diff( project, changes) self.assertEqual(exp_changes, act_changes)
def test_database_changes_for_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) _update_scores_for_all_projects() # note: using APIRequestFactory was the only way I could find to pass a request object. o/w you get: # AssertionError: `HyperlinkedIdentityField` requires the request in the serializer context. out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) changes = project_config_diff(out_config_dict, edit_config_dict) self.assertEqual( # change, num_points, num_named, num_bins, num_samples, num_truth [(Change(ObjectType.UNIT, 'location3', ChangeType.OBJ_REMOVED, None, None), 3, 0, 2, 10, 0), (Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_REMOVED, None, None), 3, 1, 3, 5, 3), (Change(ObjectType.TIMEZERO, '2011-10-02', ChangeType.OBJ_REMOVED, None, None), 11, 2, 16, 23, 5)], database_changes_for_project_config_diff(project, changes))
def test_load_truth_data_diff(self): """ Tests the relaxing of this forecast version rule when loading truth (issue [support truth "diff" uploads #319](https://github.com/reichlab/forecast-repository/issues/319) ): 3. New forecast versions cannot imply any retracted prediction elements in existing versions, i.e., you cannot load data that's a subset of the previous forecast's data. """ _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # loads docs-ground-truth.csv oracle_model = oracle_model_for_project(project) self.assertEqual(3, oracle_model.forecasts.count( )) # for 3 timezeros: 2011-10-02, 2011-10-09, 2011-10-16 self.assertEqual(14, truth_data_qs(project).count()) # updates only the five location2 rows: load_truth_data( project, Path('forecast_app/tests/truth_data/docs-ground-truth-diff.csv'), file_name='docs-ground-truth-diff.csv') self.assertEqual(3 + 1, oracle_model.forecasts.count()) self.assertEqual(14 + 5, truth_data_qs(project).count())
def test_target_rows_for_project(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) # recall that _make_docs_project() calls cache_forecast_metadata(): project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # 2011, 10, 2 # case: one model with one timezero that has five groups of one target each. # recall: `group_targets(project.targets.all())` (only one target/group in this case): # {'pct next week': [(1, 'pct next week', 'continuous', True, 1, 'percent')], # 'cases next week': [(2, 'cases next week', 'discrete', True, 2, 'cases')], # 'season severity': [(3, 'season severity', 'nominal', False, None, None)], # 'above baseline': [(4, 'above baseline', 'binary', False, None, None)], # 'Season peak week': [(5, 'Season peak week', 'date', False, None, 'week')]} exp_rows = [(forecast_model, str(time_zero.timezero_date), forecast.id, 'Season peak week', 1), (forecast_model, str(time_zero.timezero_date), forecast.id, 'above baseline', 1), (forecast_model, str(time_zero.timezero_date), forecast.id, 'cases next week', 1), (forecast_model, str(time_zero.timezero_date), forecast.id, 'pct next week', 1), (forecast_model, str(time_zero.timezero_date), forecast.id, 'season severity', 1)] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4]) for row in target_rows_for_project(project)] self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id), sorted(act_rows, key=lambda _: _[0].id)) # case: add a second forecast for a newer timezero time_zero2 = TimeZero.objects.create(project=project, timezero_date=datetime.date( 2011, 10, 3)) forecast2 = Forecast.objects.create(forecast_model=forecast_model, source='docs-predictions.json', time_zero=time_zero2, notes="a small prediction file") with open( 'forecast_app/tests/predictions/docs-predictions.json') as fp: json_io_dict_in = json.load(fp) load_predictions_from_json_io_dict(forecast2, json_io_dict_in, False) cache_forecast_metadata( forecast2 ) # required by _forecast_ids_to_present_unit_or_target_id_sets() exp_rows = [(forecast_model, str(time_zero2.timezero_date), forecast2.id, 'Season peak week', 1), (forecast_model, str(time_zero2.timezero_date), forecast2.id, 'above baseline', 1), (forecast_model, str(time_zero2.timezero_date), forecast2.id, 'cases next week', 1), (forecast_model, str(time_zero2.timezero_date), forecast2.id, 'pct next week', 1), (forecast_model, str(time_zero2.timezero_date), forecast2.id, 'season severity', 1)] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4]) for row in target_rows_for_project(project)] self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id), sorted(act_rows, key=lambda _: _[0].id)) # case: add a second model with only forecasts for one target forecast_model2 = ForecastModel.objects.create( project=project, name=forecast_model.name + '2', abbreviation=forecast_model.abbreviation + '2') time_zero3 = TimeZero.objects.create(project=project, timezero_date=datetime.date( 2011, 10, 4)) forecast3 = Forecast.objects.create(forecast_model=forecast_model2, source='docs-predictions.json', time_zero=time_zero3, notes="a small prediction file") json_io_dict = { "meta": {}, "predictions": [{ "unit": "location1", "target": "pct next week", "class": "point", "prediction": { "value": 2.1 } }] } load_predictions_from_json_io_dict(forecast3, json_io_dict, False) cache_forecast_metadata( forecast3 ) # required by _forecast_ids_to_present_unit_or_target_id_sets() exp_rows = exp_rows + [(forecast_model2, str( time_zero3.timezero_date), forecast3.id, 'pct next week', 1)] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4]) for row in target_rows_for_project(project)] self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id), sorted(act_rows, key=lambda _: _[0].id)) # case: no forecasts forecast.delete() forecast2.delete() forecast3.delete() exp_rows = [(forecast_model, '', '', '', 0), (forecast_model2, '', '', '', 0)] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4]) for row in target_rows_for_project(project)] self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id), sorted(act_rows, key=lambda _: _[0].id))
def test_unit_rows_for_project(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) # recall that _make_docs_project() calls cache_forecast_metadata(): project, time_zero, forecast_model, forecast = _make_docs_project( po_user) # 2011, 10, 2 # case: one model with one timezero. recall rows: # (model, newest_forecast_tz_date, newest_forecast_id, # num_present_unit_names, present_unit_names, missing_unit_names): exp_rows = [(forecast_model, str(time_zero.timezero_date), forecast.id, 3, '(all)', '')] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5]) for row in unit_rows_for_project(project)] self.assertEqual(exp_rows, act_rows) # case: add a second forecast for a newer timezero time_zero2 = TimeZero.objects.create(project=project, timezero_date=datetime.date( 2011, 10, 3)) forecast2 = Forecast.objects.create(forecast_model=forecast_model, source='docs-predictions.json', time_zero=time_zero2, notes="a small prediction file") with open( 'forecast_app/tests/predictions/docs-predictions.json') as fp: json_io_dict_in = json.load(fp) load_predictions_from_json_io_dict(forecast2, json_io_dict_in, False) cache_forecast_metadata( forecast2 ) # required by _forecast_ids_to_present_unit_or_target_id_sets() exp_rows = [(forecast_model, str(time_zero2.timezero_date), forecast2.id, 3, '(all)', '')] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5]) for row in unit_rows_for_project(project)] self.assertEqual(exp_rows, act_rows) # case: add a second model with only forecasts for one unit forecast_model2 = ForecastModel.objects.create( project=project, name=forecast_model.name + '2', abbreviation=forecast_model.abbreviation + '2') time_zero3 = TimeZero.objects.create(project=project, timezero_date=datetime.date( 2011, 10, 4)) forecast3 = Forecast.objects.create(forecast_model=forecast_model2, source='docs-predictions.json', time_zero=time_zero3, notes="a small prediction file") json_io_dict = { "meta": {}, "predictions": [{ "unit": "location1", "target": "pct next week", "class": "point", "prediction": { "value": 2.1 } }] } load_predictions_from_json_io_dict(forecast3, json_io_dict, False) cache_forecast_metadata( forecast3 ) # required by _forecast_ids_to_present_unit_or_target_id_sets() exp_rows = [(forecast_model, str(time_zero2.timezero_date), forecast2.id, 3, '(all)', ''), (forecast_model2, str(time_zero3.timezero_date), forecast3.id, 1, 'location1', 'location2, location3')] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5]) for row in unit_rows_for_project(project)] self.assertEqual(exp_rows, act_rows) # case: exposes bug: syntax error when no forecasts in project: # psycopg2.errors.SyntaxError: syntax error at or near ")" # LINE 6: WHERE f.id IN () forecast.delete() forecast2.delete() forecast3.delete() # (model, newest_forecast_tz_date, newest_forecast_id, num_present_unit_names, present_unit_names, # missing_unit_names): exp_rows = [(forecast_model, 'None', None, 0, '', '(all)'), (forecast_model2, 'None', None, 0, '', '(all)')] act_rows = [(row[0], str(row[1]), row[2], row[3], row[4], row[5]) for row in unit_rows_for_project(project)] self.assertEqual(sorted(exp_rows, key=lambda _: _[0].id), sorted(act_rows, key=lambda _: _[0].id))
def test_validate_scores_query(self): """ Nearly identical to test_validate_forecasts_query(). """ # case: query not a dict error_messages, _ = validate_scores_query(self.project, -1) self.assertEqual(1, len(error_messages)) self.assertIn("query was not a dict", error_messages[0]) # case: query contains invalid keys error_messages, _ = validate_scores_query(self.project, {'foo': -1}) self.assertEqual(1, len(error_messages)) self.assertIn("one or more query keys were invalid", error_messages[0]) # case: query keys are not correct type (lists) for key_name in ['models', 'units', 'targets', 'timezeros']: error_messages, _ = validate_scores_query(self.project, {key_name: -1}) self.assertEqual(1, len(error_messages)) self.assertIn(f"'{key_name}' was not a list", error_messages[0]) # case: bad object id for key_name, exp_error_msg in [ ('models', 'model with abbreviation not found'), ('units', 'unit with name not found'), ('targets', 'target with name not found'), ('timezeros', 'timezero with date not found') ]: error_messages, _ = validate_scores_query(self.project, {key_name: [-1]}) self.assertEqual(1, len(error_messages)) self.assertIn(exp_error_msg, error_messages[0]) # case: bad score error_messages, _ = validate_scores_query(self.project, {'scores': ['bad score']}) self.assertEqual(1, len(error_messages)) self.assertIn("one or more scores were invalid abbreviations", error_messages[0]) # case: object references from other project (!) project2, time_zero2, forecast_model2, forecast2 = _make_docs_project( self.po_user) for query_dict, exp_error_msg in [ ({ 'models': [project2.models.first().abbreviation] }, 'model with abbreviation not found'), ({ 'units': [project2.units.first().name] }, 'unit with name not found'), ({ 'targets': [project2.targets.first().name] }, 'target with name not found'), ({ 'timezeros': [ project2.timezeros.first().timezero_date.strftime( YYYY_MM_DD_DATE_FORMAT) ] }, 'timezero with date not found') ]: error_messages, _ = validate_scores_query(self.project, query_dict) self.assertEqual(1, len(error_messages)) self.assertIn(exp_error_msg, error_messages[0]) # case: blue sky query = { 'models': list(self.project.models.all().values_list('id', flat=True)), 'units': list(self.project.units.all().values_list('id', flat=True)), 'targets': list(self.project.targets.all().values_list('id', flat=True)), 'timezeros': list(self.project.timezeros.all().values_list('id', flat=True)), 'scores': list(SCORE_ABBREV_TO_NAME_AND_DESCR.keys()) } error_messages, _ = validate_scores_query(self.project, query) self.assertEqual(0, len(error_messages))
def test_validate_forecasts_query(self): # case: query not a dict error_messages, _ = validate_forecasts_query(self.project, -1) self.assertEqual(1, len(error_messages)) self.assertIn("query was not a dict", error_messages[0]) # case: query contains invalid keys error_messages, _ = validate_forecasts_query(self.project, {'foo': -1}) self.assertEqual(1, len(error_messages)) self.assertIn("one or more query keys were invalid", error_messages[0]) # case: query keys are not correct type (lists) for key_name in ['models', 'units', 'targets', 'timezeros']: error_messages, _ = validate_forecasts_query( self.project, {key_name: -1}) self.assertEqual(1, len(error_messages)) self.assertIn(f"'{key_name}' was not a list", error_messages[0]) # case: as_of is not a string, or is not a date in YYYY_MM_DD_DATE_FORMAT error_messages, _ = validate_forecasts_query(self.project, {'as_of': -1}) self.assertEqual(1, len(error_messages)) self.assertIn(f"'as_of' was not a string", error_messages[0]) error_messages, _ = validate_forecasts_query(self.project, {'as_of': '20201011'}) self.assertEqual(1, len(error_messages)) self.assertIn(f"'as_of' was not in YYYY-MM-DD format", error_messages[0]) try: validate_forecasts_query(self.project, {'as_of': '2020-10-11'}) except Exception as ex: self.fail(f"unexpected exception: {ex}") # case: bad object reference for key_name, exp_error_msg in [ ('models', 'model with abbreviation not found'), ('units', 'unit with name not found'), ('targets', 'target with name not found'), ('timezeros', 'timezero with date not found') ]: error_messages, _ = validate_forecasts_query( self.project, {key_name: [-1]}) self.assertEqual(1, len(error_messages)) self.assertIn(exp_error_msg, error_messages[0]) # case: bad type error_messages, _ = validate_forecasts_query(self.project, {'types': ['bad type']}) self.assertEqual(1, len(error_messages)) self.assertIn("one or more types were invalid prediction types", error_messages[0]) # case: object references from other project (!) project2, time_zero2, forecast_model2, forecast2 = _make_docs_project( self.po_user) for query_dict, exp_error_msg in [ ({ 'models': [project2.models.first().abbreviation] }, 'model with abbreviation not found'), ({ 'units': [project2.units.first().name] }, 'unit with name not found'), ({ 'targets': [project2.targets.first().name] }, 'target with name not found'), ({ 'timezeros': [ project2.timezeros.first().timezero_date.strftime( YYYY_MM_DD_DATE_FORMAT) ] }, 'timezero with date not found') ]: error_messages, _ = validate_forecasts_query( self.project, query_dict) self.assertEqual(1, len(error_messages)) self.assertIn(exp_error_msg, error_messages[0]) # case: blue sky query = { 'models': list(self.project.models.all().values_list('id', flat=True)), 'units': list(self.project.units.all().values_list('id', flat=True)), 'targets': list(self.project.targets.all().values_list('id', flat=True)), 'timezeros': list(self.project.timezeros.all().values_list('id', flat=True)), 'types': list(PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS.values()) } error_messages, _ = validate_forecasts_query(self.project, query) self.assertEqual(0, len(error_messages))
def setUpTestData(cls): _, _, cls.po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) cls.project, cls.time_zero, cls.forecast_model, cls.forecast = _make_docs_project( cls.po_user)
def test_data_rows_from_forecast(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, time_zero, forecast_model, forecast = _make_docs_project( po_user) unit_loc1 = project.units.filter(name='location1').first() unit_loc2 = project.units.filter(name='location2').first() unit_loc3 = project.units.filter(name='location3').first() target_pct_next_week = project.targets.filter( name='pct next week').first() target_cases_next_week = project.targets.filter( name='cases next week').first() target_season_severity = project.targets.filter( name='season severity').first() target_above_baseline = project.targets.filter( name='above baseline').first() target_season_peak_week = project.targets.filter( name='Season peak week').first() # rows: 5-tuple: (data_rows_bin, data_rows_named, data_rows_point, data_rows_quantile, data_rows_sample) loc_targ_to_exp_rows = { (unit_loc1, target_pct_next_week): ( [], [('location1', 'pct next week', 'norm', 1.1, 2.2, None) ], # named [('location1', 'pct next week', 2.1)], # point [], []), (unit_loc1, target_cases_next_week): ( [], [('location1', 'cases next week', 'pois', 1.1, None, None) ], # named [], [], []), (unit_loc1, target_season_severity): ( [ ('location1', 'season severity', 'mild', 0.0), # bin ('location1', 'season severity', 'moderate', 0.1), ('location1', 'season severity', 'severe', 0.9) ], [], [('location1', 'season severity', 'mild')], # point [], []), (unit_loc1, target_above_baseline): ( [], [], [('location1', 'above baseline', True)], # point [], []), (unit_loc1, target_season_peak_week): ( [ ('location1', 'Season peak week', '2019-12-15', 0.01), # bin ('location1', 'Season peak week', '2019-12-22', 0.1), ('location1', 'Season peak week', '2019-12-29', 0.89) ], [], [('location1', 'Season peak week', '2019-12-22')], # point [], [ ('location1', 'Season peak week', '2020-01-05'), # sample ('location1', 'Season peak week', '2019-12-15') ]), (unit_loc2, target_pct_next_week): ( [ ('location2', 'pct next week', 1.1, 0.3), # bin ('location2', 'pct next week', 2.2, 0.2), ('location2', 'pct next week', 3.3, 0.5) ], [], [('location2', 'pct next week', 2.0)], # point [ ('location2', 'pct next week', 0.025, 1.0), # quantile ('location2', 'pct next week', 0.25, 2.2), ('location2', 'pct next week', 0.5, 2.2), ('location2', 'pct next week', 0.75, 5.0), ('location2', 'pct next week', 0.975, 50.0) ], []), (unit_loc2, target_cases_next_week): ( [], [], [('location2', 'cases next week', 5)], # point [], [ ('location2', 'cases next week', 0), # sample ('location2', 'cases next week', 2), ('location2', 'cases next week', 5) ]), (unit_loc2, target_season_severity): ( [], [], [('location2', 'season severity', 'moderate')], # point [], [ ('location2', 'season severity', 'moderate'), # sample ('location2', 'season severity', 'severe'), ('location2', 'season severity', 'high'), ('location2', 'season severity', 'moderate'), ('location2', 'season severity', 'mild') ]), (unit_loc2, target_above_baseline): ( [('location2', 'above baseline', True, 0.9), ('location2', 'above baseline', False, 0.1)], # bin [], [], [], [ ('location2', 'above baseline', True), # sample ('location2', 'above baseline', False), ('location2', 'above baseline', True) ]), (unit_loc2, target_season_peak_week): ( [ ('location2', 'Season peak week', '2019-12-15', 0.01), # bin ('location2', 'Season peak week', '2019-12-22', 0.05), ('location2', 'Season peak week', '2019-12-29', 0.05), ('location2', 'Season peak week', '2020-01-05', 0.89) ], [], [('location2', 'Season peak week', '2020-01-05')], # point [ ('location2', 'Season peak week', 0.5, '2019-12-22'), # quantile ('location2', 'Season peak week', 0.75, '2019-12-29'), ('location2', 'Season peak week', 0.975, '2020-01-05') ], []), (unit_loc3, target_pct_next_week): ( [], [], [('location3', 'pct next week', 3.567)], # point [], [ ('location3', 'pct next week', 2.3), # sample ('location3', 'pct next week', 6.5), ('location3', 'pct next week', 0.0), ('location3', 'pct next week', 10.0234), ('location3', 'pct next week', 0.0001) ]), (unit_loc3, target_cases_next_week): ( [ ('location3', 'cases next week', 0, 0.0), # bin ('location3', 'cases next week', 2, 0.1), ('location3', 'cases next week', 50, 0.9) ], [], [('location3', 'cases next week', 10)], # point [ ('location3', 'cases next week', 0.25, 0), # quantile ('location3', 'cases next week', 0.75, 50) ], []), (unit_loc3, target_season_severity): ([], [], [], [], []), (unit_loc3, target_above_baseline): ( [], [], [], [], [ ('location3', 'above baseline', False), # sample ('location3', 'above baseline', True), ('location3', 'above baseline', True) ]), (unit_loc3, target_season_peak_week): ( [], [], [('location3', 'Season peak week', '2019-12-29')], # point [], [ ('location3', 'Season peak week', '2020-01-06'), # sample ('location3', 'Season peak week', '2019-12-16') ]), } for (unit, target), exp_rows in loc_targ_to_exp_rows.items(): act_rows = data_rows_from_forecast(forecast, unit, target) self.assertEqual(exp_rows, act_rows)
def test_serialize_change_list(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) # make some changes out_config_dict = config_dict_from_project( project, APIRequestFactory().request()) edit_config_dict = copy.deepcopy(out_config_dict) _make_some_changes(edit_config_dict) # test round-trip for one Change changes = sorted(project_config_diff(out_config_dict, edit_config_dict), key=lambda _: (_.object_type, _.object_pk, _.change_type)) exp_dict = { 'object_type': ObjectType.PROJECT, 'object_pk': None, 'change_type': ChangeType.FIELD_EDITED, 'field_name': 'name', 'object_dict': edit_config_dict } act_dict = changes[0].serialize_to_dict() self.assertEqual(exp_dict, act_dict) self.assertEqual(changes[0], Change.deserialize_dict(exp_dict)) # test serialize_to_dict() for all changes exp_dicts = [{ 'object_type': ObjectType.PROJECT, 'object_pk': None, 'change_type': ChangeType.FIELD_EDITED, 'field_name': 'name', 'object_dict': edit_config_dict }, { 'object_type': ObjectType.UNIT, 'object_pk': 'location3', 'change_type': ChangeType.OBJ_REMOVED, 'field_name': None, 'object_dict': None }, { 'object_type': ObjectType.UNIT, 'object_pk': 'location4', 'change_type': ChangeType.OBJ_ADDED, 'field_name': None, 'object_dict': { 'name': 'location4' } }, { 'object_type': ObjectType.TARGET, 'object_pk': 'cases next week', 'change_type': ChangeType.FIELD_EDITED, 'field_name': 'is_step_ahead', 'object_dict': { 'name': 'cases next week', 'type': 'discrete', 'description': 'A forecasted integer number of cases for a future week.', 'is_step_ahead': False, 'unit': 'cases', 'range': [0, 100000], 'cats': [0, 2, 50] } }, { 'object_type': ObjectType.TARGET, 'object_pk': 'cases next week', 'change_type': ChangeType.FIELD_REMOVED, 'field_name': 'step_ahead_increment', 'object_dict': None }, { 'object_type': ObjectType.TARGET, 'object_pk': 'pct next week', 'change_type': ChangeType.OBJ_ADDED, 'field_name': None, 'object_dict': { 'name': 'pct next week', 'type': 'discrete', 'description': 'new descr', 'is_step_ahead': True, 'step_ahead_increment': 1, 'unit': 'percent', 'range': [0, 100], 'cats': [0, 1, 1, 2, 2, 3, 3, 5, 10, 50] } }, { 'object_type': ObjectType.TARGET, 'object_pk': 'pct next week', 'change_type': ChangeType.OBJ_ADDED, 'field_name': None, 'object_dict': { 'type': 'discrete', 'name': 'pct next week', 'description': 'new descr', 'is_step_ahead': True, 'step_ahead_increment': 1, 'unit': 'percent', 'range': [0, 100], 'cats': [0, 1, 1, 2, 2, 3, 3, 5, 10, 50] } }, { 'object_type': ObjectType.TARGET, 'object_pk': 'pct next week', 'change_type': ChangeType.OBJ_REMOVED, 'field_name': None, 'object_dict': None }, { 'object_type': ObjectType.TARGET, 'object_pk': 'pct next week', 'change_type': ChangeType.OBJ_REMOVED, 'field_name': None, 'object_dict': None }, { 'object_type': ObjectType.TARGET, 'object_pk': 'pct next week', 'change_type': ChangeType.FIELD_EDITED, 'field_name': 'description', 'object_dict': { 'name': 'pct next week', 'type': 'discrete', 'description': 'new descr', 'is_step_ahead': True, 'step_ahead_increment': 1, 'unit': 'percent', 'range': [0, 100], 'cats': [0, 1, 1, 2, 2, 3, 3, 5, 10, 50] } }, { 'object_type': ObjectType.TIMEZERO, 'object_pk': '2011-10-02', 'change_type': ChangeType.OBJ_REMOVED, 'field_name': None, 'object_dict': None }, { 'object_type': ObjectType.TIMEZERO, 'object_pk': '2011-10-09', 'change_type': ChangeType.FIELD_EDITED, 'field_name': 'data_version_date', 'object_dict': { 'timezero_date': '2011-10-09', 'data_version_date': '2011-10-19', 'is_season_start': False } }, { 'object_type': ObjectType.TIMEZERO, 'object_pk': '2011-10-22', 'change_type': ChangeType.OBJ_ADDED, 'field_name': None, 'object_dict': { 'timezero_date': '2011-10-22', 'data_version_date': None, 'is_season_start': True, 'season_name': '2011-2012' } }] act_dicts = [change.serialize_to_dict() for change in changes] for act_dict in act_dicts: # remove 'id' and 'url' fields from TargetSerializer to ease testing if act_dict['object_dict']: if 'id' in act_dict[ 'object_dict']: # deleted in previous iteration? del act_dict['object_dict']['id'] del act_dict['object_dict']['url'] self.assertEqual(exp_dicts, act_dicts) # test round-trip for all changes for change in changes: serialized_change_dict = change.serialize_to_dict() deserialized_change = Change.deserialize_dict( serialized_change_dict) self.assertEqual(change, deserialized_change)
def test_project_config_diff(self): _, _, po_user, _, _, _, _, _ = get_or_create_super_po_mo_users( is_create_super=True) project, _, _, _ = _make_docs_project(po_user) # first we remove 'id' and 'url' fields from serializers to ease testing current_config_dict = config_dict_from_project( project, APIRequestFactory().request()) for the_dict_list in [ current_config_dict['units'], current_config_dict['targets'], current_config_dict['timezeros'] ]: for the_dict in the_dict_list: if 'id' in the_dict: del the_dict['id'] del the_dict['url'] # project fields: edit fields_new_values = [('name', 'new name'), ('is_public', False), ('description', 'new descr'), ('home_url', 'new home_url'), ('logo_url', 'new logo_url'), ('core_data', 'new core_data'), ('time_interval_type', 'Biweek'), ('visualization_y_label', 'new visualization_y_label')] edit_config_dict = copy.deepcopy(current_config_dict) for field_name, new_value in fields_new_values: edit_config_dict[field_name] = new_value exp_changes = [ Change(ObjectType.PROJECT, None, ChangeType.FIELD_EDITED, field_name, edit_config_dict) for field_name, new_value in fields_new_values ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type))) # project units: remove 'location3', add 'location4' edit_config_dict = copy.deepcopy(current_config_dict) location_3_dict = [ target_dict for target_dict in edit_config_dict['units'] if target_dict['name'] == 'location3' ][0] location_3_dict['name'] = 'location4' # 'location3' exp_changes = [ Change(ObjectType.UNIT, 'location3', ChangeType.OBJ_REMOVED, None, None), Change(ObjectType.UNIT, 'location4', ChangeType.OBJ_ADDED, None, location_3_dict) ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type))) # project timezeros: remove '2011-10-02', add '2011-10-22', edit '2011-10-09' fields edit_config_dict = copy.deepcopy(current_config_dict) tz_2011_10_02_dict = [ target_dict for target_dict in edit_config_dict['timezeros'] if target_dict['timezero_date'] == '2011-10-02' ][0] tz_2011_10_02_dict['timezero_date'] = '2011-10-22' # was '2011-10-02' tz_2011_10_09_dict = [ target_dict for target_dict in edit_config_dict['timezeros'] if target_dict['timezero_date'] == '2011-10-09' ][0] tz_2011_10_09_dict['data_version_date'] = '2011-10-19' # '2011-10-09' tz_2011_10_09_dict['is_season_start'] = True # false tz_2011_10_09_dict['season_name'] = 'season name' # null exp_changes = [ Change(ObjectType.TIMEZERO, '2011-10-02', ChangeType.OBJ_REMOVED, None, None), Change(ObjectType.TIMEZERO, '2011-10-22', ChangeType.OBJ_ADDED, None, tz_2011_10_02_dict), Change(ObjectType.TIMEZERO, '2011-10-09', ChangeType.FIELD_EDITED, 'data_version_date', tz_2011_10_09_dict), Change(ObjectType.TIMEZERO, '2011-10-09', ChangeType.FIELD_EDITED, 'is_season_start', tz_2011_10_09_dict), Change(ObjectType.TIMEZERO, '2011-10-09', ChangeType.FIELD_ADDED, 'season_name', tz_2011_10_09_dict) ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type))) # project targets: remove 'pct next week', add 'pct next week 2', edit 'cases next week' and 'Season peak week' # fields edit_config_dict = copy.deepcopy(current_config_dict) pct_next_week_target_dict = [ target_dict for target_dict in edit_config_dict['targets'] if target_dict['name'] == 'pct next week' ][0] pct_next_week_target_dict[ 'name'] = 'pct next week 2' # was 'pct next week' cases_next_week_target_dict = [ target_dict for target_dict in edit_config_dict['targets'] if target_dict['name'] == 'cases next week' ][0] cases_next_week_target_dict[ 'description'] = 'new descr' # 'cases next week' cases_next_week_target_dict['is_step_ahead'] = False del (cases_next_week_target_dict['step_ahead_increment']) season_peak_week_target_dict = [ target_dict for target_dict in edit_config_dict['targets'] if target_dict['name'] == 'Season peak week' ][0] season_peak_week_target_dict[ 'description'] = 'new descr 2' # 'Season peak week' season_peak_week_target_dict['is_step_ahead'] = True season_peak_week_target_dict['step_ahead_increment'] = 2 season_peak_week_target_dict['unit'] = 'biweek' exp_changes = [ Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_REMOVED, None, None), Change(ObjectType.TARGET, 'pct next week 2', ChangeType.OBJ_ADDED, None, pct_next_week_target_dict), Change(ObjectType.TARGET, 'cases next week', ChangeType.FIELD_REMOVED, 'step_ahead_increment', None), Change(ObjectType.TARGET, 'cases next week', ChangeType.FIELD_EDITED, 'description', cases_next_week_target_dict), Change(ObjectType.TARGET, 'cases next week', ChangeType.FIELD_EDITED, 'is_step_ahead', cases_next_week_target_dict), Change(ObjectType.TARGET, 'Season peak week', ChangeType.FIELD_ADDED, 'step_ahead_increment', season_peak_week_target_dict), Change(ObjectType.TARGET, 'Season peak week', ChangeType.FIELD_EDITED, 'description', season_peak_week_target_dict), Change(ObjectType.TARGET, 'Season peak week', ChangeType.FIELD_EDITED, 'is_step_ahead', season_peak_week_target_dict), Change(ObjectType.TARGET, 'Season peak week', ChangeType.FIELD_EDITED, 'unit', season_peak_week_target_dict) ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type))) # project targets: edit 'pct next week' 'type' (non-editable) and 'description' (editable) fields edit_config_dict = copy.deepcopy(current_config_dict) pct_next_week_target_dict = [ target_dict for target_dict in edit_config_dict['targets'] if target_dict['name'] == 'pct next week' ][0] pct_next_week_target_dict['type'] = 'discrete' # 'pct next week' pct_next_week_target_dict['description'] = 'new descr' exp_changes = [ Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_REMOVED, None, None), Change(ObjectType.TARGET, 'pct next week', ChangeType.OBJ_ADDED, None, pct_next_week_target_dict), Change(ObjectType.TARGET, 'pct next week', ChangeType.FIELD_EDITED, 'description', pct_next_week_target_dict) ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type))) # project targets: edit 'cases next week': remove 'range' (non-editable) edit_config_dict = copy.deepcopy(current_config_dict) cases_next_week_target_dict = [ target_dict for target_dict in edit_config_dict['targets'] if target_dict['name'] == 'cases next week' ][0] del (cases_next_week_target_dict['range']) # 'cases next week exp_changes = [ Change(ObjectType.TARGET, 'cases next week', ChangeType.OBJ_REMOVED, None, None), Change(ObjectType.TARGET, 'cases next week', ChangeType.OBJ_ADDED, None, cases_next_week_target_dict) ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type))) # project targets: edit 'season severity': edit 'cats' (non-editable) edit_config_dict = copy.deepcopy(current_config_dict) season_severity_target_dict = [ target_dict for target_dict in edit_config_dict['targets'] if target_dict['name'] == 'season severity' ][0] season_severity_target_dict[ 'cats'] = season_severity_target_dict['cats'] + ['cat 2'] exp_changes = [ Change(ObjectType.TARGET, 'season severity', ChangeType.OBJ_REMOVED, None, None), Change(ObjectType.TARGET, 'season severity', ChangeType.OBJ_ADDED, None, season_severity_target_dict) ] act_changes = project_config_diff(current_config_dict, edit_config_dict) self.assertEqual( sorted(exp_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)), sorted(act_changes, key=lambda _: (_.object_type, _.object_pk, _.change_type)))