def _index_tensor(x: Tensor, item: Any) -> Tensor:
    """
    """
    squeeze: List[int] = []
    if not isinstance(item, tuple):
        item = (item,)

    saw_ellipsis = False

    for i, item_i in enumerate(item):
        axis = i - len(item) if saw_ellipsis else i
        if isinstance(item_i, int):
            if item_i != -1:
                x = x.slice_axis(axis=axis, begin=item_i, end=item_i + 1)
            else:
                x = x.slice_axis(axis=axis, begin=-1, end=None)
            squeeze.append(axis)
        elif item_i == slice(None):
            continue
        elif item_i == Ellipsis:
            saw_ellipsis = True
            continue
        elif isinstance(item_i, slice):
            assert item_i.step is None
            start = item_i.start if item_i.start is not None else 0
            x = x.slice_axis(axis=axis, begin=start, end=item_i.stop)
        else:
            raise RuntimeError(f"invalid indexing item: {item}")
    if len(squeeze):
        x = x.squeeze(axis=tuple(squeeze))
    return x
Example #2
0
    def get_gp_params(
        self,
        F,
        past_target: Tensor,
        past_time_feat: Tensor,
        feat_static_cat: Tensor,
    ) -> Tuple:
        """
        This function returns the GP hyper-parameters for the model.

        Parameters
        ----------
        F : ModuleType
            A module that can either refer to the Symbol API or the NDArray
            API in MXNet.
        past_target : Tensor
            Training time series values of shape (batch_size, context_length).
        past_time_feat : Tensor
            Training features of shape (batch_size, context_length, num_features).
        feat_static_cat : Tensor
            Time series indices of shape (batch_size, 1).

        Returns
        -------
        Tuple
            Tuple of kernel hyper-parameters of length num_hyperparams.
                Each is a Tensor of shape (batch_size, 1, 1).
            Model noise sigma.
                Tensor of shape (batch_size, 1, 1).
        """
        output = self.embedding(feat_static_cat.squeeze()
                                )  # Shape (batch_size, num_hyperparams + 1)
        kernel_args = self.proj_kernel_args(output)
        sigma = softplus(
            F,
            output.slice_axis(  # sigma is the last hyper-parameter
                axis=1,
                begin=self.num_hyperparams,
                end=self.num_hyperparams + 1,
            ),
        )
        if self.params_scaling:
            scalings = self.kernel_output.gp_params_scaling(
                F, past_target, past_time_feat)
            sigma = F.broadcast_mul(sigma, scalings[self.num_hyperparams])
            kernel_args = (F.broadcast_mul(kernel_arg, scaling)
                           for kernel_arg, scaling in zip(
                               kernel_args, scalings[0:self.num_hyperparams]))
        min_value = 1e-5
        max_value = 1e8
        kernel_args = (kernel_arg.clip(min_value,
                                       max_value).expand_dims(axis=2)
                       for kernel_arg in kernel_args)
        sigma = sigma.clip(min_value, max_value).expand_dims(axis=2)
        return kernel_args, sigma