コード例 #1
0
    def _create_instance_splitter(self, mode: str):
        assert mode in ["training", "validation", "test"]

        instance_sampler = {
            "training": self.train_sampler,
            "validation": self.validation_sampler,
            "test": TestSplitSampler(),
        }[mode]

        return InstanceSplitter(
            target_field=FieldName.TARGET,
            is_pad_field=FieldName.IS_PAD,
            start_field=FieldName.START,
            forecast_start_field=FieldName.FORECAST_START,
            instance_sampler=instance_sampler,
            past_length=self.history_length,
            future_length=self.prediction_length,
            time_series_fields=[
                FieldName.FEAT_TIME,
                FieldName.OBSERVED_VALUES,
            ],
        ) + (CDFtoGaussianTransform(
            target_field=FieldName.TARGET,
            observed_values_field=FieldName.OBSERVED_VALUES,
            max_context_length=self.conditioning_length,
            target_dim=self.target_dim,
        ) if self.use_marginal_transformation else RenameFields(
            {
                f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf",
                f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf",
            }))
コード例 #2
0
 def use_marginal_transformation(
     marginal_transformation: bool, ) -> Transformation:
     if marginal_transformation:
         return CDFtoGaussianTransform(
             target_field=FieldName.TARGET,
             observed_values_field=FieldName.OBSERVED_VALUES,
             max_context_length=self.conditioning_length,
             target_dim=self.target_dim,
         )
     else:
         return RenameFields({
             f"past_{FieldName.TARGET}":
             f"past_{FieldName.TARGET}_cdf",
             f"future_{FieldName.TARGET}":
             f"future_{FieldName.TARGET}_cdf",
         })