Beispiel #1
0
    def __init__(self,
                 id: str,
                 covariates: Sequence[str],
                 process_variance: Union[bool, Collection[str]] = False,
                 decay: Union[bool, Dict[str, Tuple[float, float]],
                              Tuple[float, float]] = False,
                 initial_state: Optional[torch.nn.Module] = None):
        """

        :param id: A unique name for the process.
        :param covariates: The names of the predictors.
        :param process_variance: If False (the default), then the uncertainty about the values of the coefficients does
        not grow at each timestep, so over time these coefficients eventually converge to a certain value. If True,
        then the coefficients are allowed to 'drift' over time. Can only allow process-variance for a subset by passing
        names.
        :param decay: If True, then in forecasts (or for missing data) the coefficients will tend to shrink towards
        zero. Usually only used if `process_variance=True`. Default False. Instead of `True` you can specify custom-
        bounds for the decay-rate as a tuple. You can also pass a dictionary with predictors as keys and bounds as
        values.
        :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called,
        keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`.
        """
        if isinstance(covariates, str):
            raise TypeError(
                "`covariates` should be sequence of strings, not single string"
            )

        # process covariance:
        self._dynamic_state_elements = []
        if process_variance:
            self._dynamic_state_elements = covariates if isinstance(
                process_variance, bool) else process_variance
            extras = set(self._dynamic_state_elements) - set(covariates)
            if len(extras):
                raise ValueError(
                    f"`process_variance` includes items not in `covariates`:\n{extras}"
                )

        super().__init__(id=id,
                         state_elements=covariates,
                         initial_state=initial_state)

        # decay:
        self.decays = {}
        if decay:
            if decay is True:
                self.decays = {se: Bounded(.95, 1.00) for se in covariates}
            elif isinstance(decay, dict):
                assert set(decay).issubset(covariates)
                for se, v in decay.items():
                    self.decays[se] = v if isinstance(v, Bounded) else Bounded(
                        *v)
            else:
                self.decays = {se: Bounded(*decay) for se in covariates}

        for se in covariates:
            decay = self.decays.get(se)
            self._set_transition(from_element=se,
                                 to_element=se,
                                 value=decay.get_value if decay else 1.0)
    def __init__(self,
                 id: str,
                 seasonal_period: int,
                 season_duration: int = 1,
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Optional[str] = None,
                 dt_unit: Optional[str] = None):
        """
        Process representing discrete seasons.

        :param id: Unique name for this process
        :param seasonal_period: The number of seasons (e.g. 7 for day_in_week).
        :param season_duration: The length of each season, default 1 time-step.
        :param decay: Analogous to dampening a trend -- the state will revert to zero as we get further from the last
        observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible,
        but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future.
        :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. See DTTracker.
        :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported.
        """

        # handle datetimes:
        self.dt_tracker = DTTracker(season_start=season_start,
                                    dt_unit=dt_unit,
                                    process_id=id)

        #
        self.seasonal_period = seasonal_period
        self.season_duration = season_duration

        # state-elements:
        self.measured_name = 'measured'
        pad_n = len(str(seasonal_period))
        super().__init__(
            id=id,
            state_elements=[self.measured_name] +
            [str(i).rjust(pad_n, "0") for i in range(1, seasonal_period)])

        # transitions are placeholders, filled in w/batch
        for i, current in enumerate(self.state_elements):
            self._set_transition(from_element=current,
                                 to_element=current,
                                 value=0.)
            if i > 0:
                prev = self.state_elements[i - 1]
                self._set_transition(from_element=prev,
                                     to_element=current,
                                     value=0.)
                if i > 1:
                    self._set_transition(from_element=prev,
                                         to_element=self.measured_name,
                                         value=0.)

        if decay:
            assert not isinstance(
                decay, bool
            ), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)
        else:
            self.decay = None
