Пример #1
0
def run_test(
    env: TrainEnv,
    predictor: Predictor,
    test_dataset: Dataset,
    hyperparameters: dict,
) -> None:
    len_original = maybe_len(test_dataset)

    test_dataset = FilterTransformation(
        lambda x: x["target"].shape[-1] > predictor.prediction_length
    ).apply(test_dataset)

    len_filtered = len(test_dataset)

    if len_original is not None and len_original > len_filtered:
        logger.warning(
            f"Not all time-series in the test-channel have "
            f"enough data to be used for evaluation. Proceeding with "
            f"{len_filtered}/{len_original} "
            f"(~{int(len_filtered / len_original * 100)}%) items."
        )

    forecast_it, ts_it = backtest.make_evaluation_predictions(
        dataset=test_dataset, predictor=predictor, num_samples=100
    )

    test_quantiles = (
        [
            Quantile.parse(quantile).name
            for quantile in hyperparameters["test_quantiles"]
        ]
        if "test_quantiles" in hyperparameters
        else None
    )

    forecast_generator = getattr(predictor, "forecast_generator", None)
    if isinstance(forecast_generator, QuantileForecastGenerator):
        predictor_quantiles = forecast_generator.quantiles
        if test_quantiles is None:
            test_quantiles = predictor_quantiles
        elif not set(test_quantiles).issubset(predictor_quantiles):
            logger.warning(
                f"Some of the evaluation quantiles `{test_quantiles}` are "
                f"not in the computed quantile forecasts `{predictor_quantiles}`."
            )
            test_quantiles = predictor_quantiles

    if test_quantiles is not None:
        logger.info(f"Using quantiles `{test_quantiles}` for evaluation.")
        evaluator = Evaluator(quantiles=test_quantiles)
    else:
        evaluator = Evaluator()

    agg_metrics, item_metrics = evaluator(
        ts_iterator=ts_it,
        fcst_iterator=forecast_it,
        num_series=len(test_dataset),
    )

    # we only log aggregate metrics for now as item metrics may be very large
    for name, score in agg_metrics.items():
        logger.info(f"#test_score ({env.current_host}, {name}): {score}")

    # store metrics
    with open(env.path.model / "agg_metrics.json", "w") as agg_metric_file:
        json.dump(agg_metrics, agg_metric_file)
    with open(env.path.model / "item_metrics.csv", "w") as item_metrics_file:
        item_metrics.to_csv(item_metrics_file, index=False)
 def quantile(self, q):
     q = Quantile.parse(q).value
     sample_idx = int(np.round((self.num_samples - 1) * q))
     return self._sorted_samples[sample_idx, :]