Example #1
0
    def _train_kf(self,
                  data: torch.Tensor,
                  num_epochs: int = 8,
                  cls: Type['KalmanFilter'] = KalmanFilter):
        kf = cls(measures=['y'],
                 processes=[
                     LocalLevel(id='local_level').add_measure('y'),
                     Season(id='day_in_week',
                            seasonal_period=7,
                            **self.config['season_spec']).add_measure('y')
                 ])
        kf.opt = LBFGS(kf.parameters())

        start_datetimes = (
            np.zeros(self.config['num_groups'], dtype='timedelta64') +
            self.config['season_spec']['season_start'])

        def closure():
            kf.opt.zero_grad()
            pred = kf(data, start_datetimes=start_datetimes)
            loss = -pred.log_prob(data).mean()
            loss.backward()
            return loss

        print(f"Will train for {num_epochs} epochs...")
        loss = float('nan')
        for i in range(num_epochs):
            new_loss = kf.opt.step(closure)
            print(
                f"EPOCH {i}, LOSS {new_loss.item()}, DELTA {loss - new_loss.item()}"
            )
            loss = new_loss.item()

        return kf(data, start_datetimes=start_datetimes).predictions
def name_to_proc(id: str, **kwargs) -> Process:
    season_start = '2010-01-04'

    if 'hour_in_day' in id:
        out = FourierSeasonFixed(id=id,
                                 seasonal_period=24, season_start=season_start, dt_unit='h',
                                 **kwargs)
    elif 'day_in_year' in id:
        out = FourierSeasonFixed(id=id,
                                 seasonal_period=24 * 364.25, season_start=season_start, dt_unit='h',
                                 **kwargs)
    elif 'local_level' in id:
        out = LocalLevel(id=id, **kwargs)
    elif 'local_trend' in id:
        out = LocalTrend(id=id, **kwargs)
    elif 'day_in_week' in id:
        out = Season(id=id,
                     seasonal_period=7, season_duration=24,
                     season_start=season_start, dt_unit='h',
                     **kwargs)
    elif 'nn_predictors' in id:
        out = NN(id=id,
                 add_module_params_to_process=False,  # so we can use a separate parameter group
                 model_mat_kwarg_name='predictors',
                 **kwargs)
    elif 'predictors' in id:
        out = LinearModel(id=id,
                          covariates=self.predictors,
                          model_mat_kwarg_name='predictors',
                          **kwargs)
    else:
        raise NotImplementedError(f"Unsure what process to use for `{id}`.")

    return out
Example #3
0
    def test_predictions(self, ndim: int = 2):
        data = torch.zeros((2, 5, ndim))
        kf = KalmanFilter(processes=[
            LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim)
        ],
                          measures=[str(i) for i in range(ndim)],
                          compiled=False)
        pred = kf(data)
        self.assertEqual(len(tuple(pred)), 2)
        self.assertIsInstance(np.asanyarray(pred), np.ndarray)
        means, covs = pred
        self.assertIsInstance(means, torch.Tensor)
        self.assertIsInstance(covs, torch.Tensor)

        with self.assertRaises(TypeError):
            pred[1]

        with self.assertRaises(TypeError):
            pred[(1, )]

        pred_group2 = pred[[1]]
        self.assertTupleEqual(tuple(pred_group2.covs.shape),
                              (1, 5, ndim, ndim))
        self.assertTrue(
            (pred_group2.state_means == pred.state_means[1, :, :]).all())
        self.assertTrue(
            (pred_group2.state_covs == pred.state_covs[1, :, :, :]).all())

        pred_time3 = pred[:, [2]]
        self.assertTupleEqual(tuple(pred_time3.covs.shape), (2, 1, ndim, ndim))
        self.assertTrue(
            (pred_time3.state_means == pred.state_means[:, 2, :]).all())
        self.assertTrue(
            (pred_time3.state_covs == pred.state_covs[:, 2, :, :]).all())
