def create_predictor( self, transformation: Transformation, trained_network: ForkingSeq2SeqNetworkBase, ) -> Predictor: quantile_strs = ( [ Quantile.from_float(quantile).name for quantile in self.quantile_output.quantiles ] if self.quantile_output is not None else None ) prediction_splitter = self._create_instance_splitter("test") prediction_network_class = ( ForkingSeq2SeqPredictionNetwork if self.quantile_output is not None else ForkingSeq2SeqDistributionPredictionNetwork ) prediction_network = prediction_network_class( encoder=trained_network.encoder, enc2dec=trained_network.enc2dec, decoder=trained_network.decoder, quantile_output=trained_network.quantile_output, distr_output=trained_network.distr_output, context_length=self.context_length, num_forking=self.num_forking, cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, scaling=self.scaling, scaling_decoder_dynamic_feature=self.scaling_decoder_dynamic_feature, dtype=self.dtype, ) copy_parameters(trained_network, prediction_network) return RepresentableBlockPredictor( input_transform=transformation + prediction_splitter, prediction_net=prediction_network, batch_size=self.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, forecast_generator=( QuantileForecastGenerator(quantile_strs) if quantile_strs is not None else DistributionForecastGenerator(self.distr_output) ), )
def create_predictor(self, transformation, trained_network): prediction_splitter = self._create_instance_splitter("test") if self.sampling is True: prediction_network = SimpleFeedForwardSamplingNetwork( num_hidden_dimensions=self.num_hidden_dimensions, prediction_length=self.prediction_length, context_length=self.context_length, distr_output=self.distr_output, batch_normalization=self.batch_normalization, mean_scaling=self.mean_scaling, params=trained_network.collect_params(), num_parallel_samples=self.num_parallel_samples, ) return RepresentableBlockPredictor( input_transform=transformation + prediction_splitter, prediction_net=prediction_network, batch_size=self.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, ) else: prediction_network = SimpleFeedForwardDistributionNetwork( num_hidden_dimensions=self.num_hidden_dimensions, prediction_length=self.prediction_length, context_length=self.context_length, distr_output=self.distr_output, batch_normalization=self.batch_normalization, mean_scaling=self.mean_scaling, params=trained_network.collect_params(), num_parallel_samples=self.num_parallel_samples, ) return RepresentableBlockPredictor( input_transform=transformation + prediction_splitter, prediction_net=prediction_network, batch_size=self.batch_size, forecast_generator=DistributionForecastGenerator( self.distr_output), freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, )