예제 #1
0
def query_scores_for_project(project, query, max_num_rows=MAX_NUM_QUERY_ROWS):
    """
    Top-level function for querying scores within project. Runs in the calling thread and therefore blocks.

    There is one column per ScoreValue BUT: all Scores are on one line. Thus, the row 'key' is the (fixed) first five
    columns:

        `ForecastModel.abbreviation | ForecastModel.name , TimeZero.timezero_date, season, Unit.name, Target.name`

    Followed on the same line by a variable number of ScoreValue.value columns, one for each Score. Score names are in
    the header. An example header and first few rows:

        model,           timezero,    season,    unit,  target,          constant score,  Absolute Error
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH01,      1_biweek_ahead,  1                <blank>
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH01,      1_biweek_ahead,  <blank>           2
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH01,      2_biweek_ahead,  <blank>           1
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH01,      3_biweek_ahead,  <blank>           9
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH01,      4_biweek_ahead,  <blank>           6
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH01,      5_biweek_ahead,  <blank>           8
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH02,      1_biweek_ahead,  <blank>           6
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH02,      2_biweek_ahead,  <blank>           6
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH02,      3_biweek_ahead,  <blank>          37
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH02,      4_biweek_ahead,  <blank>          25
        gam_lag1_tops3,  2017-04-23,  2017-2018  TH02,      5_biweek_ahead,  <blank>          62

    `query` is documented at https://docs.zoltardata.com/, but briefly, like query_forecasts_for_project(), it is a dict
    that contains up to five keys, each of which is a list of strings::

    - 'models': optional list of ForecastModel.abbreviation strings
    - 'units': "" Unit.name strings
    - 'targets': "" Target.name strings
    - 'timezeros': "" TimeZero.timezero_date strings in YYYY_MM_DD_DATE_FORMAT
    - 'scores': optional list of score abbreviations as defined in SCORE_ABBREV_TO_NAME_AND_DESCR keys

    Notes:
    - `season` is each TimeZero's containing season_name, similar to Project.timezeros_in_season().
    -  for the model column we use the model's abbreviation if it's not empty, otherwise we use its name
    - NB: we were using get_valid_filename() to ensure values are CSV-compliant, i.e., no commas, returns, tabs, etc.
      (a function that was as good as any), but we removed it to help performance in the loop
    - we use groupby to group row 'keys' so that all score values are together

    :param project: a Project
    :param query: a dict specifying the query parameters. see https://docs.zoltardata.com/ for documentation, and above
        for a summary. NB: assumes it has passed validation via `validate_forecasts_query()`
    :param max_num_rows: the number of rows at which this function raises a RuntimeError
    :return: a list of CSV rows including the header
    """
    # validate query and set query defaults ("all in project") if necessary
    logger.debug(
        f"query_scores_for_project(): 1/5 validating query. query={query}, project={project}"
    )
    error_messages, (model_ids, unit_ids, target_ids, timezero_ids,
                     scores) = validate_scores_query(project, query)

    # set scores, translating Score abbreviations to objects, defaultintg to all
    scores_qs = Score.objects.filter(abbreviation__in=scores).order_by('pk') if scores \
        else Score.objects.all().order_by('pk')

    # get Forecasts to be included, applying query's constraints
    forecast_ids = latest_forecast_ids_for_project(project,
                                                   True,
                                                   model_ids=model_ids,
                                                   timezero_ids=timezero_ids)

    # write the header, which depends on which scores are being queried
    score_csv_header = SCORE_CSV_HEADER_PREFIX + [
        score.abbreviation for score in scores_qs
    ]
    yield score_csv_header

    # do the query - sorted for groupby(). todo xx use IDs!
    logger.debug(
        f"query_scores_for_project(): 2/5 preparing to iterate: project={project}"
    )
    forecast_model_id_to_obj = {
        forecast_model.pk: forecast_model
        for forecast_model in project.models.all()
    }
    timezero_id_to_obj = {
        timezero.pk: timezero
        for timezero in project.timezeros.all()
    }
    unit_id_to_obj = {unit.pk: unit for unit in project.units.all()}
    target_id_to_obj = {target.pk: target for target in project.targets.all()}
    timezero_to_season_name = project.timezero_to_season_name()

    # todo no unit_ids or target_ids -> do not pass '__in'
    if not unit_ids:
        unit_ids = project.units.all().values_list('id',
                                                   flat=True)  # "" Units ""
    if not target_ids:
        target_ids = project.targets.all().values_list(
            'id', flat=True)  # "" Targets ""

    logger.debug(
        f"query_scores_for_project(): 3/5 getting truth. project={project}")
    tz_unit_targ_pks_to_truth_vals = _tz_unit_targ_pks_to_truth_values(project)

    logger.debug(
        f"query_scores_for_project(): 4/5 iterating. project={project}")
    score_value_qs = ScoreValue.objects \
        .filter(score__id__in=list(scores_qs.values_list('id', flat=True)),
                forecast__id__in=list(forecast_ids),
                unit__id__in=list(unit_ids),
                target__id__in=list(target_ids)) \
        .order_by('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__id', 'target__id', 'score__id') \
        .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__id', 'target__id',
                     'score__id', 'value')

    num_rows = score_value_qs.count()
    if num_rows > max_num_rows:
        raise RuntimeError(
            f"number of rows exceeded maximum. num_rows={num_rows}, max_num_rows={max_num_rows}"
        )

    num_warnings = 0
    for (forecast_model_id, time_zero_id, unit_id, target_id), score_id_value_grouper \
            in groupby(score_value_qs.iterator(), key=lambda _: (_[0], _[1], _[2], _[3])):
        # get truth. should be only one value
        true_value, error_string = _validate_truth(
            tz_unit_targ_pks_to_truth_vals, time_zero_id, unit_id, target_id)
        if error_string:
            num_warnings += 1
            continue  # skip this (forecast_model_id, time_zero_id, unit_id, target_id) combination's score row

        forecast_model = forecast_model_id_to_obj[forecast_model_id]
        time_zero = timezero_id_to_obj[time_zero_id]
        unit = unit_id_to_obj[unit_id]
        target = target_id_to_obj[target_id]
        # ex score_groups: [(1, 18, 1, 1, 1, 1.0), (1, 18, 1, 1, 2, 2.0)]  # multiple scores per group
        #                  [(1, 18, 1, 2, 2, 0.0)]                         # single score
        score_groups = list(score_id_value_grouper)
        score_id_to_value = {
            score_group[-2]: score_group[-1]
            for score_group in score_groups
        }
        score_values = [
            score_id_to_value[score.id]
            if score.id in score_id_to_value else None for score in scores_qs
        ]

        # while name and abbreviation are now both required to be non-empty, we leave the check here just in case:
        model_name = forecast_model.abbreviation if forecast_model.abbreviation else forecast_model.name
        yield [
            model_name,
            time_zero.timezero_date.strftime(YYYY_MM_DD_DATE_FORMAT),
            timezero_to_season_name[time_zero], unit.name, target.name,
            true_value
        ] + score_values

    # print warning count
    logger.debug(
        f"query_scores_for_project(): 5/5 done. num_rows={num_rows}, num_warnings={num_warnings}, "
        f"project={project}")
