Esempio n. 1
0
    def test_factored_joint_mvn_diag_full(self):
        batch_shape = [3, 2]

        mvn1 = tfd.MultivariateNormalDiag(loc=tf.zeros(batch_shape + [3]),
                                          scale_diag=tf.ones(batch_shape +
                                                             [3]))

        mvn2 = tfd.MultivariateNormalFullCovariance(
            loc=tf.ones(batch_shape + [2]),
            covariance_matrix=(tf.ones(batch_shape + [2, 2]) *
                               [[5., -2], [-2, 3.1]]))

        joint = sts_util.factored_joint_mvn([mvn1, mvn2])
        self.assertEqual(
            self.evaluate(joint.event_shape_tensor()),
            self.evaluate(mvn1.event_shape_tensor() +
                          mvn2.event_shape_tensor()))

        joint_mean_ = self.evaluate(joint.mean())
        self.assertAllEqual(joint_mean_[..., :3], self.evaluate(mvn1.mean()))
        self.assertAllEqual(joint_mean_[..., 3:], self.evaluate(mvn2.mean()))

        joint_cov_ = self.evaluate(joint.covariance())
        self.assertAllEqual(joint_cov_[..., :3, :3],
                            self.evaluate(mvn1.covariance()))
        self.assertAllEqual(joint_cov_[..., 3:, 3:],
                            self.evaluate(mvn2.covariance()))
Esempio n. 2
0
  def test_factored_joint_mvn_diag_full(self):
    batch_shape = [3, 2]

    mvn1 = tfd.MultivariateNormalDiag(
        loc=tf.zeros(batch_shape + [3]),
        scale_diag=tf.ones(batch_shape + [3]))

    mvn2 = tfd.MultivariateNormalFullCovariance(
        loc=tf.ones(batch_shape + [2]),
        covariance_matrix=(tf.ones(batch_shape + [2, 2]) *
                           [[5., -2], [-2, 3.1]]))

    joint = sts_util.factored_joint_mvn([mvn1, mvn2])
    self.assertEqual(self.evaluate(joint.event_shape_tensor()),
                     self.evaluate(mvn1.event_shape_tensor() +
                                   mvn2.event_shape_tensor()))

    joint_mean_ = self.evaluate(joint.mean())
    self.assertAllEqual(joint_mean_[..., :3], self.evaluate(mvn1.mean()))
    self.assertAllEqual(joint_mean_[..., 3:], self.evaluate(mvn2.mean()))

    joint_cov_ = self.evaluate(joint.covariance())
    self.assertAllEqual(joint_cov_[..., :3, :3],
                        self.evaluate(mvn1.covariance()))
    self.assertAllEqual(joint_cov_[..., 3:, 3:],
                        self.evaluate(mvn2.covariance()))
Esempio n. 3
0
    def test_factored_joint_mvn_broadcast_batch_shape(self):
        # Test that combining MVNs with different but broadcast-compatible
        # batch shapes yields an MVN with the correct broadcast batch shape.
        random_with_shape = (
            lambda shape: np.random.standard_normal(shape).astype(np.float32))

        event_shape = [3]
        # mvn with batch shape [2]
        mvn1 = tfd.MultivariateNormalDiag(
            loc=random_with_shape([2] + event_shape),
            scale_diag=tf.exp(random_with_shape([2] + event_shape)))

        # mvn with batch shape [3, 2]
        mvn2 = tfd.MultivariateNormalDiag(
            loc=random_with_shape([3, 2] + event_shape),
            scale_diag=tf.exp(random_with_shape([1, 2] + event_shape)))

        # mvn with batch shape [1, 2]
        mvn3 = tfd.MultivariateNormalDiag(
            loc=random_with_shape([1, 2] + event_shape),
            scale_diag=tf.exp(random_with_shape([2] + event_shape)))

        joint = sts_util.factored_joint_mvn([mvn1, mvn2, mvn3])
        self.assertAllEqual(self.evaluate(joint.batch_shape_tensor()), [3, 2])

        joint_mean_ = self.evaluate(joint.mean())
        broadcast_means = tf.ones_like(joint.mean()[..., 0:1])
        self.assertAllEqual(joint_mean_[..., :3],
                            self.evaluate(broadcast_means * mvn1.mean()))
        self.assertAllEqual(joint_mean_[..., 3:6],
                            self.evaluate(broadcast_means * mvn2.mean()))
        self.assertAllEqual(joint_mean_[..., 6:9],
                            self.evaluate(broadcast_means * mvn3.mean()))

        joint_cov_ = self.evaluate(joint.covariance())
        broadcast_covs = tf.ones_like(joint.covariance()[..., :1, :1])
        self.assertAllEqual(joint_cov_[..., :3, :3],
                            self.evaluate(broadcast_covs * mvn1.covariance()))
        self.assertAllEqual(joint_cov_[..., 3:6, 3:6],
                            self.evaluate(broadcast_covs * mvn2.covariance()))
        self.assertAllEqual(joint_cov_[..., 6:9, 6:9],
                            self.evaluate(broadcast_covs * mvn3.covariance()))
