Esempio n. 1
0
    def test_batch_shape(self):
        batch_shape = [3, 2]
        partial_batch_shape = [2]
        seed = test_util.test_seed(sampler_type='stateless')

        num_seasons = 24
        initial_state_prior = tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(
                np.exp(np.random.randn(*(partial_batch_shape +
                                         [num_seasons])))))
        drift_scale = self._build_placeholder(
            np.exp(np.random.randn(*batch_shape)))
        observation_noise_scale = self._build_placeholder(
            np.exp(np.random.randn(*partial_batch_shape)))

        ssm = SeasonalStateSpaceModel(
            num_timesteps=9,
            num_seasons=24,
            num_steps_per_season=2,
            drift_scale=drift_scale,
            observation_noise_scale=observation_noise_scale,
            initial_state_prior=initial_state_prior)

        # First check that the model's batch shape is the broadcast batch shape
        # of parameters, as expected.

        self.assertAllEqual(self.evaluate(ssm.batch_shape_tensor()),
                            batch_shape)
        y_ = self.evaluate(ssm.sample(seed=seed))
        self.assertAllEqual(y_.shape[:-2], batch_shape)

        # Next check that the broadcasting works as expected, and the batch log_prob
        # actually matches the log probs of independent models.
        individual_ssms = [
            SeasonalStateSpaceModel(
                num_timesteps=9,
                num_seasons=num_seasons,
                num_steps_per_season=2,
                drift_scale=drift_scale[i, j, ...],
                observation_noise_scale=observation_noise_scale[j, ...],
                initial_state_prior=tfd.MultivariateNormalDiag(
                    scale_diag=initial_state_prior.scale.diag[j, ...]))
            for i in range(batch_shape[0]) for j in range(batch_shape[1])
        ]

        batch_lps_ = self.evaluate(ssm.log_prob(y_)).flatten()

        individual_ys = [
            y_[i, j, ...] for i in range(batch_shape[0])
            for j in range(batch_shape[1])
        ]
        individual_lps_ = self.evaluate([
            individual_ssm.log_prob(individual_y)
            for (individual_ssm,
                 individual_y) in zip(individual_ssms, individual_ys)
        ])

        self.assertAllClose(individual_lps_, batch_lps_)
Esempio n. 2
0
    def test_day_of_week_example(self):

        # Test that the Seasonal SSM is equivalent to individually modeling
        # a random walk on each season's slice of timesteps.

        seed = test_util.test_seed(sampler_type='stateless')
        drift_scale = 0.6
        observation_noise_scale = 0.1

        day_of_week = SeasonalStateSpaceModel(
            num_timesteps=28,
            num_seasons=7,
            drift_scale=self._build_placeholder(drift_scale),
            observation_noise_scale=observation_noise_scale,
            initial_state_prior=tfd.MultivariateNormalDiag(
                scale_diag=self._build_placeholder(np.ones([7]))),
            num_steps_per_season=1)

        random_walk_model = tfd.LinearGaussianStateSpaceModel(
            num_timesteps=4,
            transition_matrix=self._build_placeholder([[1.]]),
            transition_noise=tfd.MultivariateNormalDiag(
                scale_diag=self._build_placeholder([drift_scale])),
            observation_matrix=self._build_placeholder([[1.]]),
            observation_noise=tfd.MultivariateNormalDiag(
                scale_diag=self._build_placeholder([observation_noise_scale])),
            initial_state_prior=tfd.MultivariateNormalDiag(
                scale_diag=self._build_placeholder([1.])))

        sampled_time_series = day_of_week.sample(seed=seed)
        (sampled_time_series_, total_lp_, prior_mean_,
         prior_variance_) = self.evaluate([
             sampled_time_series,
             day_of_week.log_prob(sampled_time_series),
             day_of_week.mean(),
             day_of_week.variance()
         ])

        expected_daily_means_, expected_daily_variances_ = self.evaluate(
            [random_walk_model.mean(),
             random_walk_model.variance()])

        # For the (noncontiguous) indices corresponding to each season, assert
        # that the model's mean, variance, and log_prob match a random-walk model.
        daily_lps = []
        for day_idx in range(7):
            self.assertAllClose(prior_mean_[day_idx::7], expected_daily_means_)
            self.assertAllClose(prior_variance_[day_idx::7],
                                expected_daily_variances_)

            daily_lps.append(
                self.evaluate(
                    random_walk_model.log_prob(
                        sampled_time_series_[day_idx::7])))

        self.assertAlmostEqual(total_lp_, sum(daily_lps), places=3)
Esempio n. 3
0
  def test_day_of_week_example(self):

    # Test that the Seasonal SSM is equivalent to individually modeling
    # a random walk on each season's slice of timesteps.

    drift_scale = 0.6
    observation_noise_scale = 0.1

    day_of_week = SeasonalStateSpaceModel(
        num_timesteps=28,
        num_seasons=7,
        drift_scale=self._build_placeholder(drift_scale),
        observation_noise_scale=observation_noise_scale,
        initial_state_prior=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(np.ones([7]))),
        num_steps_per_season=1)

    random_walk_model = tfd.LinearGaussianStateSpaceModel(
        num_timesteps=4,
        transition_matrix=self._build_placeholder([[1.]]),
        transition_noise=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder([drift_scale])),
        observation_matrix=self._build_placeholder([[1.]]),
        observation_noise=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder([observation_noise_scale])),
        initial_state_prior=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder([1.])))

    sampled_time_series = day_of_week.sample()
    (sampled_time_series_, total_lp_,
     prior_mean_, prior_variance_) = self.evaluate([
         sampled_time_series,
         day_of_week.log_prob(sampled_time_series),
         day_of_week.mean(),
         day_of_week.variance()])

    expected_daily_means_, expected_daily_variances_ = self.evaluate([
        random_walk_model.mean(),
        random_walk_model.variance()])

    # For the (noncontiguous) indices corresponding to each season, assert
    # that the model's mean, variance, and log_prob match a random-walk model.
    daily_lps = []
    for day_idx in range(7):
      self.assertAllClose(prior_mean_[day_idx::7], expected_daily_means_)
      self.assertAllClose(prior_variance_[day_idx::7],
                          expected_daily_variances_)

      daily_lps.append(self.evaluate(random_walk_model.log_prob(
          sampled_time_series_[day_idx::7])))

    self.assertAlmostEqual(total_lp_, sum(daily_lps), places=3)
Esempio n. 4
0
  def test_batch_shape(self):
    batch_shape = [3, 2]
    partial_batch_shape = [2]

    num_seasons = 24
    initial_state_prior = tfd.MultivariateNormalDiag(
        scale_diag=self._build_placeholder(
            np.exp(np.random.randn(*(partial_batch_shape + [num_seasons])))))
    drift_scale = self._build_placeholder(
        np.exp(np.random.randn(*batch_shape)))
    observation_noise_scale = self._build_placeholder(
        np.exp(np.random.randn(*partial_batch_shape)))

    ssm = SeasonalStateSpaceModel(
        num_timesteps=9,
        num_seasons=24,
        num_steps_per_season=2,
        drift_scale=drift_scale,
        observation_noise_scale=observation_noise_scale,
        initial_state_prior=initial_state_prior)

    # First check that the model's batch shape is the broadcast batch shape
    # of parameters, as expected.

    self.assertAllEqual(self.evaluate(ssm.batch_shape_tensor()), batch_shape)
    y_ = self.evaluate(ssm.sample())
    self.assertAllEqual(y_.shape[:-2], batch_shape)

    # Next check that the broadcasting works as expected, and the batch log_prob
    # actually matches the log probs of independent models.
    individual_ssms = [SeasonalStateSpaceModel(
        num_timesteps=9,
        num_seasons=num_seasons,
        num_steps_per_season=2,
        drift_scale=drift_scale[i, j, ...],
        observation_noise_scale=observation_noise_scale[j, ...],
        initial_state_prior=tfd.MultivariateNormalDiag(
            scale_diag=initial_state_prior.scale.diag[j, ...]))
                       for i in range(batch_shape[0])
                       for j in range(batch_shape[1])]

    batch_lps_ = self.evaluate(ssm.log_prob(y_)).flatten()

    individual_ys = [y_[i, j, ...]
                     for i in range(batch_shape[0])
                     for j in range(batch_shape[1])]
    individual_lps_ = self.evaluate([
        individual_ssm.log_prob(individual_y)
        for (individual_ssm, individual_y)
        in zip(individual_ssms, individual_ys)])

    self.assertAllClose(individual_lps_, batch_lps_)