예제 #2
0
def _query_worker(job_pk, query_project_fcn):
    # imported here so that tests can patch via mock:
    from utils.cloud_file import upload_file

    # run the query
    job = get_object_or_404(Job, pk=job_pk)
    project = get_object_or_404(Project, pk=job.input_json['project_pk'])
    query = job.input_json['query']
    try:
        logger.debug(
            f"_query_worker(): 1/4 querying rows. query={query}. job={job}")
        # use a transaction to set the scope of the postgres `statement_timeout` parameter. statement_timeout raises
        # this error: django.db.utils.OperationalError ('canceling statement due to statement timeout'). Similarly,
        # idle_in_transaction_session_timeout raises django.db.utils.InternalError . todo does not consistently work!
        if connection.vendor == 'postgresql':
            with transaction.atomic(), connection.cursor() as cursor:
                cursor.execute(
                    f"SET LOCAL statement_timeout = '{QUERY_FORECAST_STATEMENT_TIMEOUT}s';"
                )
                cursor.execute(
                    f"SET LOCAL idle_in_transaction_session_timeout = '{QUERY_FORECAST_STATEMENT_TIMEOUT}s';"
                )
                rows = query_project_fcn(project, query)
        else:
            rows = query_project_fcn(project, query)
    except JobTimeoutException as jte:
        job.status = Job.TIMEOUT
        job.save()
        logger.error(f"_query_worker(): error: {jte!r}. job={job}")
        return
    except Exception as ex:
        job.status = Job.FAILED
        job.failure_message = f"_query_worker(): error: {ex!r}"
        job.save()
        logger.error(job.failure_message + f". job={job}")
        return

    # upload the file to cloud storage
    try:
        # we need a BytesIO for upload_file() (o/w it errors: "Unicode-objects must be encoded before hashing"), but
        # writerows() needs a StringIO (o/w "a bytes-like object is required, not 'str'" error), so we use
        # TextIOWrapper. BUT: https://docs.python.org/3.6/library/io.html#io.TextIOWrapper :
        #     Text I/O over a binary storage (such as a file) is significantly slower than binary I/O over the same
        #     storage, because it requires conversions between unicode and binary data using a character codec. This can
        #     become noticeable handling huge amounts of text data like large log files.

        # note: using a context is required o/w is closed and becomes unusable:
        # per https://stackoverflow.com/questions/59079354/how-to-write-utf-8-csv-into-bytesio-in-python3 :
        with io.BytesIO() as bytes_io:
            logger.debug(f"_query_worker(): 2/4 writing rows. job={job}")
            text_io_wrapper = io.TextIOWrapper(bytes_io, 'utf-8', newline='')
            rows = IterCounter(rows)
            csv.writer(text_io_wrapper).writerows(rows)
            text_io_wrapper.flush()
            bytes_io.seek(0)

            logger.debug(f"_query_worker(): 3/4 uploading file. job={job}")
            upload_file(job, bytes_io)  # might raise S3 exception
            job.output_json = {'num_rows': rows.count}
            job.status = Job.SUCCESS
            job.save()
            logger.debug(f"_query_worker(): 4/4 done. job={job}")
    except (BotoCoreError, Boto3Error, ClientError,
            ConnectionClosedError) as aws_exc:
        job.status = Job.FAILED
        job.failure_message = f"_query_worker(): error: {aws_exc!r}"
        job.save()
        logger.error(job.failure_message + f". job={job}")
    except Exception as ex:
        job.status = Job.FAILED
        job.failure_message = f"_query_worker(): error: {ex!r}"
        logger.error(job.failure_message + f". job={job}")
        job.save()
