def train_loader( dataset: ListDataset, prediction_interval_length: float, context_interval_length: float, is_train: bool = True, override_args: dict = None, ) -> DataLoader: if override_args is None: override_args = {} splitter = ContinuousTimeInstanceSplitter( future_interval_length=prediction_interval_length, past_interval_length=context_interval_length, train_sampler=ContinuousTimeUniformSampler(num_instances=10), ) kwargs: Dict[str, Any] = dict( dataset=dataset, transform=splitter, batch_size=10, stack_fn=partial(batchify, ctx=mx.cpu(), dtype=np.float32, variable_length=True), ) kwargs.update(override_args) if is_train: return TrainDataLoader(num_batches_per_epoch=22, num_workers=None, **kwargs) else: return InferenceDataLoader(num_workers=None, **kwargs)
def create_transformation(self) -> Transformation: return Chain([ ContinuousTimeInstanceSplitter( past_interval_length=self.context_interval_length, future_interval_length=self.prediction_interval_length, train_sampler=ContinuousTimeUniformSampler( num_instances=self.num_training_instances), ), RenameFields({ "past_target": "target", "past_valid_length": "valid_length", }), ])
def get_predictor(**kwargs) -> PointProcessGluonPredictor: default_kwargs = dict( input_names=["past_target", "past_valid_length"], prediction_net=MockTPPPredictionNet( prediction_interval_length=5.0), batch_size=128, prediction_interval_length=5.0, freq="H", ctx=mx.cpu(), input_transform=ContinuousTimeInstanceSplitter( 1, 5, ContinuousTimeUniformSampler(num_instances=5)), ) default_kwargs.update(**kwargs) return PointProcessGluonPredictor(**default_kwargs)
def train_loader( dataset: ListDataset, prediction_interval_length: float, context_interval_length: float, is_train: bool = True, override_args: dict = None, ) -> Iterable[DataBatch]: if override_args is None: override_args = {} if is_train: sampler = ContinuousTimeUniformSampler( num_instances=10, min_past=context_interval_length, min_future=prediction_interval_length, ) else: sampler = ContinuousTimePredictionSampler( min_past=context_interval_length) splitter = ContinuousTimeInstanceSplitter( future_interval_length=prediction_interval_length, past_interval_length=context_interval_length, instance_sampler=sampler, freq=dataset.freq, ) kwargs = dict( dataset=dataset, transform=splitter, batch_size=10, stack_fn=partial(batchify, dtype=np.float32, variable_length=True), ) kwargs.update(override_args) if is_train: return itertools.islice( TrainDataLoader(num_workers=None, **kwargs), NUM_BATCHES) else: return InferenceDataLoader(**kwargs)
def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] instance_sampler = { "training": ContinuousTimeUniformSampler( num_instances=self.num_training_instances, min_past=self.context_interval_length, min_future=self.prediction_interval_length, ), "validation": ContinuousTimePredictionSampler( allow_empty_interval=True, min_past=self.context_interval_length, min_future=self.prediction_interval_length, ), "test": ContinuousTimePredictionSampler( min_past=self.context_interval_length, allow_empty_interval=False, ), }[mode] assert isinstance(instance_sampler, ContinuousTimePointSampler) return Chain( [ ContinuousTimeInstanceSplitter( past_interval_length=self.context_interval_length, future_interval_length=self.prediction_interval_length, instance_sampler=instance_sampler, ), RenameFields( { "past_target": "target", "past_valid_length": "valid_length", } ), ] )