Esempio n. 4
0
  def test_factored_joint_mvn_broadcast_batch_shape(self):
    # Test that combining MVNs with different but broadcast-compatible
    # batch shapes yields an MVN with the correct broadcast batch shape.
    random_with_shape = (
        lambda shape: np.random.standard_normal(shape).astype(np.float32))

    event_shape = [3]
    # mvn with batch shape [2]
    mvn1 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([2] + event_shape),
        scale_diag=tf.exp(random_with_shape([2] + event_shape)))

    # mvn with batch shape [3, 2]
    mvn2 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([3, 2] + event_shape),
        scale_diag=tf.exp(random_with_shape([1, 2] + event_shape)))

    # mvn with batch shape [1, 2]
    mvn3 = tfd.MultivariateNormalDiag(
        loc=random_with_shape([1, 2] + event_shape),
        scale_diag=tf.exp(random_with_shape([2] + event_shape)))

    joint = sts_util.factored_joint_mvn([mvn1, mvn2, mvn3])
    self.assertAllEqual(self.evaluate(joint.batch_shape_tensor()), [3, 2])

    joint_mean_ = self.evaluate(joint.mean())
    broadcast_means = tf.ones_like(joint.mean()[..., 0:1])
    self.assertAllEqual(joint_mean_[..., :3],
                        self.evaluate(broadcast_means * mvn1.mean()))
    self.assertAllEqual(joint_mean_[..., 3:6],
                        self.evaluate(broadcast_means * mvn2.mean()))
    self.assertAllEqual(joint_mean_[..., 6:9],
                        self.evaluate(broadcast_means * mvn3.mean()))

    joint_cov_ = self.evaluate(joint.covariance())
    broadcast_covs = tf.ones_like(joint.covariance()[..., :1, :1])
    self.assertAllEqual(joint_cov_[..., :3, :3],
                        self.evaluate(broadcast_covs * mvn1.covariance()))
    self.assertAllEqual(joint_cov_[..., 3:6, 3:6],
                        self.evaluate(broadcast_covs * mvn2.covariance()))
    self.assertAllEqual(joint_cov_[..., 6:9, 6:9],
                        self.evaluate(broadcast_covs * mvn3.covariance()))
Esempio n. 5
0
 def transition_noise_fn(t):
     return sts_util.factored_joint_mvn([
         ssm.get_transition_noise_for_timestep(t)
         for ssm in component_ssms
     ])
