def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return tf.concat(args, axis=0) return [val for vec in args_ for val in vec]
def __init__(self, component_ssms, observation_noise_scale=None, initial_state_prior=None, initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """Build a state space model representing the sum of component models. Args: component_ssms: Python `list` containing one or more `tfd.LinearGaussianStateSpaceModel` instances. The components will in general implement different time-series models, with possibly different `latent_size`, but they must have the same `dtype`, event shape (`num_timesteps` and `observation_size`), and their batch shapes must broadcast to a compatible batch shape. observation_noise_scale: Optional scalar `float` `Tensor` indicating the standard deviation of the observation noise. May contain additional batch dimensions, which must broadcast with the batch shape of elements in `component_ssms`. If `observation_noise_scale` is specified for the `AdditiveStateSpaceModel`, the observation noise scales of component models are ignored. If `None`, the observation noise scale is derived by summing the noise variances of the component models, i.e., `observation_noise_scale = sqrt(sum( [ssm.observation_noise_scale**2 for ssm in component_ssms]))`. initial_state_prior: Optional instance of `tfd.MultivariateNormal` representing a prior distribution on the latent state at time `initial_step`. If `None`, defaults to the independent priors from component models, i.e., `[component.initial_state_prior for component in component_ssms]`. Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". Raises: ValueError: if components have different `num_timesteps`. """ with tf.name_scope(name, 'AdditiveStateSpaceModel', values=[observation_noise_scale, initial_step]) as name: assertions = [] # Check that all components have the same dtype tf.assert_same_float_dtype(component_ssms) # Construct an initial state prior as a block-diagonal combination # of the component state priors. if initial_state_prior is None: initial_state_prior = sts_util.factored_joint_mvn( [ssm.initial_state_prior for ssm in component_ssms]) dtype = initial_state_prior.dtype static_num_timesteps = [ distribution_util.static_value(ssm.num_timesteps) for ssm in component_ssms if distribution_util.static_value(ssm.num_timesteps) is not None ] # If any components have a static value for `num_timesteps`, use that # value for the additive model. (and check that all other static values # match it). if static_num_timesteps: num_timesteps = static_num_timesteps[0] if not all([component_timesteps == num_timesteps for component_timesteps in static_num_timesteps]): raise ValueError('Additive model components must all have the same ' 'number of timesteps ' '(saw: {})'.format(static_num_timesteps)) else: num_timesteps = component_ssms[0].num_timesteps if validate_args and len(static_num_timesteps) != len(component_ssms): assertions += [ tf.assert_equal(num_timesteps, ssm.num_timesteps, message='Additive model components must all have ' 'the same number of timesteps.') for ssm in component_ssms] # Define the transition and observation models for the additive SSM. # See the "mathematical details" section of the class docstring for # further information. Note that we define these as callables to # handle the fully general case in which some components have time- # varying dynamics. def transition_matrix_fn(t): return tfl.LinearOperatorBlockDiag( [ssm.get_transition_matrix_for_timestep(t) for ssm in component_ssms]) def transition_noise_fn(t): return sts_util.factored_joint_mvn( [ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms]) # Build the observation matrix, concatenating (broadcast) observation # matrices from components. We also take this as an opportunity to enforce # any dynamic assertions we may have generated above. broadcast_batch_shape = tf.convert_to_tensor( sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32) broadcast_obs_matrix = tf.ones( tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype) if assertions: with tf.control_dependencies(assertions): broadcast_obs_matrix = tf.identity(broadcast_obs_matrix) def observation_matrix_fn(t): return tfl.LinearOperatorFullMatrix( tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() * broadcast_obs_matrix for ssm in component_ssms], axis=-1)) if observation_noise_scale is not None: observation_noise_scale = tf.convert_to_tensor( observation_noise_scale, name='observation_noise_scale', dtype=dtype) def observation_noise_fn(t): return tfd.MultivariateNormalDiag( loc=sum([ssm.get_observation_noise_for_timestep(t).mean() for ssm in component_ssms]), scale_diag=observation_noise_scale[..., tf.newaxis]) else: def observation_noise_fn(t): return sts_util.sum_mvns( [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms]) super(AdditiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=transition_matrix_fn, transition_noise=transition_noise_fn, observation_matrix=observation_matrix_fn, observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)