示例#1
0
    def train_loader(
        dataset: ListDataset,
        prediction_interval_length: float,
        context_interval_length: float,
        is_train: bool = True,
        override_args: dict = None,
    ) -> DataLoader:

        if override_args is None:
            override_args = {}

        splitter = ContinuousTimeInstanceSplitter(
            future_interval_length=prediction_interval_length,
            past_interval_length=context_interval_length,
            train_sampler=ContinuousTimeUniformSampler(num_instances=10),
        )

        kwargs: Dict[str, Any] = dict(
            dataset=dataset,
            transform=splitter,
            batch_size=10,
            stack_fn=partial(batchify,
                             ctx=mx.cpu(),
                             dtype=np.float32,
                             variable_length=True),
        )
        kwargs.update(override_args)

        if is_train:
            return TrainDataLoader(num_batches_per_epoch=22,
                                   num_workers=None,
                                   **kwargs)
        else:
            return InferenceDataLoader(num_workers=None, **kwargs)
示例#2
0
 def create_transformation(self) -> Transformation:
     return Chain([
         ContinuousTimeInstanceSplitter(
             past_interval_length=self.context_interval_length,
             future_interval_length=self.prediction_interval_length,
             train_sampler=ContinuousTimeUniformSampler(
                 num_instances=self.num_training_instances),
         ),
         RenameFields({
             "past_target": "target",
             "past_valid_length": "valid_length",
         }),
     ])
示例#3
0
    def get_predictor(**kwargs) -> PointProcessGluonPredictor:
        default_kwargs = dict(
            input_names=["past_target", "past_valid_length"],
            prediction_net=MockTPPPredictionNet(
                prediction_interval_length=5.0),
            batch_size=128,
            prediction_interval_length=5.0,
            freq="H",
            ctx=mx.cpu(),
            input_transform=ContinuousTimeInstanceSplitter(
                1, 5, ContinuousTimeUniformSampler(num_instances=5)),
        )

        default_kwargs.update(**kwargs)

        return PointProcessGluonPredictor(**default_kwargs)
示例#4
0
    def train_loader(
        dataset: ListDataset,
        prediction_interval_length: float,
        context_interval_length: float,
        is_train: bool = True,
        override_args: dict = None,
    ) -> Iterable[DataBatch]:

        if override_args is None:
            override_args = {}

        if is_train:
            sampler = ContinuousTimeUniformSampler(
                num_instances=10,
                min_past=context_interval_length,
                min_future=prediction_interval_length,
            )
        else:
            sampler = ContinuousTimePredictionSampler(
                min_past=context_interval_length)

        splitter = ContinuousTimeInstanceSplitter(
            future_interval_length=prediction_interval_length,
            past_interval_length=context_interval_length,
            instance_sampler=sampler,
            freq=dataset.freq,
        )

        kwargs = dict(
            dataset=dataset,
            transform=splitter,
            batch_size=10,
            stack_fn=partial(batchify, dtype=np.float32, variable_length=True),
        )

        kwargs.update(override_args)

        if is_train:
            return itertools.islice(
                TrainDataLoader(num_workers=None, **kwargs), NUM_BATCHES)
        else:
            return InferenceDataLoader(**kwargs)
示例#5
0
    def _create_instance_splitter(self, mode: str):
        assert mode in ["training", "validation", "test"]

        instance_sampler = {
            "training": ContinuousTimeUniformSampler(
                num_instances=self.num_training_instances,
                min_past=self.context_interval_length,
                min_future=self.prediction_interval_length,
            ),
            "validation": ContinuousTimePredictionSampler(
                allow_empty_interval=True,
                min_past=self.context_interval_length,
                min_future=self.prediction_interval_length,
            ),
            "test": ContinuousTimePredictionSampler(
                min_past=self.context_interval_length,
                allow_empty_interval=False,
            ),
        }[mode]

        assert isinstance(instance_sampler, ContinuousTimePointSampler)

        return Chain(
            [
                ContinuousTimeInstanceSplitter(
                    past_interval_length=self.context_interval_length,
                    future_interval_length=self.prediction_interval_length,
                    instance_sampler=instance_sampler,
                ),
                RenameFields(
                    {
                        "past_target": "target",
                        "past_valid_length": "valid_length",
                    }
                ),
            ]
        )