Example #4
0
def simulate(num_groups: int,
             num_timesteps: int,
             season_spec: dict,
             noise: float = 1.0) -> torch.Tensor:
    # make kf:
    processes = [
        LocalLevel(id='local_level').add_measure('y'),
        Season(id='day_in_week', seasonal_period=7, fixed=True,
               **season_spec).add_measure('y'),
        FourierSeasonFixed(id='day_in_month',
                           seasonal_period=30,
                           K=2,
                           **season_spec).add_measure('y')
    ]
    kf = KalmanFilter(measures=['y'], processes=processes)

    # make local-level less aggressive:
    pcov = kf.design.process_covariance.create().data
    pcov[0, 0] *= .1
    kf.design.process_covariance.set(pcov)

    # simulate:
    start_datetimes = np.zeros(
        num_groups, dtype='timedelta64') + season_spec['season_start']
    with torch.no_grad():
        dfb = kf.design.for_batch(num_groups=num_groups,
                                  num_timesteps=num_timesteps,
                                  start_datetimes=start_datetimes)
        initial_state = kf.predict_initial_state(dfb)
        simulated_trajectories = initial_state.simulate_trajectories(dfb)
        sim_data = simulated_trajectories.sample_measurements(eps=noise)

    return sim_data
Example #5
0
def name_to_proc(id: str, **kwargs) -> Process:
    season_start = '2010-01-04'

    if 'hour_in_day' in id:
        out = FourierSeasonFixed(id=id,
                                 seasonal_period=24,
                                 season_start=season_start,
                                 dt_unit='h',
                                 **kwargs)
    elif 'day_in_year' in id:
        out = FourierSeasonFixed(id=id,
                                 seasonal_period=24 * 364.25,
                                 season_start=season_start,
                                 dt_unit='h',
                                 **kwargs)
    elif 'local_level' in id:
        out = LocalLevel(id=id, **kwargs)
    elif 'local_trend' in id:
        out = LocalTrend(id=id, **kwargs)
    elif 'day_in_week' in id:
        out = Season(id=id,
                     seasonal_period=7,
                     season_duration=24,
                     season_start=season_start,
                     dt_unit='h',
                     **kwargs)
    else:
        raise NotImplementedError(f"Unsure what process to use for `{id}`.")

    return out
Example #6
0
    def test_nans(self, ndim: int = 3, n_step: int = 1):
        ntimes = 4 + n_step
        data = torch.ones((5, ntimes, ndim)) * 10
        data[0, 2, 0:(ndim - 1)] = float('nan')
        data[2, 2, 0] = float('nan')

        # test critical helper fun:
        get_nan_groups2 = torch.jit.script(get_nan_groups)
        nan_groups = {2}
        if ndim > 1:
            nan_groups.add(0)
        for t in range(ntimes):
            for group_idx, valid_idx in get_nan_groups2(torch.isnan(data[:,
                                                                         t])):
                if t == 2:
                    if valid_idx is None:
                        self.assertEqual(len(group_idx),
                                         data.shape[0] - len(nan_groups))
                        self.assertFalse(
                            bool(
                                set(group_idx.tolist()).intersection(
                                    nan_groups)))
                    else:
                        self.assertLess(len(valid_idx), ndim)
                        self.assertGreater(len(valid_idx), 0)
                        if len(valid_idx) == 1:
                            if ndim == 2:
                                self.assertSetEqual(set(valid_idx.tolist()),
                                                    {1})
                                self.assertSetEqual(set(group_idx.tolist()),
                                                    nan_groups)
                            else:
                                self.assertSetEqual(set(valid_idx.tolist()),
                                                    {ndim - 1})
                                self.assertSetEqual(set(group_idx.tolist()),
                                                    {0})
                        else:
                            self.assertSetEqual(set(valid_idx.tolist()),
                                                {1, 2})
                            self.assertSetEqual(set(group_idx.tolist()), {2})
                else:
                    self.assertIsNone(valid_idx)

        # test `update`
        # TODO: measure dim vs. state-dim

        # test integration:
        # TODO: make missing dim highly correlated with observed dims. upward trend in observed should get reflected in
        #       unobserved state
        kf = KalmanFilter(processes=[
            LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim)
        ],
                          measures=[str(i) for i in range(ndim)],
                          compiled=True)
        obs_means, obs_covs = kf(data, n_step=n_step)
        self.assertFalse(torch.isnan(obs_means).any())
        self.assertFalse(torch.isnan(obs_covs).any())
        self.assertEqual(tuple(obs_means.shape), (5, ntimes, ndim))
