def testSameShape(self): full_batch_shape, dist = self.build_inputs([5, 4, 2, 3], [5, 4, 2, 3]) sample_shape = _augment_sample_shape(dist, full_batch_shape, validate_args=True) self.assertAllEqual(self.maybe_evaluate(sample_shape), [])
def _joint_sample_n(self, n, seed=None): """Draw a joint sample from the prior over latents and observations. This sampler is specific to LocalLevel models and is faster than the generic LinearGaussianStateSpaceModel implementation. Args: n: `int` `Tensor` number of samples to draw. seed: Optional `int` `Tensor` seed for the random number generator. Returns: latents: `float` `Tensor` of shape `concat([[n], self.batch_shape, [self.num_timesteps, self.latent_size]], axis=0)` representing samples of latent trajectories. observations: `float` `Tensor` of shape `concat([[n], self.batch_shape, [self.num_timesteps, self.observation_size]], axis=0)` representing samples of observed series generated from the sampled `latents`. """ with tf.name_scope('joint_sample_n'): (initial_level_seed, level_jumps_seed, prior_observation_seed) = samplers.split_seed( seed, n=3, salt='LocalLevelStateSpaceModel_joint_sample_n') if self.batch_shape.is_fully_defined(): batch_shape = self.batch_shape.as_list() else: batch_shape = self.batch_shape_tensor() sample_and_batch_shape = tf.cast( prefer_static.concat([[n], batch_shape], axis=0), tf.int32) # Sample the initial timestep from the prior. Since we want # this sample to have full batch shape (not just the batch shape # of the self.initial_state_prior object which might in general be # smaller), we augment the sample shape to include whatever # extra batch dimensions are required. initial_level = self.initial_state_prior.sample( linear_gaussian_ssm._augment_sample_shape( # pylint: disable=protected-access self.initial_state_prior, sample_and_batch_shape, self.validate_args), seed=initial_level_seed) # Sample the latent random walk and observed noise, more efficiently than # the generic loop in `LinearGaussianStateSpaceModel`. level_jumps = self.level_scale[..., tf.newaxis] * samplers.normal( prefer_static.concat( [sample_and_batch_shape, [self.num_timesteps - 1]], axis=0), dtype=self.dtype, seed=level_jumps_seed) prior_level_sample = tf.cumsum(tf.concat( [initial_level, level_jumps], axis=-1), axis=-1) prior_observation_sample = prior_level_sample + ( # Sample noise. self.observation_noise_scale[..., tf.newaxis] * samplers.normal(prefer_static.shape(prior_level_sample), dtype=self.dtype, seed=prior_observation_seed)) return (prior_level_sample[..., tf.newaxis], prior_observation_sample[..., tf.newaxis])
def testNotPrefixThrowsError(self): full_batch_shape, dist = self.build_inputs([5, 4, 2, 3], [1, 3]) with self.assertRaisesError("Broadcasting is not supported"): self.maybe_evaluate( _augment_sample_shape(dist, full_batch_shape, validate_args=True))
def testTooManyDimsThrowsError(self): full_batch_shape, dist = self.build_inputs([5, 4, 2, 3], [6, 5, 4, 2, 3]) with self.assertRaisesError("Cannot broadcast"): self.maybe_evaluate( _augment_sample_shape(dist, full_batch_shape, validate_args=True))
def testTooManyDimsThrowsError(self): full_batch_shape, dist = self.build_inputs([5, 4, 2, 3], [6, 5, 4, 2, 3]) with self.assertRaisesError( "(Broadcasting is not supported|Cannot broadcast)"): self.maybe_evaluate( _augment_sample_shape(dist, full_batch_shape, validate_args=True))
def _joint_sample_n(self, n, seed=None): """Draw a joint sample from the prior over latents and observations. This sampler is specific to LocalLinearTrend models and is faster than the generic LinearGaussianStateSpaceModel implementation. Args: n: `int` `Tensor` number of samples to draw. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: latents: `float` `Tensor` of shape `concat([[n], self.batch_shape, [self.num_timesteps, self.latent_size]], axis=0)` representing samples of latent trajectories. observations: `float` `Tensor` of shape `concat([[n], self.batch_shape, [self.num_timesteps, self.observation_size]], axis=0)` representing samples of observed series generated from the sampled `latents`. """ with tf.name_scope('joint_sample_n'): (initial_state_seed, level_jumps_seed, slope_jumps_seed, prior_observation_seed) = samplers.split_seed( seed, n=4, salt='LocalLinearTrendStateSpaceModel_joint_sample_n') if self.batch_shape.is_fully_defined(): batch_shape = self.batch_shape.as_list() else: batch_shape = self.batch_shape_tensor() sample_and_batch_shape = ps.cast( ps.concat([[n], batch_shape], axis=0), tf.int32) # Sample the initial timestep from the prior. Since we want # this sample to have full batch shape (not just the batch shape # of the self.initial_state_prior object which might in general be # smaller), we augment the sample shape to include whatever # extra batch dimensions are required. initial_level_and_slope = self.initial_state_prior.sample( linear_gaussian_ssm._augment_sample_shape( # pylint: disable=protected-access self.initial_state_prior, sample_and_batch_shape, self.validate_args), seed=initial_state_seed) # Sample the latent random walk on slopes. jumps_shape = ps.concat( [sample_and_batch_shape, [self.num_timesteps - 1]], axis=0) slope_jumps = samplers.normal( jumps_shape, dtype=self.dtype, seed=slope_jumps_seed) * self.slope_scale[..., tf.newaxis] prior_slope_sample = tf.cumsum(tf.concat( [initial_level_and_slope[..., 1:], slope_jumps], axis=-1), axis=-1) # Sample latent levels, given latent slopes. level_jumps = samplers.normal( jumps_shape, dtype=self.dtype, seed=level_jumps_seed) * self.level_scale[..., tf.newaxis] prior_level_sample = tf.cumsum(tf.concat([ initial_level_and_slope[..., :1], level_jumps + prior_slope_sample[..., :-1] ], axis=-1), axis=-1) # Sample noisy observations, given latent levels. prior_observation_sample = prior_level_sample + ( samplers.normal(ps.shape(prior_level_sample), dtype=self.dtype, seed=prior_observation_seed) * self.observation_noise_scale[..., tf.newaxis]) return (tf.stack([prior_level_sample, prior_slope_sample], axis=-1), prior_observation_sample[..., tf.newaxis])