Beispiel #3
0
    def __init__(self,
                 id: str,
                 input_dim: int,
                 state_dim: int,
                 nn: torch.nn.Module,
                 process_variance: bool = False,
                 decay: Union[bool, Tuple[float, float]] = False,
                 time_split_kwargs: Sequence[str] = (),
                 initial_state: Optional[torch.nn.Module] = None):
        """
        :param id: A unique identifier for the process.
        :param input_dim: The number of inputs to the nn.
        :param state_dim: The number of outputs of the nn.
        :param nn: A torch.nn.Module that takes a (num_groups, input_dim) Tensor, and outputs a (num_groups, state_dim)
        Tensor.
        :param process_variance: If False (the default), then the uncertainty about the values of the states does not
        grow at each timestep, so over time these eventually converge to a certain value. If True, then the latent-
        states are allowed to 'drift' over time.
        :param decay: If True, then in forecasts (or for missing data) the state-values will tend to shrink towards
        zero. Usually only used if `process_variance=True`. Default False. Instead of `True` you can specify custom-
        bounds for the decay-rate as a tuple.
        :param time_split_kwargs: When calling the KalmanFilter, you will pass a prediction Tensor for your nn.Module
        that is (num_groups, num_timesteps, input_dim). However, internally, this will be split up into multiple
        tensors, and your nn.Module will take a (num_groups, input_dim) tensor. If your nn.Module's `forward()` method
        takes just a single argument, then we can infer how to split this tensor. But if it takes multiple keyword
        arguments, you need to specify which will be split in this fashion.
        :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called,
        keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`.
        """

        self.input_dim = input_dim
        self.nn = nn
        if not hasattr(self.nn, '_forward_kwargs'):
            self.nn._forward_kwargs = infer_forward_kwargs(nn)
        if not hasattr(self.nn, '_time_split_kwargs'):
            assert set(time_split_kwargs).issubset(self.nn._forward_kwargs)
            self.nn._time_split_kwargs = time_split_kwargs

        #
        self._has_process_variance = process_variance

        #
        pad_n = len(str(state_dim))
        super().__init__(id=id, state_elements=[zpad(i, pad_n) for i in range(state_dim)], initial_state=initial_state)

        # decay:
        self.decays = {}
        if decay:
            if decay is True:
                self.decays = {se: Bounded(.95, 1.00) for se in self.state_elements}
            else:
                self.decays = {se: Bounded(*decay) for se in self.state_elements}

        for se in self.state_elements:
            decay = self.decays.get(se)
            self._set_transition(from_element=se, to_element=se, value=decay.get_value if decay else 1.0)
Beispiel #4
0
    def __init__(self,
                 id: str,
                 decay_velocity: Union[bool, Tuple[float,
                                                   float]] = (.95, 1.00),
                 decay_position: Union[bool, Tuple[float, float]] = False,
                 multi: float = 1.0,
                 initial_state: Optional[Module] = None):
        """
        :param id: A unique identifier for this process.
        :param decay_velocity: If set, then the trend will decay to zero as we forecast out further. The default is
        to allow the trend to decay somewhere between .95 (moderate decay) and 1.00 (no decay), with the exact value
         being a learned parameter in the nn.Module.
        :param decay_position: See `decay` in `LocalLevel`.
        :param multi: A multiplier on the trend, so that `next_position = position + multi * trend`. Reducing this
        to .1 can be helpful since the trend has such a large effect on the prediction, so that large values can
        lead to exploding gradients.
        :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called,
        keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`.
        """
        super().__init__(id=id,
                         state_elements=['position', 'velocity'],
                         initial_state=initial_state)

        self.decayed_transitions = {}

        # does position regress towards zero?
        if decay_position:
            assert decay_position[0] > 0. and decay_position[1] <= 1.
            self.decayed_transitions['position'] = Bounded(*decay_position)
            self._set_transition(
                from_element='position',
                to_element='position',
                value=self.decayed_transitions['position'].get_value)
        else:
            self._set_transition(from_element='position',
                                 to_element='position',
                                 value=1.0)

        # does velocity regress towards zero?
        if decay_velocity:
            assert decay_velocity[0] > 0. and decay_velocity[1] <= 1.
            self.decayed_transitions['velocity'] = Bounded(*decay_velocity)
            self._set_transition(
                from_element='velocity',
                to_element='velocity',
                value=self.decayed_transitions['velocity'].get_value)
        else:
            self._set_transition(from_element='velocity',
                                 to_element='velocity',
                                 value=1.0)

        # setting a low arbitrary multiplier on velocity's impact can sometimes be helpful for training in practice:
        self._set_transition(from_element='velocity',
                             to_element='position',
                             value=multi)
