def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] instance_sampler = { "training": self.train_sampler, "validation": self.validation_sampler, "test": TestSplitSampler(), }[mode] return InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=self.history_length, future_length=self.prediction_length, time_series_fields=[ FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES, ], ) + (CDFtoGaussianTransform( target_field=FieldName.TARGET, observed_values_field=FieldName.OBSERVED_VALUES, max_context_length=self.conditioning_length, target_dim=self.target_dim, ) if self.use_marginal_transformation else RenameFields( { f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", }))
def use_marginal_transformation( marginal_transformation: bool, ) -> Transformation: if marginal_transformation: return CDFtoGaussianTransform( target_field=FieldName.TARGET, observed_values_field=FieldName.OBSERVED_VALUES, max_context_length=self.conditioning_length, target_dim=self.target_dim, ) else: return RenameFields({ f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", })