Example #7
0
 def _make_kf():
     return KalmanFilter(processes=[
         LocalLevel(id=f'll{i + 1}', measure=str(i + 1))
         for i in range(ndim)
     ] + [
         LinearModel(id=f'lm{i + 1}',
                     predictors=['x1', 'x2', 'x3', 'x4', 'x5'],
                     measure=str(i + 1)) for i in range(ndim)
     ],
                         measures=[str(i + 1) for i in range(ndim)])
Example #8
0
 def test_gaussian_log_prob(self, ndim: int = 1):
     data = torch.zeros((2, 5, ndim))
     kf = KalmanFilter(processes=[
         LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim)
     ],
                       measures=[str(i) for i in range(ndim)])
     pred = kf(data)
     log_lik1 = kf.kf_step.log_prob(data, *pred)
     from torch.distributions import MultivariateNormal
     mv = MultivariateNormal(*pred)
     log_lik2 = mv.log_prob(data)
     self.assertAlmostEqual(log_lik1.sum().item(), log_lik2.sum().item())
Example #9
0
 def test_dtype(self,
                dtype: torch.dtype,
                ndim: int = 2,
                compiled: bool = True):
     data = torch.zeros((2, 5, ndim), dtype=dtype)
     kf = KalmanFilter(processes=[
         LocalLevel(id=f'll{i}', measure=str(i)) for i in range(ndim)
     ],
                       measures=[str(i) for i in range(ndim)],
                       compiled=compiled)
     kf.to(dtype=dtype)
     pred = kf(data)
     self.assertEqual(pred.means.dtype, dtype)
     loss = pred.log_prob(data)
     self.assertEqual(loss.dtype, dtype)
Example #10
0
def _simulate(num_groups: int, num_timesteps: int, dt_unit: str, noise: float = 1.0) -> torch.Tensor:
    # make kf:
    processes = [
        LocalLevel(id='local_level').add_measure('y'),
        Season(id='day_in_week', seasonal_period=7, fixed=True, dt_unit=dt_unit).add_measure('y'),
        FourierSeason(id='day_in_year', seasonal_period=365.25, K=2, fixed=True, dt_unit=dt_unit).add_measure('y')
    ]
    kf = KalmanFilter(measures=['y'], processes=processes)

    # simulate:
    start_datetimes = np.zeros(num_groups, dtype='timedelta64') + DEFAULT_START_DT
    with torch.no_grad():
        dfb = kf.design.for_batch(num_groups=num_groups, num_timesteps=num_timesteps, start_datetimes=start_datetimes)
        initial_state = kf._predict_initial_state(dfb)
        simulated_trajectories = initial_state.simulate_trajectories(dfb)
        sim_data = simulated_trajectories.sample_measurements(eps=noise)

    return sim_data
Example #11
0
                                               time_colname='date')

# Train/Val split:
dataset_train, dataset_val = dataset_all.train_val_split(dt=SPLIT_DT)
dataset_train, dataset_val
# -

# #### Specify our Model
#
# The `KalmanFilter` subclasses `torch.nn.Module`. We specify the model by passing `processes` that capture the behaviors of our `measures`.

processes = []
for measure in measures_pp:
    processes.extend([
        LocalTrend(id=f'{measure}_trend', multi=.01).add_measure(measure),
        LocalLevel(id=f'{measure}_local_level',
                   decay=(.90, 1.00)).add_measure(measure),
        FourierSeason(id=f'{measure}_day_in_year',
                      seasonal_period=365.25 / 7.,
                      dt_unit='W',
                      K=2,
                      fixed=True).add_measure(measure)
    ])
kf_first = KalmanFilter(measures=measures_pp,
                        processes=processes,
                        measure_var_predict=('seasonal',
                                             dict(K=2,
                                                  period='yearly',
                                                  dt_unit='W')))

# Here we're showing off a few useful features of `torch-kalman`:
#