def query_truth_for_project(project, query, max_num_rows=MAX_NUM_QUERY_ROWS):
    """
    Top-level function for querying truth within project. Runs in the calling thread and therefore blocks.
    Returns a list of rows in a Zoltar-specific CSV row format. The columns are defined in TRUTH_CSV_HEADER, as detailed
    at https://docs.zoltardata.com/fileformats/#truth-data-format-csv .

    `query` is documented at https://docs.zoltardata.com/, but briefly, it is a dict of up to four keys, three of which
    are lists of strings:

    - 'units': "" Unit.name strings
    - 'targets': "" Target.name strings
    - 'timezeros': "" TimeZero.timezero_date strings in YYYY_MM_DD_DATE_FORMAT

    The fourth key allows searching based on `Forecast.issued_at`:
    - 'as_of': Passing a datetime string in the optional as_of field causes the query to return only those forecast
        versions whose issued_at is <= the as_of datetime (AKA timestamp).

    Note that _strings_ are passed to refer to object *contents*, not database IDs, which means validation will fail if
    the referred-to objects are not found. NB: If multiple objects are found with the same name then the program will
    arbitrarily choose one.

    NB: The returned response will contain only those rows that actually loaded from the original CSV file passed
    to Project.load_truth_data(), which will contain fewer rows if some were invalid. For that reason we change the
    filename to hopefully hint at what's going on.

    :param project: a Project
    :param query: a dict specifying the query parameters as described above. NB: assumes it has passed validation via
        `validate_truth_query()`
    :param max_num_rows: the number of rows at which this function raises a RuntimeError
    :return: a list of CSV rows including the header
    """
    # validate query
    logger.debug(
        f"query_truth_for_project(): 1/3 validating query. query={query}, project={project}"
    )
    error_messages, (unit_ids, target_ids, timezero_ids,
                     as_of) = validate_truth_query(project, query)
    if error_messages:
        raise RuntimeError(
            f"invalid query. query={query}, errors={error_messages}")

    timezero_id_to_obj = {
        timezero.pk: timezero
        for timezero in project.timezeros.all()
    }
    unit_id_to_obj = {unit.pk: unit for unit in project.units.all()}
    target_id_to_obj = {target.pk: target for target in project.targets.all()}

    yield TRUTH_CSV_HEADER

    oracle_model = oracle_model_for_project(project)
    if not oracle_model:
        return

    # get the SQL then execute and iterate over resulting data
    model_ids = [oracle_model.pk]
    sql = _query_forecasts_sql_for_pred_class(None, model_ids, unit_ids,
                                              target_ids, timezero_ids, as_of,
                                              False)
    logger.debug(
        f"query_truth_for_project(): 2/3 executing sql. model_ids, unit_ids, target_ids, timezero_ids, "
        f"as_of= {model_ids}, {unit_ids}, {target_ids}, {timezero_ids}, {as_of}"
    )
    num_rows = 0
    with connection.cursor() as cursor:
        cursor.execute(sql, (project.pk, ))
        for fm_id, tz_id, pred_class, unit_id, target_id, is_retract, pred_data in batched_rows(
                cursor):
            # we do not have to check is_retract b/c we pass `is_include_retract=False`, which skips retractions
            num_rows += 1
            if num_rows > max_num_rows:
                raise RuntimeError(
                    f"number of rows exceeded maximum. num_rows={num_rows}, "
                    f"max_num_rows={max_num_rows}")

            # counterintuitively must use json.loads per https://code.djangoproject.com/ticket/31991
            pred_data = json.loads(pred_data)
            tz_date = timezero_id_to_obj[tz_id].timezero_date.strftime(
                YYYY_MM_DD_DATE_FORMAT)
            yield [
                tz_date, unit_id_to_obj[unit_id].name,
                target_id_to_obj[target_id].name, pred_data['value']
            ]

    # done
    logger.debug(
        f"query_truth_for_project(): 3/3 done. num_rows={num_rows}, query={query}, project={project}"
    )
