def observation_noise_fn(t): return sts_util.sum_mvns( [tfd.MultivariateNormalDiag( loc=offset_vector, scale_diag=tf.zeros_like(offset_vector))] + [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms])
def test_sum_mvns(self): batch_shape = [4, 2] random_with_shape = ( lambda shape: np.random.standard_normal(shape).astype(np.float32)) mvn1 = tfd.MultivariateNormalDiag( loc=random_with_shape(batch_shape + [3]), scale_diag=np.exp(random_with_shape(batch_shape + [3]))) mvn2 = tfd.MultivariateNormalDiag( loc=random_with_shape(batch_shape + [3]), scale_diag=np.exp(random_with_shape(batch_shape + [3]))) sum_mvn = sts_util.sum_mvns([mvn1, mvn2]) self.assertAllClose(self.evaluate(sum_mvn.mean()), self.evaluate(mvn1.mean() + mvn2.mean())) self.assertAllClose(self.evaluate(sum_mvn.covariance()), self.evaluate(mvn1.covariance() + mvn2.covariance()))
def test_sum_mvns_broadcast_batch_shape(self): random_with_shape = ( lambda shape: np.random.standard_normal(shape).astype(np.float32)) event_shape = [3] mvn1 = tfd.MultivariateNormalDiag( loc=random_with_shape([2] + event_shape), scale_diag=np.exp(random_with_shape([2] + event_shape))) mvn2 = tfd.MultivariateNormalDiag( loc=random_with_shape([1, 2] + event_shape), scale_diag=np.exp(random_with_shape([3, 2] + event_shape))) mvn3 = tfd.MultivariateNormalDiag( loc=random_with_shape([3, 2] + event_shape), scale_diag=np.exp(random_with_shape([2] + event_shape))) sum_mvn = sts_util.sum_mvns([mvn1, mvn2, mvn3]) self.assertAllClose(self.evaluate(sum_mvn.mean()), self.evaluate(mvn1.mean() + mvn2.mean() + mvn3.mean())) self.assertAllClose(self.evaluate(sum_mvn.covariance()), self.evaluate(mvn1.covariance() + mvn2.covariance() + mvn3.covariance()))
def observation_noise_fn(t): return sts_util.sum_mvns( [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms])