class TemporalFusionTransformerNetwork(HybridBlock): @validated() def __init__( self, context_length: int, prediction_length: int, d_var: int, d_hidden: int, n_head: int, n_output: int, d_past_feat_dynamic_real: List[int], c_past_feat_dynamic_cat: List[int], d_feat_dynamic_real: List[int], c_feat_dynamic_cat: List[int], d_feat_static_real: List[int], c_feat_static_cat: List[int], dropout: float = 0.0, **kwargs, ): super(TemporalFusionTransformerNetwork, self).__init__(**kwargs) self.context_length = context_length self.prediction_length = prediction_length self.d_var = d_var self.d_hidden = d_hidden self.n_head = n_head self.n_output = n_output self.quantiles = sum( [[i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)], [0.5], ) self.normalize_eps = 1e-5 self.d_past_feat_dynamic_real = d_past_feat_dynamic_real self.c_past_feat_dynamic_cat = c_past_feat_dynamic_cat self.d_feat_dynamic_real = d_feat_dynamic_real self.c_feat_dynamic_cat = c_feat_dynamic_cat self.d_feat_static_real = d_feat_static_real self.c_feat_static_cat = c_feat_static_cat self.n_past_feat_dynamic = len(self.d_past_feat_dynamic_real) + len( self.c_past_feat_dynamic_cat ) self.n_feat_dynamic = len(self.d_feat_dynamic_real) + len( self.c_feat_dynamic_cat ) self.n_feat_static = len(self.d_feat_static_real) + len( self.c_feat_static_cat ) with self.name_scope(): self.target_proj = nn.Dense( units=self.d_var, in_units=1, flatten=False, prefix=f"target_projection_", ) if self.d_past_feat_dynamic_real: self.past_feat_dynamic_proj = FeatureProjector( feature_dims=self.d_past_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_past_feat_dynamic_real), prefix="past_feat_dynamic_", ) else: self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: self.past_feat_dynamic_embed = FeatureEmbedder( cardinalities=self.c_past_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_past_feat_dynamic_cat), prefix="past_feat_dynamic_", ) else: self.past_feat_dynamic_embed = None if self.d_feat_dynamic_real: self.feat_dynamic_proj = FeatureProjector( feature_dims=self.d_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), prefix="feat_dynamic_", ) else: self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: self.feat_dynamic_embed = FeatureEmbedder( cardinalities=self.c_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), prefix="feat_dynamic_", ) else: self.feat_dynamic_embed = None if self.d_feat_static_real: self.feat_static_proj = FeatureProjector( feature_dims=self.d_feat_static_real, embedding_dims=[self.d_var] * len(self.d_feat_static_real), prefix="feat_static_", ) else: self.feat_static_proj = None if self.c_feat_static_cat: self.feat_static_embed = FeatureEmbedder( cardinalities=self.c_feat_static_cat, embedding_dims=[self.d_var] * len(self.c_feat_static_cat), prefix="feat_static_", ) else: self.feat_static_embed = None self.static_selector = VariableSelectionNetwork( d_hidden=self.d_var, n_vars=self.n_feat_static, dropout=dropout, ) self.ctx_selector = VariableSelectionNetwork( d_hidden=self.d_var, n_vars=self.n_past_feat_dynamic + self.n_feat_dynamic + 1, add_static=True, dropout=dropout, ) self.tgt_selector = VariableSelectionNetwork( d_hidden=self.d_var, n_vars=self.n_feat_dynamic, add_static=True, dropout=dropout, ) self.selection = GatedResidualNetwork( d_hidden=self.d_var, dropout=dropout, ) self.enrichment = GatedResidualNetwork( d_hidden=self.d_var, dropout=dropout, ) self.state_h = GatedResidualNetwork( d_hidden=self.d_var, d_output=self.d_hidden, dropout=dropout, ) self.state_c = GatedResidualNetwork( d_hidden=self.d_var, d_output=self.d_hidden, dropout=dropout, ) self.temporal_encoder = TemporalFusionEncoder( context_length=self.context_length, prediction_length=self.prediction_length, d_input=self.d_var, d_hidden=self.d_hidden, ) self.temporal_decoder = TemporalFusionDecoder( context_length=self.context_length, prediction_length=self.prediction_length, d_hidden=self.d_hidden, d_var=self.d_var, n_head=self.n_head, dropout=dropout, ) self.output = QuantileOutput(quantiles=self.quantiles) self.output_proj = self.output.get_quantile_proj() self.loss = self.output.get_loss() def _preprocess( self, F, past_target: Tensor, past_observed_values: Tensor, past_feat_dynamic_real: Tensor, past_feat_dynamic_cat: Tensor, feat_dynamic_real: Tensor, feat_dynamic_cat: Tensor, feat_static_real: Tensor, feat_static_cat: Tensor, ): obs = F.broadcast_mul(past_target, past_observed_values) count = F.sum(past_observed_values, axis=1, keepdims=True) offset = F.broadcast_div( F.sum(obs, axis=1, keepdims=True), count + self.normalize_eps, ) scale = F.broadcast_div( F.sum(obs ** 2, axis=1, keepdims=True), count + self.normalize_eps, ) scale = F.broadcast_sub(scale, offset ** 2) scale = F.sqrt(scale) past_target = F.broadcast_div( F.broadcast_sub(past_target, offset), scale + self.normalize_eps, ) past_target = F.expand_dims(past_target, axis=-1) past_covariates = [] future_covariates = [] static_covariates: List[Tensor] = [] proj = self.target_proj(past_target) past_covariates.append(proj) if self.past_feat_dynamic_proj is not None: projs = self.past_feat_dynamic_proj(past_feat_dynamic_real) past_covariates.extend(projs) if self.past_feat_dynamic_embed is not None: embs = self.past_feat_dynamic_embed(past_feat_dynamic_cat) past_covariates.extend(embs) if self.feat_dynamic_proj is not None: projs = self.feat_dynamic_proj(feat_dynamic_real) for proj in projs: ctx_proj = F.slice_axis( proj, axis=1, begin=0, end=self.context_length ) tgt_proj = F.slice_axis( proj, axis=1, begin=self.context_length, end=None ) past_covariates.append(ctx_proj) future_covariates.append(tgt_proj) if self.feat_dynamic_embed is not None: embs = self.feat_dynamic_embed(feat_dynamic_cat) for emb in embs: ctx_emb = F.slice_axis( emb, axis=1, begin=0, end=self.context_length ) tgt_emb = F.slice_axis( emb, axis=1, begin=self.context_length, end=None ) past_covariates.append(ctx_emb) future_covariates.append(tgt_emb) if self.feat_static_proj is not None: projs = self.feat_static_proj(feat_static_real) static_covariates.extend(projs) if self.feat_static_embed is not None: embs = self.feat_static_embed(feat_static_cat) static_covariates.extend(embs) return ( past_covariates, future_covariates, static_covariates, offset, scale, ) def _postprocess( self, F, preds: Tensor, offset: Tensor, scale: Tensor, ) -> Tensor: offset = F.expand_dims(offset, axis=-1) scale = F.expand_dims(scale, axis=-1) preds = F.broadcast_add( F.broadcast_mul(preds, (scale + self.normalize_eps)), offset, ) return preds def _forward( self, F, past_observed_values: Tensor, past_covariates: Tensor, future_covariates: Tensor, static_covariates: Tensor, ): static_var, _ = self.static_selector(static_covariates) c_selection = self.selection(static_var).expand_dims(axis=1) c_enrichment = self.enrichment(static_var).expand_dims(axis=1) c_h = self.state_h(static_var) c_c = self.state_c(static_var) ctx_input, _ = self.ctx_selector(past_covariates, c_selection) tgt_input, _ = self.tgt_selector(future_covariates, c_selection) encoding = self.temporal_encoder(ctx_input, tgt_input, [c_h, c_c]) decoding = self.temporal_decoder( encoding, c_enrichment, past_observed_values ) preds = self.output_proj(decoding) return preds
def __init__( self, context_length: int, prediction_length: int, d_hidden: int, m_ffn: int, n_head: int, n_layers: int, n_output: int, cardinalities: List[int], kernel_sizes: Optional[List[int]], dist_enc: Optional[str], pre_ln: bool, dropout: float, temperature: float, normalizer_eps: float = 1e-5, **kwargs, ): super().__init__(**kwargs) if kernel_sizes is None or len(kernel_sizes) == 0: self.kernel_sizes = (1, ) else: self.kernel_sizes = kernel_sizes self.context_length = context_length self.prediction_length = prediction_length self.d_hidden = d_hidden assert (n_output % 2 == 1) and (n_output <= 9) self.quantiles = sum( ([i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)), [0.5], ) self.normalizer_eps = normalizer_eps with self.name_scope(): self._blocks = [] for layer in range(n_layers): block = SelfAttentionBlock( d_hidden=self.d_hidden, m_ffn=m_ffn, kernel_sizes=self.kernel_sizes, n_head=n_head, dist_enc=dist_enc, pre_ln=pre_ln, dropout=dropout, temperature=temperature, ) self.register_child(block=block, name=f"block_{layer+1}") self._blocks.append(block) self.target_proj = nn.Dense( units=self.d_hidden, in_units=1, use_bias=True, flatten=False, weight_initializer=init.Xavier(), prefix="target_proj_", ) self.covar_proj = nn.Dense( units=self.d_hidden, use_bias=True, flatten=False, weight_initializer=init.Xavier(), prefix="covar_proj_", ) if cardinalities: self.embedder = FeatureEmbedder( cardinalities=cardinalities, embedding_dims=[self.d_hidden] * len(cardinalities), prefix="embedder_", ) self.output = QuantileOutput(quantiles=self.quantiles) self.output_proj = self.output.get_quantile_proj() self.loss = self.output.get_loss()
class SelfAttentionNetwork(HybridBlock): @validated() def __init__( self, context_length: int, prediction_length: int, d_hidden: int, m_ffn: int, n_head: int, n_layers: int, n_output: int, cardinalities: List[int], kernel_sizes: Optional[List[int]], dist_enc: Optional[str], pre_ln: bool, dropout: float, temperature: float, normalizer_eps: float = 1e-5, **kwargs, ): super().__init__(**kwargs) if kernel_sizes is None or len(kernel_sizes) == 0: self.kernel_sizes = (1, ) else: self.kernel_sizes = kernel_sizes self.context_length = context_length self.prediction_length = prediction_length self.d_hidden = d_hidden assert (n_output % 2 == 1) and (n_output <= 9) self.quantiles = sum( ([i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)), [0.5], ) self.normalizer_eps = normalizer_eps with self.name_scope(): self._blocks = [] for layer in range(n_layers): block = SelfAttentionBlock( d_hidden=self.d_hidden, m_ffn=m_ffn, kernel_sizes=self.kernel_sizes, n_head=n_head, dist_enc=dist_enc, pre_ln=pre_ln, dropout=dropout, temperature=temperature, ) self.register_child(block=block, name=f"block_{layer+1}") self._blocks.append(block) self.target_proj = nn.Dense( units=self.d_hidden, in_units=1, use_bias=True, flatten=False, weight_initializer=init.Xavier(), prefix="target_proj_", ) self.covar_proj = nn.Dense( units=self.d_hidden, use_bias=True, flatten=False, weight_initializer=init.Xavier(), prefix="covar_proj_", ) if cardinalities: self.embedder = FeatureEmbedder( cardinalities=cardinalities, embedding_dims=[self.d_hidden] * len(cardinalities), prefix="embedder_", ) self.output = QuantileOutput(quantiles=self.quantiles) self.output_proj = self.output.get_quantile_proj() self.loss = self.output.get_loss() def _preprocess( self, F, past_target: Tensor, past_observed_values: Tensor, past_is_pad: Tensor, past_feat_dynamic_real: Tensor, past_feat_dynamic_cat: Tensor, future_target: Tensor, future_feat_dynamic_real: Tensor, future_feat_dynamic_cat: Tensor, feat_static_real: Tensor, feat_static_cat: Tensor, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Optional[Tensor], Tensor, Tensor, ]: obs = past_target * past_observed_values count = F.sum(past_observed_values, axis=1, keepdims=True) offset = F.sum(obs, axis=1, keepdims=True) / (count + self.normalizer_eps) scale = F.sum(obs**2, axis=1, keepdims=True) / (count + self.normalizer_eps) scale = scale - offset**2 scale = scale.sqrt() past_target = (past_target - offset) / (scale + self.normalizer_eps) if future_target is not None: future_target = (future_target - offset) / (scale + self.normalizer_eps) def _assemble_covariates( feat_dynamic_real: Tensor, feat_dynamic_cat: Tensor, feat_static_real: Tensor, feat_static_cat: Tensor, is_past: bool, ) -> Tensor: covariates = [] if feat_dynamic_real.shape[-1] > 0: covariates.append(feat_dynamic_real) if feat_static_real.shape[-1] > 0: covariates.append( feat_static_real.expand_dims(axis=1).repeat( axis=1, repeats=self.context_length if is_past else self.prediction_length, )) if len(covariates) > 0: covariates = F.concat(*covariates, dim=-1) covariates = self.covar_proj(covariates) else: covariates = None categories = [] if feat_dynamic_cat.shape[-1] > 0: categories.append(feat_dynamic_cat) if feat_static_cat.shape[-1] > 0: categories.append( feat_static_cat.expand_dims(axis=1).repeat( axis=1, repeats=self.context_length if is_past else self.prediction_length, )) if len(categories) > 0: categories = F.concat(*categories, dim=-1) embeddings = self.embedder(categories) embeddings = F.reshape(embeddings, shape=(0, 0, -4, self.d_hidden, -1)).sum(axis=-1) if covariates is not None: covariates = covariates + embeddings else: covariates = embeddings else: pass return covariates past_covariates = _assemble_covariates( past_feat_dynamic_real, past_feat_dynamic_cat, feat_static_real, feat_static_cat, is_past=True, ) future_covariates = _assemble_covariates( future_feat_dynamic_real, future_feat_dynamic_cat, feat_static_real, feat_static_cat, is_past=False, ) past_observed_values = F.broadcast_logical_and( past_observed_values, F.logical_not(past_is_pad), ) return ( past_target, past_covariates, past_observed_values, future_target, future_covariates, offset, scale, ) def _postprocess(self, F, preds: Tensor, offset: Tensor, scale: Tensor) -> Tensor: offset = F.expand_dims(offset, axis=-1) scale = F.expand_dims(scale, axis=-1) preds = preds * (scale + self.normalizer_eps) + offset return preds def _forward_step( self, F, horizon: int, target: Tensor, covars: Optional[Tensor], mask: Tensor, ) -> Tensor: target = F.expand_dims(target, axis=-1) mask = F.expand_dims(mask, axis=-1) value = self.target_proj(target) if covars is not None: value = value + covars for block in self._blocks: value = block(value, mask) value = F.slice_axis(value, axis=1, begin=-horizon, end=None) preds = self.output_proj(value) return preds
def __init__( self, prediction_length: int, freq: str, context_length: Optional[int] = None, decoder_mlp_dim_seq: List[int] = None, trainer: Trainer = Trainer(), quantiles: Optional[List[float]] = None, distr_output: Optional[DistributionOutput] = None, scaling: bool = False, scaling_decoder_dynamic_feature: bool = False, ) -> None: assert (prediction_length > 0), f"Invalid prediction length: {prediction_length}." assert decoder_mlp_dim_seq is None or all( d > 0 for d in decoder_mlp_dim_seq ), "Elements of `mlp_hidden_dimension_seq` should be > 0" assert quantiles is None or all( 0 <= d <= 1 for d in quantiles), "Elements of `quantiles` should be >= 0 and <= 1" self.decoder_mlp_dim_seq = (decoder_mlp_dim_seq if decoder_mlp_dim_seq is not None else [30]) self.quantiles = (quantiles if (quantiles is not None) or (distr_output is not None) else [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) # `use_static_feat` and `use_dynamic_feat` always True because network # always receives input; either from the input data or constants encoder = RNNEncoder( mode="gru", hidden_size=50, num_layers=1, bidirectional=True, prefix="encoder_", use_static_feat=True, use_dynamic_feat=True, ) decoder = ForkingMLPDecoder( dec_len=prediction_length, final_dim=self.decoder_mlp_dim_seq[-1], hidden_dimension_sequence=self.decoder_mlp_dim_seq[:-1], prefix="decoder_", ) quantile_output = (QuantileOutput(self.quantiles) if self.quantiles else None) super().__init__( encoder=encoder, decoder=decoder, quantile_output=quantile_output, distr_output=distr_output, freq=freq, prediction_length=prediction_length, context_length=context_length, trainer=trainer, scaling=scaling, scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature, )
def __init__( self, freq: str, prediction_length: int, context_length: Optional[int] = None, use_past_feat_dynamic_real: bool = False, use_feat_dynamic_real: bool = False, use_feat_static_cat: bool = False, cardinality: List[int] = None, embedding_dimension: List[int] = None, add_time_feature: bool = True, add_age_feature: bool = False, enable_encoder_dynamic_feature: bool = True, enable_decoder_dynamic_feature: bool = True, seed: Optional[int] = None, decoder_mlp_dim_seq: Optional[List[int]] = None, channels_seq: Optional[List[int]] = None, dilation_seq: Optional[List[int]] = None, kernel_size_seq: Optional[List[int]] = None, use_residual: bool = True, quantiles: Optional[List[float]] = None, distr_output: Optional[DistributionOutput] = None, trainer: Trainer = Trainer(), scaling: bool = False, scaling_decoder_dynamic_feature: bool = False, num_forking: Optional[int] = None, ) -> None: assert (distr_output is None) or (quantiles is None) assert (prediction_length > 0), f"Invalid prediction length: {prediction_length}." assert decoder_mlp_dim_seq is None or all( d > 0 for d in decoder_mlp_dim_seq ), "Elements of `mlp_hidden_dimension_seq` should be > 0" assert channels_seq is None or all( d > 0 for d in channels_seq), "Elements of `channels_seq` should be > 0" assert dilation_seq is None or all( d > 0 for d in dilation_seq), "Elements of `dilation_seq` should be > 0" # TODO: add support for kernel size=1 assert kernel_size_seq is None or all( d > 1 for d in kernel_size_seq), "Elements of `kernel_size_seq` should be > 0" assert quantiles is None or all( 0 <= d <= 1 for d in quantiles), "Elements of `quantiles` should be >= 0 and <= 1" self.decoder_mlp_dim_seq = (decoder_mlp_dim_seq if decoder_mlp_dim_seq is not None else [30]) self.channels_seq = (channels_seq if channels_seq is not None else [30, 30, 30]) self.dilation_seq = (dilation_seq if dilation_seq is not None else [1, 3, 9]) self.kernel_size_seq = (kernel_size_seq if kernel_size_seq is not None else [7, 3, 3]) self.quantiles = (quantiles if (quantiles is not None) or (distr_output is not None) else [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) assert (len(self.channels_seq) == len(self.dilation_seq) == len( self.kernel_size_seq)), ( f"mismatch CNN configurations: {len(self.channels_seq)} vs. " f"{len(self.dilation_seq)} vs. {len(self.kernel_size_seq)}") if seed: np.random.seed(seed) mx.random.seed(seed, trainer.ctx) # `use_static_feat` and `use_dynamic_feat` always True because network # always receives input; either from the input data or constants encoder = HierarchicalCausalConv1DEncoder( dilation_seq=self.dilation_seq, kernel_size_seq=self.kernel_size_seq, channels_seq=self.channels_seq, use_residual=use_residual, use_static_feat=True, use_dynamic_feat=True, prefix="encoder_", ) decoder = ForkingMLPDecoder( dec_len=prediction_length, final_dim=self.decoder_mlp_dim_seq[-1], hidden_dimension_sequence=self.decoder_mlp_dim_seq[:-1], prefix="decoder_", ) quantile_output = (QuantileOutput(self.quantiles) if self.quantiles else None) super().__init__( encoder=encoder, decoder=decoder, quantile_output=quantile_output, distr_output=distr_output, freq=freq, prediction_length=prediction_length, context_length=context_length, use_past_feat_dynamic_real=use_past_feat_dynamic_real, use_feat_dynamic_real=use_feat_dynamic_real, use_feat_static_cat=use_feat_static_cat, enable_encoder_dynamic_feature=enable_encoder_dynamic_feature, enable_decoder_dynamic_feature=enable_decoder_dynamic_feature, cardinality=cardinality, embedding_dimension=embedding_dimension, add_time_feature=add_time_feature, add_age_feature=add_age_feature, trainer=trainer, scaling=scaling, scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature, num_forking=num_forking, )
def __init__( self, context_length: int, prediction_length: int, d_var: int, d_hidden: int, n_head: int, n_output: int, d_past_feat_dynamic_real: List[int], c_past_feat_dynamic_cat: List[int], d_feat_dynamic_real: List[int], c_feat_dynamic_cat: List[int], d_feat_static_real: List[int], c_feat_static_cat: List[int], dropout: float = 0.0, **kwargs, ): super(TemporalFusionTransformerNetwork, self).__init__(**kwargs) self.context_length = context_length self.prediction_length = prediction_length self.d_var = d_var self.d_hidden = d_hidden self.n_head = n_head self.n_output = n_output self.quantiles = sum( [[i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)], [0.5], ) self.normalize_eps = 1e-5 self.d_past_feat_dynamic_real = d_past_feat_dynamic_real self.c_past_feat_dynamic_cat = c_past_feat_dynamic_cat self.d_feat_dynamic_real = d_feat_dynamic_real self.c_feat_dynamic_cat = c_feat_dynamic_cat self.d_feat_static_real = d_feat_static_real self.c_feat_static_cat = c_feat_static_cat self.n_past_feat_dynamic = len(self.d_past_feat_dynamic_real) + len( self.c_past_feat_dynamic_cat) self.n_feat_dynamic = len(self.d_feat_dynamic_real) + len( self.c_feat_dynamic_cat) self.n_feat_static = len(self.d_feat_static_real) + len( self.c_feat_static_cat) with self.name_scope(): self.target_proj = nn.Dense( units=self.d_var, in_units=1, flatten=False, prefix=f"target_projection_", ) if self.d_past_feat_dynamic_real: self.past_feat_dynamic_proj = FeatureProjector( feature_dims=self.d_past_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_past_feat_dynamic_real), prefix="past_feat_dynamic_", ) else: self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: self.past_feat_dynamic_embed = FeatureEmbedder( cardinalities=self.c_past_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_past_feat_dynamic_cat), prefix="past_feat_dynamic_", ) else: self.past_feat_dynamic_embed = None if self.d_feat_dynamic_real: self.feat_dynamic_proj = FeatureProjector( feature_dims=self.d_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), prefix="feat_dynamic_", ) else: self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: self.feat_dynamic_embed = FeatureEmbedder( cardinalities=self.c_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), prefix="feat_dynamic_", ) else: self.feat_dynamic_embed = None if self.d_feat_static_real: self.feat_static_proj = FeatureProjector( feature_dims=self.d_feat_static_real, embedding_dims=[self.d_var] * len(self.d_feat_static_real), prefix="feat_static_", ) else: self.feat_static_proj = None if self.c_feat_static_cat: self.feat_static_embed = FeatureEmbedder( cardinalities=self.c_feat_static_cat, embedding_dims=[self.d_var] * len(self.c_feat_static_cat), prefix="feat_static_", ) else: self.feat_static_embed = None self.static_selector = VariableSelectionNetwork( d_hidden=self.d_var, n_vars=self.n_feat_static, dropout=dropout, ) self.ctx_selector = VariableSelectionNetwork( d_hidden=self.d_var, n_vars=self.n_past_feat_dynamic + self.n_feat_dynamic + 1, add_static=True, dropout=dropout, ) self.tgt_selector = VariableSelectionNetwork( d_hidden=self.d_var, n_vars=self.n_feat_dynamic, add_static=True, dropout=dropout, ) self.selection = GatedResidualNetwork( d_hidden=self.d_var, dropout=dropout, ) self.enrichment = GatedResidualNetwork( d_hidden=self.d_var, dropout=dropout, ) self.state_h = GatedResidualNetwork( d_hidden=self.d_var, d_output=self.d_hidden, dropout=dropout, ) self.state_c = GatedResidualNetwork( d_hidden=self.d_var, d_output=self.d_hidden, dropout=dropout, ) self.temporal_encoder = TemporalFusionEncoder( context_length=self.context_length, prediction_length=self.prediction_length, d_input=self.d_var, d_hidden=self.d_hidden, ) self.temporal_decoder = TemporalFusionDecoder( context_length=self.context_length, prediction_length=self.prediction_length, d_hidden=self.d_hidden, d_var=self.d_var, n_head=self.n_head, dropout=dropout, ) self.output = QuantileOutput(quantiles=self.quantiles) self.output_proj = self.output.get_quantile_proj() self.loss = self.output.get_loss()