Beispiel #5
0
    def __init__(self,
                 id: str,
                 seasonal_period: int,
                 season_duration: int = 1,
                 decay: Union[bool, Tuple[float, float]] = False,
                 dt_unit: Optional[str] = None,
                 fixed: bool = False):
        """
        :param id: Unique name for this process
        :param seasonal_period: The number of seasons (e.g. 7 for day_in_week).
        :param season_duration: The length of each season, default 1 time-step.
        :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend --
        the state will revert to zero as we get further from the last observation. This can be useful if two processes
        are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to
        zero, while the other is less variable but extrapolates into the future.
        :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported.
        :param fixed: If True, then the seasonality does not vary over time, and this amounts to one-hot-encoding the
        seasons. Default False.
        """

        #
        self.seasonal_period = seasonal_period
        self.season_duration = season_duration
        self.fixed = fixed

        if dt_unit is None:
            # optional for some seasonal processes, but not for this one
            raise TypeError(f"Must pass `dt_unit` to {type(self).__name__}")
        self._dt_helper = DateTimeHelper(dt_unit=dt_unit)

        # state-elements:
        pad_n = len(str(seasonal_period))
        super().__init__(
            id=id,
            state_elements=[self.measured_name] + [zpad(i, pad_n) for i in range(1, seasonal_period)]
        )

        # transitions are placeholders, filled in w/batch
        for i, current in enumerate(self.state_elements):
            self._set_transition(from_element=current, to_element=current, value=0.)
            if i > 0:
                prev = self.state_elements[i - 1]
                self._set_transition(from_element=prev, to_element=current, value=0.)
                if i > 1:
                    self._set_transition(from_element=prev, to_element=self.measured_name, value=0.)

        if decay:
            assert not isinstance(decay, bool), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)
        else:
            self.decay = None
Beispiel #6
0
    def __init__(self,
                 id: str,
                 seasonal_period: Union[int, float],
                 K: Union[int, float],
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Optional[str] = None,
                 dt_unit: Optional[str] = None):

        # season structure:
        self.seasonal_period = seasonal_period
        if isinstance(K, float):
            assert K.is_integer()
        self.K = int(K)

        self.decay = None
        if decay:
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)

        state_elements, list_of_trans_kwargs = self._setup(decay=decay)

        super().__init__(id=id, state_elements=state_elements)

        self._dt_helper = DateTimeHelper(dt_unit=dt_unit,
                                         start_datetime=season_start)

        for trans_kwargs in list_of_trans_kwargs:
            self._set_transition(**trans_kwargs)
Beispiel #7
0
    def __init__(self,
                 id: str,
                 seasonal_period: Union[int, float],
                 K: Union[int, float],
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Union[str, None, bool] = None,
                 dt_unit: Optional[str] = None):

        # handle datetimes:
        self.dt_tracker = DTTracker(season_start=season_start,
                                    dt_unit=dt_unit,
                                    process_id=id)

        # season structure:
        self.seasonal_period = seasonal_period
        if isinstance(K, float):
            assert K.is_integer()
        self.K = int(K)

        self.decay: Optional[Bounded] = None
        if decay:
            assert not isinstance(
                decay, bool
            ), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)

        state_elements, list_of_trans_kwargs = self._setup(decay=decay)
        super().__init__(id=id, state_elements=state_elements)
        for trans_kwargs in list_of_trans_kwargs:
            self._set_transition(**trans_kwargs)
Beispiel #8
0
    def __init__(self,
                 id: str,
                 decay: Union[bool, Tuple[float, float]] = False,
                 initial_state: Optional[torch.nn.Module] = None):
        """
        :param id: A unique identifier for this process.
        :param decay: If the process has decay, then the random walk will tend towards zero as we forecast out further
        (note that this means you should center your time-series, or you should include another process that does not
        have this property). Decay can be between 0 and 1, but values < .50 (or even .90) can often be too rapid and
        you will run into trouble with vanishing gradients. When passing a pair of floats, the nn.Module will assign a
        parameter representing the decay as a learned parameter somewhere between these bounds.
        TODO: support {process}__decay__{kwarg}
        :param initial_state: Optional, a callable (typically a torch.nn.Module). When the KalmanFilter is called,
        keyword-arguments can be passed to initial_state in the format `{this_process}_initial_state__{kwarg}`.
        """
        super().__init__(id=id, state_elements=['position'], initial_state=initial_state)

        self.decay = None
        if decay:
            assert decay[0] >= -1. and decay[1] <= 1.
            self.decay = Bounded(*decay)
            self._set_transition(
                from_element='position',
                to_element='position',
                value=self.decay.get_value
            )
        else:
            self._set_transition(from_element='position', to_element='position', value=1.)
