Пример #1
    def _standardize_var_nn(self,
                            var_nn: Union[torch.nn.Module, Sequence],
                            var_type: str,
                            top_level: bool = False) -> torch.nn.Module:

        if top_level:
            if isinstance(var_nn, torch.nn.ModuleList):
                return var_nn

            if callable(var_nn):
                # they passed a single NN instead of a list, wrap it:
                var_nn = [var_nn]
            elif len(var_nn) > 0 and isinstance(var_nn[0], str):
                # they passed a single alias instead of a list, wrap it:
                var_nn = [var_nn]

            return torch.nn.ModuleList([
                self._standardize_var_nn(sub_nn, var_type) for sub_nn in var_nn
            if callable(var_nn):
                out_nn = var_nn
            elif isinstance(var_nn, (tuple, list)):
                alias, args_or_kwargs = var_nn
                num_outputs = len(self.measures if var_type ==
                                  'measure' else self.dynamic_state_elements)
                if alias == 'per_group' and isinstance(args_or_kwargs, int):
                    args_or_kwargs = (args_or_kwargs, )
                if isinstance(args_or_kwargs, dict):
                    args, kwargs = (), args_or_kwargs
                    args, kwargs = args_or_kwargs, {}

                if alias == 'per_group':
                    if 'embedding_dim' not in kwargs:
                        kwargs['embedding_dim'] = num_outputs
                    out_nn = NamedEmbedding(*args, **kwargs)
                    out_nn._forward_kwargs_aliases = {'input': 'group_names'}
                elif alias == 'seasonal':
                    out_nn = FourierSeasonNN(*args,
                    out_nn._time_split_kwargs = ['datetimes']
                    raise ValueError(
                        f"Known aliases are 'per_group' and 'seasonal'; got '{alias}'"
                raise TypeError(
                    f"Expected `{var_type}_var_nn` to be a callable/torch.nn.Module, or a tuple with format "
                    f"`('alias',(arg1,arg2,...)`. Instead got `{type(var_nn)}`."
            if not hasattr(out_nn, '_forward_kwargs'):
                out_nn._forward_kwargs = infer_forward_kwargs(out_nn)
            if not hasattr(out_nn, '_forward_kwargs_aliases'):
                out_nn._forward_kwargs_aliases = {}
            return out_nn
Пример #2
    def __init__(self,
                 id: str,
                 input_dim: int,
                 state_dim: int,
                 nn: torch.nn.Module,
                 process_variance: bool = False,
                 decay: Union[bool, Tuple[float, float]] = False,
                 time_split_kwargs: Sequence[str] = (),
                 initial_state: Optional[torch.nn.Module] = None):
        :param id: A unique identifier for the process.
        :param input_dim: The number of inputs to the nn.
        :param state_dim: The number of outputs of the nn.
        :param nn: A torch.nn.Module that takes a (num_groups, input_dim) Tensor, and outputs a (num_groups, state_dim)
        :param process_variance: If False (the default), then the uncertainty about the values of the states does not
        grow at each timestep, so over time these eventually converge to a certain value. If True, then the latent-
        states are allowed to 'drift' over time.
        :param decay: If True, then in forecasts (or for missing data) the state-values will tend to shrink towards
        zero. Usually only used if `process_variance=True`. Default False. Instead of `True` you can specify custom-
        bounds for the decay-rate as a tuple.
        :param time_split_kwargs: When calling the KalmanFilter, you will pass a prediction Tensor for your nn.Module
        that is (num_groups, num_timesteps, input_dim). However, internally, this will be split up into multiple
        tensors, and your nn.Module will take a (num_groups, input_dim) tensor. If your nn.Module's `forward()` method
        takes just a single argument, then we can infer how to split this tensor. But if it takes multiple keyword
        arguments, you need to specify which will be split in this fashion.
        :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called,
        keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`.

        self.input_dim = input_dim
        self.nn = nn
        if not hasattr(self.nn, '_forward_kwargs'):
            self.nn._forward_kwargs = infer_forward_kwargs(nn)
        if not hasattr(self.nn, '_time_split_kwargs'):
            assert set(time_split_kwargs).issubset(self.nn._forward_kwargs)
            self.nn._time_split_kwargs = time_split_kwargs

        self._has_process_variance = process_variance

        pad_n = len(str(state_dim))
        super().__init__(id=id, state_elements=[zpad(i, pad_n) for i in range(state_dim)], initial_state=initial_state)

        # decay:
        self.decays = {}
        if decay:
            if decay is True:
                self.decays = {se: Bounded(.95, 1.00) for se in self.state_elements}
                self.decays = {se: Bounded(*decay) for se in self.state_elements}

        for se in self.state_elements:
            decay = self.decays.get(se)
            self._set_transition(from_element=se, to_element=se, value=decay.get_value if decay else 1.0)
Пример #3
    def __init__(self,
                 id: str,
                 input_dim: int,
                 state_dim: int,
                 nn: torch.nn.Module,
                 init_variance: bool = True,
                 process_variance: bool = False,
                 add_module_params_to_process: bool = True,
                 inv_link: Optional[Callable] = None,
                 time_split_kwargs: Sequence[str] = ()):
        :param id: A unique identifier for the process.
        :param input_dim: The number of inputs to the nn.
        :param state_dim: The number of outputs of the nn.
        :param nn: A torch.nn.Module that takes a (num_groups, input_dim) Tensor, and outputs a (num_groups, state_dim)
        :param init_variance: If True (the default), then there is initial uncertainty about the values of the states.
        :param process_variance: If False (the default), then the uncertainty about the values of the states does not
        grow at each timestep, so over time these eventually converge to a certain value. If True, then the latent-
        states are allowed to 'drift' over time.
        :param add_module_params_to_process: If `False`, then you need to pass your nn.Module's `.parameters()` to the
        optimizer manually. This can be useful if you are using parameter-groups in your optimizer (e.g. for different
        learning rates).
        :param inv_link: An inverse link function that maps the linear-model to the prediction; default the identity.
        :param time_split_kwargs: When calling the KalmanFilter, you will pass a prediction Tensor for your nn.Module
        that is (num_groups, num_timesteps, input_dim). However, internally, this will be split up into multiple
        tensors, and your nn.Module will take a (num_groups, input_dim) tensor. If your nn.Module's `forward()` method
        takes just a single argument, then we can infer how to split this tensor. But if it takes multiple keyword
        arguments, you need to specify which will be split in this fashion.
        self.inv_link = inv_link

        self.add_module_params_to_process = add_module_params_to_process
        self.input_dim = input_dim
        self.nn = nn
        if not hasattr(self.nn, '_forward_kwargs'):
            self.nn._forward_kwargs = infer_forward_kwargs(nn)
        if not hasattr(self.nn, '_time_split_kwargs'):
            assert set(time_split_kwargs).issubset(self.nn._forward_kwargs)
            self.nn._time_split_kwargs = time_split_kwargs

        self._has_process_variance = process_variance
        self._has_init_variance = init_variance

        pad_n = len(str(state_dim))
        super().__init__(id=id, state_elements=[zpad(i, pad_n) for i in range(state_dim)])

        for se in self.state_elements:
            self._set_transition(from_element=se, to_element=se, value=1.0)
Пример #4
 def init_mean_kwargs(self) -> Iterable[str]:
     if not hasattr(self.initial_state, '_forward_kwargs'):
         self.initial_state._forward_kwargs = infer_forward_kwargs(
     return self.initial_state._forward_kwargs