Example #1
0
 def create_validation_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     input_names = get_hybrid_forward_input_names(CanonicalTrainingNetwork)
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         instance_splitter = self._create_instance_splitter("validation")
     return ValidationDataLoader(
         dataset=data,
         transform=instance_splitter + SelectFields(input_names),
         batch_size=self.batch_size,
         stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
     )
Example #2
0
 def create_validation_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         validation_transform = (
             self._create_instance_splitter("validation") +
             self._create_post_split_transform() +
             SelectFields(["past_target", "valid_length"]))
     return ValidationDataLoader(
         validation_transform.apply(data),
         batch_size=self.batch_size,
         stack_fn=self._stack_fn(),
     )
Example #3
0
 def create_training_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         train_transform = (self._create_instance_splitter("training") +
                            self._create_post_split_transform() +
                            SelectFields(["past_target", "valid_length"]))
     return TrainDataLoader(
         train_transform.apply(Cyclic(data)),
         batch_size=self.batch_size,
         stack_fn=self._stack_fn(),
         decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
     )
Example #4
0
 def create_training_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     input_names = get_hybrid_forward_input_names(NBEATSTrainingNetwork)
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         instance_splitter = self._create_instance_splitter("training")
     return TrainDataLoader(
         dataset=data,
         transform=instance_splitter + SelectFields(input_names),
         batch_size=self.batch_size,
         stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
         decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
         **kwargs,
     )
Example #5
0
    def train_model(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: int = 0,
        prefetch_factor: int = 2,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
        **kwargs,
    ) -> TrainOutput:
        transformation = self.create_transformation()

        trained_net = self.create_training_network(self.trainer.device)

        input_names = get_module_forward_input_names(trained_net)

        with env._let(max_idle_transforms=maybe_len(training_data) or 0):
            training_instance_splitter = self.create_instance_splitter(
                "training")
        training_iter_dataset = TransformedIterableDataset(
            dataset=training_data,
            transform=transformation + training_instance_splitter +
            SelectFields(input_names),
            is_train=True,
            shuffle_buffer_length=shuffle_buffer_length,
            cache_data=cache_data,
        )

        training_data_loader = DataLoader(
            training_iter_dataset,
            batch_size=self.trainer.batch_size,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            pin_memory=True,
            worker_init_fn=self._worker_init_fn,
            **kwargs,
        )

        validation_data_loader = None
        if validation_data is not None:
            with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
                validation_instance_splitter = self.create_instance_splitter(
                    "validation")
            validation_iter_dataset = TransformedIterableDataset(
                dataset=validation_data,
                transform=transformation + validation_instance_splitter +
                SelectFields(input_names),
                is_train=True,
                cache_data=cache_data,
            )
            validation_data_loader = DataLoader(
                validation_iter_dataset,
                batch_size=self.trainer.batch_size,
                num_workers=num_workers,
                prefetch_factor=prefetch_factor,
                pin_memory=True,
                worker_init_fn=self._worker_init_fn,
                **kwargs,
            )

        self.trainer(
            net=trained_net,
            train_iter=training_data_loader,
            validation_iter=validation_data_loader,
        )

        return TrainOutput(
            transformation=transformation,
            trained_net=trained_net,
            predictor=self.create_predictor(transformation, trained_net,
                                            self.trainer.device),
        )
Example #6
0
def backtest_metrics(
    test_dataset: Dataset,
    predictor: Predictor,
    evaluator=Evaluator(quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8,
                                   0.9)),
    num_samples: int = 100,
    logging_file: Optional[str] = None,
) -> Tuple[dict, pd.DataFrame]:
    """
    Parameters
    ----------
    test_dataset
        Dataset to use for testing.
    predictor
        The predictor to test.
    evaluator
        Evaluator to use.
    num_samples
        Number of samples to use when generating sample-based forecasts. Only
        sampling-based models will use this.
    logging_file
        If specified, information of the backtest is redirected to this file.

    Returns
    -------
    Tuple[dict, pd.DataFrame]
        A tuple of aggregate metrics and per-time-series metrics obtained by
        training `forecaster` on `train_dataset` and evaluating the resulting
        `evaluator` provided on the `test_dataset`.
    """

    if logging_file is not None:
        log_formatter = logging.Formatter(
            "[%(asctime)s %(levelname)s %(thread)d] %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
        )
        logger = logging.getLogger(__name__)
        handler = logging.FileHandler(logging_file)
        handler.setFormatter(log_formatter)
        logger.addHandler(handler)
    else:
        logger = logging.getLogger(__name__)

    test_statistics = calculate_dataset_statistics(test_dataset)
    serialize_message(logger, test_dataset_stats_key, test_statistics)

    forecast_it, ts_it = make_evaluation_predictions(test_dataset,
                                                     predictor=predictor,
                                                     num_samples=num_samples)

    agg_metrics, item_metrics = evaluator(ts_it,
                                          forecast_it,
                                          num_series=maybe_len(test_dataset))

    # we only log aggregate metrics for now as item metrics may be very large
    for name, value in agg_metrics.items():
        serialize_message(logger, f"metric-{name}", value)

    if logging_file is not None:
        # Close the file handler to avoid letting the file open.
        # https://stackoverflow.com/questions/24816456/python-logging-wont-shutdown
        logger.removeHandler(handler)
        del logger, handler

    return agg_metrics, item_metrics