def create_predictor( self, transformation: Transformation, trained_network: HybridBlock) -> RepresentableBlockPredictor: prediction_network = SelfAttentionPredictionNetwork( context_length=self.context_length, prediction_length=self.prediction_length, d_hidden=self.model_dim, m_ffn=self.ffn_dim_multiplier, n_head=self.num_heads, n_layers=self.num_layers, n_output=self.num_outputs, cardinalities=self.cardinalities, kernel_sizes=self.kernel_sizes, dist_enc=self.distance_encoding, pre_ln=self.pre_layer_norm, dropout=self.dropout, temperature=self.temperature, ) copy_parameters(trained_network, prediction_network) return RepresentableBlockPredictor( input_transform=transformation, prediction_net=prediction_network, batch_size=self.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, forecast_generator=QuantileForecastGenerator( quantiles=[str(q) for q in prediction_network.quantiles], ), )
def create_predictor( self, transformation: transform.Transformation, trained_network: Seq2SeqTrainingNetwork, ) -> Predictor: # todo: this is specific to quantile output quantile_strs = [ Quantile.from_float(quantile).name for quantile in self.quantiles ] prediction_splitter = self._create_instance_splitter("test") prediction_network = Seq2SeqPredictionNetwork( embedder=trained_network.embedder, scaler=trained_network.scaler, encoder=trained_network.encoder, enc2dec=trained_network.enc2dec, decoder=trained_network.decoder, quantile_output=trained_network.quantile_output, ) copy_parameters(trained_network, prediction_network) return RepresentableBlockPredictor( input_transform=transformation + prediction_splitter, prediction_net=prediction_network, batch_size=self.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, forecast_generator=QuantileForecastGenerator(quantile_strs), )
def create_predictor( self, transformation: Transformation, trained_network: HybridBlock) -> RepresentableBlockPredictor: prediction_network = TemporalFusionTransformerPredictionNetwork( context_length=self.context_length, prediction_length=self.prediction_length, d_var=self.variable_dim, d_hidden=self.hidden_dim, n_head=self.num_heads, n_output=self.num_outputs, d_past_feat_dynamic_real=list( self.past_dynamic_feature_dims.values()), c_past_feat_dynamic_cat=list( self.past_dynamic_cardinalities.values()), d_feat_dynamic_real=[1] * len(self.time_features) + list(self.dynamic_feature_dims.values()), c_feat_dynamic_cat=list(self.dynamic_cardinalities.values()), d_feat_static_real=list(self.static_feature_dims.values()), c_feat_static_cat=list(self.static_cardinalities.values()), dropout=self.dropout_rate, ) copy_parameters(trained_network, prediction_network) return RepresentableBlockPredictor( input_transform=transformation, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, forecast_generator=QuantileForecastGenerator( quantiles=[str(q) for q in prediction_network.quantiles], ), )
def create_predictor( self, transformation: Transformation, trained_network: ForkingSeq2SeqNetworkBase, ) -> Predictor: # this is specific to quantile output quantile_strs = [ Quantile.from_float(quantile).name for quantile in self.quantile_output.quantiles ] prediction_network = ForkingSeq2SeqPredictionNetwork( encoder=trained_network.encoder, enc2dec=trained_network.enc2dec, decoder=trained_network.decoder, quantile_output=trained_network.quantile_output, context_length=self.context_length, cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, scaling=self.scaling, dtype=self.dtype, ) copy_parameters(trained_network, prediction_network) return RepresentableBlockPredictor( input_transform=transformation, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, forecast_generator=QuantileForecastGenerator(quantile_strs), )
def create_predictor( self, transformation: Transformation, trained_network: ForkingSeq2SeqNetworkBase, ) -> Predictor: quantile_strs = ( [ Quantile.from_float(quantile).name for quantile in self.quantile_output.quantiles ] if self.quantile_output is not None else None ) prediction_splitter = self._create_instance_splitter("test") prediction_network_class = ( ForkingSeq2SeqPredictionNetwork if self.quantile_output is not None else ForkingSeq2SeqDistributionPredictionNetwork ) prediction_network = prediction_network_class( encoder=trained_network.encoder, enc2dec=trained_network.enc2dec, decoder=trained_network.decoder, quantile_output=trained_network.quantile_output, distr_output=trained_network.distr_output, context_length=self.context_length, num_forking=self.num_forking, cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, scaling=self.scaling, scaling_decoder_dynamic_feature=self.scaling_decoder_dynamic_feature, dtype=self.dtype, ) copy_parameters(trained_network, prediction_network) return RepresentableBlockPredictor( input_transform=transformation + prediction_splitter, prediction_net=prediction_network, batch_size=self.batch_size, freq=self.freq, prediction_length=self.prediction_length, ctx=self.trainer.ctx, forecast_generator=( QuantileForecastGenerator(quantile_strs) if quantile_strs is not None else DistributionForecastGenerator(self.distr_output) ), )
def create_predictor( self, transformation: Transformation, trained_network: TemporalFusionTransformerTrainingNetwork, device: torch.device, ) -> Predictor: prediction_network = TemporalFusionTransformerPredictionNetwork( context_length=self.context_length, prediction_length=self.prediction_length, variable_dim=self.variable_dim, embed_dim=self.embed_dim, num_heads=self.num_heads, num_outputs=self.num_outputs, dropout=self.dropout_rate, d_past_feat_dynamic_real=_default_feat_args( list(self.past_dynamic_feature_dims.values())), c_past_feat_dynamic_cat=_default_feat_args( list(self.past_dynamic_cardinalities.values())), # +1 is for Age Feature d_feat_dynamic_real=_default_feat_args( [1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values())), c_feat_dynamic_cat=_default_feat_args( list(self.dynamic_cardinalities.values())), d_feat_static_real=_default_feat_args( list(self.static_feature_dims.values()), ), c_feat_static_cat=_default_feat_args( list(self.static_cardinalities.values()), ), ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, forecast_generator=QuantileForecastGenerator( quantiles=[str(q) for q in prediction_network.quantiles], ), )