def test_factored_joint_mvn_diag_full(self): batch_shape = [3, 2] mvn1 = tfd.MultivariateNormalDiag(loc=tf.zeros(batch_shape + [3]), scale_diag=tf.ones(batch_shape + [3])) mvn2 = tfd.MultivariateNormalFullCovariance( loc=tf.ones(batch_shape + [2]), covariance_matrix=(tf.ones(batch_shape + [2, 2]) * [[5., -2], [-2, 3.1]])) joint = sts_util.factored_joint_mvn([mvn1, mvn2]) self.assertEqual( self.evaluate(joint.event_shape_tensor()), self.evaluate(mvn1.event_shape_tensor() + mvn2.event_shape_tensor())) joint_mean_ = self.evaluate(joint.mean()) self.assertAllEqual(joint_mean_[..., :3], self.evaluate(mvn1.mean())) self.assertAllEqual(joint_mean_[..., 3:], self.evaluate(mvn2.mean())) joint_cov_ = self.evaluate(joint.covariance()) self.assertAllEqual(joint_cov_[..., :3, :3], self.evaluate(mvn1.covariance())) self.assertAllEqual(joint_cov_[..., 3:, 3:], self.evaluate(mvn2.covariance()))
def test_factored_joint_mvn_diag_full(self): batch_shape = [3, 2] mvn1 = tfd.MultivariateNormalDiag( loc=tf.zeros(batch_shape + [3]), scale_diag=tf.ones(batch_shape + [3])) mvn2 = tfd.MultivariateNormalFullCovariance( loc=tf.ones(batch_shape + [2]), covariance_matrix=(tf.ones(batch_shape + [2, 2]) * [[5., -2], [-2, 3.1]])) joint = sts_util.factored_joint_mvn([mvn1, mvn2]) self.assertEqual(self.evaluate(joint.event_shape_tensor()), self.evaluate(mvn1.event_shape_tensor() + mvn2.event_shape_tensor())) joint_mean_ = self.evaluate(joint.mean()) self.assertAllEqual(joint_mean_[..., :3], self.evaluate(mvn1.mean())) self.assertAllEqual(joint_mean_[..., 3:], self.evaluate(mvn2.mean())) joint_cov_ = self.evaluate(joint.covariance()) self.assertAllEqual(joint_cov_[..., :3, :3], self.evaluate(mvn1.covariance())) self.assertAllEqual(joint_cov_[..., 3:, 3:], self.evaluate(mvn2.covariance()))
def test_factored_joint_mvn_broadcast_batch_shape(self): # Test that combining MVNs with different but broadcast-compatible # batch shapes yields an MVN with the correct broadcast batch shape. random_with_shape = ( lambda shape: np.random.standard_normal(shape).astype(np.float32)) event_shape = [3] # mvn with batch shape [2] mvn1 = tfd.MultivariateNormalDiag( loc=random_with_shape([2] + event_shape), scale_diag=tf.exp(random_with_shape([2] + event_shape))) # mvn with batch shape [3, 2] mvn2 = tfd.MultivariateNormalDiag( loc=random_with_shape([3, 2] + event_shape), scale_diag=tf.exp(random_with_shape([1, 2] + event_shape))) # mvn with batch shape [1, 2] mvn3 = tfd.MultivariateNormalDiag( loc=random_with_shape([1, 2] + event_shape), scale_diag=tf.exp(random_with_shape([2] + event_shape))) joint = sts_util.factored_joint_mvn([mvn1, mvn2, mvn3]) self.assertAllEqual(self.evaluate(joint.batch_shape_tensor()), [3, 2]) joint_mean_ = self.evaluate(joint.mean()) broadcast_means = tf.ones_like(joint.mean()[..., 0:1]) self.assertAllEqual(joint_mean_[..., :3], self.evaluate(broadcast_means * mvn1.mean())) self.assertAllEqual(joint_mean_[..., 3:6], self.evaluate(broadcast_means * mvn2.mean())) self.assertAllEqual(joint_mean_[..., 6:9], self.evaluate(broadcast_means * mvn3.mean())) joint_cov_ = self.evaluate(joint.covariance()) broadcast_covs = tf.ones_like(joint.covariance()[..., :1, :1]) self.assertAllEqual(joint_cov_[..., :3, :3], self.evaluate(broadcast_covs * mvn1.covariance())) self.assertAllEqual(joint_cov_[..., 3:6, 3:6], self.evaluate(broadcast_covs * mvn2.covariance())) self.assertAllEqual(joint_cov_[..., 6:9, 6:9], self.evaluate(broadcast_covs * mvn3.covariance()))
def test_factored_joint_mvn_broadcast_batch_shape(self): # Test that combining MVNs with different but broadcast-compatible # batch shapes yields an MVN with the correct broadcast batch shape. random_with_shape = ( lambda shape: np.random.standard_normal(shape).astype(np.float32)) event_shape = [3] # mvn with batch shape [2] mvn1 = tfd.MultivariateNormalDiag( loc=random_with_shape([2] + event_shape), scale_diag=tf.exp(random_with_shape([2] + event_shape))) # mvn with batch shape [3, 2] mvn2 = tfd.MultivariateNormalDiag( loc=random_with_shape([3, 2] + event_shape), scale_diag=tf.exp(random_with_shape([1, 2] + event_shape))) # mvn with batch shape [1, 2] mvn3 = tfd.MultivariateNormalDiag( loc=random_with_shape([1, 2] + event_shape), scale_diag=tf.exp(random_with_shape([2] + event_shape))) joint = sts_util.factored_joint_mvn([mvn1, mvn2, mvn3]) self.assertAllEqual(self.evaluate(joint.batch_shape_tensor()), [3, 2]) joint_mean_ = self.evaluate(joint.mean()) broadcast_means = tf.ones_like(joint.mean()[..., 0:1]) self.assertAllEqual(joint_mean_[..., :3], self.evaluate(broadcast_means * mvn1.mean())) self.assertAllEqual(joint_mean_[..., 3:6], self.evaluate(broadcast_means * mvn2.mean())) self.assertAllEqual(joint_mean_[..., 6:9], self.evaluate(broadcast_means * mvn3.mean())) joint_cov_ = self.evaluate(joint.covariance()) broadcast_covs = tf.ones_like(joint.covariance()[..., :1, :1]) self.assertAllEqual(joint_cov_[..., :3, :3], self.evaluate(broadcast_covs * mvn1.covariance())) self.assertAllEqual(joint_cov_[..., 3:6, 3:6], self.evaluate(broadcast_covs * mvn2.covariance())) self.assertAllEqual(joint_cov_[..., 6:9, 6:9], self.evaluate(broadcast_covs * mvn3.covariance()))
def transition_noise_fn(t): return sts_util.factored_joint_mvn([ ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms ])
def __init__(self, component_ssms, constant_offset=0., observation_noise_scale=None, initial_state_prior=None, initial_step=0, validate_args=False, name=None, **linear_gaussian_ssm_kwargs): """Build a state space model representing the sum of component models. Args: component_ssms: Python `list` containing one or more `tfd.LinearGaussianStateSpaceModel` instances. The components will in general implement different time-series models, with possibly different `latent_size`, but they must have the same `dtype`, event shape (`num_timesteps` and `observation_size`), and their batch shapes must broadcast to a compatible batch shape. constant_offset: `float` `Tensor` of shape broadcasting to `concat([batch_shape, [num_timesteps]]`) specifying a constant value added to the sum of outputs from the component models. This allows the components to model the shifted series `observed_time_series - constant_offset`. Default value: `0.` observation_noise_scale: Optional scalar `float` `Tensor` indicating the standard deviation of the observation noise. May contain additional batch dimensions, which must broadcast with the batch shape of elements in `component_ssms`. If `observation_noise_scale` is specified for the `AdditiveStateSpaceModel`, the observation noise scales of component models are ignored. If `None`, the observation noise scale is derived by summing the noise variances of the component models, i.e., `observation_noise_scale = sqrt(sum( [ssm.observation_noise_scale**2 for ssm in component_ssms]))`. initial_state_prior: Optional instance of `tfd.MultivariateNormal` representing a prior distribution on the latent state at time `initial_step`. If `None`, defaults to the independent priors from component models, i.e., `[component.initial_state_prior for component in component_ssms]`. Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to to the base `tfd.LinearGaussianStateSpaceModel` constructor. Raises: ValueError: if components have different `num_timesteps`. """ parameters = dict(locals()) parameters.update(linear_gaussian_ssm_kwargs) del parameters['linear_gaussian_ssm_kwargs'] with tf.name_scope(name or 'AdditiveStateSpaceModel') as name: # Check that all components have the same dtype dtype = tf.debugging.assert_same_float_dtype(component_ssms) # Convert scalar offsets to canonical shape `[..., num_timesteps]`. constant_offset = (tf.convert_to_tensor( value=constant_offset, name='constant_offset', dtype=dtype) * tf.ones([1], dtype=dtype)) offset_length = prefer_static.shape(constant_offset)[-1] assertions = [] # Construct an initial state prior as a block-diagonal combination # of the component state priors. if initial_state_prior is None: initial_state_prior = sts_util.factored_joint_mvn( [ssm.initial_state_prior for ssm in component_ssms]) dtype = initial_state_prior.dtype static_num_timesteps = [ tf.get_static_value(ssm.num_timesteps) for ssm in component_ssms if tf.get_static_value(ssm.num_timesteps) is not None ] # If any components have a static value for `num_timesteps`, use that # value for the additive model. (and check that all other static values # match it). if static_num_timesteps: num_timesteps = static_num_timesteps[0] if not all([ component_timesteps == num_timesteps for component_timesteps in static_num_timesteps ]): raise ValueError( 'Additive model components must all have the same ' 'number of timesteps ' '(saw: {})'.format(static_num_timesteps)) else: num_timesteps = component_ssms[0].num_timesteps if validate_args and len(static_num_timesteps) != len( component_ssms): assertions += [ tf.debugging.assert_equal( # pylint: disable=g-complex-comprehension num_timesteps, ssm.num_timesteps, message='Additive model components must all have ' 'the same number of timesteps.') for ssm in component_ssms ] # Define the transition and observation models for the additive SSM. # See the "mathematical details" section of the class docstring for # further information. Note that we define these as callables to # handle the fully general case in which some components have time- # varying dynamics. def transition_matrix_fn(t): return tfl.LinearOperatorBlockDiag([ ssm.get_transition_matrix_for_timestep(t) for ssm in component_ssms ]) def transition_noise_fn(t): return sts_util.factored_joint_mvn([ ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms ]) # Build the observation matrix, concatenating (broadcast) observation # matrices from components. We also take this as an opportunity to enforce # any dynamic assertions we may have generated above. broadcast_batch_shape = tf.convert_to_tensor( value=sts_util.broadcast_batch_shape([ ssm.get_observation_matrix_for_timestep(initial_step) for ssm in component_ssms ]), dtype=tf.int32) broadcast_obs_matrix = tf.ones(tf.concat( [broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype) if assertions: with tf.control_dependencies(assertions): broadcast_obs_matrix = tf.identity(broadcast_obs_matrix) def observation_matrix_fn(t): return tfl.LinearOperatorFullMatrix( tf.concat([ ssm.get_observation_matrix_for_timestep(t).to_dense() * broadcast_obs_matrix for ssm in component_ssms ], axis=-1)) # Broadcast the constant offset across timesteps. offset_at_step = lambda t: ( # pylint: disable=g-long-lambda constant_offset if offset_length == 1 else tf.gather( constant_offset, tf.minimum(t, offset_length - 1), axis=-1) [..., tf.newaxis]) if observation_noise_scale is not None: observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) def observation_noise_fn(t): return tfd.MultivariateNormalDiag( loc=(sum([ ssm.get_observation_noise_for_timestep(t).mean() for ssm in component_ssms ]) + offset_at_step(t)), scale_diag=observation_noise_scale[..., tf.newaxis]) else: def observation_noise_fn(t): offset = offset_at_step(t) return sts_util.sum_mvns([ tfd.MultivariateNormalDiag( loc=offset, scale_diag=tf.zeros_like(offset)) ] + [ ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms ]) super(AdditiveStateSpaceModel, self).__init__(num_timesteps=num_timesteps, transition_matrix=transition_matrix_fn, transition_noise=transition_noise_fn, observation_matrix=observation_matrix_fn, observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step, validate_args=validate_args, name=name, **linear_gaussian_ssm_kwargs) self._parameters = parameters
def __init__(self, component_ssms, observation_noise_scale=None, initial_state_prior=None, initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """Build a state space model representing the sum of component models. Args: component_ssms: Python `list` containing one or more `tfd.LinearGaussianStateSpaceModel` instances. The components will in general implement different time-series models, with possibly different `latent_size`, but they must have the same `dtype`, event shape (`num_timesteps` and `observation_size`), and their batch shapes must broadcast to a compatible batch shape. observation_noise_scale: Optional scalar `float` `Tensor` indicating the standard deviation of the observation noise. May contain additional batch dimensions, which must broadcast with the batch shape of elements in `component_ssms`. If `observation_noise_scale` is specified for the `AdditiveStateSpaceModel`, the observation noise scales of component models are ignored. If `None`, the observation noise scale is derived by summing the noise variances of the component models, i.e., `observation_noise_scale = sqrt(sum( [ssm.observation_noise_scale**2 for ssm in component_ssms]))`. initial_state_prior: Optional instance of `tfd.MultivariateNormal` representing a prior distribution on the latent state at time `initial_step`. If `None`, defaults to the independent priors from component models, i.e., `[component.initial_state_prior for component in component_ssms]`. Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". Raises: ValueError: if components have different `num_timesteps`. """ with tf.compat.v1.name_scope( name, 'AdditiveStateSpaceModel', values=[observation_noise_scale, initial_step]) as name: assertions = [] # Check that all components have the same dtype tf.debugging.assert_same_float_dtype(component_ssms) # Construct an initial state prior as a block-diagonal combination # of the component state priors. if initial_state_prior is None: initial_state_prior = sts_util.factored_joint_mvn( [ssm.initial_state_prior for ssm in component_ssms]) dtype = initial_state_prior.dtype static_num_timesteps = [ tf.get_static_value(ssm.num_timesteps) for ssm in component_ssms if tf.get_static_value(ssm.num_timesteps) is not None ] # If any components have a static value for `num_timesteps`, use that # value for the additive model. (and check that all other static values # match it). if static_num_timesteps: num_timesteps = static_num_timesteps[0] if not all([component_timesteps == num_timesteps for component_timesteps in static_num_timesteps]): raise ValueError('Additive model components must all have the same ' 'number of timesteps ' '(saw: {})'.format(static_num_timesteps)) else: num_timesteps = component_ssms[0].num_timesteps if validate_args and len(static_num_timesteps) != len(component_ssms): assertions += [ tf.compat.v1.assert_equal( num_timesteps, ssm.num_timesteps, message='Additive model components must all have ' 'the same number of timesteps.') for ssm in component_ssms ] # Define the transition and observation models for the additive SSM. # See the "mathematical details" section of the class docstring for # further information. Note that we define these as callables to # handle the fully general case in which some components have time- # varying dynamics. def transition_matrix_fn(t): return tfl.LinearOperatorBlockDiag( [ssm.get_transition_matrix_for_timestep(t) for ssm in component_ssms]) def transition_noise_fn(t): return sts_util.factored_joint_mvn( [ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms]) # Build the observation matrix, concatenating (broadcast) observation # matrices from components. We also take this as an opportunity to enforce # any dynamic assertions we may have generated above. broadcast_batch_shape = tf.convert_to_tensor( value=sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32) broadcast_obs_matrix = tf.ones( tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype) if assertions: with tf.control_dependencies(assertions): broadcast_obs_matrix = tf.identity(broadcast_obs_matrix) def observation_matrix_fn(t): return tfl.LinearOperatorFullMatrix( tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() * broadcast_obs_matrix for ssm in component_ssms], axis=-1)) if observation_noise_scale is not None: observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) def observation_noise_fn(t): return tfd.MultivariateNormalDiag( loc=sum([ssm.get_observation_noise_for_timestep(t).mean() for ssm in component_ssms]), scale_diag=observation_noise_scale[..., tf.newaxis]) else: def observation_noise_fn(t): return sts_util.sum_mvns( [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms]) super(AdditiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=transition_matrix_fn, transition_noise=transition_noise_fn, observation_matrix=observation_matrix_fn, observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def transition_noise_fn(t): return sts_util.factored_joint_mvn( [ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms])
def __init__(self, component_ssms, observation_noise_scale=None, initial_state_prior=None, initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """Build a state space model representing the sum of component models. Args: component_ssms: Python `list` containing one or more `tfd.LinearGaussianStateSpaceModel` instances. The components will in general implement different time-series models, with possibly different `latent_size`, but they must have the same `dtype`, event shape (`num_timesteps` and `observation_size`), and their batch shapes must broadcast to a compatible batch shape. observation_noise_scale: Optional scalar `float` `Tensor` indicating the standard deviation of the observation noise. May contain additional batch dimensions, which must broadcast with the batch shape of elements in `component_ssms`. If `observation_noise_scale` is specified for the `AdditiveStateSpaceModel`, the observation noise scales of component models are ignored. If `None`, the observation noise scale is derived by summing the noise variances of the component models, i.e., `observation_noise_scale = sqrt(sum( [ssm.observation_noise_scale**2 for ssm in component_ssms]))`. initial_state_prior: Optional instance of `tfd.MultivariateNormal` representing a prior distribution on the latent state at time `initial_step`. If `None`, defaults to the independent priors from component models, i.e., `[component.initial_state_prior for component in component_ssms]`. Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". Raises: ValueError: if components have different `num_timesteps`. """ with tf.name_scope(name, 'AdditiveStateSpaceModel', values=[observation_noise_scale, initial_step]) as name: assertions = [] # Check that all components have the same dtype tf.assert_same_float_dtype(component_ssms) # Construct an initial state prior as a block-diagonal combination # of the component state priors. if initial_state_prior is None: initial_state_prior = sts_util.factored_joint_mvn( [ssm.initial_state_prior for ssm in component_ssms]) dtype = initial_state_prior.dtype static_num_timesteps = [ distribution_util.static_value(ssm.num_timesteps) for ssm in component_ssms if distribution_util.static_value(ssm.num_timesteps) is not None ] # If any components have a static value for `num_timesteps`, use that # value for the additive model. (and check that all other static values # match it). if static_num_timesteps: num_timesteps = static_num_timesteps[0] if not all([component_timesteps == num_timesteps for component_timesteps in static_num_timesteps]): raise ValueError('Additive model components must all have the same ' 'number of timesteps ' '(saw: {})'.format(static_num_timesteps)) else: num_timesteps = component_ssms[0].num_timesteps if validate_args and len(static_num_timesteps) != len(component_ssms): assertions += [ tf.assert_equal(num_timesteps, ssm.num_timesteps, message='Additive model components must all have ' 'the same number of timesteps.') for ssm in component_ssms] # Define the transition and observation models for the additive SSM. # See the "mathematical details" section of the class docstring for # further information. Note that we define these as callables to # handle the fully general case in which some components have time- # varying dynamics. def transition_matrix_fn(t): return tfl.LinearOperatorBlockDiag( [ssm.get_transition_matrix_for_timestep(t) for ssm in component_ssms]) def transition_noise_fn(t): return sts_util.factored_joint_mvn( [ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms]) # Build the observation matrix, concatenating (broadcast) observation # matrices from components. We also take this as an opportunity to enforce # any dynamic assertions we may have generated above. broadcast_batch_shape = tf.convert_to_tensor( sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32) broadcast_obs_matrix = tf.ones( tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype) if assertions: with tf.control_dependencies(assertions): broadcast_obs_matrix = tf.identity(broadcast_obs_matrix) def observation_matrix_fn(t): return tfl.LinearOperatorFullMatrix( tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() * broadcast_obs_matrix for ssm in component_ssms], axis=-1)) if observation_noise_scale is not None: observation_noise_scale = tf.convert_to_tensor( observation_noise_scale, name='observation_noise_scale', dtype=dtype) def observation_noise_fn(t): return tfd.MultivariateNormalDiag( loc=sum([ssm.get_observation_noise_for_timestep(t).mean() for ssm in component_ssms]), scale_diag=observation_noise_scale[..., tf.newaxis]) else: def observation_noise_fn(t): return sts_util.sum_mvns( [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms]) super(AdditiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=transition_matrix_fn, transition_noise=transition_noise_fn, observation_matrix=observation_matrix_fn, observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def _pad_mvn_with_trailing_zeros(mvn, num_zeros): zeros = tf.zeros([num_zeros], dtype=mvn.dtype) return sts_util.factored_joint_mvn( [mvn, tfd.MultivariateNormalDiag(loc=zeros, scale_diag=zeros)])