def test_batch_shape(self): batch_shape = [3, 2] partial_batch_shape = [2] seed = test_util.test_seed(sampler_type='stateless') num_seasons = 24 initial_state_prior = tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( np.exp(np.random.randn(*(partial_batch_shape + [num_seasons]))))) drift_scale = self._build_placeholder( np.exp(np.random.randn(*batch_shape))) observation_noise_scale = self._build_placeholder( np.exp(np.random.randn(*partial_batch_shape))) ssm = SeasonalStateSpaceModel( num_timesteps=9, num_seasons=24, num_steps_per_season=2, drift_scale=drift_scale, observation_noise_scale=observation_noise_scale, initial_state_prior=initial_state_prior) # First check that the model's batch shape is the broadcast batch shape # of parameters, as expected. self.assertAllEqual(self.evaluate(ssm.batch_shape_tensor()), batch_shape) y_ = self.evaluate(ssm.sample(seed=seed)) self.assertAllEqual(y_.shape[:-2], batch_shape) # Next check that the broadcasting works as expected, and the batch log_prob # actually matches the log probs of independent models. individual_ssms = [ SeasonalStateSpaceModel( num_timesteps=9, num_seasons=num_seasons, num_steps_per_season=2, drift_scale=drift_scale[i, j, ...], observation_noise_scale=observation_noise_scale[j, ...], initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=initial_state_prior.scale.diag[j, ...])) for i in range(batch_shape[0]) for j in range(batch_shape[1]) ] batch_lps_ = self.evaluate(ssm.log_prob(y_)).flatten() individual_ys = [ y_[i, j, ...] for i in range(batch_shape[0]) for j in range(batch_shape[1]) ] individual_lps_ = self.evaluate([ individual_ssm.log_prob(individual_y) for (individual_ssm, individual_y) in zip(individual_ssms, individual_ys) ]) self.assertAllClose(individual_lps_, batch_lps_)
def test_day_of_week_example(self): # Test that the Seasonal SSM is equivalent to individually modeling # a random walk on each season's slice of timesteps. seed = test_util.test_seed(sampler_type='stateless') drift_scale = 0.6 observation_noise_scale = 0.1 day_of_week = SeasonalStateSpaceModel( num_timesteps=28, num_seasons=7, drift_scale=self._build_placeholder(drift_scale), observation_noise_scale=observation_noise_scale, initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(np.ones([7]))), num_steps_per_season=1) random_walk_model = tfd.LinearGaussianStateSpaceModel( num_timesteps=4, transition_matrix=self._build_placeholder([[1.]]), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([drift_scale])), observation_matrix=self._build_placeholder([[1.]]), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([observation_noise_scale])), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([1.]))) sampled_time_series = day_of_week.sample(seed=seed) (sampled_time_series_, total_lp_, prior_mean_, prior_variance_) = self.evaluate([ sampled_time_series, day_of_week.log_prob(sampled_time_series), day_of_week.mean(), day_of_week.variance() ]) expected_daily_means_, expected_daily_variances_ = self.evaluate( [random_walk_model.mean(), random_walk_model.variance()]) # For the (noncontiguous) indices corresponding to each season, assert # that the model's mean, variance, and log_prob match a random-walk model. daily_lps = [] for day_idx in range(7): self.assertAllClose(prior_mean_[day_idx::7], expected_daily_means_) self.assertAllClose(prior_variance_[day_idx::7], expected_daily_variances_) daily_lps.append( self.evaluate( random_walk_model.log_prob( sampled_time_series_[day_idx::7]))) self.assertAlmostEqual(total_lp_, sum(daily_lps), places=3)
def test_day_of_week_example(self): # Test that the Seasonal SSM is equivalent to individually modeling # a random walk on each season's slice of timesteps. drift_scale = 0.6 observation_noise_scale = 0.1 day_of_week = SeasonalStateSpaceModel( num_timesteps=28, num_seasons=7, drift_scale=self._build_placeholder(drift_scale), observation_noise_scale=observation_noise_scale, initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(np.ones([7]))), num_steps_per_season=1) random_walk_model = tfd.LinearGaussianStateSpaceModel( num_timesteps=4, transition_matrix=self._build_placeholder([[1.]]), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([drift_scale])), observation_matrix=self._build_placeholder([[1.]]), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([observation_noise_scale])), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([1.]))) sampled_time_series = day_of_week.sample() (sampled_time_series_, total_lp_, prior_mean_, prior_variance_) = self.evaluate([ sampled_time_series, day_of_week.log_prob(sampled_time_series), day_of_week.mean(), day_of_week.variance()]) expected_daily_means_, expected_daily_variances_ = self.evaluate([ random_walk_model.mean(), random_walk_model.variance()]) # For the (noncontiguous) indices corresponding to each season, assert # that the model's mean, variance, and log_prob match a random-walk model. daily_lps = [] for day_idx in range(7): self.assertAllClose(prior_mean_[day_idx::7], expected_daily_means_) self.assertAllClose(prior_variance_[day_idx::7], expected_daily_variances_) daily_lps.append(self.evaluate(random_walk_model.log_prob( sampled_time_series_[day_idx::7]))) self.assertAlmostEqual(total_lp_, sum(daily_lps), places=3)
def test_batch_shape(self): batch_shape = [3, 2] partial_batch_shape = [2] num_seasons = 24 initial_state_prior = tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( np.exp(np.random.randn(*(partial_batch_shape + [num_seasons]))))) drift_scale = self._build_placeholder( np.exp(np.random.randn(*batch_shape))) observation_noise_scale = self._build_placeholder( np.exp(np.random.randn(*partial_batch_shape))) ssm = SeasonalStateSpaceModel( num_timesteps=9, num_seasons=24, num_steps_per_season=2, drift_scale=drift_scale, observation_noise_scale=observation_noise_scale, initial_state_prior=initial_state_prior) # First check that the model's batch shape is the broadcast batch shape # of parameters, as expected. self.assertAllEqual(self.evaluate(ssm.batch_shape_tensor()), batch_shape) y_ = self.evaluate(ssm.sample()) self.assertAllEqual(y_.shape[:-2], batch_shape) # Next check that the broadcasting works as expected, and the batch log_prob # actually matches the log probs of independent models. individual_ssms = [SeasonalStateSpaceModel( num_timesteps=9, num_seasons=num_seasons, num_steps_per_season=2, drift_scale=drift_scale[i, j, ...], observation_noise_scale=observation_noise_scale[j, ...], initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=initial_state_prior.scale.diag[j, ...])) for i in range(batch_shape[0]) for j in range(batch_shape[1])] batch_lps_ = self.evaluate(ssm.log_prob(y_)).flatten() individual_ys = [y_[i, j, ...] for i in range(batch_shape[0]) for j in range(batch_shape[1])] individual_lps_ = self.evaluate([ individual_ssm.log_prob(individual_y) for (individual_ssm, individual_y) in zip(individual_ssms, individual_ys)]) self.assertAllClose(individual_lps_, batch_lps_)