예제 #4
0
def query_forecasts_for_project(project,
                                query,
                                max_num_rows=MAX_NUM_QUERY_ROWS):
    """
    Top-level function for querying forecasts within project. Runs in the calling thread and therefore blocks.

    Returns a list of rows in a Zoltar-specific CSV row format. The columns are defined in FORECAST_CSV_HEADER. Note
    that the csv is 'sparse': not every row uses all columns, and unused ones are empty (''). However, the first four
    columns are always non-empty, i.e., every prediction has them.

    The 'class' of each row is named to be the same as Zoltar's utils.forecast.PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS
    variable. Column ordering is FORECAST_CSV_HEADER.

    `query` is documented at https://docs.zoltardata.com/, but briefly, it is a dict of up to six keys, five of which
    are lists of strings:

    - 'models': optional list of ForecastModel.abbreviation strings
    - 'units': "" Unit.name strings
    - 'targets': "" Target.name strings
    - 'timezeros': "" TimeZero.timezero_date strings in YYYY_MM_DD_DATE_FORMAT
    - 'types': optional list of type strings as defined in PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS.values()

    The sixth key allows searching based on `Forecast.issue_date`:
    - 'as_of': optional inclusive issue_date in YYYY_MM_DD_DATE_FORMAT to limit the search to. the default behavior if
               not passed is to use the newest forecast for each TimeZero.

    Note that _strings_ are passed to refer to object *contents*, not database IDs, which means validation will fail if
    the referred-to objects are not found. NB: If multiple objects are found with the same name then the program will
    arbitrarily choose one.

    :param project: a Project
    :param query: a dict specifying the query parameters. see https://docs.zoltardata.com/ for documentation, and above
        for a summary. NB: assumes it has passed validation via `validate_forecasts_query()`
    :param max_num_rows: the number of rows at which this function raises a RuntimeError
    :return: a list of CSV rows including the header
    """
    from utils.forecast import PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS  # avoid circular imports

    # validate query
    logger.debug(
        f"query_forecasts_for_project(): 1/4 validating query. query={query}, project={project}"
    )
    error_messages, (model_ids, unit_ids, target_ids, timezero_ids,
                     types) = validate_forecasts_query(project, query)

    # get which types to include
    is_include_bin = (not types) or (
        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[BinDistribution] in types)
    is_include_named = (not types) or (
        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[NamedDistribution] in types)
    is_include_point = (not types) or (
        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[PointPrediction] in types)
    is_include_sample = (not types) or (
        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[SampleDistribution] in types)
    is_include_quantile = (not types) or (
        PREDICTION_CLASS_TO_JSON_IO_DICT_CLASS[QuantileDistribution] in types)

    # get Forecasts to be included, applying query's constraints
    forecast_ids = latest_forecast_ids_for_project(project,
                                                   True,
                                                   model_ids=model_ids,
                                                   timezero_ids=timezero_ids,
                                                   as_of=query.get(
                                                       'as_of', None))

    # create queries for each prediction type, but don't execute them yet. first check # rows and limit if necessary.
    # note that not all will be executed, depending on the 'types' key

    # todo no unit_ids or target_ids -> do not pass '__in'
    if not unit_ids:
        unit_ids = project.units.all().values_list('id',
                                                   flat=True)  # "" Units ""
    if not target_ids:
        target_ids = project.targets.all().values_list(
            'id', flat=True)  # "" Targets ""

    bin_qs = BinDistribution.objects.filter(forecast__id__in=list(forecast_ids),
                                            unit__id__in=list(unit_ids),
                                            target__id__in=list(target_ids)) \
        .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name',
                     'prob', 'cat_i', 'cat_f', 'cat_t', 'cat_d', 'cat_b')
    named_qs = NamedDistribution.objects.filter(forecast__id__in=list(forecast_ids),
                                                unit__id__in=list(unit_ids),
                                                target__id__in=list(target_ids)) \
        .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name',
                     'family', 'param1', 'param2', 'param3')
    point_qs = PointPrediction.objects.filter(forecast__id__in=list(forecast_ids),
                                              unit__id__in=list(unit_ids),
                                              target__id__in=list(target_ids)) \
        .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name',
                     'value_i', 'value_f', 'value_t', 'value_d', 'value_b')
    sample_qs = SampleDistribution.objects.filter(forecast__id__in=list(forecast_ids),
                                                  unit__id__in=list(unit_ids),
                                                  target__id__in=list(target_ids)) \
        .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name',
                     'sample_i', 'sample_f', 'sample_t', 'sample_d', 'sample_b')
    quantile_qs = QuantileDistribution.objects.filter(forecast__id__in=list(forecast_ids),
                                                      unit__id__in=list(unit_ids),
                                                      target__id__in=list(target_ids)) \
        .values_list('forecast__forecast_model__id', 'forecast__time_zero__id', 'unit__name', 'target__name',
                     'quantile', 'value_i', 'value_f', 'value_d')

    # count number of rows to query, and error if too many
    logger.debug(
        f"query_forecasts_for_project(): 2/4 getting counts. query={query}, project={project}"
    )
    is_include_query_set_pred_types = [
        (is_include_bin, bin_qs, 'bin'), (is_include_named, named_qs, 'named'),
        (is_include_point, point_qs, 'point'),
        (is_include_sample, sample_qs, 'sample'),
        (is_include_quantile, quantile_qs, 'quantile')
    ]

    pred_type_counts = [
    ]  # filled next. NB: we do not use a list comprehension b/c we want logging for each pred_type
    for idx, (is_include, query_set,
              pred_type) in enumerate(is_include_query_set_pred_types):
        if is_include:
            logger.debug(
                f"query_forecasts_for_project(): 2{string.ascii_letters[idx]}/4 getting counts: {pred_type!r}"
            )
            pred_type_counts.append((pred_type, query_set.count()))

    num_rows = sum([_[1] for _ in pred_type_counts])
    logger.debug(
        f"query_forecasts_for_project(): 3/4 preparing to query. pred_type_counts={pred_type_counts}. total "
        f"num_rows={num_rows}. query={query}, project={project}")
    if num_rows > max_num_rows:
        raise RuntimeError(
            f"number of rows exceeded maximum. num_rows={num_rows}, max_num_rows={max_num_rows}"
        )

    # output rows for each Prediction subclass
    yield FORECAST_CSV_HEADER

    forecast_model_id_to_obj = {
        forecast_model.pk: forecast_model
        for forecast_model in project.models.all()
    }
    timezero_id_to_obj = {
        timezero.pk: timezero
        for timezero in project.timezeros.all()
    }
    timezero_to_season_name = project.timezero_to_season_name()

    # add BinDistributions
    if is_include_bin:
        logger.debug(
            f"query_forecasts_for_project(): 3a/4 getting BinDistributions")
        # class-specific columns all default to empty:
        value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', ''
        for forecast_model_id, timezero_id, unit_name, target_name, prob, cat_i, cat_f, cat_t, cat_d, cat_b in bin_qs:
            model_str, timezero_str, season, class_str = _model_tz_season_class_strs(
                forecast_model_id_to_obj[forecast_model_id],
                timezero_id_to_obj[timezero_id], timezero_to_season_name,
                BinDistribution)
            cat = PointPrediction.first_non_none_value(cat_i, cat_f, cat_t,
                                                       cat_d, cat_b)
            cat = cat.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance(
                cat, datetime.date) else cat
            yield [
                model_str, timezero_str, season, unit_name, target_name,
                class_str, value, cat, prob, sample, quantile, family, param1,
                param2, param3
            ]

    # add NamedDistributions
    if is_include_named:
        logger.debug(
            f"query_forecasts_for_project(): 3b/4 getting NamedDistributions")
        # class-specific columns all default to empty:
        value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', ''
        for forecast_model_id, timezero_id, unit_name, target_name, family, param1, param2, param3 in named_qs:
            model_str, timezero_str, season, class_str = _model_tz_season_class_strs(
                forecast_model_id_to_obj[forecast_model_id],
                timezero_id_to_obj[timezero_id], timezero_to_season_name,
                NamedDistribution)
            family = NamedDistribution.FAMILY_CHOICE_TO_ABBREVIATION[family]
            yield [
                model_str, timezero_str, season, unit_name, target_name,
                class_str, value, cat, prob, sample, quantile, family, param1,
                param2, param3
            ]

    # add PointPredictions
    if is_include_point:
        logger.debug(
            f"query_forecasts_for_project(): 3c/4 getting PointPredictions")
        # class-specific columns all default to empty:
        value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', ''
        for forecast_model_id, timezero_id, unit_name, target_name, value_i, value_f, value_t, value_d, value_b \
                in point_qs:
            model_str, timezero_str, season, class_str = _model_tz_season_class_strs(
                forecast_model_id_to_obj[forecast_model_id],
                timezero_id_to_obj[timezero_id], timezero_to_season_name,
                PointPrediction)
            value = PointPrediction.first_non_none_value(
                value_i, value_f, value_t, value_d, value_b)
            value = value.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance(
                value, datetime.date) else value
            yield [
                model_str, timezero_str, season, unit_name, target_name,
                class_str, value, cat, prob, sample, quantile, family, param1,
                param2, param3
            ]

    # add SampleDistribution
    if is_include_sample:
        logger.debug(
            f"query_forecasts_for_project(): 3d/4 getting SampleDistributions")
        # class-specific columns all default to empty:
        value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', ''
        for forecast_model_id, timezero_id, unit_name, target_name, \
            sample_i, sample_f, sample_t, sample_d, sample_b in sample_qs:
            model_str, timezero_str, season, class_str = _model_tz_season_class_strs(
                forecast_model_id_to_obj[forecast_model_id],
                timezero_id_to_obj[timezero_id], timezero_to_season_name,
                SampleDistribution)
            sample = PointPrediction.first_non_none_value(
                sample_i, sample_f, sample_t, sample_d, sample_b)
            sample = sample.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance(
                sample, datetime.date) else sample
            yield [
                model_str, timezero_str, season, unit_name, target_name,
                class_str, value, cat, prob, sample, quantile, family, param1,
                param2, param3
            ]

    # add QuantileDistribution
    if is_include_quantile:
        logger.debug(
            f"query_forecasts_for_project(): 3e/4 getting QuantileDistributions"
        )
        # class-specific columns all default to empty:
        value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', ''
        for forecast_model_id, timezero_id, unit_name, target_name, quantile, value_i, value_f, value_d in quantile_qs:
            model_str, timezero_str, season, class_str = _model_tz_season_class_strs(
                forecast_model_id_to_obj[forecast_model_id],
                timezero_id_to_obj[timezero_id], timezero_to_season_name,
                QuantileDistribution)
            value = PointPrediction.first_non_none_value(
                value_i, value_f, None, value_d, None)
            value = value.strftime(YYYY_MM_DD_DATE_FORMAT) if isinstance(
                value, datetime.date) else value
            yield [
                model_str, timezero_str, season, unit_name, target_name,
                class_str, value, cat, prob, sample, quantile, family, param1,
                param2, param3
            ]

    # NB: we do not sort b/c it's expensive
    logger.debug(
        f"query_forecasts_for_project(): 4/4 done. num_rows={num_rows}, query={query}, project={project}"
    )
