def query_scores_for_project(project, query, max_num_rows=MAX_NUM_QUERY_ROWS): """ Top-level function for querying scores within project. Runs in the calling thread and therefore blocks. There is one column per ScoreValue BUT: all Scores are on one line. Thus, the row 'key' is the (fixed) first five columns: `ForecastModel.abbreviation | ForecastModel.name , TimeZero.timezero_date, season, Unit.name, Target.name` Followed on the same line by a variable number of ScoreValue.value columns, one for each Score. Score names are in the header. An example header and first few rows: model, timezero, season, unit, target, constant score, Absolute Error gam_lag1_tops3, 2017-04-23, 2017-2018 TH01, 1_biweek_ahead, 1 <blank> gam_lag1_tops3, 2017-04-23, 2017-2018 TH01, 1_biweek_ahead, <blank> 2 gam_lag1_tops3, 2017-04-23, 2017-2018 TH01, 2_biweek_ahead, <blank> 1 gam_lag1_tops3, 2017-04-23, 2017-2018 TH01, 3_biweek_ahead, <blank> 9 gam_lag1_tops3, 2017-04-23, 2017-2018 TH01, 4_biweek_ahead, <blank> 6 gam_lag1_tops3, 2017-04-23, 2017-2018 TH01, 5_biweek_ahead, <blank> 8 gam_lag1_tops3, 2017-04-23, 2017-2018 TH02, 1_biweek_ahead, <blank> 6 gam_lag1_tops3, 2017-04-23, 2017-2018 TH02, 2_biweek_ahead, <blank> 6 gam_lag1_tops3, 2017-04-23, 2017-2018 TH02, 3_biweek_ahead, <blank> 37 gam_lag1_tops3, 2017-04-23, 2017-2018 TH02, 4_biweek_ahead, <blank> 25 gam_lag1_tops3, 2017-04-23, 2017-2018 TH02, 5_biweek_ahead, <blank> 62 `query` is documented at https://docs.zoltardata.com/, but briefly, like query_forecasts_for_project(), it is a dict that contains up to five keys, each of which is a list of strings:: - 'models': optional list of ForecastModel.abbreviation strings - 'units': "" Unit.name strings - 'targets': "" Target.name strings - 'timezeros': "" TimeZero.timezero_date strings in YYYY_MM_DD_DATE_FORMAT - 'scores': optional list of score abbreviations as defined in SCORE_ABBREV_TO_NAME_AND_DESCR keys Notes: - `season` is each TimeZero's containing season_name, similar to Project.timezeros_in_season(). - for the model column we use the model's abbreviation if it's not empty, otherwise we use its name - NB: we were using get_valid_filename() to ensure values are CSV-compliant, i.e., no commas, returns, tabs, etc. (a function that was as good as any), but we removed it to help performance in the loop - we use groupby to group row 'keys' so that all score values are together :param project: a Project :param query: a dict specifying the query parameters. see https://docs.zoltardata.com/ for documentation, and above for a summary. NB: assumes it has passed validation via `validate_forecasts_query()` :param max_num_rows: the number of rows at which this function raises a RuntimeError :return: a list of CSV rows including the header """ # validate query and set query defaults ("all in project") if necessary logger.debug( f"query_scores_for_project(): 1/5 validating query. query={query}, project={project}" ) error_messages, (model_ids, unit_ids, target_ids, timezero_ids, scores) = validate_scores_query(project, query) # set scores, translating Score abbreviations to objects, defaultintg to all scores_qs = Score.objects.filter(abbreviation__in=scores).order_by('pk') if scores \ else Score.objects.all().order_by('pk') # get Forecasts to be included, applying query's constraints forecast_ids = latest_forecast_ids_for_project(project, True, model_ids=model_ids, timezero_ids=timezero_ids) # write the header, which depends on which scores are being queried score_csv_header = SCORE_CSV_HEADER_PREFIX + [ score.abbreviation for score in scores_qs ] yield score_csv_header # do the query - sorted for groupby(). todo xx use IDs! logger.debug( f"query_scores_for_project(): 2/5 preparing to iterate: project={project}" ) forecast_model_id_to_obj = { forecast_model.pk: forecast_model for forecast_model in project.models.all() } timezero_id_to_obj = { timezero.pk: timezero for timezero in project.timezeros.all() } unit_id_to_obj = {unit.pk: unit for unit in project.units.all()} target_id_to_obj = {target.pk: target for target in project.targets.all()} timezero_to_season_name = project.timezero_to_season_name() # todo no unit_ids or target_ids -> do not pass '__in' if not unit_ids: unit_ids = project.units.all().values_list('id', flat=True) # "" Units "" if not target_ids: target_ids = project.targets.all().values_list( 'id', flat=True) # "" Targets "" logger.debug( f"query_scores_for_project(): 3/5 getting truth. project={project}") tz_unit_targ_pks_to_truth_vals = _tz_unit_targ_pks_to_truth_values(project) logger.debug( f"query_scores_for_project(): 4/5 iterating. project={project}") score_value_qs = ScoreValue.objects \ .filter(score__id__in=list(scores_qs.values_list('id', flat=True)), forecast__id__in=list(forecast_ids), unit__id__in=list(unit_ids), target__id__in=list(target_ids)) \ .order_by('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__id', 'target__id', 'score__id') \ .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__id', 'target__id', 'score__id', 'value') num_rows = score_value_qs.count() if num_rows > max_num_rows: raise RuntimeError( f"number of rows exceeded maximum. num_rows={num_rows}, max_num_rows={max_num_rows}" ) num_warnings = 0 for (forecast_model_id, time_zero_id, unit_id, target_id), score_id_value_grouper \ in groupby(score_value_qs.iterator(), key=lambda _: (_[0], _[1], _[2], _[3])): # get truth. should be only one value true_value, error_string = _validate_truth( tz_unit_targ_pks_to_truth_vals, time_zero_id, unit_id, target_id) if error_string: num_warnings += 1 continue # skip this (forecast_model_id, time_zero_id, unit_id, target_id) combination's score row forecast_model = forecast_model_id_to_obj[forecast_model_id] time_zero = timezero_id_to_obj[time_zero_id] unit = unit_id_to_obj[unit_id] target = target_id_to_obj[target_id] # ex score_groups: [(1, 18, 1, 1, 1, 1.0), (1, 18, 1, 1, 2, 2.0)] # multiple scores per group # [(1, 18, 1, 2, 2, 0.0)] # single score score_groups = list(score_id_value_grouper) score_id_to_value = { score_group[-2]: score_group[-1] for score_group in score_groups } score_values = [ score_id_to_value[score.id] if score.id in score_id_to_value else None for score in scores_qs ] # while name and abbreviation are now both required to be non-empty, we leave the check here just in case: model_name = forecast_model.abbreviation if forecast_model.abbreviation else forecast_model.name yield [ model_name, time_zero.timezero_date.strftime(YYYY_MM_DD_DATE_FORMAT), timezero_to_season_name[time_zero], unit.name, target.name, true_value ] + score_values # print warning count logger.debug( f"query_scores_for_project(): 5/5 done. num_rows={num_rows}, num_warnings={num_warnings}, " f"project={project}")
def _query_worker(job_pk, query_project_fcn): # imported here so that tests can patch via mock: from utils.cloud_file import upload_file # run the query job = get_object_or_404(Job, pk=job_pk) project = get_object_or_404(Project, pk=job.input_json['project_pk']) query = job.input_json['query'] try: logger.debug( f"_query_worker(): 1/4 querying rows. query={query}. job={job}") # use a transaction to set the scope of the postgres `statement_timeout` parameter. statement_timeout raises # this error: django.db.utils.OperationalError ('canceling statement due to statement timeout'). Similarly, # idle_in_transaction_session_timeout raises django.db.utils.InternalError . todo does not consistently work! if connection.vendor == 'postgresql': with transaction.atomic(), connection.cursor() as cursor: cursor.execute( f"SET LOCAL statement_timeout = '{QUERY_FORECAST_STATEMENT_TIMEOUT}s';" ) cursor.execute( f"SET LOCAL idle_in_transaction_session_timeout = '{QUERY_FORECAST_STATEMENT_TIMEOUT}s';" ) rows = query_project_fcn(project, query) else: rows = query_project_fcn(project, query) except JobTimeoutException as jte: job.status = Job.TIMEOUT job.save() logger.error(f"_query_worker(): error: {jte!r}. job={job}") return except Exception as ex: job.status = Job.FAILED job.failure_message = f"_query_worker(): error: {ex!r}" job.save() logger.error(job.failure_message + f". job={job}") return # upload the file to cloud storage try: # we need a BytesIO for upload_file() (o/w it errors: "Unicode-objects must be encoded before hashing"), but # writerows() needs a StringIO (o/w "a bytes-like object is required, not 'str'" error), so we use # TextIOWrapper. BUT: https://docs.python.org/3.6/library/io.html#io.TextIOWrapper : # Text I/O over a binary storage (such as a file) is significantly slower than binary I/O over the same # storage, because it requires conversions between unicode and binary data using a character codec. This can # become noticeable handling huge amounts of text data like large log files. # note: using a context is required o/w is closed and becomes unusable: # per https://stackoverflow.com/questions/59079354/how-to-write-utf-8-csv-into-bytesio-in-python3 : with io.BytesIO() as bytes_io: logger.debug(f"_query_worker(): 2/4 writing rows. job={job}") text_io_wrapper = io.TextIOWrapper(bytes_io, 'utf-8', newline='') rows = IterCounter(rows) csv.writer(text_io_wrapper).writerows(rows) text_io_wrapper.flush() bytes_io.seek(0) logger.debug(f"_query_worker(): 3/4 uploading file. job={job}") upload_file(job, bytes_io) # might raise S3 exception job.output_json = {'num_rows': rows.count} job.status = Job.SUCCESS job.save() logger.debug(f"_query_worker(): 4/4 done. job={job}") except (BotoCoreError, Boto3Error, ClientError, ConnectionClosedError) as aws_exc: job.status = Job.FAILED job.failure_message = f"_query_worker(): error: {aws_exc!r}" job.save() logger.error(job.failure_message + f". job={job}") except Exception as ex: job.status = Job.FAILED job.failure_message = f"_query_worker(): error: {ex!r}" logger.error(job.failure_message + f". job={job}") job.save()
def query_truth_for_project(project, query, max_num_rows=MAX_NUM_QUERY_ROWS): """ Top-level function for querying truth within project. Runs in the calling thread and therefore blocks. Returns a list of rows in a Zoltar-specific CSV row format. The columns are defined in TRUTH_CSV_HEADER, as detailed at https://docs.zoltardata.com/fileformats/#truth-data-format-csv . `query` is documented at https://docs.zoltardata.com/, but briefly, it is a dict of up to four keys, three of which are lists of strings: - 'units': "" Unit.name strings - 'targets': "" Target.name strings - 'timezeros': "" TimeZero.timezero_date strings in YYYY_MM_DD_DATE_FORMAT The fourth key allows searching based on `Forecast.issued_at`: - 'as_of': Passing a datetime string in the optional as_of field causes the query to return only those forecast versions whose issued_at is <= the as_of datetime (AKA timestamp). Note that _strings_ are passed to refer to object *contents*, not database IDs, which means validation will fail if the referred-to objects are not found. NB: If multiple objects are found with the same name then the program will arbitrarily choose one. NB: The returned response will contain only those rows that actually loaded from the original CSV file passed to Project.load_truth_data(), which will contain fewer rows if some were invalid. For that reason we change the filename to hopefully hint at what's going on. :param project: a Project :param query: a dict specifying the query parameters as described above. NB: assumes it has passed validation via `validate_truth_query()` :param max_num_rows: the number of rows at which this function raises a RuntimeError :return: a list of CSV rows including the header """ # validate query logger.debug( f"query_truth_for_project(): 1/3 validating query. query={query}, project={project}" ) error_messages, (unit_ids, target_ids, timezero_ids, as_of) = validate_truth_query(project, query) if error_messages: raise RuntimeError( f"invalid query. query={query}, errors={error_messages}") timezero_id_to_obj = { timezero.pk: timezero for timezero in project.timezeros.all() } unit_id_to_obj = {unit.pk: unit for unit in project.units.all()} target_id_to_obj = {target.pk: target for target in project.targets.all()} yield TRUTH_CSV_HEADER oracle_model = oracle_model_for_project(project) if not oracle_model: return # get the SQL then execute and iterate over resulting data model_ids = [oracle_model.pk] sql = _query_forecasts_sql_for_pred_class(None, model_ids, unit_ids, target_ids, timezero_ids, as_of, False) logger.debug( f"query_truth_for_project(): 2/3 executing sql. model_ids, unit_ids, target_ids, timezero_ids, " f"as_of= {model_ids}, {unit_ids}, {target_ids}, {timezero_ids}, {as_of}" ) num_rows = 0 with connection.cursor() as cursor: cursor.execute(sql, (project.pk, )) for fm_id, tz_id, pred_class, unit_id, target_id, is_retract, pred_data in batched_rows( cursor): # we do not have to check is_retract b/c we pass `is_include_retract=False`, which skips retractions num_rows += 1 if num_rows > max_num_rows: raise RuntimeError( f"number of rows exceeded maximum. num_rows={num_rows}, " f"max_num_rows={max_num_rows}") # counterintuitively must use json.loads per https://code.djangoproject.com/ticket/31991 pred_data = json.loads(pred_data) tz_date = timezero_id_to_obj[tz_id].timezero_date.strftime( YYYY_MM_DD_DATE_FORMAT) yield [ tz_date, unit_id_to_obj[unit_id].name, target_id_to_obj[target_id].name, pred_data['value'] ] # done logger.debug( f"query_truth_for_project(): 3/3 done. num_rows={num_rows}, query={query}, project={project}" )
def query_forecasts_for_project(project, query, max_num_rows=MAX_NUM_QUERY_ROWS): """ Top-level function for querying forecasts within project. Runs in the calling thread and therefore blocks. Returns a list of rows in a Zoltar-specific CSV row format. The columns are defined in FORECAST_CSV_HEADER. Note that the csv is 'sparse': not every row uses all columns, and unused ones are empty (''). However, the first four columns are always non-empty, i.e., every prediction has them. The 'class' of each row is named to be the same as Zoltar's utils.forecast.PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS variable. Column ordering is FORECAST_CSV_HEADER. `query` is documented at https://docs.zoltardata.com/, but briefly, it is a dict of up to six keys, five of which are lists of strings: - 'models': optional list of ForecastModel.abbreviation strings - 'units': "" Unit.name strings - 'targets': "" Target.name strings - 'timezeros': "" TimeZero.timezero_date strings in YYYY_MM_DD_DATE_FORMAT - 'types': optional list of type strings as defined in PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS.values() The sixth key allows searching based on `Forecast.issue_date`: - 'as_of': optional inclusive issue_date in YYYY_MM_DD_DATE_FORMAT to limit the search to. the default behavior if not passed is to use the newest forecast for each TimeZero. Note that _strings_ are passed to refer to object *contents*, not database IDs, which means validation will fail if the referred-to objects are not found. NB: If multiple objects are found with the same name then the program will arbitrarily choose one. :param project: a Project :param query: a dict specifying the query parameters. see https://docs.zoltardata.com/ for documentation, and above for a summary. NB: assumes it has passed validation via `validate_forecasts_query()` :param max_num_rows: the number of rows at which this function raises a RuntimeError :return: a list of CSV rows including the header """ from utils.forecast import PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS # avoid circular imports # validate query logger.debug( f"query_forecasts_for_project(): 1/4 validating query. query={query}, project={project}" ) error_messages, (model_ids, unit_ids, target_ids, timezero_ids, types) = validate_forecasts_query(project, query) # get which types to include is_include_bin = (not types) or ( PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[BinDistribution] in types) is_include_named = (not types) or ( PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[NamedDistribution] in types) is_include_point = (not types) or ( PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[PointPrediction] in types) is_include_sample = (not types) or ( PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[SampleDistribution] in types) is_include_quantile = (not types) or ( PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[QuantileDistribution] in types) # get Forecasts to be included, applying query's constraints forecast_ids = latest_forecast_ids_for_project(project, True, model_ids=model_ids, timezero_ids=timezero_ids, as_of=query.get( 'as_of', None)) # create queries for each prediction type, but don't execute them yet. first check # rows and limit if necessary. # note that not all will be executed, depending on the 'types' key # todo no unit_ids or target_ids -> do not pass '__in' if not unit_ids: unit_ids = project.units.all().values_list('id', flat=True) # "" Units "" if not target_ids: target_ids = project.targets.all().values_list( 'id', flat=True) # "" Targets "" bin_qs = BinDistribution.objects.filter(forecast__id__in=list(forecast_ids), unit__id__in=list(unit_ids), target__id__in=list(target_ids)) \ .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name', 'prob', 'cat_i', 'cat_f', 'cat_t', 'cat_d', 'cat_b') named_qs = NamedDistribution.objects.filter(forecast__id__in=list(forecast_ids), unit__id__in=list(unit_ids), target__id__in=list(target_ids)) \ .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name', 'family', 'param1', 'param2', 'param3') point_qs = PointPrediction.objects.filter(forecast__id__in=list(forecast_ids), unit__id__in=list(unit_ids), target__id__in=list(target_ids)) \ .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name', 'value_i', 'value_f', 'value_t', 'value_d', 'value_b') sample_qs = SampleDistribution.objects.filter(forecast__id__in=list(forecast_ids), unit__id__in=list(unit_ids), target__id__in=list(target_ids)) \ .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name', 'sample_i', 'sample_f', 'sample_t', 'sample_d', 'sample_b') quantile_qs = QuantileDistribution.objects.filter(forecast__id__in=list(forecast_ids), unit__id__in=list(unit_ids), target__id__in=list(target_ids)) \ .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name', 'quantile', 'value_i', 'value_f', 'value_d') # count number of rows to query, and error if too many logger.debug( f"query_forecasts_for_project(): 2/4 getting counts. query={query}, project={project}" ) is_include_query_set_pred_types = [ (is_include_bin, bin_qs, 'bin'), (is_include_named, named_qs, 'named'), (is_include_point, point_qs, 'point'), (is_include_sample, sample_qs, 'sample'), (is_include_quantile, quantile_qs, 'quantile') ] pred_type_counts = [ ] # filled next. NB: we do not use a list comprehension b/c we want logging for each pred_type for idx, (is_include, query_set, pred_type) in enumerate(is_include_query_set_pred_types): if is_include: logger.debug( f"query_forecasts_for_project(): 2{string.ascii_letters[idx]}/4 getting counts: {pred_type!r}" ) pred_type_counts.append((pred_type, query_set.count())) num_rows = sum([_[1] for _ in pred_type_counts]) logger.debug( f"query_forecasts_for_project(): 3/4 preparing to query. pred_type_counts={pred_type_counts}. total " f"num_rows={num_rows}. query={query}, project={project}") if num_rows > max_num_rows: raise RuntimeError( f"number of rows exceeded maximum. num_rows={num_rows}, max_num_rows={max_num_rows}" ) # output rows for each Prediction subclass yield FORECAST_CSV_HEADER forecast_model_id_to_obj = { forecast_model.pk: forecast_model for forecast_model in project.models.all() } timezero_id_to_obj = { timezero.pk: timezero for timezero in project.timezeros.all() } timezero_to_season_name = project.timezero_to_season_name() # add BinDistributions if is_include_bin: logger.debug( f"query_forecasts_for_project(): 3a/4 getting BinDistributions") # class-specific columns all default to empty: value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', '' for forecast_model_id, timezero_id, unit_name, target_name, prob, cat_i, cat_f, cat_t, cat_d, cat_b in bin_qs: model_str, timezero_str, season, class_str = _model_tz_season_class_strs( forecast_model_id_to_obj[forecast_model_id], timezero_id_to_obj[timezero_id], timezero_to_season_name, BinDistribution) cat = PointPrediction.first_non_none_value(cat_i, cat_f, cat_t, cat_d, cat_b) cat = cat.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance( cat, datetime.date) else cat yield [ model_str, timezero_str, season, unit_name, target_name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] # add NamedDistributions if is_include_named: logger.debug( f"query_forecasts_for_project(): 3b/4 getting NamedDistributions") # class-specific columns all default to empty: value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', '' for forecast_model_id, timezero_id, unit_name, target_name, family, param1, param2, param3 in named_qs: model_str, timezero_str, season, class_str = _model_tz_season_class_strs( forecast_model_id_to_obj[forecast_model_id], timezero_id_to_obj[timezero_id], timezero_to_season_name, NamedDistribution) family = NamedDistribution.FAMILY_CHOICE_TO_ABBREVIATION[family] yield [ model_str, timezero_str, season, unit_name, target_name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] # add PointPredictions if is_include_point: logger.debug( f"query_forecasts_for_project(): 3c/4 getting PointPredictions") # class-specific columns all default to empty: value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', '' for forecast_model_id, timezero_id, unit_name, target_name, value_i, value_f, value_t, value_d, value_b \ in point_qs: model_str, timezero_str, season, class_str = _model_tz_season_class_strs( forecast_model_id_to_obj[forecast_model_id], timezero_id_to_obj[timezero_id], timezero_to_season_name, PointPrediction) value = PointPrediction.first_non_none_value( value_i, value_f, value_t, value_d, value_b) value = value.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance( value, datetime.date) else value yield [ model_str, timezero_str, season, unit_name, target_name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] # add SampleDistribution if is_include_sample: logger.debug( f"query_forecasts_for_project(): 3d/4 getting SampleDistributions") # class-specific columns all default to empty: value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', '' for forecast_model_id, timezero_id, unit_name, target_name, \ sample_i, sample_f, sample_t, sample_d, sample_b in sample_qs: model_str, timezero_str, season, class_str = _model_tz_season_class_strs( forecast_model_id_to_obj[forecast_model_id], timezero_id_to_obj[timezero_id], timezero_to_season_name, SampleDistribution) sample = PointPrediction.first_non_none_value( sample_i, sample_f, sample_t, sample_d, sample_b) sample = sample.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance( sample, datetime.date) else sample yield [ model_str, timezero_str, season, unit_name, target_name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] # add QuantileDistribution if is_include_quantile: logger.debug( f"query_forecasts_for_project(): 3e/4 getting QuantileDistributions" ) # class-specific columns all default to empty: value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', '' for forecast_model_id, timezero_id, unit_name, target_name, quantile, value_i, value_f, value_d in quantile_qs: model_str, timezero_str, season, class_str = _model_tz_season_class_strs( forecast_model_id_to_obj[forecast_model_id], timezero_id_to_obj[timezero_id], timezero_to_season_name, QuantileDistribution) value = PointPrediction.first_non_none_value( value_i, value_f, None, value_d, None) value = value.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance( value, datetime.date) else value yield [ model_str, timezero_str, season, unit_name, target_name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] # NB: we do not sort b/c it's expensive logger.debug( f"query_forecasts_for_project(): 4/4 done. num_rows={num_rows}, query={query}, project={project}" )
def query_forecasts_for_project(project, query, max_num_rows=MAX_NUM_QUERY_ROWS): """ Top-level function for querying forecasts within project. Runs in the calling thread and therefore blocks. Returns a list of rows in a Zoltar-specific CSV row format. The columns are defined in FORECAST_CSV_HEADER. Note that the csv is 'sparse': not every row uses all columns, and unused ones are empty (''). However, the first four columns are always non-empty, i.e., every prediction has them. The 'class' of each row is named to be the same as Zoltar's utils.forecast.PRED_CLASS_INT_TO_NAME variable. Column ordering is FORECAST_CSV_HEADER. `query` is documented at https://docs.zoltardata.com/, but briefly, it is a dict of up to six keys, five of which are lists of strings. all are optional: - 'models': Pass zero or more model abbreviations in the models field. - 'units': Pass zero or more unit names in the units field. - 'targets': Pass zero or more target names in the targets field. - 'timezeros': Pass zero or more timezero dates in YYYY_MM_DD_DATE_FORMAT format in the timezeros field. - 'types': Pass a list of string types in the types field. Choices are PRED_CLASS_INT_TO_NAME.values(). The sixth key allows searching based on `Forecast.issued_at`: - 'as_of': Passing a datetime string in the optional as_of field causes the query to return only those forecast versions whose issued_at is <= the as_of datetime (AKA timestamp). Note that _strings_ are passed to refer to object *contents*, not database IDs, which means validation will fail if the referred-to objects are not found. NB: If multiple objects are found with the same name then the program will arbitrarily choose one. :param project: a Project :param query: a dict specifying the query parameters as described above. NB: assumes it has passed validation via `validate_forecasts_query()` :param max_num_rows: the number of rows at which this function raises a RuntimeError :return: a list of CSV rows including the header """ logger.debug( f"query_forecasts_for_project(): 1/3 validating query. query={query}, project={project}" ) # validate query error_messages, (model_ids, unit_ids, target_ids, timezero_ids, type_ints, as_of) = \ validate_forecasts_query(project, query) if error_messages: raise RuntimeError( f"invalid query. query={query}, errors={error_messages}") forecast_model_id_to_obj = { forecast_model.pk: forecast_model for forecast_model in project.models.all() } timezero_id_to_obj = { timezero.pk: timezero for timezero in project.timezeros.all() } unit_id_to_obj = {unit.pk: unit for unit in project.units.all()} target_id_to_obj = {target.pk: target for target in project.targets.all()} timezero_to_season_name = project.timezero_to_season_name() yield FORECAST_CSV_HEADER # get the SQL then execute and iterate over resulting data sql = _query_forecasts_sql_for_pred_class(type_ints, model_ids, unit_ids, target_ids, timezero_ids, as_of, True) logger.debug( f"query_forecasts_for_project(): 2/3 executing sql. type_ints, model_ids, unit_ids, target_ids, " f"timezero_ids, as_of= {type_ints}, {model_ids}, {unit_ids}, {target_ids}, {timezero_ids}, " f"{as_of}") num_rows = 0 with connection.cursor() as cursor: cursor.execute(sql, (project.pk, )) for fm_id, tz_id, pred_class, unit_id, target_id, is_retract, pred_data in batched_rows( cursor): # we do not have to check is_retract b/c we pass `is_include_retract=False`, which skips retractions num_rows += 1 if num_rows > max_num_rows: raise RuntimeError( f"number of rows exceeded maximum. num_rows={num_rows}, " f"max_num_rows={max_num_rows}") # counterintuitively must use json.loads per https://code.djangoproject.com/ticket/31991 pred_data = json.loads(pred_data) model_str, timezero_str, season, class_str = _model_tz_season_class_strs( forecast_model_id_to_obj[fm_id], timezero_id_to_obj[tz_id], timezero_to_season_name, pred_class) value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', '' if pred_class == PredictionElement.BIN_CLASS: for cat, prob in zip(pred_data['cat'], pred_data['prob']): yield [ model_str, timezero_str, season, unit_id_to_obj[unit_id].name, target_id_to_obj[target_id].name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] elif pred_class == PredictionElement.NAMED_CLASS: family = pred_data['family'] param1 = pred_data.get('param1', '') param2 = pred_data.get('param2', '') param3 = pred_data.get('param3', '') yield [ model_str, timezero_str, season, unit_id_to_obj[unit_id].name, target_id_to_obj[target_id].name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] elif pred_class == PredictionElement.POINT_CLASS: value = pred_data['value'] yield [ model_str, timezero_str, season, unit_id_to_obj[unit_id].name, target_id_to_obj[target_id].name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] elif pred_class == PredictionElement.QUANTILE_CLASS: for quantile, value in zip(pred_data['quantile'], pred_data['value']): yield [ model_str, timezero_str, season, unit_id_to_obj[unit_id].name, target_id_to_obj[target_id].name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] elif pred_class == PredictionElement.SAMPLE_CLASS: for sample in pred_data['sample']: yield [ model_str, timezero_str, season, unit_id_to_obj[unit_id].name, target_id_to_obj[target_id].name, class_str, value, cat, prob, sample, quantile, family, param1, param2, param3 ] # done logger.debug( f"query_forecasts_for_project(): 3/3 done. num_rows={num_rows}, query={query}, project={project}" )