Beispiel #9
0
    def __init__(self,
                 id: str,
                 decay_velocity: Union[bool, Tuple[float,
                                                   float]] = (.95, 1.00),
                 decay_position: Union[bool, Tuple[float, float]] = False):

        super().__init__(id=id, state_elements=['position', 'velocity'])
        self._set_transition(from_element='velocity',
                             to_element='position',
                             value=1.0)

        self.decayed_transitions = {}
        if decay_position:
            assert not isinstance(
                decay_position, bool
            ), "decay_position should be floats of bounds (or False for no decay)"
            assert decay_position[0] > 0. and decay_position[1] <= 1.
            self.decayed_transitions['position'] = Bounded(*decay_position)
            self._set_transition(
                from_element='position',
                to_element='position',
                value=self.decayed_transitions['position'].get_value,
                inv_link=False)
        else:
            self._set_transition(from_element='position',
                                 to_element='position',
                                 value=1.0)

        if decay_velocity:
            assert not isinstance(
                decay_velocity, bool
            ), "decay_velocity should be floats of bounds (or False for no decay)"
            assert decay_velocity[0] > 0. and decay_velocity[1] <= 1.
            self.decayed_transitions['velocity'] = Bounded(*decay_velocity)
            self._set_transition(
                from_element='velocity',
                to_element='velocity',
                value=self.decayed_transitions['velocity'].get_value,
                inv_link=False)
        else:
            self._set_transition(from_element='velocity',
                                 to_element='velocity',
                                 value=1.0)
    def __init__(self,
                 id: str,
                 decay: Union[bool, Tuple[float, float]] = False):
        super().__init__(id=id, state_elements=['position'])

        self.decay: Optional[Bounded] = None
        if decay:
            assert not isinstance(
                decay, bool
            ), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] >= -1. and decay[1] <= 1.
            self.decay = Bounded(*decay)
            self._set_transition(from_element='position',
                                 to_element='position',
                                 value=self.decay.get_value,
                                 inv_link=False)
        else:
            self._set_transition(from_element='position',
                                 to_element='position',
                                 value=1.)
Beispiel #11
0
    def __init__(self,
                 id: str,
                 seasonal_period: Union[int, float],
                 K: Union[int, float],
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Optional[str] = None,
                 dt_unit: Optional[str] = None):
        """
        :param id: Unique name for this instance.
        :param seasonal_period: The seasonal period (e.g. 24 for daily season in hourly data, 365.25 for yearly season
        in daily data)
        :param K: The "K" parameter of the fourier series, see `fourier_tensor`.
        :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend --
        the state will revert to zero as we get further from the last observation. This can be useful if two processes
        are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to
        zero, while the other is less variable but extrapolates into the future.
        :param season_start:  A string that can be parsed into a datetime by `numpy.datetime64`. This is when the season
        starts, which is useful to specify if season boundaries are meaningful. It is important to specify if different
        groups in your dataset start on different dates; when calling the kalman-filter you'll pass an array of
        `start_datetimes` for group in the input, and this will be used to align the seasons for each group.
        :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported.
        """

        # season structure:
        self.seasonal_period = seasonal_period
        if isinstance(K, float):
            assert K.is_integer()
        self.K = int(K)

        self.decay = None
        if decay:
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)

        state_elements, list_of_trans_kwargs = self._setup(decay=decay)

        super().__init__(id=id, state_elements=state_elements, season_start=season_start, dt_unit=dt_unit)

        for trans_kwargs in list_of_trans_kwargs:
            self._set_transition(**trans_kwargs)
Beispiel #12
0
    def __init__(self,
                 id: str,
                 decay: Union[bool, Tuple[float, float]] = False):
        """
        :param id: A unique identifier for this process.
        :param decay: If the process has decay, then the random walk will tend towards zero as we forecast out further
        (note that this means you should center your time-series, or you should include another process that does not
        have this property). Decay can be between 0 and 1, but values < .50 (or even .90) can often be too rapid and
        you will run into trouble with vanishing gradients. When passing a pair of floats, the nn.Module will assign a
        parameter representing the decay as a learned parameter somewhere between these bounds.
        """
        super().__init__(id=id, state_elements=['position'])

        self.decay = None
        if decay:
            assert decay[0] >= -1. and decay[1] <= 1.
            self.decay = Bounded(*decay)
            self._set_transition(
                from_element='position',
                to_element='position',
                value=self.decay.get_value
            )
        else:
            self._set_transition(from_element='position', to_element='position', value=1.)
