def test_sum_of_white_noise_is_random_walk(self): num_timesteps = 20 level_scale = 0.6 noise_scale = 0.3 random_walk_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=[[1.]], transition_noise=tfd.MultivariateNormalDiag( loc=[0.], scale_diag=[level_scale]), observation_matrix=[[1.]], observation_noise=tfd.MultivariateNormalDiag( loc=[0.], scale_diag=[noise_scale]), initial_state_prior=tfd.MultivariateNormalDiag( loc=[0.], scale_diag=[level_scale])) white_noise_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=[[0.]], transition_noise=tfd.MultivariateNormalDiag( loc=[0.], scale_diag=[level_scale]), observation_matrix=[[1.]], observation_noise=tfd.MultivariateNormalDiag( loc=[0.], scale_diag=[noise_scale]), initial_state_prior=tfd.MultivariateNormalDiag( loc=[0.], scale_diag=[level_scale])) cumsum_white_noise_ssm = IntegratedStateSpaceModel(white_noise_ssm) x, lp = cumsum_white_noise_ssm.experimental_sample_and_log_prob( [3], seed=test_util.test_seed()) self.assertAllClose(lp, random_walk_ssm.log_prob(x), atol=1e-5)
def test_broadcasting_correctness(self): # This test verifies that broadcasting of component parameters works as # expected. We construct a SSM with no batch shape, and test that when we # add it to another SSM of batch shape [3], we get the same model # as if we had explicitly broadcast the parameters of the first SSM before # adding. num_timesteps = 5 transition_matrix = np.random.randn(2, 2) transition_noise_diag = np.exp(np.random.randn(2)) observation_matrix = np.random.randn(1, 2) observation_noise_diag = np.exp(np.random.randn(1)) initial_state_prior_diag = np.exp(np.random.randn(2)) # First build the model in which we let AdditiveSSM do the broadcasting. batchless_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder(transition_matrix), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(transition_noise_diag)), observation_matrix=self._build_placeholder(observation_matrix), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(observation_noise_diag)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(initial_state_prior_diag)) ) another_ssm = self._dummy_model(num_timesteps=num_timesteps, latent_size=4, batch_shape=[3]) broadcast_additive_ssm = AdditiveStateSpaceModel( [batchless_ssm, another_ssm]) # Next try doing our own broadcasting explicitly. broadcast_vector = np.ones([3, 1]) broadcast_matrix = np.ones([3, 1, 1]) batch_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder( transition_matrix * broadcast_matrix), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( transition_noise_diag * broadcast_vector)), observation_matrix=self._build_placeholder( observation_matrix * broadcast_matrix), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( observation_noise_diag * broadcast_vector)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( initial_state_prior_diag * broadcast_vector))) manual_additive_ssm = AdditiveStateSpaceModel([batch_ssm, another_ssm]) # Both additive SSMs define the same model, so they should give the same # log_probs. y = self.evaluate(broadcast_additive_ssm.sample(seed=42)) self.assertAllEqual(self.evaluate(broadcast_additive_ssm.log_prob(y)), self.evaluate(manual_additive_ssm.log_prob(y)))
def _dummy_model(self, num_timesteps=5, batch_shape=None, initial_state_prior_batch_shape=None, latent_size=2, observation_size=1, dtype=None): batch_shape = batch_shape if batch_shape is not None else [] initial_state_prior_batch_shape = (initial_state_prior_batch_shape if initial_state_prior_batch_shape is not None else batch_shape) dtype = dtype if dtype is not None else self.dtype return tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder(np.eye(latent_size), dtype=dtype), transition_noise=tfd.MultivariateNormalDiag( scale_diag=np.ones(batch_shape + [latent_size]).astype(dtype)), observation_matrix=self._build_placeholder( np.random.standard_normal(batch_shape + [observation_size, latent_size]), dtype=dtype), observation_noise=tfd.MultivariateNormalDiag( loc=self._build_placeholder(np.ones(batch_shape + [observation_size]), dtype=dtype), scale_diag=self._build_placeholder(np.ones(batch_shape + [observation_size]), dtype=dtype)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(np.ones( initial_state_prior_batch_shape + [latent_size]), dtype=dtype)))