def create_predictor( self, transformation: Transformation, trained_network: nn.Module, device: torch.device, ) -> Predictor: prediction_network = NBEATSPredictionNetwork( prediction_length=self.prediction_length, context_length=self.context_length, num_stacks=self.num_stacks, widths=self.widths, num_blocks=self.num_blocks, num_block_layers=self.num_block_layers, expansion_coefficient_lengths=self.expansion_coefficient_lengths, sharing=self.sharing, stack_types=self.stack_types, ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, )
def create_predictor( self, transformation: Transformation, trained_network: nn.Module, device: torch.device, ) -> Predictor: prediction_splitter = self.create_instance_splitter("test") prediction_network = SimpleFeedForwardPredictionNetwork( num_hidden_dimensions=self.num_hidden_dimensions, prediction_length=self.prediction_length, context_length=self.context_length, distr_output=self.distr_output, batch_normalization=self.batch_normalization, mean_scaling=self.mean_scaling, num_parallel_samples=self.num_parallel_samples, ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, )
def create_predictor( self, transformation: Transformation, trained_network: TemporalFusionTransformerTrainingNetwork, device: torch.device, ) -> Predictor: prediction_network = TemporalFusionTransformerPredictionNetwork( context_length=self.context_length, prediction_length=self.prediction_length, variable_dim=self.variable_dim, embed_dim=self.embed_dim, num_heads=self.num_heads, num_outputs=self.num_outputs, dropout=self.dropout_rate, d_past_feat_dynamic_real=_default_feat_args( list(self.past_dynamic_feature_dims.values())), c_past_feat_dynamic_cat=_default_feat_args( list(self.past_dynamic_cardinalities.values())), # +1 is for Age Feature d_feat_dynamic_real=_default_feat_args( [1] * (len(self.time_features) + 1) + list(self.dynamic_feature_dims.values())), c_feat_dynamic_cat=_default_feat_args( list(self.dynamic_cardinalities.values())), d_feat_static_real=_default_feat_args( list(self.static_feature_dims.values()), ), c_feat_static_cat=_default_feat_args( list(self.static_cardinalities.values()), ), ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, forecast_generator=QuantileForecastGenerator( quantiles=[str(q) for q in prediction_network.quantiles], ), )
def create_predictor( self, transformation: Transformation, trained_network: TransformerTempFlowTrainingNetwork, device: torch.device, ) -> Predictor: prediction_network = TransformerTempFlowPredictionNetwork( input_size=self.input_size, target_dim=self.target_dim, num_heads=self.num_heads, act_type=self.act_type, d_model=self.d_model, dim_feedforward_scale=self.dim_feedforward_scale, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, history_length=self.history_length, context_length=self.context_length, prediction_length=self.prediction_length, dropout_rate=self.dropout_rate, cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, lags_seq=self.lags_seq, scaling=self.scaling, flow_type=self.flow_type, n_blocks=self.n_blocks, hidden_size=self.hidden_size, n_hidden=self.n_hidden, conditioning_length=self.conditioning_length, dequantize=self.dequantize, num_parallel_samples=self.num_parallel_samples, ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, )
def create_predictor( self, transformation: Transformation, trained_network: TimeGradTrainingNetwork, device: torch.device, ) -> Predictor: prediction_network = TimeGradPredictionNetwork( input_size=self.input_size, target_dim=self.target_dim, num_layers=self.num_layers, num_cells=self.num_cells, cell_type=self.cell_type, history_length=self.history_length, context_length=self.context_length, prediction_length=self.prediction_length, dropout_rate=self.dropout_rate, cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, diff_steps=self.diff_steps, loss_type=self.loss_type, beta_end=self.beta_end, beta_schedule=self.beta_schedule, residual_layers=self.residual_layers, residual_channels=self.residual_channels, dilation_cycle_length=self.dilation_cycle_length, lags_seq=self.lags_seq, scaling=self.scaling, conditioning_length=self.conditioning_length, num_parallel_samples=self.num_parallel_samples, ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, )
def create_predictor( self, transformation: Transformation, trained_network: DeepVARTrainingNetwork, device: torch.device, ) -> Predictor: prediction_network = DeepVARPredictionNetwork( input_size=self.input_size, target_dim=self.target_dim, num_parallel_samples=self.num_parallel_samples, num_layers=self.num_layers, num_cells=self.num_cells, cell_type=self.cell_type, history_length=self.history_length, context_length=self.context_length, prediction_length=self.prediction_length, distr_output=self.distr_output, dropout_rate=self.dropout_rate, cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, lags_seq=self.lags_seq, scaling=self.scaling, ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.prediction_length, device=device, output_transform=self.output_transform, )
def create_predictor( self, transformation: Transformation, trained_network: LSTNetTrain, device: torch.device, ) -> PyTorchPredictor: prediction_network = LSTNetPredict( num_series=self.num_series, channels=self.channels, kernel_size=self.kernel_size, rnn_cell_type=self.rnn_cell_type, rnn_num_cells=self.rnn_num_cells, skip_rnn_cell_type=self.skip_rnn_cell_type, skip_rnn_num_cells=self.skip_rnn_num_cells, skip_size=self.skip_size, ar_window=self.ar_window, context_length=self.context_length, horizon=self.horizon, prediction_length=self.prediction_length, dropout_rate=self.dropout_rate, output_activation=self.output_activation, scaling=self.scaling, ).to(device) copy_parameters(trained_network, prediction_network) input_names = get_module_forward_input_names(prediction_network) prediction_splitter = self.create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=input_names, prediction_net=prediction_network, batch_size=self.trainer.batch_size, freq=self.freq, prediction_length=self.horizon or self.prediction_length, device=device, )