示例#1
0
class TemporalFusionTransformerNetwork(HybridBlock):
    @validated()
    def __init__(
        self,
        context_length: int,
        prediction_length: int,
        d_var: int,
        d_hidden: int,
        n_head: int,
        n_output: int,
        d_past_feat_dynamic_real: List[int],
        c_past_feat_dynamic_cat: List[int],
        d_feat_dynamic_real: List[int],
        c_feat_dynamic_cat: List[int],
        d_feat_static_real: List[int],
        c_feat_static_cat: List[int],
        dropout: float = 0.0,
        **kwargs,
    ):
        super(TemporalFusionTransformerNetwork, self).__init__(**kwargs)
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.d_var = d_var
        self.d_hidden = d_hidden
        self.n_head = n_head
        self.n_output = n_output
        self.quantiles = sum(
            [[i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)],
            [0.5],
        )
        self.normalize_eps = 1e-5

        self.d_past_feat_dynamic_real = d_past_feat_dynamic_real
        self.c_past_feat_dynamic_cat = c_past_feat_dynamic_cat
        self.d_feat_dynamic_real = d_feat_dynamic_real
        self.c_feat_dynamic_cat = c_feat_dynamic_cat
        self.d_feat_static_real = d_feat_static_real
        self.c_feat_static_cat = c_feat_static_cat
        self.n_past_feat_dynamic = len(self.d_past_feat_dynamic_real) + len(
            self.c_past_feat_dynamic_cat
        )
        self.n_feat_dynamic = len(self.d_feat_dynamic_real) + len(
            self.c_feat_dynamic_cat
        )
        self.n_feat_static = len(self.d_feat_static_real) + len(
            self.c_feat_static_cat
        )

        with self.name_scope():
            self.target_proj = nn.Dense(
                units=self.d_var,
                in_units=1,
                flatten=False,
                prefix=f"target_projection_",
            )
            if self.d_past_feat_dynamic_real:
                self.past_feat_dynamic_proj = FeatureProjector(
                    feature_dims=self.d_past_feat_dynamic_real,
                    embedding_dims=[self.d_var]
                    * len(self.d_past_feat_dynamic_real),
                    prefix="past_feat_dynamic_",
                )
            else:
                self.past_feat_dynamic_proj = None

            if self.c_past_feat_dynamic_cat:
                self.past_feat_dynamic_embed = FeatureEmbedder(
                    cardinalities=self.c_past_feat_dynamic_cat,
                    embedding_dims=[self.d_var]
                    * len(self.c_past_feat_dynamic_cat),
                    prefix="past_feat_dynamic_",
                )
            else:
                self.past_feat_dynamic_embed = None

            if self.d_feat_dynamic_real:
                self.feat_dynamic_proj = FeatureProjector(
                    feature_dims=self.d_feat_dynamic_real,
                    embedding_dims=[self.d_var]
                    * len(self.d_feat_dynamic_real),
                    prefix="feat_dynamic_",
                )
            else:
                self.feat_dynamic_proj = None

            if self.c_feat_dynamic_cat:
                self.feat_dynamic_embed = FeatureEmbedder(
                    cardinalities=self.c_feat_dynamic_cat,
                    embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat),
                    prefix="feat_dynamic_",
                )
            else:
                self.feat_dynamic_embed = None

            if self.d_feat_static_real:
                self.feat_static_proj = FeatureProjector(
                    feature_dims=self.d_feat_static_real,
                    embedding_dims=[self.d_var] * len(self.d_feat_static_real),
                    prefix="feat_static_",
                )
            else:
                self.feat_static_proj = None

            if self.c_feat_static_cat:
                self.feat_static_embed = FeatureEmbedder(
                    cardinalities=self.c_feat_static_cat,
                    embedding_dims=[self.d_var] * len(self.c_feat_static_cat),
                    prefix="feat_static_",
                )
            else:
                self.feat_static_embed = None

            self.static_selector = VariableSelectionNetwork(
                d_hidden=self.d_var,
                n_vars=self.n_feat_static,
                dropout=dropout,
            )
            self.ctx_selector = VariableSelectionNetwork(
                d_hidden=self.d_var,
                n_vars=self.n_past_feat_dynamic + self.n_feat_dynamic + 1,
                add_static=True,
                dropout=dropout,
            )
            self.tgt_selector = VariableSelectionNetwork(
                d_hidden=self.d_var,
                n_vars=self.n_feat_dynamic,
                add_static=True,
                dropout=dropout,
            )
            self.selection = GatedResidualNetwork(
                d_hidden=self.d_var,
                dropout=dropout,
            )
            self.enrichment = GatedResidualNetwork(
                d_hidden=self.d_var,
                dropout=dropout,
            )
            self.state_h = GatedResidualNetwork(
                d_hidden=self.d_var,
                d_output=self.d_hidden,
                dropout=dropout,
            )
            self.state_c = GatedResidualNetwork(
                d_hidden=self.d_var,
                d_output=self.d_hidden,
                dropout=dropout,
            )
            self.temporal_encoder = TemporalFusionEncoder(
                context_length=self.context_length,
                prediction_length=self.prediction_length,
                d_input=self.d_var,
                d_hidden=self.d_hidden,
            )
            self.temporal_decoder = TemporalFusionDecoder(
                context_length=self.context_length,
                prediction_length=self.prediction_length,
                d_hidden=self.d_hidden,
                d_var=self.d_var,
                n_head=self.n_head,
                dropout=dropout,
            )
            self.output = QuantileOutput(quantiles=self.quantiles)
            self.output_proj = self.output.get_quantile_proj()
            self.loss = self.output.get_loss()

    def _preprocess(
        self,
        F,
        past_target: Tensor,
        past_observed_values: Tensor,
        past_feat_dynamic_real: Tensor,
        past_feat_dynamic_cat: Tensor,
        feat_dynamic_real: Tensor,
        feat_dynamic_cat: Tensor,
        feat_static_real: Tensor,
        feat_static_cat: Tensor,
    ):
        obs = F.broadcast_mul(past_target, past_observed_values)
        count = F.sum(past_observed_values, axis=1, keepdims=True)
        offset = F.broadcast_div(
            F.sum(obs, axis=1, keepdims=True),
            count + self.normalize_eps,
        )
        scale = F.broadcast_div(
            F.sum(obs ** 2, axis=1, keepdims=True),
            count + self.normalize_eps,
        )
        scale = F.broadcast_sub(scale, offset ** 2)
        scale = F.sqrt(scale)
        past_target = F.broadcast_div(
            F.broadcast_sub(past_target, offset),
            scale + self.normalize_eps,
        )
        past_target = F.expand_dims(past_target, axis=-1)

        past_covariates = []
        future_covariates = []
        static_covariates: List[Tensor] = []
        proj = self.target_proj(past_target)
        past_covariates.append(proj)
        if self.past_feat_dynamic_proj is not None:
            projs = self.past_feat_dynamic_proj(past_feat_dynamic_real)
            past_covariates.extend(projs)
        if self.past_feat_dynamic_embed is not None:
            embs = self.past_feat_dynamic_embed(past_feat_dynamic_cat)
            past_covariates.extend(embs)
        if self.feat_dynamic_proj is not None:
            projs = self.feat_dynamic_proj(feat_dynamic_real)
            for proj in projs:
                ctx_proj = F.slice_axis(
                    proj, axis=1, begin=0, end=self.context_length
                )
                tgt_proj = F.slice_axis(
                    proj, axis=1, begin=self.context_length, end=None
                )
                past_covariates.append(ctx_proj)
                future_covariates.append(tgt_proj)
        if self.feat_dynamic_embed is not None:
            embs = self.feat_dynamic_embed(feat_dynamic_cat)
            for emb in embs:
                ctx_emb = F.slice_axis(
                    emb, axis=1, begin=0, end=self.context_length
                )
                tgt_emb = F.slice_axis(
                    emb, axis=1, begin=self.context_length, end=None
                )
                past_covariates.append(ctx_emb)
                future_covariates.append(tgt_emb)

        if self.feat_static_proj is not None:
            projs = self.feat_static_proj(feat_static_real)
            static_covariates.extend(projs)
        if self.feat_static_embed is not None:
            embs = self.feat_static_embed(feat_static_cat)
            static_covariates.extend(embs)

        return (
            past_covariates,
            future_covariates,
            static_covariates,
            offset,
            scale,
        )

    def _postprocess(
        self,
        F,
        preds: Tensor,
        offset: Tensor,
        scale: Tensor,
    ) -> Tensor:
        offset = F.expand_dims(offset, axis=-1)
        scale = F.expand_dims(scale, axis=-1)
        preds = F.broadcast_add(
            F.broadcast_mul(preds, (scale + self.normalize_eps)),
            offset,
        )
        return preds

    def _forward(
        self,
        F,
        past_observed_values: Tensor,
        past_covariates: Tensor,
        future_covariates: Tensor,
        static_covariates: Tensor,
    ):
        static_var, _ = self.static_selector(static_covariates)
        c_selection = self.selection(static_var).expand_dims(axis=1)
        c_enrichment = self.enrichment(static_var).expand_dims(axis=1)
        c_h = self.state_h(static_var)
        c_c = self.state_c(static_var)

        ctx_input, _ = self.ctx_selector(past_covariates, c_selection)
        tgt_input, _ = self.tgt_selector(future_covariates, c_selection)

        encoding = self.temporal_encoder(ctx_input, tgt_input, [c_h, c_c])
        decoding = self.temporal_decoder(
            encoding, c_enrichment, past_observed_values
        )
        preds = self.output_proj(decoding)

        return preds
