def test_multivariate_observations(self): # since STS components are scalar by design, we manually construct # a multivariate-output model to verify that the additive SSM handles # this case. num_timesteps = 5 observation_size = 2 multivariate_ssm = self._dummy_model(num_timesteps=num_timesteps, observation_size=observation_size) # Note it would not work to specify observation_noise_scale here; # multivariate observations need to derive the (multivariate) # observation noise distribution from their components. combined_ssm = AdditiveStateSpaceModel( [multivariate_ssm, multivariate_ssm]) y = combined_ssm.sample() expected_event_shape = [num_timesteps, observation_size] if self.use_static_shape: self.assertAllEqual(combined_ssm.event_shape.as_list(), expected_event_shape) self.assertAllEqual(y.shape.as_list()[-2:], expected_event_shape) else: self.assertAllEqual( self.evaluate(combined_ssm.event_shape_tensor()), expected_event_shape) self.assertAllEqual( self.evaluate(tf.shape(input=y))[-2:], expected_event_shape)
def test_nesting_additive_ssms(self): ssm1 = self._dummy_model(batch_shape=[1, 2]) ssm2 = self._dummy_model(batch_shape=[3, 2]) observation_noise_scale = 0.1 additive_ssm = AdditiveStateSpaceModel( [ssm1, ssm2], observation_noise_scale=observation_noise_scale) nested_additive_ssm = AdditiveStateSpaceModel( [AdditiveStateSpaceModel([ssm1]), AdditiveStateSpaceModel([ssm2])], observation_noise_scale=observation_noise_scale) # Test that both models behave equivalently. y = self.evaluate(nested_additive_ssm.sample()) additive_lp = additive_ssm.log_prob(y) nested_additive_lp = nested_additive_ssm.log_prob(y) self.assertAllClose(self.evaluate(additive_lp), self.evaluate(nested_additive_lp)) additive_mean = additive_ssm.mean() nested_additive_mean = nested_additive_ssm.mean() self.assertAllClose(self.evaluate(additive_mean), self.evaluate(nested_additive_mean)) additive_variance = additive_ssm.variance() nested_additive_variance = nested_additive_ssm.variance() self.assertAllClose(self.evaluate(additive_variance), self.evaluate(nested_additive_variance))
def test_broadcasting_batch_shape(self): seed = test_util.test_seed(sampler_type='stateless') # Build three SSMs with broadcast batch shape. ssm1 = self._dummy_model(batch_shape=[2]) ssm2 = self._dummy_model(batch_shape=[3, 2]) ssm3 = self._dummy_model(batch_shape=[1, 2]) additive_ssm = AdditiveStateSpaceModel( component_ssms=[ssm1, ssm2, ssm3]) y = additive_ssm.sample(seed=seed) broadcast_batch_shape = [3, 2] if self.use_static_shape: self.assertAllEqual( tensorshape_util.as_list(additive_ssm.batch_shape), broadcast_batch_shape) self.assertAllEqual( tensorshape_util.as_list(y.shape)[:-2], broadcast_batch_shape) else: self.assertAllEqual( self.evaluate(additive_ssm.batch_shape_tensor()), broadcast_batch_shape) self.assertAllEqual( self.evaluate(tf.shape(y))[:-2], broadcast_batch_shape)
def test_broadcasting_correctness(self): # This test verifies that broadcasting of component parameters works as # expected. We construct a SSM with no batch shape, and test that when we # add it to another SSM of batch shape [3], we get the same model # as if we had explicitly broadcast the parameters of the first SSM before # adding. num_timesteps = 5 transition_matrix = np.random.randn(2, 2) transition_noise_diag = np.exp(np.random.randn(2)) observation_matrix = np.random.randn(1, 2) observation_noise_diag = np.exp(np.random.randn(1)) initial_state_prior_diag = np.exp(np.random.randn(2)) # First build the model in which we let AdditiveSSM do the broadcasting. batchless_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder(transition_matrix), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(transition_noise_diag)), observation_matrix=self._build_placeholder(observation_matrix), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(observation_noise_diag)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(initial_state_prior_diag)) ) another_ssm = self._dummy_model(num_timesteps=num_timesteps, latent_size=4, batch_shape=[3]) broadcast_additive_ssm = AdditiveStateSpaceModel( [batchless_ssm, another_ssm]) # Next try doing our own broadcasting explicitly. broadcast_vector = np.ones([3, 1]) broadcast_matrix = np.ones([3, 1, 1]) batch_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder( transition_matrix * broadcast_matrix), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( transition_noise_diag * broadcast_vector)), observation_matrix=self._build_placeholder( observation_matrix * broadcast_matrix), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( observation_noise_diag * broadcast_vector)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder( initial_state_prior_diag * broadcast_vector))) manual_additive_ssm = AdditiveStateSpaceModel([batch_ssm, another_ssm]) # Both additive SSMs define the same model, so they should give the same # log_probs. y = self.evaluate(broadcast_additive_ssm.sample(seed=42)) self.assertAllEqual(self.evaluate(broadcast_additive_ssm.log_prob(y)), self.evaluate(manual_additive_ssm.log_prob(y)))
def test_mismatched_observation_size_error(self): ssm1 = self._dummy_model(observation_size=1) ssm2 = self._dummy_model(observation_size=2) with self.assertRaisesWithPredicateMatch(Exception, ''): # In the static case, the constructor should raise an exception. additive_ssm = AdditiveStateSpaceModel(component_ssms=[ssm1, ssm2]) # In the dynamic case, the exception is raised at runtime. _ = self.evaluate(additive_ssm.sample())
def test_mismatched_num_timesteps_error(self): ssm1 = self._dummy_model(num_timesteps=10) ssm2 = self._dummy_model(num_timesteps=8) with self.assertRaisesWithPredicateMatch(ValueError, 'same number of timesteps'): # In the static case, the constructor should raise an exception. additive_ssm = AdditiveStateSpaceModel(component_ssms=[ssm1, ssm2]) # In the dynamic case, the exception is raised at runtime. _ = self.evaluate(additive_ssm.sample())
def test_batch_shape(self): batch_shape = [3, 2] ssm = self._dummy_model(batch_shape=batch_shape) additive_ssm = AdditiveStateSpaceModel([ssm, ssm]) y = additive_ssm.sample() if self.use_static_shape: self.assertAllEqual(additive_ssm.batch_shape.as_list(), batch_shape) self.assertAllEqual(y.shape.as_list()[:-2], batch_shape) else: self.assertAllEqual(self.evaluate(additive_ssm.batch_shape_tensor()), batch_shape) self.assertAllEqual(self.evaluate(tf.shape(y))[:-2], batch_shape)
def test_batch_shape(self): batch_shape = [3, 2] seed = test_util.test_seed(sampler_type='stateless') ssm = self._dummy_model(batch_shape=batch_shape) additive_ssm = AdditiveStateSpaceModel([ssm, ssm]) y = additive_ssm.sample(seed=seed) if self.use_static_shape: self.assertAllEqual( tensorshape_util.as_list(additive_ssm.batch_shape), batch_shape) self.assertAllEqual( tensorshape_util.as_list(y.shape)[:-2], batch_shape) else: self.assertAllEqual( self.evaluate(additive_ssm.batch_shape_tensor()), batch_shape) self.assertAllEqual(self.evaluate(tf.shape(y))[:-2], batch_shape)
def test_broadcasting_batch_shape(self): # Build three SSMs with broadcast batch shape. ssm1 = self._dummy_model(batch_shape=[2]) ssm2 = self._dummy_model(batch_shape=[3, 2]) ssm3 = self._dummy_model(batch_shape=[1, 2]) additive_ssm = AdditiveStateSpaceModel( component_ssms=[ssm1, ssm2, ssm3]) y = additive_ssm.sample() broadcast_batch_shape = [3, 2] if self.use_static_shape: self.assertAllEqual(additive_ssm.batch_shape.as_list(), broadcast_batch_shape) self.assertAllEqual(y.shape.as_list()[:-2], broadcast_batch_shape) else: self.assertAllEqual( self.evaluate(additive_ssm.batch_shape_tensor()), broadcast_batch_shape) self.assertAllEqual( self.evaluate(tf.shape(input=y))[:-2], broadcast_batch_shape)