Example #1
0
    def test_logprob(self):

        num_timesteps = 5
        y = self._build_placeholder([1.0, 2.5, 4.3, 6.1, 7.8])

        ssm = SemiLocalLinearTrendStateSpaceModel(
            num_timesteps=num_timesteps,
            level_scale=self._build_placeholder(0.5),
            slope_scale=self._build_placeholder(0.5),
            slope_mean=self._build_placeholder(0.2),
            autoregressive_coef=self._build_placeholder(0.3),
            initial_state_prior=tfd.MultivariateNormalDiag(
                scale_diag=self._build_placeholder([1., 1.])))

        lp = ssm.log_prob(y[:, np.newaxis])
        expected_lp = -9.846248626708984
        self.assertAllClose(self.evaluate(lp), expected_lp)
Example #2
0
    def test_matches_locallineartrend(self):
        """SemiLocalLinearTrend with trivial AR process is a LocalLinearTrend."""

        level_scale = self._build_placeholder(0.5)
        slope_scale = self._build_placeholder(0.5)
        initial_level = self._build_placeholder(3.)
        initial_slope = self._build_placeholder(-2.)
        num_timesteps = 5
        y = self._build_placeholder([1.0, 2.5, 4.3, 6.1, 7.8])

        semilocal_ssm = SemiLocalLinearTrendStateSpaceModel(
            num_timesteps=num_timesteps,
            level_scale=level_scale,
            slope_scale=slope_scale,
            slope_mean=self._build_placeholder(0.),
            autoregressive_coef=self._build_placeholder(1.),
            initial_state_prior=tfd.MultivariateNormalDiag(
                loc=[initial_level, initial_slope],
                scale_diag=self._build_placeholder([1., 1.])))

        local_ssm = LocalLinearTrendStateSpaceModel(
            num_timesteps=num_timesteps,
            level_scale=level_scale,
            slope_scale=slope_scale,
            initial_state_prior=tfd.MultivariateNormalDiag(
                loc=[initial_level, initial_slope],
                scale_diag=self._build_placeholder([1., 1.])))

        semilocal_lp = semilocal_ssm.log_prob(y[:, tf.newaxis])
        local_lp = local_ssm.log_prob(y[:, tf.newaxis])
        self.assertAllClose(self.evaluate(semilocal_lp),
                            self.evaluate(local_lp))

        semilocal_mean = semilocal_ssm.mean()
        local_mean = local_ssm.mean()
        self.assertAllClose(self.evaluate(semilocal_mean),
                            self.evaluate(local_mean))

        semilocal_variance = semilocal_ssm.variance()
        local_variance = local_ssm.variance()
        self.assertAllClose(self.evaluate(semilocal_variance),
                            self.evaluate(local_variance))