def test_batch_shape(self): batch_shape = [4, 2] level_scale = self._build_placeholder( np.exp(np.random.randn(*(batch_shape)))) slope_scale = self._build_placeholder( np.exp(np.random.randn(*batch_shape))) autoregressive_coef = self._build_placeholder( np.random.randn(*batch_shape)) slope_mean = self._build_placeholder(np.random.randn(*batch_shape)) ssm = SemiLocalLinearTrendStateSpaceModel( num_timesteps=10, level_scale=level_scale, slope_scale=slope_scale, autoregressive_coef=autoregressive_coef, slope_mean=slope_mean, initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([1., 1.]))) if self.use_static_shape: model_batch_shape = ssm.batch_shape.as_list() else: model_batch_shape = self.evaluate(ssm.batch_shape_tensor()) self.assertAllEqual(model_batch_shape, batch_shape) y = ssm.sample() if self.use_static_shape: y_batch_shape = y.shape.as_list()[:-2] else: y_batch_shape = self.evaluate(tf.shape(input=y))[:-2] self.assertAllEqual(y_batch_shape, batch_shape)
def test_logprob(self): num_timesteps = 5 y = self._build_placeholder([1.0, 2.5, 4.3, 6.1, 7.8]) ssm = SemiLocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=self._build_placeholder(0.5), slope_scale=self._build_placeholder(0.5), slope_mean=self._build_placeholder(0.2), autoregressive_coef=self._build_placeholder(0.3), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([1., 1.]))) lp = ssm.log_prob(y[:, np.newaxis]) expected_lp = -9.846248626708984 self.assertAllClose(self.evaluate(lp), expected_lp)
def test_slope_mean_and_variance(self): """Check that slope follows `slope_mean` and has stationary variance.""" level_scale = 0.1 slope_scale = 0.2 initial_level = 3. initial_slope = 0. slope_mean = -2. initial_level = 3. autoregressive_coef = 0.9 num_timesteps = 50 # Stationary distribution of an AR1 process, from # (https://en.wikipedia.org/wiki/Autoregressive_model#Example:_An_AR(1)_process) # pylint: disable=line-too-long stationary_slope_variance = slope_scale**2 / (1. - autoregressive_coef**2) # Initialize the slope prior at the stationary variance. initial_state_prior = tfd.MultivariateNormalDiag( loc=self._build_placeholder([initial_level, initial_slope]), scale_diag=self._build_placeholder( [1., np.sqrt(stationary_slope_variance)])) semilocal_ssm = SemiLocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=self._build_placeholder(level_scale), slope_scale=self._build_placeholder(slope_scale), slope_mean=self._build_placeholder(slope_mean), autoregressive_coef=self._build_placeholder(autoregressive_coef), initial_state_prior=initial_state_prior) # The slope of the mean should converge to `slope_mean` (as opposed to # staying fixed at `initial_slope` as in a LocalLinearTrend). mean_ = self.evaluate(semilocal_ssm.mean()[..., 0]) final_slope = mean_[num_timesteps - 1] - mean_[num_timesteps - 2] self.assertAllClose(final_slope, slope_mean, atol=0.05) # The variance in latent `slope` should converge to the stationary # distribution of an AR1 process: latent_covs, _ = semilocal_ssm._joint_covariances() actual_slope_variances = tf.linalg.diag_part(latent_covs)[:, 1] converged_slope_variance = actual_slope_variances[-1] self.assertAllClose(self.evaluate(converged_slope_variance), stationary_slope_variance, atol=1e-4)
def test_matches_locallineartrend(self): """SemiLocalLinearTrend with trivial AR process is a LocalLinearTrend.""" level_scale = self._build_placeholder(0.5) slope_scale = self._build_placeholder(0.5) initial_level = self._build_placeholder(3.) initial_slope = self._build_placeholder(-2.) num_timesteps = 5 y = self._build_placeholder([1.0, 2.5, 4.3, 6.1, 7.8]) semilocal_ssm = SemiLocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=level_scale, slope_scale=slope_scale, slope_mean=self._build_placeholder(0.), autoregressive_coef=self._build_placeholder(1.), initial_state_prior=tfd.MultivariateNormalDiag( loc=[initial_level, initial_slope], scale_diag=self._build_placeholder([1., 1.]))) local_ssm = LocalLinearTrendStateSpaceModel( num_timesteps=num_timesteps, level_scale=level_scale, slope_scale=slope_scale, initial_state_prior=tfd.MultivariateNormalDiag( loc=[initial_level, initial_slope], scale_diag=self._build_placeholder([1., 1.]))) semilocal_lp = semilocal_ssm.log_prob(y[:, tf.newaxis]) local_lp = local_ssm.log_prob(y[:, tf.newaxis]) self.assertAllClose(self.evaluate(semilocal_lp), self.evaluate(local_lp)) semilocal_mean = semilocal_ssm.mean() local_mean = local_ssm.mean() self.assertAllClose(self.evaluate(semilocal_mean), self.evaluate(local_mean)) semilocal_variance = semilocal_ssm.variance() local_variance = local_ssm.variance() self.assertAllClose(self.evaluate(semilocal_variance), self.evaluate(local_variance))