Ejemplo n.º 1
0
    def test_broadcasting_batch_shape(self):
        seed = test_util.test_seed(sampler_type='stateless')

        # Build three SSMs with broadcast batch shape.
        ssm1 = self._dummy_model(batch_shape=[2])
        ssm2 = self._dummy_model(batch_shape=[3, 2])
        ssm3 = self._dummy_model(batch_shape=[1, 2])

        additive_ssm = AdditiveStateSpaceModel(
            component_ssms=[ssm1, ssm2, ssm3])
        y = additive_ssm.sample(seed=seed)

        broadcast_batch_shape = [3, 2]
        if self.use_static_shape:
            self.assertAllEqual(
                tensorshape_util.as_list(additive_ssm.batch_shape),
                broadcast_batch_shape)
            self.assertAllEqual(
                tensorshape_util.as_list(y.shape)[:-2], broadcast_batch_shape)
        else:
            self.assertAllEqual(
                self.evaluate(additive_ssm.batch_shape_tensor()),
                broadcast_batch_shape)
            self.assertAllEqual(
                self.evaluate(tf.shape(y))[:-2], broadcast_batch_shape)
Ejemplo n.º 2
0
  def test_batch_shape(self):
    batch_shape = [3, 2]

    ssm = self._dummy_model(batch_shape=batch_shape)
    additive_ssm = AdditiveStateSpaceModel([ssm, ssm])
    y = additive_ssm.sample()

    if self.use_static_shape:
      self.assertAllEqual(additive_ssm.batch_shape.as_list(), batch_shape)
      self.assertAllEqual(y.shape.as_list()[:-2], batch_shape)
    else:
      self.assertAllEqual(self.evaluate(additive_ssm.batch_shape_tensor()),
                          batch_shape)
      self.assertAllEqual(self.evaluate(tf.shape(y))[:-2], batch_shape)
Ejemplo n.º 3
0
    def test_batch_shape(self):
        batch_shape = [3, 2]
        seed = test_util.test_seed(sampler_type='stateless')

        ssm = self._dummy_model(batch_shape=batch_shape)
        additive_ssm = AdditiveStateSpaceModel([ssm, ssm])
        y = additive_ssm.sample(seed=seed)

        if self.use_static_shape:
            self.assertAllEqual(
                tensorshape_util.as_list(additive_ssm.batch_shape),
                batch_shape)
            self.assertAllEqual(
                tensorshape_util.as_list(y.shape)[:-2], batch_shape)
        else:
            self.assertAllEqual(
                self.evaluate(additive_ssm.batch_shape_tensor()), batch_shape)
            self.assertAllEqual(self.evaluate(tf.shape(y))[:-2], batch_shape)
Ejemplo n.º 4
0
    def test_broadcasting_batch_shape(self):

        # Build three SSMs with broadcast batch shape.
        ssm1 = self._dummy_model(batch_shape=[2])
        ssm2 = self._dummy_model(batch_shape=[3, 2])
        ssm3 = self._dummy_model(batch_shape=[1, 2])

        additive_ssm = AdditiveStateSpaceModel(
            component_ssms=[ssm1, ssm2, ssm3])
        y = additive_ssm.sample()

        broadcast_batch_shape = [3, 2]
        if self.use_static_shape:
            self.assertAllEqual(additive_ssm.batch_shape.as_list(),
                                broadcast_batch_shape)
            self.assertAllEqual(y.shape.as_list()[:-2], broadcast_batch_shape)
        else:
            self.assertAllEqual(
                self.evaluate(additive_ssm.batch_shape_tensor()),
                broadcast_batch_shape)
            self.assertAllEqual(
                self.evaluate(tf.shape(input=y))[:-2], broadcast_batch_shape)
Ejemplo n.º 5
0
    def test_batch_shape_ignores_component_state_priors(self):
        # If we pass an initial_state_prior directly to an AdditiveSSM, overriding
        # the initial state priors of component models, the overall batch shape
        # should no longer depend on the (overridden) component priors.
        # This ensures that we produce correct shapes in forecasting, where the
        # shapes may have changed to include dimensions corresponding to posterior
        # draws.

        # Create a component model with no batch shape *except* in the initial state
        # prior.
        latent_size = 2
        ssm = self._dummy_model(latent_size=latent_size,
                                batch_shape=[],
                                initial_state_prior_batch_shape=[5, 5])

        # If we override the initial state prior with an unbatched prior, the
        # resulting AdditiveSSM should not have batch dimensions.
        unbatched_initial_state_prior = tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(np.ones([latent_size])))
        additive_ssm = AdditiveStateSpaceModel(
            [ssm], initial_state_prior=unbatched_initial_state_prior)

        self.assertAllEqual(self.evaluate(additive_ssm.batch_shape_tensor()),
                            [])