def test_constant_offset(self): offset_ = 1.23456 offset = self._build_placeholder(offset_) ssm = self._dummy_model() additive_ssm = AdditiveStateSpaceModel([ssm]) additive_ssm_with_offset = AdditiveStateSpaceModel( [ssm], constant_offset=offset) additive_ssm_with_offset_and_explicit_scale = AdditiveStateSpaceModel( [ssm], constant_offset=offset, observation_noise_scale=( ssm.get_observation_noise_for_timestep(0).stddev()[..., 0])) mean_, offset_mean_, offset_with_scale_mean_ = self.evaluate( (additive_ssm.mean(), additive_ssm_with_offset.mean(), additive_ssm_with_offset_and_explicit_scale.mean())) print(mean_.shape, offset_mean_.shape, offset_with_scale_mean_.shape) self.assertAllClose(mean_, offset_mean_ - offset_) self.assertAllClose(mean_, offset_with_scale_mean_ - offset_) # Offset should not affect the stddev. stddev_, offset_stddev_, offset_with_scale_stddev_ = self.evaluate( (additive_ssm.stddev(), additive_ssm_with_offset.stddev(), additive_ssm_with_offset_and_explicit_scale.stddev())) self.assertAllClose(stddev_, offset_stddev_) self.assertAllClose(stddev_, offset_with_scale_stddev_)
def test_constant_offset(self, is_scalar=True): offset_ = np.array(3.1415) if is_scalar else np.array( [3., 1., 4., 1., 5.]) offset = self._build_placeholder(offset_) ssm = self._dummy_model() additive_ssm = AdditiveStateSpaceModel([ssm]) additive_ssm_with_offset = AdditiveStateSpaceModel( [ssm], constant_offset=offset) additive_ssm_with_offset_and_explicit_scale = AdditiveStateSpaceModel( [ssm], constant_offset=offset, observation_noise_scale=( ssm.get_observation_noise_for_timestep(0).stddev()[..., 0])) mean_, offset_mean_, offset_with_scale_mean_ = self.evaluate( (additive_ssm.mean(), additive_ssm_with_offset.mean(), additive_ssm_with_offset_and_explicit_scale.mean())) self.assertAllClose(mean_, offset_mean_ - offset_[..., tf.newaxis]) self.assertAllClose(mean_, offset_with_scale_mean_ - offset_[..., tf.newaxis]) # Offset should not affect the stddev. stddev_, offset_stddev_, offset_with_scale_stddev_ = self.evaluate( (additive_ssm.stddev(), additive_ssm_with_offset.stddev(), additive_ssm_with_offset_and_explicit_scale.stddev())) self.assertAllClose(stddev_, offset_stddev_) self.assertAllClose(stddev_, offset_with_scale_stddev_)