Esempio n. 6
0
    def __init__(self,
                 component_ssms,
                 constant_offset=0.,
                 observation_noise_scale=None,
                 initial_state_prior=None,
                 initial_step=0,
                 validate_args=False,
                 name=None,
                 **linear_gaussian_ssm_kwargs):
        """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      constant_offset: `float` `Tensor` of shape broadcasting to
        `concat([batch_shape, [num_timesteps]]`) specifying a constant value
        added to the sum of outputs from the component models. This allows the
        components to model the shifted series
        `observed_time_series - constant_offset`.
        Default value: `0.`
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".
      **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to
        to the base `tfd.LinearGaussianStateSpaceModel` constructor.
    Raises:
      ValueError: if components have different `num_timesteps`.
    """
        parameters = dict(locals())
        parameters.update(linear_gaussian_ssm_kwargs)
        del parameters['linear_gaussian_ssm_kwargs']
        with tf.name_scope(name or 'AdditiveStateSpaceModel') as name:
            # Check that all components have the same dtype
            dtype = tf.debugging.assert_same_float_dtype(component_ssms)

            # Convert scalar offsets to canonical shape `[..., num_timesteps]`.
            constant_offset = (tf.convert_to_tensor(
                value=constant_offset, name='constant_offset', dtype=dtype) *
                               tf.ones([1], dtype=dtype))
            offset_length = prefer_static.shape(constant_offset)[-1]
            assertions = []

            # Construct an initial state prior as a block-diagonal combination
            # of the component state priors.
            if initial_state_prior is None:
                initial_state_prior = sts_util.factored_joint_mvn(
                    [ssm.initial_state_prior for ssm in component_ssms])
            dtype = initial_state_prior.dtype

            static_num_timesteps = [
                tf.get_static_value(ssm.num_timesteps)
                for ssm in component_ssms
                if tf.get_static_value(ssm.num_timesteps) is not None
            ]

            # If any components have a static value for `num_timesteps`, use that
            # value for the additive model. (and check that all other static values
            # match it).
            if static_num_timesteps:
                num_timesteps = static_num_timesteps[0]
                if not all([
                        component_timesteps == num_timesteps
                        for component_timesteps in static_num_timesteps
                ]):
                    raise ValueError(
                        'Additive model components must all have the same '
                        'number of timesteps '
                        '(saw: {})'.format(static_num_timesteps))
            else:
                num_timesteps = component_ssms[0].num_timesteps
            if validate_args and len(static_num_timesteps) != len(
                    component_ssms):
                assertions += [
                    tf.debugging.assert_equal(  # pylint: disable=g-complex-comprehension
                        num_timesteps,
                        ssm.num_timesteps,
                        message='Additive model components must all have '
                        'the same number of timesteps.')
                    for ssm in component_ssms
                ]

            # Define the transition and observation models for the additive SSM.
            # See the "mathematical details" section of the class docstring for
            # further information. Note that we define these as callables to
            # handle the fully general case in which some components have time-
            # varying dynamics.
            def transition_matrix_fn(t):
                return tfl.LinearOperatorBlockDiag([
                    ssm.get_transition_matrix_for_timestep(t)
                    for ssm in component_ssms
                ])

            def transition_noise_fn(t):
                return sts_util.factored_joint_mvn([
                    ssm.get_transition_noise_for_timestep(t)
                    for ssm in component_ssms
                ])

            # Build the observation matrix, concatenating (broadcast) observation
            # matrices from components. We also take this as an opportunity to enforce
            # any dynamic assertions we may have generated above.
            broadcast_batch_shape = tf.convert_to_tensor(
                value=sts_util.broadcast_batch_shape([
                    ssm.get_observation_matrix_for_timestep(initial_step)
                    for ssm in component_ssms
                ]),
                dtype=tf.int32)
            broadcast_obs_matrix = tf.ones(tf.concat(
                [broadcast_batch_shape, [1, 1]], axis=0),
                                           dtype=dtype)
            if assertions:
                with tf.control_dependencies(assertions):
                    broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

            def observation_matrix_fn(t):
                return tfl.LinearOperatorFullMatrix(
                    tf.concat([
                        ssm.get_observation_matrix_for_timestep(t).to_dense() *
                        broadcast_obs_matrix for ssm in component_ssms
                    ],
                              axis=-1))

            # Broadcast the constant offset across timesteps.
            offset_at_step = lambda t: (  # pylint: disable=g-long-lambda
                constant_offset if offset_length == 1 else tf.gather(
                    constant_offset, tf.minimum(t, offset_length - 1), axis=-1)
                [..., tf.newaxis])

            if observation_noise_scale is not None:
                observation_noise_scale = tf.convert_to_tensor(
                    value=observation_noise_scale,
                    name='observation_noise_scale',
                    dtype=dtype)

                def observation_noise_fn(t):
                    return tfd.MultivariateNormalDiag(
                        loc=(sum([
                            ssm.get_observation_noise_for_timestep(t).mean()
                            for ssm in component_ssms
                        ]) + offset_at_step(t)),
                        scale_diag=observation_noise_scale[..., tf.newaxis])
            else:

                def observation_noise_fn(t):
                    offset = offset_at_step(t)
                    return sts_util.sum_mvns([
                        tfd.MultivariateNormalDiag(
                            loc=offset, scale_diag=tf.zeros_like(offset))
                    ] + [
                        ssm.get_observation_noise_for_timestep(t)
                        for ssm in component_ssms
                    ])

            super(AdditiveStateSpaceModel,
                  self).__init__(num_timesteps=num_timesteps,
                                 transition_matrix=transition_matrix_fn,
                                 transition_noise=transition_noise_fn,
                                 observation_matrix=observation_matrix_fn,
                                 observation_noise=observation_noise_fn,
                                 initial_state_prior=initial_state_prior,
                                 initial_step=initial_step,
                                 validate_args=validate_args,
                                 name=name,
                                 **linear_gaussian_ssm_kwargs)
            self._parameters = parameters
