def __init__( self, freq: str, context_length: int, prediction_length: int, trainer: Trainer = Trainer(), num_layers: int = 1, num_cells: int = 50, cell_type: str = "lstm", num_eval_samples: int = 100, cardinality: List[int] = list([1]), embedding_dimension: int = 10, distr_output: DistributionOutput = StudentTOutput(), ) -> None: model = RNN(mode=cell_type, num_layers=num_layers, num_hidden=num_cells) super(CanonicalRNNEstimator, self).__init__( model=model, is_sequential=True, freq=freq, context_length=context_length, prediction_length=prediction_length, trainer=trainer, num_eval_samples=num_eval_samples, cardinality=cardinality, embedding_dimension=embedding_dimension, distr_output=distr_output, )
def __init__( self, mode: str, hidden_size: int, num_layers: int, bidirectional: bool, **kwargs, ) -> None: assert num_layers > 0, "`num_layers` value must be greater than zero" assert hidden_size > 0, "`hidden_size` value must be greater than zero" super().__init__(**kwargs) with self.name_scope(): self.rnn = RNN(mode, hidden_size, num_layers, bidirectional)
def __init__( self, mode, num_hidden, num_layers, num_output, bidirectional=False, **kwargs, ): super(RNNModel, self).__init__(**kwargs) self.num_output = num_output with self.name_scope(): self.rnn = RNN( mode=mode, num_hidden=num_hidden, num_layers=num_layers, bidirectional=bidirectional, ) self.decoder = nn.Dense( num_output, in_units=num_hidden, flatten=False )
def __init__( self, mode: str, hidden_size: int, num_layers: int, bidirectional: bool, use_static_feat: bool = False, use_dynamic_feat: bool = False, **kwargs, ) -> None: assert num_layers > 0, "`num_layers` value must be greater than zero" assert hidden_size > 0, "`hidden_size` value must be greater than zero" super().__init__(**kwargs) self.mode = mode self.hidden_size = hidden_size self.num_layers = num_layers self.bidirectional = bidirectional self.use_static_feat = use_static_feat self.use_dynamic_feat = use_dynamic_feat with self.name_scope(): self.rnn = RNN(mode, hidden_size, num_layers, bidirectional)