def test_slope_mean_and_variance(self): """Check that slope follows `slope_mean` and has stationary variance.""" level_scale = 0.1 slope_scale = 0.2 initial_level = 3. initial_slope = 0. slope_mean = -2. initial_level = 3. autoregressive_coef = 0.9 num_timesteps = 50 # Stationary distribution of an AR1 process, from # (https://en.wikipedia.org/wiki/Autoregressive_model#Example:_An_AR(1)_process) # pylint: disable=line-too-long stationary_slope_variance = slope_scale**2 / (1. - autoregressive_coef**2) # Initialize the slope prior at the stationary variance. initial_state_prior = tfd.MultivariateNormalDiag( loc=self._build_placeholder([initial_level, initial_slope]), scale_diag=self._build_placeholder( [1., np.sqrt(stationary_slope_variance)])) semilocal_ssm = SemiLocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=self._build_placeholder(level_scale), slope_scale=self._build_placeholder(slope_scale), slope_mean=self._build_placeholder(slope_mean), autoregressive_coef=self._build_placeholder(autoregressive_coef), initial_state_prior=initial_state_prior) # The slope of the mean should converge to `slope_mean` (as opposed to # staying fixed at `initial_slope` as in a LocalLinearTrend). mean_ = self.evaluate(semilocal_ssm.mean()[..., 0]) final_slope = mean_[num_timesteps - 1] - mean_[num_timesteps - 2] self.assertAllClose(final_slope, slope_mean, atol=0.05) # The variance in latent `slope` should converge to the stationary # distribution of an AR1 process: latent_covs, _ = semilocal_ssm._joint_covariances() actual_slope_variances = tf.linalg.diag_part(latent_covs)[:, 1] converged_slope_variance = actual_slope_variances[-1] self.assertAllClose(self.evaluate(converged_slope_variance), stationary_slope_variance, atol=1e-4)
def test_matches_locallineartrend(self): """SemiLocalLinearTrend with trivial AR process is a LocalLinearTrend.""" level_scale = self._build_placeholder(0.5) slope_scale = self._build_placeholder(0.5) initial_level = self._build_placeholder(3.) initial_slope = self._build_placeholder(-2.) num_timesteps = 5 y = self._build_placeholder([1.0, 2.5, 4.3, 6.1, 7.8]) semilocal_ssm = SemiLocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=level_scale, slope_scale=slope_scale, slope_mean=self._build_placeholder(0.), autoregressive_coef=self._build_placeholder(1.), initial_state_prior=tfd.MultivariateNormalDiag( loc=[initial_level, initial_slope], scale_diag=self._build_placeholder([1., 1.]))) local_ssm = LocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=level_scale, slope_scale=slope_scale, initial_state_prior=tfd.MultivariateNormalDiag( loc=[initial_level, initial_slope], scale_diag=self._build_placeholder([1., 1.]))) semilocal_lp = semilocal_ssm.log_prob(y[:, tf.newaxis]) local_lp = local_ssm.log_prob(y[:, tf.newaxis]) self.assertAllClose(self.evaluate(semilocal_lp), self.evaluate(local_lp)) semilocal_mean = semilocal_ssm.mean() local_mean = local_ssm.mean() self.assertAllClose(self.evaluate(semilocal_mean), self.evaluate(local_mean)) semilocal_variance = semilocal_ssm.variance() local_variance = local_ssm.variance() self.assertAllClose(self.evaluate(semilocal_variance), self.evaluate(local_variance))