Esempio n. 7
0
  def __init__(self,
               component_ssms,
               observation_noise_scale=None,
               initial_state_prior=None,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".

    Raises:
      ValueError: if components have different `num_timesteps`.
    """

    with tf.compat.v1.name_scope(
        name,
        'AdditiveStateSpaceModel',
        values=[observation_noise_scale, initial_step]) as name:

      assertions = []

      # Check that all components have the same dtype
      tf.debugging.assert_same_float_dtype(component_ssms)

      # Construct an initial state prior as a block-diagonal combination
      # of the component state priors.
      if initial_state_prior is None:
        initial_state_prior = sts_util.factored_joint_mvn(
            [ssm.initial_state_prior for ssm in component_ssms])
      dtype = initial_state_prior.dtype

      static_num_timesteps = [
          tf.get_static_value(ssm.num_timesteps)
          for ssm in component_ssms
          if tf.get_static_value(ssm.num_timesteps) is not None
      ]

      # If any components have a static value for `num_timesteps`, use that
      # value for the additive model. (and check that all other static values
      # match it).
      if static_num_timesteps:
        num_timesteps = static_num_timesteps[0]
        if not all([component_timesteps == num_timesteps
                    for component_timesteps in static_num_timesteps]):
          raise ValueError('Additive model components must all have the same '
                           'number of timesteps '
                           '(saw: {})'.format(static_num_timesteps))
      else:
        num_timesteps = component_ssms[0].num_timesteps
      if validate_args and len(static_num_timesteps) != len(component_ssms):
        assertions += [
            tf.compat.v1.assert_equal(
                num_timesteps,
                ssm.num_timesteps,
                message='Additive model components must all have '
                'the same number of timesteps.') for ssm in component_ssms
        ]

      # Define the transition and observation models for the additive SSM.
      # See the "mathematical details" section of the class docstring for
      # further information. Note that we define these as callables to
      # handle the fully general case in which some components have time-
      # varying dynamics.
      def transition_matrix_fn(t):
        return tfl.LinearOperatorBlockDiag(
            [ssm.get_transition_matrix_for_timestep(t)
             for ssm in component_ssms])

      def transition_noise_fn(t):
        return sts_util.factored_joint_mvn(
            [ssm.get_transition_noise_for_timestep(t)
             for ssm in component_ssms])

      # Build the observation matrix, concatenating (broadcast) observation
      # matrices from components. We also take this as an opportunity to enforce
      # any dynamic assertions we may have generated above.
      broadcast_batch_shape = tf.convert_to_tensor(
          value=sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32)
      broadcast_obs_matrix = tf.ones(
          tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype)
      if assertions:
        with tf.control_dependencies(assertions):
          broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

      def observation_matrix_fn(t):
        return tfl.LinearOperatorFullMatrix(
            tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() *
                       broadcast_obs_matrix for ssm in component_ssms],
                      axis=-1))

      if observation_noise_scale is not None:
        observation_noise_scale = tf.convert_to_tensor(
            value=observation_noise_scale,
            name='observation_noise_scale',
            dtype=dtype)
        def observation_noise_fn(t):
          return tfd.MultivariateNormalDiag(
              loc=sum([ssm.get_observation_noise_for_timestep(t).mean()
                       for ssm in component_ssms]),
              scale_diag=observation_noise_scale[..., tf.newaxis])
      else:
        def observation_noise_fn(t):
          return sts_util.sum_mvns(
              [ssm.get_observation_noise_for_timestep(t)
               for ssm in component_ssms])

      super(AdditiveStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=transition_matrix_fn,
          transition_noise=transition_noise_fn,
          observation_matrix=observation_matrix_fn,
          observation_noise=observation_noise_fn,
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)
Esempio n. 8
0
 def transition_noise_fn(t):
   return sts_util.factored_joint_mvn(
       [ssm.get_transition_noise_for_timestep(t)
        for ssm in component_ssms])
Esempio n. 9
0
  def __init__(self,
               component_ssms,
               observation_noise_scale=None,
               initial_state_prior=None,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".

    Raises:
      ValueError: if components have different `num_timesteps`.
    """

    with tf.name_scope(name, 'AdditiveStateSpaceModel',
                       values=[observation_noise_scale, initial_step]) as name:

      assertions = []

      # Check that all components have the same dtype
      tf.assert_same_float_dtype(component_ssms)

      # Construct an initial state prior as a block-diagonal combination
      # of the component state priors.
      if initial_state_prior is None:
        initial_state_prior = sts_util.factored_joint_mvn(
            [ssm.initial_state_prior for ssm in component_ssms])
      dtype = initial_state_prior.dtype

      static_num_timesteps = [
          distribution_util.static_value(ssm.num_timesteps)
          for ssm in component_ssms
          if distribution_util.static_value(ssm.num_timesteps) is not None
      ]

      # If any components have a static value for `num_timesteps`, use that
      # value for the additive model. (and check that all other static values
      # match it).
      if static_num_timesteps:
        num_timesteps = static_num_timesteps[0]
        if not all([component_timesteps == num_timesteps
                    for component_timesteps in static_num_timesteps]):
          raise ValueError('Additive model components must all have the same '
                           'number of timesteps '
                           '(saw: {})'.format(static_num_timesteps))
      else:
        num_timesteps = component_ssms[0].num_timesteps
      if validate_args and len(static_num_timesteps) != len(component_ssms):
        assertions += [
            tf.assert_equal(num_timesteps,
                            ssm.num_timesteps,
                            message='Additive model components must all have '
                            'the same number of timesteps.')
            for ssm in component_ssms]

      # Define the transition and observation models for the additive SSM.
      # See the "mathematical details" section of the class docstring for
      # further information. Note that we define these as callables to
      # handle the fully general case in which some components have time-
      # varying dynamics.
      def transition_matrix_fn(t):
        return tfl.LinearOperatorBlockDiag(
            [ssm.get_transition_matrix_for_timestep(t)
             for ssm in component_ssms])

      def transition_noise_fn(t):
        return sts_util.factored_joint_mvn(
            [ssm.get_transition_noise_for_timestep(t)
             for ssm in component_ssms])

      # Build the observation matrix, concatenating (broadcast) observation
      # matrices from components. We also take this as an opportunity to enforce
      # any dynamic assertions we may have generated above.
      broadcast_batch_shape = tf.convert_to_tensor(
          sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32)
      broadcast_obs_matrix = tf.ones(
          tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype)
      if assertions:
        with tf.control_dependencies(assertions):
          broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

      def observation_matrix_fn(t):
        return tfl.LinearOperatorFullMatrix(
            tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() *
                       broadcast_obs_matrix for ssm in component_ssms],
                      axis=-1))

      if observation_noise_scale is not None:
        observation_noise_scale = tf.convert_to_tensor(
            observation_noise_scale,
            name='observation_noise_scale',
            dtype=dtype)
        def observation_noise_fn(t):
          return tfd.MultivariateNormalDiag(
              loc=sum([ssm.get_observation_noise_for_timestep(t).mean()
                       for ssm in component_ssms]),
              scale_diag=observation_noise_scale[..., tf.newaxis])
      else:
        def observation_noise_fn(t):
          return sts_util.sum_mvns(
              [ssm.get_observation_noise_for_timestep(t)
               for ssm in component_ssms])

      super(AdditiveStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=transition_matrix_fn,
          transition_noise=transition_noise_fn,
          observation_matrix=observation_matrix_fn,
          observation_noise=observation_noise_fn,
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)
def _pad_mvn_with_trailing_zeros(mvn, num_zeros):
  zeros = tf.zeros([num_zeros], dtype=mvn.dtype)
  return sts_util.factored_joint_mvn(
      [mvn,
       tfd.MultivariateNormalDiag(loc=zeros, scale_diag=zeros)])