Ejemplo n.º 1
0
 def observation_noise_fn(t):
   return sts_util.sum_mvns(
       [tfd.MultivariateNormalDiag(
           loc=offset_vector,
           scale_diag=tf.zeros_like(offset_vector))] +
       [ssm.get_observation_noise_for_timestep(t)
        for ssm in component_ssms])
Ejemplo n.º 2
0
  def test_sum_mvns(self):
    batch_shape = [4, 2]
    random_with_shape = (
        lambda shape: np.random.standard_normal(shape).astype(np.float32))

    mvn1 = tfd.MultivariateNormalDiag(
        loc=random_with_shape(batch_shape + [3]),
        scale_diag=np.exp(random_with_shape(batch_shape + [3])))
    mvn2 = tfd.MultivariateNormalDiag(
        loc=random_with_shape(batch_shape + [3]),
        scale_diag=np.exp(random_with_shape(batch_shape + [3])))

    sum_mvn = sts_util.sum_mvns([mvn1, mvn2])
    self.assertAllClose(self.evaluate(sum_mvn.mean()),
                        self.evaluate(mvn1.mean() + mvn2.mean()))
    self.assertAllClose(self.evaluate(sum_mvn.covariance()),
                        self.evaluate(mvn1.covariance() + mvn2.covariance()))
Ejemplo n.º 3
0
  def test_sum_mvns(self):
    batch_shape = [4, 2]
    random_with_shape = (
        lambda shape: np.random.standard_normal(shape).astype(np.float32))

    mvn1 = tfd.MultivariateNormalDiag(
        loc=random_with_shape(batch_shape + [3]),
        scale_diag=np.exp(random_with_shape(batch_shape + [3])))
    mvn2 = tfd.MultivariateNormalDiag(
        loc=random_with_shape(batch_shape + [3]),
        scale_diag=np.exp(random_with_shape(batch_shape + [3])))

    sum_mvn = sts_util.sum_mvns([mvn1, mvn2])
    self.assertAllClose(self.evaluate(sum_mvn.mean()),
                        self.evaluate(mvn1.mean() + mvn2.mean()))
    self.assertAllClose(self.evaluate(sum_mvn.covariance()),
                        self.evaluate(mvn1.covariance() + mvn2.covariance()))
Ejemplo n.º 4
0
  def test_sum_mvns_broadcast_batch_shape(self):
    random_with_shape = (
        lambda shape: np.random.standard_normal(shape).astype(np.float32))

    event_shape = [3]
    mvn1 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([2] + event_shape),
        scale_diag=np.exp(random_with_shape([2] + event_shape)))
    mvn2 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([1, 2] + event_shape),
        scale_diag=np.exp(random_with_shape([3, 2] + event_shape)))
    mvn3 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([3, 2] + event_shape),
        scale_diag=np.exp(random_with_shape([2] + event_shape)))

    sum_mvn = sts_util.sum_mvns([mvn1, mvn2, mvn3])
    self.assertAllClose(self.evaluate(sum_mvn.mean()),
                        self.evaluate(mvn1.mean() + mvn2.mean() + mvn3.mean()))
    self.assertAllClose(self.evaluate(sum_mvn.covariance()),
                        self.evaluate(mvn1.covariance() +
                                      mvn2.covariance() +
                                      mvn3.covariance()))
Ejemplo n.º 5
0
  def test_sum_mvns_broadcast_batch_shape(self):
    random_with_shape = (
        lambda shape: np.random.standard_normal(shape).astype(np.float32))

    event_shape = [3]
    mvn1 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([2] + event_shape),
        scale_diag=np.exp(random_with_shape([2] + event_shape)))
    mvn2 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([1, 2] + event_shape),
        scale_diag=np.exp(random_with_shape([3, 2] + event_shape)))
    mvn3 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([3, 2] + event_shape),
        scale_diag=np.exp(random_with_shape([2] + event_shape)))

    sum_mvn = sts_util.sum_mvns([mvn1, mvn2, mvn3])
    self.assertAllClose(self.evaluate(sum_mvn.mean()),
                        self.evaluate(mvn1.mean() + mvn2.mean() + mvn3.mean()))
    self.assertAllClose(self.evaluate(sum_mvn.covariance()),
                        self.evaluate(mvn1.covariance() +
                                      mvn2.covariance() +
                                      mvn3.covariance()))
Ejemplo n.º 6
0
 def observation_noise_fn(t):
   return sts_util.sum_mvns(
       [ssm.get_observation_noise_for_timestep(t)
        for ssm in component_ssms])
Ejemplo n.º 7
0
 def observation_noise_fn(t):
   return sts_util.sum_mvns(
       [ssm.get_observation_noise_for_timestep(t)
        for ssm in component_ssms])