def _standardize_var_nn(self, var_nn: Union[torch.nn.Module, Sequence], var_type: str, top_level: bool = False) -> torch.nn.Module: if top_level: if isinstance(var_nn, torch.nn.ModuleList): return var_nn if callable(var_nn): # they passed a single NN instead of a list, wrap it: var_nn = [var_nn] elif len(var_nn) > 0 and isinstance(var_nn[0], str): # they passed a single alias instead of a list, wrap it: var_nn = [var_nn] return torch.nn.ModuleList([ self._standardize_var_nn(sub_nn, var_type) for sub_nn in var_nn ]) else: if callable(var_nn): out_nn = var_nn elif isinstance(var_nn, (tuple, list)): alias, args_or_kwargs = var_nn num_outputs = len(self.measures if var_type == 'measure' else self.dynamic_state_elements) if alias == 'per_group' and isinstance(args_or_kwargs, int): args_or_kwargs = (args_or_kwargs, ) if isinstance(args_or_kwargs, dict): args, kwargs = (), args_or_kwargs else: args, kwargs = args_or_kwargs, {} if alias == 'per_group': if 'embedding_dim' not in kwargs: kwargs['embedding_dim'] = num_outputs out_nn = NamedEmbedding(*args, **kwargs) out_nn._forward_kwargs_aliases = {'input': 'group_names'} elif alias == 'seasonal': out_nn = FourierSeasonNN(*args, **kwargs, num_outputs=num_outputs) out_nn._time_split_kwargs = ['datetimes'] else: raise ValueError( f"Known aliases are 'per_group' and 'seasonal'; got '{alias}'" ) else: raise TypeError( f"Expected `{var_type}_var_nn` to be a callable/torch.nn.Module, or a tuple with format " f"`('alias',(arg1,arg2,...)`. Instead got `{type(var_nn)}`." ) if not hasattr(out_nn, '_forward_kwargs'): out_nn._forward_kwargs = infer_forward_kwargs(out_nn) if not hasattr(out_nn, '_forward_kwargs_aliases'): out_nn._forward_kwargs_aliases = {} return out_nn
def __init__(self, id: str, input_dim: int, state_dim: int, nn: torch.nn.Module, process_variance: bool = False, decay: Union[bool, Tuple[float, float]] = False, time_split_kwargs: Sequence[str] = (), initial_state: Optional[torch.nn.Module] = None): """ :param id: A unique identifier for the process. :param input_dim: The number of inputs to the nn. :param state_dim: The number of outputs of the nn. :param nn: A torch.nn.Module that takes a (num_groups, input_dim) Tensor, and outputs a (num_groups, state_dim) Tensor. :param process_variance: If False (the default), then the uncertainty about the values of the states does not grow at each timestep, so over time these eventually converge to a certain value. If True, then the latent- states are allowed to 'drift' over time. :param decay: If True, then in forecasts (or for missing data) the state-values will tend to shrink towards zero. Usually only used if `process_variance=True`. Default False. Instead of `True` you can specify custom- bounds for the decay-rate as a tuple. :param time_split_kwargs: When calling the KalmanFilter, you will pass a prediction Tensor for your nn.Module that is (num_groups, num_timesteps, input_dim). However, internally, this will be split up into multiple tensors, and your nn.Module will take a (num_groups, input_dim) tensor. If your nn.Module's `forward()` method takes just a single argument, then we can infer how to split this tensor. But if it takes multiple keyword arguments, you need to specify which will be split in this fashion. :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called, keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`. """ self.input_dim = input_dim self.nn = nn if not hasattr(self.nn, '_forward_kwargs'): self.nn._forward_kwargs = infer_forward_kwargs(nn) if not hasattr(self.nn, '_time_split_kwargs'): assert set(time_split_kwargs).issubset(self.nn._forward_kwargs) self.nn._time_split_kwargs = time_split_kwargs # self._has_process_variance = process_variance # pad_n = len(str(state_dim)) super().__init__(id=id, state_elements=[zpad(i, pad_n) for i in range(state_dim)], initial_state=initial_state) # decay: self.decays = {} if decay: if decay is True: self.decays = {se: Bounded(.95, 1.00) for se in self.state_elements} else: self.decays = {se: Bounded(*decay) for se in self.state_elements} for se in self.state_elements: decay = self.decays.get(se) self._set_transition(from_element=se, to_element=se, value=decay.get_value if decay else 1.0)
def __init__(self, id: str, input_dim: int, state_dim: int, nn: torch.nn.Module, init_variance: bool = True, process_variance: bool = False, add_module_params_to_process: bool = True, inv_link: Optional[Callable] = None, time_split_kwargs: Sequence[str] = ()): """ :param id: A unique identifier for the process. :param input_dim: The number of inputs to the nn. :param state_dim: The number of outputs of the nn. :param nn: A torch.nn.Module that takes a (num_groups, input_dim) Tensor, and outputs a (num_groups, state_dim) Tensor. :param init_variance: If True (the default), then there is initial uncertainty about the values of the states. :param process_variance: If False (the default), then the uncertainty about the values of the states does not grow at each timestep, so over time these eventually converge to a certain value. If True, then the latent- states are allowed to 'drift' over time. :param add_module_params_to_process: If `False`, then you need to pass your nn.Module's `.parameters()` to the optimizer manually. This can be useful if you are using parameter-groups in your optimizer (e.g. for different learning rates). :param inv_link: An inverse link function that maps the linear-model to the prediction; default the identity. :param time_split_kwargs: When calling the KalmanFilter, you will pass a prediction Tensor for your nn.Module that is (num_groups, num_timesteps, input_dim). However, internally, this will be split up into multiple tensors, and your nn.Module will take a (num_groups, input_dim) tensor. If your nn.Module's `forward()` method takes just a single argument, then we can infer how to split this tensor. But if it takes multiple keyword arguments, you need to specify which will be split in this fashion. """ self.inv_link = inv_link self.add_module_params_to_process = add_module_params_to_process self.input_dim = input_dim self.nn = nn if not hasattr(self.nn, '_forward_kwargs'): self.nn._forward_kwargs = infer_forward_kwargs(nn) if not hasattr(self.nn, '_time_split_kwargs'): assert set(time_split_kwargs).issubset(self.nn._forward_kwargs) self.nn._time_split_kwargs = time_split_kwargs # self._has_process_variance = process_variance self._has_init_variance = init_variance # pad_n = len(str(state_dim)) super().__init__(id=id, state_elements=[zpad(i, pad_n) for i in range(state_dim)]) for se in self.state_elements: self._set_transition(from_element=se, to_element=se, value=1.0)
def init_mean_kwargs(self) -> Iterable[str]: if not hasattr(self.initial_state, '_forward_kwargs'): self.initial_state._forward_kwargs = infer_forward_kwargs( self.initial_state) return self.initial_state._forward_kwargs