def _train_kf(self, data: torch.Tensor, num_epochs: int = 8, cls: Type['KalmanFilter'] = KalmanFilter): kf = cls(measures=['y'], processes=[ LocalLevel(id='local_level').add_measure('y'), Season(id='day_in_week', seasonal_period=7, **self.config['season_spec']).add_measure('y') ]) kf.opt = LBFGS(kf.parameters()) start_datetimes = ( np.zeros(self.config['num_groups'], dtype='timedelta64') + self.config['season_spec']['season_start']) def closure(): kf.opt.zero_grad() pred = kf(data, start_datetimes=start_datetimes) loss = -pred.log_prob(data).mean() loss.backward() return loss print(f"Will train for {num_epochs} epochs...") loss = float('nan') for i in range(num_epochs): new_loss = kf.opt.step(closure) print( f"EPOCH {i}, LOSS {new_loss.item()}, DELTA {loss - new_loss.item()}" ) loss = new_loss.item() return kf(data, start_datetimes=start_datetimes).predictions
def name_to_proc(id: str, **kwargs) -> Process: season_start = '2010-01-04' if 'hour_in_day' in id: out = FourierSeasonFixed(id=id, seasonal_period=24, season_start=season_start, dt_unit='h', **kwargs) elif 'day_in_year' in id: out = FourierSeasonFixed(id=id, seasonal_period=24 * 364.25, season_start=season_start, dt_unit='h', **kwargs) elif 'local_level' in id: out = LocalLevel(id=id, **kwargs) elif 'local_trend' in id: out = LocalTrend(id=id, **kwargs) elif 'day_in_week' in id: out = Season(id=id, seasonal_period=7, season_duration=24, season_start=season_start, dt_unit='h', **kwargs) elif 'nn_predictors' in id: out = NN(id=id, add_module_params_to_process=False, # so we can use a separate parameter group model_mat_kwarg_name='predictors', **kwargs) elif 'predictors' in id: out = LinearModel(id=id, covariates=self.predictors, model_mat_kwarg_name='predictors', **kwargs) else: raise NotImplementedError(f"Unsure what process to use for `{id}`.") return out
def test_predictions(self, ndim: int = 2): data = torch.zeros((2, 5, ndim)) kf = KalmanFilter(processes=[ LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim) ], measures=[str(i) for i in range(ndim)], compiled=False) pred = kf(data) self.assertEqual(len(tuple(pred)), 2) self.assertIsInstance(np.asanyarray(pred), np.ndarray) means, covs = pred self.assertIsInstance(means, torch.Tensor) self.assertIsInstance(covs, torch.Tensor) with self.assertRaises(TypeError): pred[1] with self.assertRaises(TypeError): pred[(1, )] pred_group2 = pred[[1]] self.assertTupleEqual(tuple(pred_group2.covs.shape), (1, 5, ndim, ndim)) self.assertTrue( (pred_group2.state_means == pred.state_means[1, :, :]).all()) self.assertTrue( (pred_group2.state_covs == pred.state_covs[1, :, :, :]).all()) pred_time3 = pred[:, [2]] self.assertTupleEqual(tuple(pred_time3.covs.shape), (2, 1, ndim, ndim)) self.assertTrue( (pred_time3.state_means == pred.state_means[:, 2, :]).all()) self.assertTrue( (pred_time3.state_covs == pred.state_covs[:, 2, :, :]).all())
def simulate(num_groups: int, num_timesteps: int, season_spec: dict, noise: float = 1.0) -> torch.Tensor: # make kf: processes = [ LocalLevel(id='local_level').add_measure('y'), Season(id='day_in_week', seasonal_period=7, fixed=True, **season_spec).add_measure('y'), FourierSeasonFixed(id='day_in_month', seasonal_period=30, K=2, **season_spec).add_measure('y') ] kf = KalmanFilter(measures=['y'], processes=processes) # make local-level less aggressive: pcov = kf.design.process_covariance.create().data pcov[0, 0] *= .1 kf.design.process_covariance.set(pcov) # simulate: start_datetimes = np.zeros( num_groups, dtype='timedelta64') + season_spec['season_start'] with torch.no_grad(): dfb = kf.design.for_batch(num_groups=num_groups, num_timesteps=num_timesteps, start_datetimes=start_datetimes) initial_state = kf.predict_initial_state(dfb) simulated_trajectories = initial_state.simulate_trajectories(dfb) sim_data = simulated_trajectories.sample_measurements(eps=noise) return sim_data
def name_to_proc(id: str, **kwargs) -> Process: season_start = '2010-01-04' if 'hour_in_day' in id: out = FourierSeasonFixed(id=id, seasonal_period=24, season_start=season_start, dt_unit='h', **kwargs) elif 'day_in_year' in id: out = FourierSeasonFixed(id=id, seasonal_period=24 * 364.25, season_start=season_start, dt_unit='h', **kwargs) elif 'local_level' in id: out = LocalLevel(id=id, **kwargs) elif 'local_trend' in id: out = LocalTrend(id=id, **kwargs) elif 'day_in_week' in id: out = Season(id=id, seasonal_period=7, season_duration=24, season_start=season_start, dt_unit='h', **kwargs) else: raise NotImplementedError(f"Unsure what process to use for `{id}`.") return out
def test_nans(self, ndim: int = 3, n_step: int = 1): ntimes = 4 + n_step data = torch.ones((5, ntimes, ndim)) * 10 data[0, 2, 0:(ndim - 1)] = float('nan') data[2, 2, 0] = float('nan') # test critical helper fun: get_nan_groups2 = torch.jit.script(get_nan_groups) nan_groups = {2} if ndim > 1: nan_groups.add(0) for t in range(ntimes): for group_idx, valid_idx in get_nan_groups2(torch.isnan(data[:, t])): if t == 2: if valid_idx is None: self.assertEqual(len(group_idx), data.shape[0] - len(nan_groups)) self.assertFalse( bool( set(group_idx.tolist()).intersection( nan_groups))) else: self.assertLess(len(valid_idx), ndim) self.assertGreater(len(valid_idx), 0) if len(valid_idx) == 1: if ndim == 2: self.assertSetEqual(set(valid_idx.tolist()), {1}) self.assertSetEqual(set(group_idx.tolist()), nan_groups) else: self.assertSetEqual(set(valid_idx.tolist()), {ndim - 1}) self.assertSetEqual(set(group_idx.tolist()), {0}) else: self.assertSetEqual(set(valid_idx.tolist()), {1, 2}) self.assertSetEqual(set(group_idx.tolist()), {2}) else: self.assertIsNone(valid_idx) # test `update` # TODO: measure dim vs. state-dim # test integration: # TODO: make missing dim highly correlated with observed dims. upward trend in observed should get reflected in # unobserved state kf = KalmanFilter(processes=[ LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim) ], measures=[str(i) for i in range(ndim)], compiled=True) obs_means, obs_covs = kf(data, n_step=n_step) self.assertFalse(torch.isnan(obs_means).any()) self.assertFalse(torch.isnan(obs_covs).any()) self.assertEqual(tuple(obs_means.shape), (5, ntimes, ndim))
def _make_kf(): return KalmanFilter(processes=[ LocalLevel(id=f'll{i + 1}', measure=str(i + 1)) for i in range(ndim) ] + [ LinearModel(id=f'lm{i + 1}', predictors=['x1', 'x2', 'x3', 'x4', 'x5'], measure=str(i + 1)) for i in range(ndim) ], measures=[str(i + 1) for i in range(ndim)])
def test_gaussian_log_prob(self, ndim: int = 1): data = torch.zeros((2, 5, ndim)) kf = KalmanFilter(processes=[ LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim) ], measures=[str(i) for i in range(ndim)]) pred = kf(data) log_lik1 = kf.kf_step.log_prob(data, *pred) from torch.distributions import MultivariateNormal mv = MultivariateNormal(*pred) log_lik2 = mv.log_prob(data) self.assertAlmostEqual(log_lik1.sum().item(), log_lik2.sum().item())
def test_dtype(self, dtype: torch.dtype, ndim: int = 2, compiled: bool = True): data = torch.zeros((2, 5, ndim), dtype=dtype) kf = KalmanFilter(processes=[ LocalLevel(id=f'll{i}', measure=str(i)) for i in range(ndim) ], measures=[str(i) for i in range(ndim)], compiled=compiled) kf.to(dtype=dtype) pred = kf(data) self.assertEqual(pred.means.dtype, dtype) loss = pred.log_prob(data) self.assertEqual(loss.dtype, dtype)
def _simulate(num_groups: int, num_timesteps: int, dt_unit: str, noise: float = 1.0) -> torch.Tensor: # make kf: processes = [ LocalLevel(id='local_level').add_measure('y'), Season(id='day_in_week', seasonal_period=7, fixed=True, dt_unit=dt_unit).add_measure('y'), FourierSeason(id='day_in_year', seasonal_period=365.25, K=2, fixed=True, dt_unit=dt_unit).add_measure('y') ] kf = KalmanFilter(measures=['y'], processes=processes) # simulate: start_datetimes = np.zeros(num_groups, dtype='timedelta64') + DEFAULT_START_DT with torch.no_grad(): dfb = kf.design.for_batch(num_groups=num_groups, num_timesteps=num_timesteps, start_datetimes=start_datetimes) initial_state = kf._predict_initial_state(dfb) simulated_trajectories = initial_state.simulate_trajectories(dfb) sim_data = simulated_trajectories.sample_measurements(eps=noise) return sim_data
time_colname='date') # Train/Val split: dataset_train, dataset_val = dataset_all.train_val_split(dt=SPLIT_DT) dataset_train, dataset_val # - # #### Specify our Model # # The `KalmanFilter` subclasses `torch.nn.Module`. We specify the model by passing `processes` that capture the behaviors of our `measures`. processes = [] for measure in measures_pp: processes.extend([ LocalTrend(id=f'{measure}_trend', multi=.01).add_measure(measure), LocalLevel(id=f'{measure}_local_level', decay=(.90, 1.00)).add_measure(measure), FourierSeason(id=f'{measure}_day_in_year', seasonal_period=365.25 / 7., dt_unit='W', K=2, fixed=True).add_measure(measure) ]) kf_first = KalmanFilter(measures=measures_pp, processes=processes, measure_var_predict=('seasonal', dict(K=2, period='yearly', dt_unit='W'))) # Here we're showing off a few useful features of `torch-kalman`: #