def test_broadcasting_batch_shape(self): seed = test_util.test_seed(sampler_type='stateless') # Build three SSMs with broadcast batch shape. ssm1 = self._dummy_model(batch_shape=[2]) ssm2 = self._dummy_model(batch_shape=[3, 2]) ssm3 = self._dummy_model(batch_shape=[1, 2]) additive_ssm = AdditiveStateSpaceModel( component_ssms=[ssm1, ssm2, ssm3]) y = additive_ssm.sample(seed=seed) broadcast_batch_shape = [3, 2] if self.use_static_shape: self.assertAllEqual( tensorshape_util.as_list(additive_ssm.batch_shape), broadcast_batch_shape) self.assertAllEqual( tensorshape_util.as_list(y.shape)[:-2], broadcast_batch_shape) else: self.assertAllEqual( self.evaluate(additive_ssm.batch_shape_tensor()), broadcast_batch_shape) self.assertAllEqual( self.evaluate(tf.shape(y))[:-2], broadcast_batch_shape)
def test_batch_shape(self): batch_shape = [3, 2] ssm = self._dummy_model(batch_shape=batch_shape) additive_ssm = AdditiveStateSpaceModel([ssm, ssm]) y = additive_ssm.sample() if self.use_static_shape: self.assertAllEqual(additive_ssm.batch_shape.as_list(), batch_shape) self.assertAllEqual(y.shape.as_list()[:-2], batch_shape) else: self.assertAllEqual(self.evaluate(additive_ssm.batch_shape_tensor()), batch_shape) self.assertAllEqual(self.evaluate(tf.shape(y))[:-2], batch_shape)
def test_batch_shape(self): batch_shape = [3, 2] seed = test_util.test_seed(sampler_type='stateless') ssm = self._dummy_model(batch_shape=batch_shape) additive_ssm = AdditiveStateSpaceModel([ssm, ssm]) y = additive_ssm.sample(seed=seed) if self.use_static_shape: self.assertAllEqual( tensorshape_util.as_list(additive_ssm.batch_shape), batch_shape) self.assertAllEqual( tensorshape_util.as_list(y.shape)[:-2], batch_shape) else: self.assertAllEqual( self.evaluate(additive_ssm.batch_shape_tensor()), batch_shape) self.assertAllEqual(self.evaluate(tf.shape(y))[:-2], batch_shape)
def test_broadcasting_batch_shape(self): # Build three SSMs with broadcast batch shape. ssm1 = self._dummy_model(batch_shape=[2]) ssm2 = self._dummy_model(batch_shape=[3, 2]) ssm3 = self._dummy_model(batch_shape=[1, 2]) additive_ssm = AdditiveStateSpaceModel( component_ssms=[ssm1, ssm2, ssm3]) y = additive_ssm.sample() broadcast_batch_shape = [3, 2] if self.use_static_shape: self.assertAllEqual(additive_ssm.batch_shape.as_list(), broadcast_batch_shape) self.assertAllEqual(y.shape.as_list()[:-2], broadcast_batch_shape) else: self.assertAllEqual( self.evaluate(additive_ssm.batch_shape_tensor()), broadcast_batch_shape) self.assertAllEqual( self.evaluate(tf.shape(input=y))[:-2], broadcast_batch_shape)
def test_batch_shape_ignores_component_state_priors(self): # If we pass an initial_state_prior directly to an AdditiveSSM, overriding # the initial state priors of component models, the overall batch shape # should no longer depend on the (overridden) component priors. # This ensures that we produce correct shapes in forecasting, where the # shapes may have changed to include dimensions corresponding to posterior # draws. # Create a component model with no batch shape *except* in the initial state # prior. latent_size = 2 ssm = self._dummy_model(latent_size=latent_size, batch_shape=[], initial_state_prior_batch_shape=[5, 5]) # If we override the initial state prior with an unbatched prior, the # resulting AdditiveSSM should not have batch dimensions. unbatched_initial_state_prior = tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(np.ones([latent_size]))) additive_ssm = AdditiveStateSpaceModel( [ssm], initial_state_prior=unbatched_initial_state_prior) self.assertAllEqual(self.evaluate(additive_ssm.batch_shape_tensor()), [])