コード例 #1
0
    def test_sum_of_white_noise_is_random_walk(self):
        num_timesteps = 20
        level_scale = 0.6
        noise_scale = 0.3
        random_walk_ssm = tfd.LinearGaussianStateSpaceModel(
            num_timesteps=num_timesteps,
            transition_matrix=[[1.]],
            transition_noise=tfd.MultivariateNormalDiag(
                loc=[0.], scale_diag=[level_scale]),
            observation_matrix=[[1.]],
            observation_noise=tfd.MultivariateNormalDiag(
                loc=[0.], scale_diag=[noise_scale]),
            initial_state_prior=tfd.MultivariateNormalDiag(
                loc=[0.], scale_diag=[level_scale]))

        white_noise_ssm = tfd.LinearGaussianStateSpaceModel(
            num_timesteps=num_timesteps,
            transition_matrix=[[0.]],
            transition_noise=tfd.MultivariateNormalDiag(
                loc=[0.], scale_diag=[level_scale]),
            observation_matrix=[[1.]],
            observation_noise=tfd.MultivariateNormalDiag(
                loc=[0.], scale_diag=[noise_scale]),
            initial_state_prior=tfd.MultivariateNormalDiag(
                loc=[0.], scale_diag=[level_scale]))
        cumsum_white_noise_ssm = IntegratedStateSpaceModel(white_noise_ssm)
        x, lp = cumsum_white_noise_ssm.experimental_sample_and_log_prob(
            [3], seed=test_util.test_seed())
        self.assertAllClose(lp, random_walk_ssm.log_prob(x), atol=1e-5)
コード例 #2
0
ファイル: sum_test.py プロジェクト: zhouyonglong/probability
  def test_broadcasting_correctness(self):

    # This test verifies that broadcasting of component parameters works as
    # expected. We construct a SSM with no batch shape, and test that when we
    # add it to another SSM of batch shape [3], we get the same model
    # as if we had explicitly broadcast the parameters of the first SSM before
    # adding.

    num_timesteps = 5
    transition_matrix = np.random.randn(2, 2)
    transition_noise_diag = np.exp(np.random.randn(2))
    observation_matrix = np.random.randn(1, 2)
    observation_noise_diag = np.exp(np.random.randn(1))
    initial_state_prior_diag = np.exp(np.random.randn(2))

    # First build the model in which we let AdditiveSSM do the broadcasting.
    batchless_ssm = tfd.LinearGaussianStateSpaceModel(
        num_timesteps=num_timesteps,
        transition_matrix=self._build_placeholder(transition_matrix),
        transition_noise=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(transition_noise_diag)),
        observation_matrix=self._build_placeholder(observation_matrix),
        observation_noise=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(observation_noise_diag)),
        initial_state_prior=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(initial_state_prior_diag))
    )
    another_ssm = self._dummy_model(num_timesteps=num_timesteps,
                                    latent_size=4,
                                    batch_shape=[3])
    broadcast_additive_ssm = AdditiveStateSpaceModel(
        [batchless_ssm, another_ssm])

    # Next try doing our own broadcasting explicitly.
    broadcast_vector = np.ones([3, 1])
    broadcast_matrix = np.ones([3, 1, 1])
    batch_ssm = tfd.LinearGaussianStateSpaceModel(
        num_timesteps=num_timesteps,
        transition_matrix=self._build_placeholder(
            transition_matrix * broadcast_matrix),
        transition_noise=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(
                transition_noise_diag * broadcast_vector)),
        observation_matrix=self._build_placeholder(
            observation_matrix * broadcast_matrix),
        observation_noise=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(
                observation_noise_diag * broadcast_vector)),
        initial_state_prior=tfd.MultivariateNormalDiag(
            scale_diag=self._build_placeholder(
                initial_state_prior_diag * broadcast_vector)))
    manual_additive_ssm = AdditiveStateSpaceModel([batch_ssm, another_ssm])

    # Both additive SSMs define the same model, so they should give the same
    # log_probs.
    y = self.evaluate(broadcast_additive_ssm.sample(seed=42))
    self.assertAllEqual(self.evaluate(broadcast_additive_ssm.log_prob(y)),
                        self.evaluate(manual_additive_ssm.log_prob(y)))
コード例 #3
0
    def _dummy_model(self,
                     num_timesteps=5,
                     batch_shape=None,
                     initial_state_prior_batch_shape=None,
                     latent_size=2,
                     observation_size=1,
                     dtype=None):
        batch_shape = batch_shape if batch_shape is not None else []
        initial_state_prior_batch_shape = (initial_state_prior_batch_shape
                                           if initial_state_prior_batch_shape
                                           is not None else batch_shape)
        dtype = dtype if dtype is not None else self.dtype

        return tfd.LinearGaussianStateSpaceModel(
            num_timesteps=num_timesteps,
            transition_matrix=self._build_placeholder(np.eye(latent_size),
                                                      dtype=dtype),
            transition_noise=tfd.MultivariateNormalDiag(
                scale_diag=np.ones(batch_shape + [latent_size]).astype(dtype)),
            observation_matrix=self._build_placeholder(
                np.random.standard_normal(batch_shape +
                                          [observation_size, latent_size]),
                dtype=dtype),
            observation_noise=tfd.MultivariateNormalDiag(
                loc=self._build_placeholder(np.ones(batch_shape +
                                                    [observation_size]),
                                            dtype=dtype),
                scale_diag=self._build_placeholder(np.ones(batch_shape +
                                                           [observation_size]),
                                                   dtype=dtype)),
            initial_state_prior=tfd.MultivariateNormalDiag(
                scale_diag=self._build_placeholder(np.ones(
                    initial_state_prior_batch_shape + [latent_size]),
                                                   dtype=dtype)))