class Season(Process):
    def __init__(self,
                 id: str,
                 seasonal_period: int,
                 season_duration: int = 1,
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Optional[str] = None,
                 dt_unit: Optional[str] = None):
        """
        Process representing discrete seasons.

        :param id: Unique name for this process
        :param seasonal_period: The number of seasons (e.g. 7 for day_in_week).
        :param season_duration: The length of each season, default 1 time-step.
        :param decay: Analogous to dampening a trend -- the state will revert to zero as we get further from the last
        observation. This can be useful if two processes are capturing the same seasonal pattern: one can be more flexible,
        but with decay have a tendency to revert to zero, while the other is less variable but extrapolates into the future.
        :param season_start: A string that can be parsed into a datetime by `numpy.datetime64`. See DTTracker.
        :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported.
        """

        # handle datetimes:
        self.dt_tracker = DTTracker(season_start=season_start,
                                    dt_unit=dt_unit,
                                    process_id=id)

        #
        self.seasonal_period = seasonal_period
        self.season_duration = season_duration

        # state-elements:
        self.measured_name = 'measured'
        pad_n = len(str(seasonal_period))
        super().__init__(
            id=id,
            state_elements=[self.measured_name] +
            [str(i).rjust(pad_n, "0") for i in range(1, seasonal_period)])

        # transitions are placeholders, filled in w/batch
        for i, current in enumerate(self.state_elements):
            self._set_transition(from_element=current,
                                 to_element=current,
                                 value=0.)
            if i > 0:
                prev = self.state_elements[i - 1]
                self._set_transition(from_element=prev,
                                     to_element=current,
                                     value=0.)
                if i > 1:
                    self._set_transition(from_element=prev,
                                         to_element=self.measured_name,
                                         value=0.)

        if decay:
            assert not isinstance(
                decay, bool
            ), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)
        else:
            self.decay = None

    def add_measure(self, measure: str):
        self._set_measure(measure=measure, state_element='measured', value=1.0)

    def parameters(self) -> Generator[Parameter, None, None]:
        if self.decay is not None:
            yield self.decay.parameter

    @property
    def dynamic_state_elements(self) -> Sequence[str]:
        return [self.measured_name]

    def for_batch(
            self,
            num_groups: int,
            num_timesteps: int,
            start_datetimes: Optional[np.ndarray] = None) -> ProcessForBatch:

        for_batch = super().for_batch(num_groups=num_groups,
                                      num_timesteps=num_timesteps)

        delta = self.dt_tracker.get_delta(for_batch.num_groups,
                                          for_batch.num_timesteps,
                                          start_datetimes=start_datetimes)

        in_transition = (delta %
                         self.season_duration) == (self.season_duration - 1)

        transitions = dict()
        transitions['to_next_state'] = torch.from_numpy(
            in_transition.astype('float32'))
        transitions['to_self'] = 1 - transitions['to_next_state']
        transitions['to_measured'] = -transitions['to_next_state']
        transitions['from_measured_to_measured'] = torch.from_numpy(
            np.where(in_transition, -1., 1.).astype('float32'))
        for k in transitions.keys():
            transitions[k] = split_flat(transitions[k], dim=1, clone=True)
            if self.decay is not None:
                decay_value = self.decay.get_value()
                transitions[k] = [x * decay_value for x in transitions[k]]

        # this is convoluted, but the idea is to manipulate the transitions so that we use one less degree of freedom than
        # the number of seasons, by having the 'measured' state be equal to -sum(all others)
        for i in range(1, len(self.state_elements)):
            current = self.state_elements[i]
            prev = self.state_elements[i - 1]

            if prev == self.measured_name:  # measured requires special-case
                to_measured = transitions['from_measured_to_measured']
            else:
                to_measured = transitions['to_measured']

            for_batch.adjust_transition(
                from_element=prev,
                to_element=current,
                adjustment=transitions['to_next_state'])
            for_batch.adjust_transition(from_element=prev,
                                        to_element=self.measured_name,
                                        adjustment=to_measured)

            # from state to itself:
            for_batch.adjust_transition(from_element=current,
                                        to_element=current,
                                        adjustment=transitions['to_self'])

        return for_batch

    def initial_state_means_for_batch(
            self,
            parameters: Parameter,
            num_groups: int,
            start_datetimes: Optional[np.ndarray] = None) -> Tensor:

        delta = self.dt_tracker.get_delta(
            num_groups, 1, start_datetimes=start_datetimes).squeeze(1)
        season_shift = (np.floor(delta / self.season_duration) %
                        self.seasonal_period).astype('int')
        means = [
            torch.cat([parameters[-shift:], parameters[:-shift]])
            for shift in season_shift
        ]
        return torch.stack(means, 0)