示例#2
0
    def __init__(
        self,
        context_length: int,
        prediction_length: int,
        d_hidden: int,
        m_ffn: int,
        n_head: int,
        n_layers: int,
        n_output: int,
        cardinalities: List[int],
        kernel_sizes: Optional[List[int]],
        dist_enc: Optional[str],
        pre_ln: bool,
        dropout: float,
        temperature: float,
        normalizer_eps: float = 1e-5,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if kernel_sizes is None or len(kernel_sizes) == 0:
            self.kernel_sizes = (1, )
        else:
            self.kernel_sizes = kernel_sizes
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.d_hidden = d_hidden
        assert (n_output % 2 == 1) and (n_output <= 9)
        self.quantiles = sum(
            ([i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)),
            [0.5],
        )
        self.normalizer_eps = normalizer_eps

        with self.name_scope():
            self._blocks = []
            for layer in range(n_layers):
                block = SelfAttentionBlock(
                    d_hidden=self.d_hidden,
                    m_ffn=m_ffn,
                    kernel_sizes=self.kernel_sizes,
                    n_head=n_head,
                    dist_enc=dist_enc,
                    pre_ln=pre_ln,
                    dropout=dropout,
                    temperature=temperature,
                )
                self.register_child(block=block, name=f"block_{layer+1}")
                self._blocks.append(block)

            self.target_proj = nn.Dense(
                units=self.d_hidden,
                in_units=1,
                use_bias=True,
                flatten=False,
                weight_initializer=init.Xavier(),
                prefix="target_proj_",
            )
            self.covar_proj = nn.Dense(
                units=self.d_hidden,
                use_bias=True,
                flatten=False,
                weight_initializer=init.Xavier(),
                prefix="covar_proj_",
            )
            if cardinalities:
                self.embedder = FeatureEmbedder(
                    cardinalities=cardinalities,
                    embedding_dims=[self.d_hidden] * len(cardinalities),
                    prefix="embedder_",
                )
            self.output = QuantileOutput(quantiles=self.quantiles)
            self.output_proj = self.output.get_quantile_proj()
            self.loss = self.output.get_loss()
示例#3
0
class SelfAttentionNetwork(HybridBlock):
    @validated()
    def __init__(
        self,
        context_length: int,
        prediction_length: int,
        d_hidden: int,
        m_ffn: int,
        n_head: int,
        n_layers: int,
        n_output: int,
        cardinalities: List[int],
        kernel_sizes: Optional[List[int]],
        dist_enc: Optional[str],
        pre_ln: bool,
        dropout: float,
        temperature: float,
        normalizer_eps: float = 1e-5,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if kernel_sizes is None or len(kernel_sizes) == 0:
            self.kernel_sizes = (1, )
        else:
            self.kernel_sizes = kernel_sizes
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.d_hidden = d_hidden
        assert (n_output % 2 == 1) and (n_output <= 9)
        self.quantiles = sum(
            ([i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)),
            [0.5],
        )
        self.normalizer_eps = normalizer_eps

        with self.name_scope():
            self._blocks = []
            for layer in range(n_layers):
                block = SelfAttentionBlock(
                    d_hidden=self.d_hidden,
                    m_ffn=m_ffn,
                    kernel_sizes=self.kernel_sizes,
                    n_head=n_head,
                    dist_enc=dist_enc,
                    pre_ln=pre_ln,
                    dropout=dropout,
                    temperature=temperature,
                )
                self.register_child(block=block, name=f"block_{layer+1}")
                self._blocks.append(block)

            self.target_proj = nn.Dense(
                units=self.d_hidden,
                in_units=1,
                use_bias=True,
                flatten=False,
                weight_initializer=init.Xavier(),
                prefix="target_proj_",
            )
            self.covar_proj = nn.Dense(
                units=self.d_hidden,
                use_bias=True,
                flatten=False,
                weight_initializer=init.Xavier(),
                prefix="covar_proj_",
            )
            if cardinalities:
                self.embedder = FeatureEmbedder(
                    cardinalities=cardinalities,
                    embedding_dims=[self.d_hidden] * len(cardinalities),
                    prefix="embedder_",
                )
            self.output = QuantileOutput(quantiles=self.quantiles)
            self.output_proj = self.output.get_quantile_proj()
            self.loss = self.output.get_loss()

    def _preprocess(
        self,
        F,
        past_target: Tensor,
        past_observed_values: Tensor,
        past_is_pad: Tensor,
        past_feat_dynamic_real: Tensor,
        past_feat_dynamic_cat: Tensor,
        future_target: Tensor,
        future_feat_dynamic_real: Tensor,
        future_feat_dynamic_cat: Tensor,
        feat_static_real: Tensor,
        feat_static_cat: Tensor,
    ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Optional[Tensor],
               Tensor, Tensor, ]:
        obs = past_target * past_observed_values
        count = F.sum(past_observed_values, axis=1, keepdims=True)
        offset = F.sum(obs, axis=1,
                       keepdims=True) / (count + self.normalizer_eps)
        scale = F.sum(obs**2, axis=1,
                      keepdims=True) / (count + self.normalizer_eps)
        scale = scale - offset**2
        scale = scale.sqrt()

        past_target = (past_target - offset) / (scale + self.normalizer_eps)
        if future_target is not None:
            future_target = (future_target - offset) / (scale +
                                                        self.normalizer_eps)

        def _assemble_covariates(
            feat_dynamic_real: Tensor,
            feat_dynamic_cat: Tensor,
            feat_static_real: Tensor,
            feat_static_cat: Tensor,
            is_past: bool,
        ) -> Tensor:
            covariates = []
            if feat_dynamic_real.shape[-1] > 0:
                covariates.append(feat_dynamic_real)
            if feat_static_real.shape[-1] > 0:
                covariates.append(
                    feat_static_real.expand_dims(axis=1).repeat(
                        axis=1,
                        repeats=self.context_length
                        if is_past else self.prediction_length,
                    ))
            if len(covariates) > 0:
                covariates = F.concat(*covariates, dim=-1)
                covariates = self.covar_proj(covariates)
            else:
                covariates = None

            categories = []
            if feat_dynamic_cat.shape[-1] > 0:
                categories.append(feat_dynamic_cat)
            if feat_static_cat.shape[-1] > 0:
                categories.append(
                    feat_static_cat.expand_dims(axis=1).repeat(
                        axis=1,
                        repeats=self.context_length
                        if is_past else self.prediction_length,
                    ))
            if len(categories) > 0:
                categories = F.concat(*categories, dim=-1)
                embeddings = self.embedder(categories)
                embeddings = F.reshape(embeddings,
                                       shape=(0, 0, -4, self.d_hidden,
                                              -1)).sum(axis=-1)
                if covariates is not None:
                    covariates = covariates + embeddings
                else:
                    covariates = embeddings
            else:
                pass

            return covariates

        past_covariates = _assemble_covariates(
            past_feat_dynamic_real,
            past_feat_dynamic_cat,
            feat_static_real,
            feat_static_cat,
            is_past=True,
        )
        future_covariates = _assemble_covariates(
            future_feat_dynamic_real,
            future_feat_dynamic_cat,
            feat_static_real,
            feat_static_cat,
            is_past=False,
        )
        past_observed_values = F.broadcast_logical_and(
            past_observed_values,
            F.logical_not(past_is_pad),
        )

        return (
            past_target,
            past_covariates,
            past_observed_values,
            future_target,
            future_covariates,
            offset,
            scale,
        )

    def _postprocess(self, F, preds: Tensor, offset: Tensor,
                     scale: Tensor) -> Tensor:
        offset = F.expand_dims(offset, axis=-1)
        scale = F.expand_dims(scale, axis=-1)
        preds = preds * (scale + self.normalizer_eps) + offset
        return preds

    def _forward_step(
        self,
        F,
        horizon: int,
        target: Tensor,
        covars: Optional[Tensor],
        mask: Tensor,
    ) -> Tensor:
        target = F.expand_dims(target, axis=-1)
        mask = F.expand_dims(mask, axis=-1)
        value = self.target_proj(target)
        if covars is not None:
            value = value + covars
        for block in self._blocks:
            value = block(value, mask)
        value = F.slice_axis(value, axis=1, begin=-horizon, end=None)
        preds = self.output_proj(value)
        return preds
示例#4
0
    def __init__(
        self,
        prediction_length: int,
        freq: str,
        context_length: Optional[int] = None,
        decoder_mlp_dim_seq: List[int] = None,
        trainer: Trainer = Trainer(),
        quantiles: Optional[List[float]] = None,
        distr_output: Optional[DistributionOutput] = None,
        scaling: bool = False,
        scaling_decoder_dynamic_feature: bool = False,
    ) -> None:

        assert (prediction_length >
                0), f"Invalid prediction length: {prediction_length}."
        assert decoder_mlp_dim_seq is None or all(
            d > 0 for d in decoder_mlp_dim_seq
        ), "Elements of `mlp_hidden_dimension_seq` should be > 0"
        assert quantiles is None or all(
            0 <= d <= 1 for d in
            quantiles), "Elements of `quantiles` should be >= 0 and <= 1"

        self.decoder_mlp_dim_seq = (decoder_mlp_dim_seq if decoder_mlp_dim_seq
                                    is not None else [30])
        self.quantiles = (quantiles if
                          (quantiles is not None) or (distr_output is not None)
                          else [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])

        # `use_static_feat` and `use_dynamic_feat` always True because network
        # always receives input; either from the input data or constants
        encoder = RNNEncoder(
            mode="gru",
            hidden_size=50,
            num_layers=1,
            bidirectional=True,
            prefix="encoder_",
            use_static_feat=True,
            use_dynamic_feat=True,
        )

        decoder = ForkingMLPDecoder(
            dec_len=prediction_length,
            final_dim=self.decoder_mlp_dim_seq[-1],
            hidden_dimension_sequence=self.decoder_mlp_dim_seq[:-1],
            prefix="decoder_",
        )

        quantile_output = (QuantileOutput(self.quantiles)
                           if self.quantiles else None)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            quantile_output=quantile_output,
            distr_output=distr_output,
            freq=freq,
            prediction_length=prediction_length,
            context_length=context_length,
            trainer=trainer,
            scaling=scaling,
            scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature,
        )
示例#5
0
    def __init__(
        self,
        freq: str,
        prediction_length: int,
        context_length: Optional[int] = None,
        use_past_feat_dynamic_real: bool = False,
        use_feat_dynamic_real: bool = False,
        use_feat_static_cat: bool = False,
        cardinality: List[int] = None,
        embedding_dimension: List[int] = None,
        add_time_feature: bool = True,
        add_age_feature: bool = False,
        enable_encoder_dynamic_feature: bool = True,
        enable_decoder_dynamic_feature: bool = True,
        seed: Optional[int] = None,
        decoder_mlp_dim_seq: Optional[List[int]] = None,
        channels_seq: Optional[List[int]] = None,
        dilation_seq: Optional[List[int]] = None,
        kernel_size_seq: Optional[List[int]] = None,
        use_residual: bool = True,
        quantiles: Optional[List[float]] = None,
        distr_output: Optional[DistributionOutput] = None,
        trainer: Trainer = Trainer(),
        scaling: bool = False,
        scaling_decoder_dynamic_feature: bool = False,
        num_forking: Optional[int] = None,
    ) -> None:

        assert (distr_output is None) or (quantiles is None)
        assert (prediction_length >
                0), f"Invalid prediction length: {prediction_length}."
        assert decoder_mlp_dim_seq is None or all(
            d > 0 for d in decoder_mlp_dim_seq
        ), "Elements of `mlp_hidden_dimension_seq` should be > 0"
        assert channels_seq is None or all(
            d > 0
            for d in channels_seq), "Elements of `channels_seq` should be > 0"
        assert dilation_seq is None or all(
            d > 0
            for d in dilation_seq), "Elements of `dilation_seq` should be > 0"
        # TODO: add support for kernel size=1
        assert kernel_size_seq is None or all(
            d > 1 for d in
            kernel_size_seq), "Elements of `kernel_size_seq` should be > 0"
        assert quantiles is None or all(
            0 <= d <= 1 for d in
            quantiles), "Elements of `quantiles` should be >= 0 and <= 1"

        self.decoder_mlp_dim_seq = (decoder_mlp_dim_seq if decoder_mlp_dim_seq
                                    is not None else [30])
        self.channels_seq = (channels_seq
                             if channels_seq is not None else [30, 30, 30])
        self.dilation_seq = (dilation_seq
                             if dilation_seq is not None else [1, 3, 9])
        self.kernel_size_seq = (kernel_size_seq
                                if kernel_size_seq is not None else [7, 3, 3])
        self.quantiles = (quantiles if
                          (quantiles is not None) or (distr_output is not None)
                          else [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])

        assert (len(self.channels_seq) == len(self.dilation_seq) == len(
            self.kernel_size_seq)), (
                f"mismatch CNN configurations: {len(self.channels_seq)} vs. "
                f"{len(self.dilation_seq)} vs. {len(self.kernel_size_seq)}")

        if seed:
            np.random.seed(seed)
            mx.random.seed(seed, trainer.ctx)

        # `use_static_feat` and `use_dynamic_feat` always True because network
        # always receives input; either from the input data or constants
        encoder = HierarchicalCausalConv1DEncoder(
            dilation_seq=self.dilation_seq,
            kernel_size_seq=self.kernel_size_seq,
            channels_seq=self.channels_seq,
            use_residual=use_residual,
            use_static_feat=True,
            use_dynamic_feat=True,
            prefix="encoder_",
        )

        decoder = ForkingMLPDecoder(
            dec_len=prediction_length,
            final_dim=self.decoder_mlp_dim_seq[-1],
            hidden_dimension_sequence=self.decoder_mlp_dim_seq[:-1],
            prefix="decoder_",
        )

        quantile_output = (QuantileOutput(self.quantiles)
                           if self.quantiles else None)

        super().__init__(
            encoder=encoder,
            decoder=decoder,
            quantile_output=quantile_output,
            distr_output=distr_output,
            freq=freq,
            prediction_length=prediction_length,
            context_length=context_length,
            use_past_feat_dynamic_real=use_past_feat_dynamic_real,
            use_feat_dynamic_real=use_feat_dynamic_real,
            use_feat_static_cat=use_feat_static_cat,
            enable_encoder_dynamic_feature=enable_encoder_dynamic_feature,
            enable_decoder_dynamic_feature=enable_decoder_dynamic_feature,
            cardinality=cardinality,
            embedding_dimension=embedding_dimension,
            add_time_feature=add_time_feature,
            add_age_feature=add_age_feature,
            trainer=trainer,
            scaling=scaling,
            scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature,
            num_forking=num_forking,
        )
示例#6
0
    def __init__(
        self,
        context_length: int,
        prediction_length: int,
        d_var: int,
        d_hidden: int,
        n_head: int,
        n_output: int,
        d_past_feat_dynamic_real: List[int],
        c_past_feat_dynamic_cat: List[int],
        d_feat_dynamic_real: List[int],
        c_feat_dynamic_cat: List[int],
        d_feat_static_real: List[int],
        c_feat_static_cat: List[int],
        dropout: float = 0.0,
        **kwargs,
    ):
        super(TemporalFusionTransformerNetwork, self).__init__(**kwargs)
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.d_var = d_var
        self.d_hidden = d_hidden
        self.n_head = n_head
        self.n_output = n_output
        self.quantiles = sum(
            [[i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)],
            [0.5],
        )
        self.normalize_eps = 1e-5

        self.d_past_feat_dynamic_real = d_past_feat_dynamic_real
        self.c_past_feat_dynamic_cat = c_past_feat_dynamic_cat
        self.d_feat_dynamic_real = d_feat_dynamic_real
        self.c_feat_dynamic_cat = c_feat_dynamic_cat
        self.d_feat_static_real = d_feat_static_real
        self.c_feat_static_cat = c_feat_static_cat
        self.n_past_feat_dynamic = len(self.d_past_feat_dynamic_real) + len(
            self.c_past_feat_dynamic_cat)
        self.n_feat_dynamic = len(self.d_feat_dynamic_real) + len(
            self.c_feat_dynamic_cat)
        self.n_feat_static = len(self.d_feat_static_real) + len(
            self.c_feat_static_cat)

        with self.name_scope():
            self.target_proj = nn.Dense(
                units=self.d_var,
                in_units=1,
                flatten=False,
                prefix=f"target_projection_",
            )
            if self.d_past_feat_dynamic_real:
                self.past_feat_dynamic_proj = FeatureProjector(
                    feature_dims=self.d_past_feat_dynamic_real,
                    embedding_dims=[self.d_var] *
                    len(self.d_past_feat_dynamic_real),
                    prefix="past_feat_dynamic_",
                )
            else:
                self.past_feat_dynamic_proj = None

            if self.c_past_feat_dynamic_cat:
                self.past_feat_dynamic_embed = FeatureEmbedder(
                    cardinalities=self.c_past_feat_dynamic_cat,
                    embedding_dims=[self.d_var] *
                    len(self.c_past_feat_dynamic_cat),
                    prefix="past_feat_dynamic_",
                )
            else:
                self.past_feat_dynamic_embed = None

            if self.d_feat_dynamic_real:
                self.feat_dynamic_proj = FeatureProjector(
                    feature_dims=self.d_feat_dynamic_real,
                    embedding_dims=[self.d_var] *
                    len(self.d_feat_dynamic_real),
                    prefix="feat_dynamic_",
                )
            else:
                self.feat_dynamic_proj = None

            if self.c_feat_dynamic_cat:
                self.feat_dynamic_embed = FeatureEmbedder(
                    cardinalities=self.c_feat_dynamic_cat,
                    embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat),
                    prefix="feat_dynamic_",
                )
            else:
                self.feat_dynamic_embed = None

            if self.d_feat_static_real:
                self.feat_static_proj = FeatureProjector(
                    feature_dims=self.d_feat_static_real,
                    embedding_dims=[self.d_var] * len(self.d_feat_static_real),
                    prefix="feat_static_",
                )
            else:
                self.feat_static_proj = None

            if self.c_feat_static_cat:
                self.feat_static_embed = FeatureEmbedder(
                    cardinalities=self.c_feat_static_cat,
                    embedding_dims=[self.d_var] * len(self.c_feat_static_cat),
                    prefix="feat_static_",
                )
            else:
                self.feat_static_embed = None

            self.static_selector = VariableSelectionNetwork(
                d_hidden=self.d_var,
                n_vars=self.n_feat_static,
                dropout=dropout,
            )
            self.ctx_selector = VariableSelectionNetwork(
                d_hidden=self.d_var,
                n_vars=self.n_past_feat_dynamic + self.n_feat_dynamic + 1,
                add_static=True,
                dropout=dropout,
            )
            self.tgt_selector = VariableSelectionNetwork(
                d_hidden=self.d_var,
                n_vars=self.n_feat_dynamic,
                add_static=True,
                dropout=dropout,
            )
            self.selection = GatedResidualNetwork(
                d_hidden=self.d_var,
                dropout=dropout,
            )
            self.enrichment = GatedResidualNetwork(
                d_hidden=self.d_var,
                dropout=dropout,
            )
            self.state_h = GatedResidualNetwork(
                d_hidden=self.d_var,
                d_output=self.d_hidden,
                dropout=dropout,
            )
            self.state_c = GatedResidualNetwork(
                d_hidden=self.d_var,
                d_output=self.d_hidden,
                dropout=dropout,
            )
            self.temporal_encoder = TemporalFusionEncoder(
                context_length=self.context_length,
                prediction_length=self.prediction_length,
                d_input=self.d_var,
                d_hidden=self.d_hidden,
            )
            self.temporal_decoder = TemporalFusionDecoder(
                context_length=self.context_length,
                prediction_length=self.prediction_length,
                d_hidden=self.d_hidden,
                d_var=self.d_var,
                n_head=self.n_head,
                dropout=dropout,
            )
            self.output = QuantileOutput(quantiles=self.quantiles)
            self.output_proj = self.output.get_quantile_proj()
            self.loss = self.output.get_loss()