def __init__(self, id: str, covariates: Sequence[str], process_variance: Union[bool, Collection[str]] = False, decay: Union[bool, Dict[str, Tuple[float, float]], Tuple[float, float]] = False, initial_state: Optional[torch.nn.Module] = None): """ :param id: A unique name for the process. :param covariates: The names of the predictors. :param process_variance: If False (the default), then the uncertainty about the values of the coefficients does not grow at each timestep, so over time these coefficients eventually converge to a certain value. If True, then the coefficients are allowed to 'drift' over time. Can only allow process-variance for a subset by passing names. :param decay: If True, then in forecasts (or for missing data) the coefficients will tend to shrink towards zero. Usually only used if `process_variance=True`. Default False. Instead of `True` you can specify custom- bounds for the decay-rate as a tuple. You can also pass a dictionary with predictors as keys and bounds as values. :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called, keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`. """ if isinstance(covariates, str): raise TypeError( "`covariates` should be sequence of strings, not single string" ) # process covariance: self._dynamic_state_elements = [] if process_variance: self._dynamic_state_elements = covariates if isinstance( process_variance, bool) else process_variance extras = set(self._dynamic_state_elements) - set(covariates) if len(extras): raise ValueError( f"`process_variance` includes items not in `covariates`:\n{extras}" ) super().__init__(id=id, state_elements=covariates, initial_state=initial_state) # decay: self.decays = {} if decay: if decay is True: self.decays = {se: Bounded(.95, 1.00) for se in covariates} elif isinstance(decay, dict): assert set(decay).issubset(covariates) for se, v in decay.items(): self.decays[se] = v if isinstance(v, Bounded) else Bounded( *v) else: self.decays = {se: Bounded(*decay) for se in covariates} for se in covariates: decay = self.decays.get(se) self._set_transition(from_element=se, to_element=se, value=decay.get_value if decay else 1.0)
def __init__(self, id: str, seasonal_period: int, season_duration: int = 1, decay: Union[bool, Tuple[float, float]] = False, season_start: Optional[str] = None, dt_unit: Optional[str] = None): """ Process representing discrete seasons. :param id: Unique name for this process :param seasonal_period: The number of seasons (e.g. 7 for day_in_week). :param season_duration: The length of each season, default 1 time-step. :param decay: Analogous to dampening a trend -- the state will revert to zero as we get further from the last observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future. :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. See DTTracker. :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported. """ # handle datetimes: self.dt_tracker = DTTracker(season_start=season_start, dt_unit=dt_unit, process_id=id) # self.seasonal_period = seasonal_period self.season_duration = season_duration # state-elements: self.measured_name = 'measured' pad_n = len(str(seasonal_period)) super().__init__( id=id, state_elements=[self.measured_name] + [str(i).rjust(pad_n, "0") for i in range(1, seasonal_period)]) # transitions are placeholders, filled in w/batch for i, current in enumerate(self.state_elements): self._set_transition(from_element=current, to_element=current, value=0.) if i > 0: prev = self.state_elements[i - 1] self._set_transition(from_element=prev, to_element=current, value=0.) if i > 1: self._set_transition(from_element=prev, to_element=self.measured_name, value=0.) if decay: assert not isinstance( decay, bool ), "decay should be floats of bounds (or False for no decay)" assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) else: self.decay = None
def __init__(self, id: str, input_dim: int, state_dim: int, nn: torch.nn.Module, process_variance: bool = False, decay: Union[bool, Tuple[float, float]] = False, time_split_kwargs: Sequence[str] = (), initial_state: Optional[torch.nn.Module] = None): """ :param id: A unique identifier for the process. :param input_dim: The number of inputs to the nn. :param state_dim: The number of outputs of the nn. :param nn: A torch.nn.Module that takes a (num_groups, input_dim) Tensor, and outputs a (num_groups, state_dim) Tensor. :param process_variance: If False (the default), then the uncertainty about the values of the states does not grow at each timestep, so over time these eventually converge to a certain value. If True, then the latent- states are allowed to 'drift' over time. :param decay: If True, then in forecasts (or for missing data) the state-values will tend to shrink towards zero. Usually only used if `process_variance=True`. Default False. Instead of `True` you can specify custom- bounds for the decay-rate as a tuple. :param time_split_kwargs: When calling the KalmanFilter, you will pass a prediction Tensor for your nn.Module that is (num_groups, num_timesteps, input_dim). However, internally, this will be split up into multiple tensors, and your nn.Module will take a (num_groups, input_dim) tensor. If your nn.Module's `forward()` method takes just a single argument, then we can infer how to split this tensor. But if it takes multiple keyword arguments, you need to specify which will be split in this fashion. :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called, keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`. """ self.input_dim = input_dim self.nn = nn if not hasattr(self.nn, '_forward_kwargs'): self.nn._forward_kwargs = infer_forward_kwargs(nn) if not hasattr(self.nn, '_time_split_kwargs'): assert set(time_split_kwargs).issubset(self.nn._forward_kwargs) self.nn._time_split_kwargs = time_split_kwargs # self._has_process_variance = process_variance # pad_n = len(str(state_dim)) super().__init__(id=id, state_elements=[zpad(i, pad_n) for i in range(state_dim)], initial_state=initial_state) # decay: self.decays = {} if decay: if decay is True: self.decays = {se: Bounded(.95, 1.00) for se in self.state_elements} else: self.decays = {se: Bounded(*decay) for se in self.state_elements} for se in self.state_elements: decay = self.decays.get(se) self._set_transition(from_element=se, to_element=se, value=decay.get_value if decay else 1.0)
def __init__(self, id: str, decay_velocity: Union[bool, Tuple[float, float]] = (.95, 1.00), decay_position: Union[bool, Tuple[float, float]] = False, multi: float = 1.0, initial_state: Optional[Module] = None): """ :param id: A unique identifier for this process. :param decay_velocity: If set, then the trend will decay to zero as we forecast out further. The default is to allow the trend to decay somewhere between .95 (moderate decay) and 1.00 (no decay), with the exact value being a learned parameter in the nn.Module. :param decay_position: See `decay` in `LocalLevel`. :param multi: A multiplier on the trend, so that `next_position = position + multi * trend`. Reducing this to .1 can be helpful since the trend has such a large effect on the prediction, so that large values can lead to exploding gradients. :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called, keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`. """ super().__init__(id=id, state_elements=['position', 'velocity'], initial_state=initial_state) self.decayed_transitions = {} # does position regress towards zero? if decay_position: assert decay_position[0] > 0. and decay_position[1] <= 1. self.decayed_transitions['position'] = Bounded(*decay_position) self._set_transition( from_element='position', to_element='position', value=self.decayed_transitions['position'].get_value) else: self._set_transition(from_element='position', to_element='position', value=1.0) # does velocity regress towards zero? if decay_velocity: assert decay_velocity[0] > 0. and decay_velocity[1] <= 1. self.decayed_transitions['velocity'] = Bounded(*decay_velocity) self._set_transition( from_element='velocity', to_element='velocity', value=self.decayed_transitions['velocity'].get_value) else: self._set_transition(from_element='velocity', to_element='velocity', value=1.0) # setting a low arbitrary multiplier on velocity's impact can sometimes be helpful for training in practice: self._set_transition(from_element='velocity', to_element='position', value=multi)
def __init__(self, id: str, seasonal_period: int, season_duration: int = 1, decay: Union[bool, Tuple[float, float]] = False, dt_unit: Optional[str] = None, fixed: bool = False): """ :param id: Unique name for this process :param seasonal_period: The number of seasons (e.g. 7 for day_in_week). :param season_duration: The length of each season, default 1 time-step. :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend -- the state will revert to zero as we get further from the last observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future. :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported. :param fixed: If True, then the seasonality does not vary over time, and this amounts to one-hot-encoding the seasons. Default False. """ # self.seasonal_period = seasonal_period self.season_duration = season_duration self.fixed = fixed if dt_unit is None: # optional for some seasonal processes, but not for this one raise TypeError(f"Must pass `dt_unit` to {type(self).__name__}") self._dt_helper = DateTimeHelper(dt_unit=dt_unit) # state-elements: pad_n = len(str(seasonal_period)) super().__init__( id=id, state_elements=[self.measured_name] + [zpad(i, pad_n) for i in range(1, seasonal_period)] ) # transitions are placeholders, filled in w/batch for i, current in enumerate(self.state_elements): self._set_transition(from_element=current, to_element=current, value=0.) if i > 0: prev = self.state_elements[i - 1] self._set_transition(from_element=prev, to_element=current, value=0.) if i > 1: self._set_transition(from_element=prev, to_element=self.measured_name, value=0.) if decay: assert not isinstance(decay, bool), "decay should be floats of bounds (or False for no decay)" assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) else: self.decay = None
def __init__(self, id: str, seasonal_period: Union[int, float], K: Union[int, float], decay: Union[bool, Tuple[float, float]] = False, season_start: Optional[str] = None, dt_unit: Optional[str] = None): # season structure: self.seasonal_period = seasonal_period if isinstance(K, float): assert K.is_integer() self.K = int(K) self.decay = None if decay: assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) state_elements, list_of_trans_kwargs = self._setup(decay=decay) super().__init__(id=id, state_elements=state_elements) self._dt_helper = DateTimeHelper(dt_unit=dt_unit, start_datetime=season_start) for trans_kwargs in list_of_trans_kwargs: self._set_transition(**trans_kwargs)
def __init__(self, id: str, seasonal_period: Union[int, float], K: Union[int, float], decay: Union[bool, Tuple[float, float]] = False, season_start: Union[str, None, bool] = None, dt_unit: Optional[str] = None): # handle datetimes: self.dt_tracker = DTTracker(season_start=season_start, dt_unit=dt_unit, process_id=id) # season structure: self.seasonal_period = seasonal_period if isinstance(K, float): assert K.is_integer() self.K = int(K) self.decay: Optional[Bounded] = None if decay: assert not isinstance( decay, bool ), "decay should be floats of bounds (or False for no decay)" assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) state_elements, list_of_trans_kwargs = self._setup(decay=decay) super().__init__(id=id, state_elements=state_elements) for trans_kwargs in list_of_trans_kwargs: self._set_transition(**trans_kwargs)
def __init__(self, id: str, decay: Union[bool, Tuple[float, float]] = False, initial_state: Optional[torch.nn.Module] = None): """ :param id: A unique identifier for this process. :param decay: If the process has decay, then the random walk will tend towards zero as we forecast out further (note that this means you should center your time-series, or you should include another process that does not have this property). Decay can be between 0 and 1, but values < .50 (or even .90) can often be too rapid and you will run into trouble with vanishing gradients. When passing a pair of floats, the nn.Module will assign a parameter representing the decay as a learned parameter somewhere between these bounds. TODO: support {process}__decay__{kwarg} :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called, keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`. """ super().__init__(id=id, state_elements=['position'], initial_state=initial_state) self.decay = None if decay: assert decay[0] >= -1. and decay[1] <= 1. self.decay = Bounded(*decay) self._set_transition( from_element='position', to_element='position', value=self.decay.get_value ) else: self._set_transition(from_element='position', to_element='position', value=1.)
def __init__(self, id: str, decay_velocity: Union[bool, Tuple[float, float]] = (.95, 1.00), decay_position: Union[bool, Tuple[float, float]] = False): super().__init__(id=id, state_elements=['position', 'velocity']) self._set_transition(from_element='velocity', to_element='position', value=1.0) self.decayed_transitions = {} if decay_position: assert not isinstance( decay_position, bool ), "decay_position should be floats of bounds (or False for no decay)" assert decay_position[0] > 0. and decay_position[1] <= 1. self.decayed_transitions['position'] = Bounded(*decay_position) self._set_transition( from_element='position', to_element='position', value=self.decayed_transitions['position'].get_value, inv_link=False) else: self._set_transition(from_element='position', to_element='position', value=1.0) if decay_velocity: assert not isinstance( decay_velocity, bool ), "decay_velocity should be floats of bounds (or False for no decay)" assert decay_velocity[0] > 0. and decay_velocity[1] <= 1. self.decayed_transitions['velocity'] = Bounded(*decay_velocity) self._set_transition( from_element='velocity', to_element='velocity', value=self.decayed_transitions['velocity'].get_value, inv_link=False) else: self._set_transition(from_element='velocity', to_element='velocity', value=1.0)
def __init__(self, id: str, decay: Union[bool, Tuple[float, float]] = False): super().__init__(id=id, state_elements=['position']) self.decay: Optional[Bounded] = None if decay: assert not isinstance( decay, bool ), "decay should be floats of bounds (or False for no decay)" assert decay[0] >= -1. and decay[1] <= 1. self.decay = Bounded(*decay) self._set_transition(from_element='position', to_element='position', value=self.decay.get_value, inv_link=False) else: self._set_transition(from_element='position', to_element='position', value=1.)
def __init__(self, id: str, seasonal_period: Union[int, float], K: Union[int, float], decay: Union[bool, Tuple[float, float]] = False, season_start: Optional[str] = None, dt_unit: Optional[str] = None): """ :param id: Unique name for this instance. :param seasonal_period: The seasonal period (e.g. 24 for daily season in hourly data, 365.25 for yearly season in daily data) :param K: The "K" parameter of the fourier series, see `fourier_tensor`. :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend -- the state will revert to zero as we get further from the last observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future. :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. This is when the season starts, which is useful to specify if season boundaries are meaningful. It is important to specify if different groups in your dataset start on different dates; when calling the kalman-filter you'll pass an array of `start_datetimes` for group in the input, and this will be used to align the seasons for each group. :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported. """ # season structure: self.seasonal_period = seasonal_period if isinstance(K, float): assert K.is_integer() self.K = int(K) self.decay = None if decay: assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) state_elements, list_of_trans_kwargs = self._setup(decay=decay) super().__init__(id=id, state_elements=state_elements, season_start=season_start, dt_unit=dt_unit) for trans_kwargs in list_of_trans_kwargs: self._set_transition(**trans_kwargs)
def __init__(self, id: str, decay: Union[bool, Tuple[float, float]] = False): """ :param id: A unique identifier for this process. :param decay: If the process has decay, then the random walk will tend towards zero as we forecast out further (note that this means you should center your time-series, or you should include another process that does not have this property). Decay can be between 0 and 1, but values < .50 (or even .90) can often be too rapid and you will run into trouble with vanishing gradients. When passing a pair of floats, the nn.Module will assign a parameter representing the decay as a learned parameter somewhere between these bounds. """ super().__init__(id=id, state_elements=['position']) self.decay = None if decay: assert decay[0] >= -1. and decay[1] <= 1. self.decay = Bounded(*decay) self._set_transition( from_element='position', to_element='position', value=self.decay.get_value ) else: self._set_transition(from_element='position', to_element='position', value=1.)
class Season(Process): def __init__(self, id: str, seasonal_period: int, season_duration: int = 1, decay: Union[bool, Tuple[float, float]] = False, season_start: Optional[str] = None, dt_unit: Optional[str] = None): """ Process representing discrete seasons. :param id: Unique name for this process :param seasonal_period: The number of seasons (e.g. 7 for day_in_week). :param season_duration: The length of each season, default 1 time-step. :param decay: Analogous to dampening a trend -- the state will revert to zero as we get further from the last observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future. :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. See DTTracker. :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported. """ # handle datetimes: self.dt_tracker = DTTracker(season_start=season_start, dt_unit=dt_unit, process_id=id) # self.seasonal_period = seasonal_period self.season_duration = season_duration # state-elements: self.measured_name = 'measured' pad_n = len(str(seasonal_period)) super().__init__( id=id, state_elements=[self.measured_name] + [str(i).rjust(pad_n, "0") for i in range(1, seasonal_period)]) # transitions are placeholders, filled in w/batch for i, current in enumerate(self.state_elements): self._set_transition(from_element=current, to_element=current, value=0.) if i > 0: prev = self.state_elements[i - 1] self._set_transition(from_element=prev, to_element=current, value=0.) if i > 1: self._set_transition(from_element=prev, to_element=self.measured_name, value=0.) if decay: assert not isinstance( decay, bool ), "decay should be floats of bounds (or False for no decay)" assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) else: self.decay = None def add_measure(self, measure: str): self._set_measure(measure=measure, state_element='measured', value=1.0) def parameters(self) -> Generator[Parameter, None, None]: if self.decay is not None: yield self.decay.parameter @property def dynamic_state_elements(self) -> Sequence[str]: return [self.measured_name] def for_batch( self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None) -> ProcessForBatch: for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) delta = self.dt_tracker.get_delta(for_batch.num_groups, for_batch.num_timesteps, start_datetimes=start_datetimes) in_transition = (delta % self.season_duration) == (self.season_duration - 1) transitions = dict() transitions['to_next_state'] = torch.from_numpy( in_transition.astype('float32')) transitions['to_self'] = 1 - transitions['to_next_state'] transitions['to_measured'] = -transitions['to_next_state'] transitions['from_measured_to_measured'] = torch.from_numpy( np.where(in_transition, -1., 1.).astype('float32')) for k in transitions.keys(): transitions[k] = split_flat(transitions[k], dim=1, clone=True) if self.decay is not None: decay_value = self.decay.get_value() transitions[k] = [x * decay_value for x in transitions[k]] # this is convoluted, but the idea is to manipulate the transitions so that we use one less degree of freedom than # the number of seasons, by having the 'measured' state be equal to -sum(all others) for i in range(1, len(self.state_elements)): current = self.state_elements[i] prev = self.state_elements[i - 1] if prev == self.measured_name: # measured requires special-case to_measured = transitions['from_measured_to_measured'] else: to_measured = transitions['to_measured'] for_batch.adjust_transition( from_element=prev, to_element=current, adjustment=transitions['to_next_state']) for_batch.adjust_transition(from_element=prev, to_element=self.measured_name, adjustment=to_measured) # from state to itself: for_batch.adjust_transition(from_element=current, to_element=current, adjustment=transitions['to_self']) return for_batch def initial_state_means_for_batch( self, parameters: Parameter, num_groups: int, start_datetimes: Optional[np.ndarray] = None) -> Tensor: delta = self.dt_tracker.get_delta( num_groups, 1, start_datetimes=start_datetimes).squeeze(1) season_shift = (np.floor(delta / self.season_duration) % self.seasonal_period).astype('int') means = [ torch.cat([parameters[-shift:], parameters[:-shift]]) for shift in season_shift ] return torch.stack(means, 0)
def __init__(self, id: str, seasonal_period: int, season_duration: int = 1, decay: Union[bool, Tuple[float, float]] = False, season_start: Optional[str] = None, dt_unit: Optional[str] = None, fixed: bool = False): """ Process representing discrete seasons. :param id: Unique name for this process :param seasonal_period: The number of seasons (e.g. 7 for day_in_week). :param season_duration: The length of each season, default 1 time-step. :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend -- the state will revert to zero as we get further from the last observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future. :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. This is when the season starts, which is useful to specify if season boundaries are meaningful. It is important to specify if different groups in your dataset start on different dates; when calling the kalman-filter you'll pass an array of `start_datetimes` for group in the input, and this will be used to align the seasons for each group. :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported. :param fixed: If True, then the seasonality does not vary over time, and this amounts to one-hot-encoding the seasons. Default False. """ # self.seasonal_period = seasonal_period self.season_duration = season_duration self.fixed = fixed # state-elements: pad_n = len(str(seasonal_period)) super().__init__(id=id, state_elements=[self.measured_name] + [zpad(i, pad_n) for i in range(1, seasonal_period)], season_start=season_start, dt_unit=dt_unit) # transitions are placeholders, filled in w/batch for i, current in enumerate(self.state_elements): self._set_transition(from_element=current, to_element=current, value=0.) if i > 0: prev = self.state_elements[i - 1] self._set_transition(from_element=prev, to_element=current, value=0.) if i > 1: self._set_transition(from_element=prev, to_element=self.measured_name, value=0.) if decay: assert not isinstance( decay, bool ), "decay should be floats of bounds (or False for no decay)" assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) else: self.decay = None
class Season(DatetimeProcess, Process): measured_name = 'measured' def __init__(self, id: str, seasonal_period: int, season_duration: int = 1, decay: Union[bool, Tuple[float, float]] = False, season_start: Optional[str] = None, dt_unit: Optional[str] = None, fixed: bool = False): """ Process representing discrete seasons. :param id: Unique name for this process :param seasonal_period: The number of seasons (e.g. 7 for day_in_week). :param season_duration: The length of each season, default 1 time-step. :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend -- the state will revert to zero as we get further from the last observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future. :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. This is when the season starts, which is useful to specify if season boundaries are meaningful. It is important to specify if different groups in your dataset start on different dates; when calling the kalman-filter you'll pass an array of `start_datetimes` for group in the input, and this will be used to align the seasons for each group. :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported. :param fixed: If True, then the seasonality does not vary over time, and this amounts to one-hot-encoding the seasons. Default False. """ # self.seasonal_period = seasonal_period self.season_duration = season_duration self.fixed = fixed # state-elements: pad_n = len(str(seasonal_period)) super().__init__(id=id, state_elements=[self.measured_name] + [zpad(i, pad_n) for i in range(1, seasonal_period)], season_start=season_start, dt_unit=dt_unit) # transitions are placeholders, filled in w/batch for i, current in enumerate(self.state_elements): self._set_transition(from_element=current, to_element=current, value=0.) if i > 0: prev = self.state_elements[i - 1] self._set_transition(from_element=prev, to_element=current, value=0.) if i > 1: self._set_transition(from_element=prev, to_element=self.measured_name, value=0.) if decay: assert not isinstance( decay, bool ), "decay should be floats of bounds (or False for no decay)" assert decay[0] > 0. and decay[1] <= 1.0 self.decay = Bounded(*decay) else: self.decay = None def add_measure(self, measure: str) -> 'Season': self._set_measure(measure=measure, state_element='measured', value=1.0) return self def param_dict(self) -> ParameterDict: p = ParameterDict() if self.decay is not None: p['decay'] = self.decay.parameter return p @property def dynamic_state_elements(self) -> Sequence[str]: return [] if self.fixed else [self.measured_name] def for_batch(self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None): if start_datetimes is not None: if len(start_datetimes.shape) != 1 or len( start_datetimes) != num_groups: raise ValueError( f"Expected `start_datetimes` to be 1D array of length {num_groups}." ) for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) delta = self._get_delta(num_groups, num_timesteps, start_datetimes=start_datetimes) in_transition = (delta % self.season_duration) == (self.season_duration - 1) transitions = { 'to_next_state': torch.from_numpy(in_transition.astype('float32')), 'from_measured_to_measured': torch.from_numpy( np.where(in_transition, -1., 1.).astype('float32')) } transitions['to_self'] = 1 - transitions['to_next_state'] transitions['to_measured'] = -transitions['to_next_state'] for k in transitions.keys(): transitions[k] = split_flat(transitions[k], dim=1, clone=True) if self.decay is not None: decay_value = self.decay.get_value() transitions[k] = [x * decay_value for x in transitions[k]] # this is convoluted, but the idea is to manipulate the transitions so that we use one less degree of freedom # than the number of seasons, by having the 'measured' state be equal to -sum(all others) for i in range(1, len(self.state_elements)): current = self.state_elements[i] prev = self.state_elements[i - 1] if prev == self.measured_name: # measured requires special-case to_measured = transitions['from_measured_to_measured'] else: to_measured = transitions['to_measured'] for_batch._adjust_transition( from_element=prev, to_element=current, adjustment=transitions['to_next_state']) for_batch._adjust_transition(from_element=prev, to_element=self.measured_name, adjustment=to_measured) # from state to itself: for_batch._adjust_transition(from_element=current, to_element=current, adjustment=transitions['to_self']) return for_batch def initial_state_means_for_batch( self, parameters: Parameter, num_groups: int, start_datetimes: Optional[np.ndarray] = None) -> Tensor: delta = self._get_delta(num_groups, 1, start_datetimes=start_datetimes).squeeze(1) season_shift = (np.floor(delta / self.season_duration) % self.seasonal_period).astype('int') means = [ torch.cat([parameters[-shift:], parameters[:-shift]]) for shift in season_shift ] return torch.stack(means, 0)