Beispiel #14
0
    def __init__(self,
                 id: str,
                 seasonal_period: int,
                 season_duration: int = 1,
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Optional[str] = None,
                 dt_unit: Optional[str] = None,
                 fixed: bool = False):
        """
        Process representing discrete seasons.

        :param id: Unique name for this process
        :param seasonal_period: The number of seasons (e.g. 7 for day_in_week).
        :param season_duration: The length of each season, default 1 time-step.
        :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend --
        the state will revert to zero as we get further from the last observation. This can be useful if two processes
        are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to
        zero, while the other is less variable but extrapolates into the future.
        :param season_start:  A string that can be parsed into a datetime by `numpy.datetime64`. This is when the season
        starts, which is useful to specify if season boundaries are meaningful. It is important to specify if different
        groups in your dataset start on different dates; when calling the kalman-filter you'll pass an array of
        `start_datetimes` for group in the input, and this will be used to align the seasons for each group.
        :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported.
        :param fixed: If True, then the seasonality does not vary over time, and this amounts to one-hot-encoding the
        seasons. Default False.
        """

        #
        self.seasonal_period = seasonal_period
        self.season_duration = season_duration
        self.fixed = fixed

        # state-elements:
        pad_n = len(str(seasonal_period))
        super().__init__(id=id,
                         state_elements=[self.measured_name] +
                         [zpad(i, pad_n) for i in range(1, seasonal_period)],
                         season_start=season_start,
                         dt_unit=dt_unit)

        # transitions are placeholders, filled in w/batch
        for i, current in enumerate(self.state_elements):
            self._set_transition(from_element=current,
                                 to_element=current,
                                 value=0.)
            if i > 0:
                prev = self.state_elements[i - 1]
                self._set_transition(from_element=prev,
                                     to_element=current,
                                     value=0.)
                if i > 1:
                    self._set_transition(from_element=prev,
                                         to_element=self.measured_name,
                                         value=0.)

        if decay:
            assert not isinstance(
                decay, bool
            ), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)
        else:
            self.decay = None