def query_forecasts_for_project(project,
                                query,
                                max_num_rows=MAX_NUM_QUERY_ROWS):
    """
    Top-level function for querying forecasts within project. Runs in the calling thread and therefore blocks.

    Returns a list of rows in a Zoltar-specific CSV row format. The columns are defined in FORECAST_CSV_HEADER. Note
    that the csv is 'sparse': not every row uses all columns, and unused ones are empty (''). However, the first four
    columns are always non-empty, i.e., every prediction has them.

    The 'class' of each row is named to be the same as Zoltar's utils.forecast.PRED_CLASS_INT_TO_NAME
    variable. Column ordering is FORECAST_CSV_HEADER.

    `query` is documented at https://docs.zoltardata.com/, but briefly, it is a dict of up to six keys, five of which
    are lists of strings. all are optional:

    - 'models': Pass zero or more model abbreviations in the models field.
    - 'units': Pass zero or more unit names in the units field.
    - 'targets': Pass zero or more target names in the targets field.
    - 'timezeros': Pass zero or more timezero dates in YYYY_MM_DD_DATE_FORMAT format in the timezeros field.
    - 'types': Pass a list of string types in the types field. Choices are PRED_CLASS_INT_TO_NAME.values().

    The sixth key allows searching based on `Forecast.issued_at`:
    - 'as_of': Passing a datetime string in the optional as_of field causes the query to return only those forecast
        versions whose issued_at is <= the as_of datetime (AKA timestamp).

    Note that _strings_ are passed to refer to object *contents*, not database IDs, which means validation will fail if
    the referred-to objects are not found. NB: If multiple objects are found with the same name then the program will
    arbitrarily choose one.

    :param project: a Project
    :param query: a dict specifying the query parameters as described above. NB: assumes it has passed validation via
        `validate_forecasts_query()`
    :param max_num_rows: the number of rows at which this function raises a RuntimeError
    :return: a list of CSV rows including the header
    """
    logger.debug(
        f"query_forecasts_for_project(): 1/3 validating query. query={query}, project={project}"
    )

    # validate query
    error_messages, (model_ids, unit_ids, target_ids, timezero_ids, type_ints, as_of) = \
        validate_forecasts_query(project, query)
    if error_messages:
        raise RuntimeError(
            f"invalid query. query={query}, errors={error_messages}")

    forecast_model_id_to_obj = {
        forecast_model.pk: forecast_model
        for forecast_model in project.models.all()
    }
    timezero_id_to_obj = {
        timezero.pk: timezero
        for timezero in project.timezeros.all()
    }
    unit_id_to_obj = {unit.pk: unit for unit in project.units.all()}
    target_id_to_obj = {target.pk: target for target in project.targets.all()}
    timezero_to_season_name = project.timezero_to_season_name()

    yield FORECAST_CSV_HEADER

    # get the SQL then execute and iterate over resulting data
    sql = _query_forecasts_sql_for_pred_class(type_ints, model_ids, unit_ids,
                                              target_ids, timezero_ids, as_of,
                                              True)
    logger.debug(
        f"query_forecasts_for_project(): 2/3 executing sql. type_ints, model_ids, unit_ids, target_ids, "
        f"timezero_ids, as_of= {type_ints}, {model_ids}, {unit_ids}, {target_ids}, {timezero_ids}, "
        f"{as_of}")
    num_rows = 0
    with connection.cursor() as cursor:
        cursor.execute(sql, (project.pk, ))
        for fm_id, tz_id, pred_class, unit_id, target_id, is_retract, pred_data in batched_rows(
                cursor):
            # we do not have to check is_retract b/c we pass `is_include_retract=False`, which skips retractions
            num_rows += 1
            if num_rows > max_num_rows:
                raise RuntimeError(
                    f"number of rows exceeded maximum. num_rows={num_rows}, "
                    f"max_num_rows={max_num_rows}")

            # counterintuitively must use json.loads per https://code.djangoproject.com/ticket/31991
            pred_data = json.loads(pred_data)
            model_str, timezero_str, season, class_str = _model_tz_season_class_strs(
                forecast_model_id_to_obj[fm_id], timezero_id_to_obj[tz_id],
                timezero_to_season_name, pred_class)
            value, cat, prob, sample, quantile, family, param1, param2, param3 = '', '', '', '', '', '', '', '', ''
            if pred_class == PredictionElement.BIN_CLASS:
                for cat, prob in zip(pred_data['cat'], pred_data['prob']):
                    yield [
                        model_str, timezero_str, season,
                        unit_id_to_obj[unit_id].name,
                        target_id_to_obj[target_id].name, class_str, value,
                        cat, prob, sample, quantile, family, param1, param2,
                        param3
                    ]
            elif pred_class == PredictionElement.NAMED_CLASS:
                family = pred_data['family']
                param1 = pred_data.get('param1', '')
                param2 = pred_data.get('param2', '')
                param3 = pred_data.get('param3', '')
                yield [
                    model_str, timezero_str, season,
                    unit_id_to_obj[unit_id].name,
                    target_id_to_obj[target_id].name, class_str, value, cat,
                    prob, sample, quantile, family, param1, param2, param3
                ]
            elif pred_class == PredictionElement.POINT_CLASS:
                value = pred_data['value']
                yield [
                    model_str, timezero_str, season,
                    unit_id_to_obj[unit_id].name,
                    target_id_to_obj[target_id].name, class_str, value, cat,
                    prob, sample, quantile, family, param1, param2, param3
                ]
            elif pred_class == PredictionElement.QUANTILE_CLASS:
                for quantile, value in zip(pred_data['quantile'],
                                           pred_data['value']):
                    yield [
                        model_str, timezero_str, season,
                        unit_id_to_obj[unit_id].name,
                        target_id_to_obj[target_id].name, class_str, value,
                        cat, prob, sample, quantile, family, param1, param2,
                        param3
                    ]
            elif pred_class == PredictionElement.SAMPLE_CLASS:
                for sample in pred_data['sample']:
                    yield [
                        model_str, timezero_str, season,
                        unit_id_to_obj[unit_id].name,
                        target_id_to_obj[target_id].name, class_str, value,
                        cat, prob, sample, quantile, family, param1, param2,
                        param3
                    ]

    # done
    logger.debug(
        f"query_forecasts_for_project(): 3/3 done. num_rows={num_rows}, query={query}, project={project}"
    )