コード例 #1
0
    def create_predictor(
        self,
        transformation: Transformation,
        trained_network: ForkingSeq2SeqNetworkBase,
    ) -> Predictor:
        quantile_strs = (
            [
                Quantile.from_float(quantile).name
                for quantile in self.quantile_output.quantiles
            ]
            if self.quantile_output is not None
            else None
        )

        prediction_splitter = self._create_instance_splitter("test")

        prediction_network_class = (
            ForkingSeq2SeqPredictionNetwork
            if self.quantile_output is not None
            else ForkingSeq2SeqDistributionPredictionNetwork
        )

        prediction_network = prediction_network_class(
            encoder=trained_network.encoder,
            enc2dec=trained_network.enc2dec,
            decoder=trained_network.decoder,
            quantile_output=trained_network.quantile_output,
            distr_output=trained_network.distr_output,
            context_length=self.context_length,
            num_forking=self.num_forking,
            cardinality=self.cardinality,
            embedding_dimension=self.embedding_dimension,
            scaling=self.scaling,
            scaling_decoder_dynamic_feature=self.scaling_decoder_dynamic_feature,
            dtype=self.dtype,
        )

        copy_parameters(trained_network, prediction_network)

        return RepresentableBlockPredictor(
            input_transform=transformation + prediction_splitter,
            prediction_net=prediction_network,
            batch_size=self.batch_size,
            freq=self.freq,
            prediction_length=self.prediction_length,
            ctx=self.trainer.ctx,
            forecast_generator=(
                QuantileForecastGenerator(quantile_strs)
                if quantile_strs is not None
                else DistributionForecastGenerator(self.distr_output)
            ),
        )
コード例 #2
0
ファイル: _estimator.py プロジェクト: yifeim/gluon-ts
    def create_predictor(self, transformation, trained_network):
        prediction_splitter = self._create_instance_splitter("test")

        if self.sampling is True:
            prediction_network = SimpleFeedForwardSamplingNetwork(
                num_hidden_dimensions=self.num_hidden_dimensions,
                prediction_length=self.prediction_length,
                context_length=self.context_length,
                distr_output=self.distr_output,
                batch_normalization=self.batch_normalization,
                mean_scaling=self.mean_scaling,
                params=trained_network.collect_params(),
                num_parallel_samples=self.num_parallel_samples,
            )

            return RepresentableBlockPredictor(
                input_transform=transformation + prediction_splitter,
                prediction_net=prediction_network,
                batch_size=self.batch_size,
                freq=self.freq,
                prediction_length=self.prediction_length,
                ctx=self.trainer.ctx,
            )

        else:
            prediction_network = SimpleFeedForwardDistributionNetwork(
                num_hidden_dimensions=self.num_hidden_dimensions,
                prediction_length=self.prediction_length,
                context_length=self.context_length,
                distr_output=self.distr_output,
                batch_normalization=self.batch_normalization,
                mean_scaling=self.mean_scaling,
                params=trained_network.collect_params(),
                num_parallel_samples=self.num_parallel_samples,
            )
            return RepresentableBlockPredictor(
                input_transform=transformation + prediction_splitter,
                prediction_net=prediction_network,
                batch_size=self.batch_size,
                forecast_generator=DistributionForecastGenerator(
                    self.distr_output),
                freq=self.freq,
                prediction_length=self.prediction_length,
                ctx=self.trainer.ctx,
            )