Beispiel #15
0
class Season(DatetimeProcess, Process):
    measured_name = 'measured'

    def __init__(self,
                 id: str,
                 seasonal_period: int,
                 season_duration: int = 1,
                 decay: Union[bool, Tuple[float, float]] = False,
                 season_start: Optional[str] = None,
                 dt_unit: Optional[str] = None,
                 fixed: bool = False):
        """
        Process representing discrete seasons.

        :param id: Unique name for this process
        :param seasonal_period: The number of seasons (e.g. 7 for day_in_week).
        :param season_duration: The length of each season, default 1 time-step.
        :param decay: Optional (float,float) boundaries for decay (between 0 and 1). Analogous to dampening a trend --
        the state will revert to zero as we get further from the last observation. This can be useful if two processes
        are capturing the same seasonal pattern: one can be more flexible, but with decay have a tendency to revert to
        zero, while the other is less variable but extrapolates into the future.
        :param season_start:  A string that can be parsed into a datetime by `numpy.datetime64`. This is when the season
        starts, which is useful to specify if season boundaries are meaningful. It is important to specify if different
        groups in your dataset start on different dates; when calling the kalman-filter you'll pass an array of
        `start_datetimes` for group in the input, and this will be used to align the seasons for each group.
        :param dt_unit: Currently supports {'Y', 'D', 'h', 'm', 's'}. 'W' is experimentally supported.
        :param fixed: If True, then the seasonality does not vary over time, and this amounts to one-hot-encoding the
        seasons. Default False.
        """

        #
        self.seasonal_period = seasonal_period
        self.season_duration = season_duration
        self.fixed = fixed

        # state-elements:
        pad_n = len(str(seasonal_period))
        super().__init__(id=id,
                         state_elements=[self.measured_name] +
                         [zpad(i, pad_n) for i in range(1, seasonal_period)],
                         season_start=season_start,
                         dt_unit=dt_unit)

        # transitions are placeholders, filled in w/batch
        for i, current in enumerate(self.state_elements):
            self._set_transition(from_element=current,
                                 to_element=current,
                                 value=0.)
            if i > 0:
                prev = self.state_elements[i - 1]
                self._set_transition(from_element=prev,
                                     to_element=current,
                                     value=0.)
                if i > 1:
                    self._set_transition(from_element=prev,
                                         to_element=self.measured_name,
                                         value=0.)

        if decay:
            assert not isinstance(
                decay, bool
            ), "decay should be floats of bounds (or False for no decay)"
            assert decay[0] > 0. and decay[1] <= 1.0
            self.decay = Bounded(*decay)
        else:
            self.decay = None

    def add_measure(self, measure: str) -> 'Season':
        self._set_measure(measure=measure, state_element='measured', value=1.0)
        return self

    def param_dict(self) -> ParameterDict:
        p = ParameterDict()
        if self.decay is not None:
            p['decay'] = self.decay.parameter
        return p

    @property
    def dynamic_state_elements(self) -> Sequence[str]:
        return [] if self.fixed else [self.measured_name]

    def for_batch(self,
                  num_groups: int,
                  num_timesteps: int,
                  start_datetimes: Optional[np.ndarray] = None):

        if start_datetimes is not None:
            if len(start_datetimes.shape) != 1 or len(
                    start_datetimes) != num_groups:
                raise ValueError(
                    f"Expected `start_datetimes` to be 1D array of length {num_groups}."
                )

        for_batch = super().for_batch(num_groups=num_groups,
                                      num_timesteps=num_timesteps)

        delta = self._get_delta(num_groups,
                                num_timesteps,
                                start_datetimes=start_datetimes)

        in_transition = (delta %
                         self.season_duration) == (self.season_duration - 1)

        transitions = {
            'to_next_state':
            torch.from_numpy(in_transition.astype('float32')),
            'from_measured_to_measured':
            torch.from_numpy(
                np.where(in_transition, -1., 1.).astype('float32'))
        }
        transitions['to_self'] = 1 - transitions['to_next_state']
        transitions['to_measured'] = -transitions['to_next_state']

        for k in transitions.keys():
            transitions[k] = split_flat(transitions[k], dim=1, clone=True)
            if self.decay is not None:
                decay_value = self.decay.get_value()
                transitions[k] = [x * decay_value for x in transitions[k]]

        # this is convoluted, but the idea is to manipulate the transitions so that we use one less degree of freedom
        # than the number of seasons, by having the 'measured' state be equal to -sum(all others)
        for i in range(1, len(self.state_elements)):
            current = self.state_elements[i]
            prev = self.state_elements[i - 1]

            if prev == self.measured_name:  # measured requires special-case
                to_measured = transitions['from_measured_to_measured']
            else:
                to_measured = transitions['to_measured']

            for_batch._adjust_transition(
                from_element=prev,
                to_element=current,
                adjustment=transitions['to_next_state'])
            for_batch._adjust_transition(from_element=prev,
                                         to_element=self.measured_name,
                                         adjustment=to_measured)

            # from state to itself:
            for_batch._adjust_transition(from_element=current,
                                         to_element=current,
                                         adjustment=transitions['to_self'])

        return for_batch

    def initial_state_means_for_batch(
            self,
            parameters: Parameter,
            num_groups: int,
            start_datetimes: Optional[np.ndarray] = None) -> Tensor:

        delta = self._get_delta(num_groups, 1,
                                start_datetimes=start_datetimes).squeeze(1)
        season_shift = (np.floor(delta / self.season_duration) %
                        self.seasonal_period).astype('int')
        means = [
            torch.cat([parameters[-shift:], parameters[:-shift]])
            for shift in season_shift
        ]
        return torch.stack(means, 0)