Beispiel #1
0
  def test_broadcast_batch_shape_static(self):

    batch_shapes = ([2], [3, 2], [1, 2])
    distributions = [
        tfd.Normal(loc=tf.zeros(batch_shape), scale=tf.ones(batch_shape))
        for batch_shape in batch_shapes
    ]
    self.assertEqual(sts_util.broadcast_batch_shape(distributions), [3, 2])
Beispiel #2
0
    def test_broadcast_batch_shape_static(self):

        batch_shapes = ([2], [3, 2], [1, 2])
        distributions = [
            tfd.Normal(loc=tf.zeros(batch_shape), scale=tf.ones(batch_shape))
            for batch_shape in batch_shapes
        ]
        self.assertEqual(sts_util.broadcast_batch_shape(distributions), [3, 2])
Beispiel #3
0
 def test_broadcast_batch_shape(self):
     batch_shapes = ([2], [3, 2], [1, 2])
     distributions = [
         tfd.Normal(loc=self._build_tensor(np.zeros(batch_shape)),
                    scale=self._build_tensor(np.ones(batch_shape)))
         for batch_shape in batch_shapes
     ]
     if self.use_static_shape:
         self.assertEqual([3, 2],
                          sts_util.broadcast_batch_shape(distributions))
     else:
         broadcast_batch_shape = sts_util.broadcast_batch_shape(
             distributions)
         # Broadcast shape in Eager can contain Python `int`s, so we need to
         # explicitly convert to Tensor.
         self.assertAllEqual(
             [3, 2],
             self.evaluate(
                 tf.convert_to_tensor(value=broadcast_batch_shape)))
Beispiel #4
0
  def test_broadcast_batch_shape_dynamic(self):
    # Run in graph mode only, since eager mode always takes the static path

    batch_shapes = ([2], [3, 2], [1, 2])
    distributions = [tfd.Normal(
        loc=tf.placeholder_with_default(
            input=tf.zeros(batch_shape), shape=None),
        scale=tf.placeholder_with_default(
            input=tf.ones(batch_shape), shape=None))
                     for batch_shape in batch_shapes]

    self.assertAllEqual(self.evaluate(
        sts_util.broadcast_batch_shape(distributions)),
                        [3, 2])
Beispiel #5
0
    def test_broadcast_batch_shape_dynamic(self):
        # Run in graph mode only, since eager mode always takes the static path
        if tf.executing_eagerly(): return

        batch_shapes = ([2], [3, 2], [1, 2])
        distributions = [
            tfd.Normal(loc=tf.compat.v1.placeholder_with_default(
                input=tf.zeros(batch_shape), shape=None),
                       scale=tf.compat.v1.placeholder_with_default(
                           input=tf.ones(batch_shape), shape=None))
            for batch_shape in batch_shapes
        ]

        self.assertAllEqual([3, 2],
                            self.evaluate(
                                sts_util.broadcast_batch_shape(distributions)))
    def __init__(self,
                 component_ssms,
                 constant_offset=0.,
                 observation_noise_scale=None,
                 initial_state_prior=None,
                 initial_step=0,
                 validate_args=False,
                 name=None,
                 **linear_gaussian_ssm_kwargs):
        """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      constant_offset: `float` `Tensor` of shape broadcasting to
        `concat([batch_shape, [num_timesteps]]`) specifying a constant value
        added to the sum of outputs from the component models. This allows the
        components to model the shifted series
        `observed_time_series - constant_offset`.
        Default value: `0.`
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".
      **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to
        to the base `tfd.LinearGaussianStateSpaceModel` constructor.
    Raises:
      ValueError: if components have different `num_timesteps`.
    """
        parameters = dict(locals())
        parameters.update(linear_gaussian_ssm_kwargs)
        del parameters['linear_gaussian_ssm_kwargs']
        with tf.name_scope(name or 'AdditiveStateSpaceModel') as name:
            # Check that all components have the same dtype
            dtype = tf.debugging.assert_same_float_dtype(component_ssms)

            # Convert scalar offsets to canonical shape `[..., num_timesteps]`.
            constant_offset = (tf.convert_to_tensor(
                value=constant_offset, name='constant_offset', dtype=dtype) *
                               tf.ones([1], dtype=dtype))
            offset_length = prefer_static.shape(constant_offset)[-1]
            assertions = []

            # Construct an initial state prior as a block-diagonal combination
            # of the component state priors.
            if initial_state_prior is None:
                initial_state_prior = sts_util.factored_joint_mvn(
                    [ssm.initial_state_prior for ssm in component_ssms])
            dtype = initial_state_prior.dtype

            static_num_timesteps = [
                tf.get_static_value(ssm.num_timesteps)
                for ssm in component_ssms
                if tf.get_static_value(ssm.num_timesteps) is not None
            ]

            # If any components have a static value for `num_timesteps`, use that
            # value for the additive model. (and check that all other static values
            # match it).
            if static_num_timesteps:
                num_timesteps = static_num_timesteps[0]
                if not all([
                        component_timesteps == num_timesteps
                        for component_timesteps in static_num_timesteps
                ]):
                    raise ValueError(
                        'Additive model components must all have the same '
                        'number of timesteps '
                        '(saw: {})'.format(static_num_timesteps))
            else:
                num_timesteps = component_ssms[0].num_timesteps
            if validate_args and len(static_num_timesteps) != len(
                    component_ssms):
                assertions += [
                    tf.debugging.assert_equal(  # pylint: disable=g-complex-comprehension
                        num_timesteps,
                        ssm.num_timesteps,
                        message='Additive model components must all have '
                        'the same number of timesteps.')
                    for ssm in component_ssms
                ]

            # Define the transition and observation models for the additive SSM.
            # See the "mathematical details" section of the class docstring for
            # further information. Note that we define these as callables to
            # handle the fully general case in which some components have time-
            # varying dynamics.
            def transition_matrix_fn(t):
                return tfl.LinearOperatorBlockDiag([
                    ssm.get_transition_matrix_for_timestep(t)
                    for ssm in component_ssms
                ])

            def transition_noise_fn(t):
                return sts_util.factored_joint_mvn([
                    ssm.get_transition_noise_for_timestep(t)
                    for ssm in component_ssms
                ])

            # Build the observation matrix, concatenating (broadcast) observation
            # matrices from components. We also take this as an opportunity to enforce
            # any dynamic assertions we may have generated above.
            broadcast_batch_shape = tf.convert_to_tensor(
                value=sts_util.broadcast_batch_shape([
                    ssm.get_observation_matrix_for_timestep(initial_step)
                    for ssm in component_ssms
                ]),
                dtype=tf.int32)
            broadcast_obs_matrix = tf.ones(tf.concat(
                [broadcast_batch_shape, [1, 1]], axis=0),
                                           dtype=dtype)
            if assertions:
                with tf.control_dependencies(assertions):
                    broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

            def observation_matrix_fn(t):
                return tfl.LinearOperatorFullMatrix(
                    tf.concat([
                        ssm.get_observation_matrix_for_timestep(t).to_dense() *
                        broadcast_obs_matrix for ssm in component_ssms
                    ],
                              axis=-1))

            # Broadcast the constant offset across timesteps.
            offset_at_step = lambda t: (  # pylint: disable=g-long-lambda
                constant_offset if offset_length == 1 else tf.gather(
                    constant_offset, tf.minimum(t, offset_length - 1), axis=-1)
                [..., tf.newaxis])

            if observation_noise_scale is not None:
                observation_noise_scale = tf.convert_to_tensor(
                    value=observation_noise_scale,
                    name='observation_noise_scale',
                    dtype=dtype)

                def observation_noise_fn(t):
                    return tfd.MultivariateNormalDiag(
                        loc=(sum([
                            ssm.get_observation_noise_for_timestep(t).mean()
                            for ssm in component_ssms
                        ]) + offset_at_step(t)),
                        scale_diag=observation_noise_scale[..., tf.newaxis])
            else:

                def observation_noise_fn(t):
                    offset = offset_at_step(t)
                    return sts_util.sum_mvns([
                        tfd.MultivariateNormalDiag(
                            loc=offset, scale_diag=tf.zeros_like(offset))
                    ] + [
                        ssm.get_observation_noise_for_timestep(t)
                        for ssm in component_ssms
                    ])

            super(AdditiveStateSpaceModel,
                  self).__init__(num_timesteps=num_timesteps,
                                 transition_matrix=transition_matrix_fn,
                                 transition_noise=transition_noise_fn,
                                 observation_matrix=observation_matrix_fn,
                                 observation_noise=observation_noise_fn,
                                 initial_state_prior=initial_state_prior,
                                 initial_step=initial_step,
                                 validate_args=validate_args,
                                 name=name,
                                 **linear_gaussian_ssm_kwargs)
            self._parameters = parameters
Beispiel #7
0
  def __init__(self,
               component_ssms,
               observation_noise_scale=None,
               initial_state_prior=None,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".

    Raises:
      ValueError: if components have different `num_timesteps`.
    """

    with tf.compat.v1.name_scope(
        name,
        'AdditiveStateSpaceModel',
        values=[observation_noise_scale, initial_step]) as name:

      assertions = []

      # Check that all components have the same dtype
      tf.debugging.assert_same_float_dtype(component_ssms)

      # Construct an initial state prior as a block-diagonal combination
      # of the component state priors.
      if initial_state_prior is None:
        initial_state_prior = sts_util.factored_joint_mvn(
            [ssm.initial_state_prior for ssm in component_ssms])
      dtype = initial_state_prior.dtype

      static_num_timesteps = [
          tf.get_static_value(ssm.num_timesteps)
          for ssm in component_ssms
          if tf.get_static_value(ssm.num_timesteps) is not None
      ]

      # If any components have a static value for `num_timesteps`, use that
      # value for the additive model. (and check that all other static values
      # match it).
      if static_num_timesteps:
        num_timesteps = static_num_timesteps[0]
        if not all([component_timesteps == num_timesteps
                    for component_timesteps in static_num_timesteps]):
          raise ValueError('Additive model components must all have the same '
                           'number of timesteps '
                           '(saw: {})'.format(static_num_timesteps))
      else:
        num_timesteps = component_ssms[0].num_timesteps
      if validate_args and len(static_num_timesteps) != len(component_ssms):
        assertions += [
            tf.compat.v1.assert_equal(
                num_timesteps,
                ssm.num_timesteps,
                message='Additive model components must all have '
                'the same number of timesteps.') for ssm in component_ssms
        ]

      # Define the transition and observation models for the additive SSM.
      # See the "mathematical details" section of the class docstring for
      # further information. Note that we define these as callables to
      # handle the fully general case in which some components have time-
      # varying dynamics.
      def transition_matrix_fn(t):
        return tfl.LinearOperatorBlockDiag(
            [ssm.get_transition_matrix_for_timestep(t)
             for ssm in component_ssms])

      def transition_noise_fn(t):
        return sts_util.factored_joint_mvn(
            [ssm.get_transition_noise_for_timestep(t)
             for ssm in component_ssms])

      # Build the observation matrix, concatenating (broadcast) observation
      # matrices from components. We also take this as an opportunity to enforce
      # any dynamic assertions we may have generated above.
      broadcast_batch_shape = tf.convert_to_tensor(
          value=sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32)
      broadcast_obs_matrix = tf.ones(
          tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype)
      if assertions:
        with tf.control_dependencies(assertions):
          broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

      def observation_matrix_fn(t):
        return tfl.LinearOperatorFullMatrix(
            tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() *
                       broadcast_obs_matrix for ssm in component_ssms],
                      axis=-1))

      if observation_noise_scale is not None:
        observation_noise_scale = tf.convert_to_tensor(
            value=observation_noise_scale,
            name='observation_noise_scale',
            dtype=dtype)
        def observation_noise_fn(t):
          return tfd.MultivariateNormalDiag(
              loc=sum([ssm.get_observation_noise_for_timestep(t).mean()
                       for ssm in component_ssms]),
              scale_diag=observation_noise_scale[..., tf.newaxis])
      else:
        def observation_noise_fn(t):
          return sts_util.sum_mvns(
              [ssm.get_observation_noise_for_timestep(t)
               for ssm in component_ssms])

      super(AdditiveStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=transition_matrix_fn,
          transition_noise=transition_noise_fn,
          observation_matrix=observation_matrix_fn,
          observation_noise=observation_noise_fn,
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)
Beispiel #8
0
  def __init__(self,
               component_ssms,
               observation_noise_scale=None,
               initial_state_prior=None,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".

    Raises:
      ValueError: if components have different `num_timesteps`.
    """

    with tf.name_scope(name, 'AdditiveStateSpaceModel',
                       values=[observation_noise_scale, initial_step]) as name:

      assertions = []

      # Check that all components have the same dtype
      tf.assert_same_float_dtype(component_ssms)

      # Construct an initial state prior as a block-diagonal combination
      # of the component state priors.
      if initial_state_prior is None:
        initial_state_prior = sts_util.factored_joint_mvn(
            [ssm.initial_state_prior for ssm in component_ssms])
      dtype = initial_state_prior.dtype

      static_num_timesteps = [
          distribution_util.static_value(ssm.num_timesteps)
          for ssm in component_ssms
          if distribution_util.static_value(ssm.num_timesteps) is not None
      ]

      # If any components have a static value for `num_timesteps`, use that
      # value for the additive model. (and check that all other static values
      # match it).
      if static_num_timesteps:
        num_timesteps = static_num_timesteps[0]
        if not all([component_timesteps == num_timesteps
                    for component_timesteps in static_num_timesteps]):
          raise ValueError('Additive model components must all have the same '
                           'number of timesteps '
                           '(saw: {})'.format(static_num_timesteps))
      else:
        num_timesteps = component_ssms[0].num_timesteps
      if validate_args and len(static_num_timesteps) != len(component_ssms):
        assertions += [
            tf.assert_equal(num_timesteps,
                            ssm.num_timesteps,
                            message='Additive model components must all have '
                            'the same number of timesteps.')
            for ssm in component_ssms]

      # Define the transition and observation models for the additive SSM.
      # See the "mathematical details" section of the class docstring for
      # further information. Note that we define these as callables to
      # handle the fully general case in which some components have time-
      # varying dynamics.
      def transition_matrix_fn(t):
        return tfl.LinearOperatorBlockDiag(
            [ssm.get_transition_matrix_for_timestep(t)
             for ssm in component_ssms])

      def transition_noise_fn(t):
        return sts_util.factored_joint_mvn(
            [ssm.get_transition_noise_for_timestep(t)
             for ssm in component_ssms])

      # Build the observation matrix, concatenating (broadcast) observation
      # matrices from components. We also take this as an opportunity to enforce
      # any dynamic assertions we may have generated above.
      broadcast_batch_shape = tf.convert_to_tensor(
          sts_util.broadcast_batch_shape(component_ssms), dtype=tf.int32)
      broadcast_obs_matrix = tf.ones(
          tf.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype)
      if assertions:
        with tf.control_dependencies(assertions):
          broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

      def observation_matrix_fn(t):
        return tfl.LinearOperatorFullMatrix(
            tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() *
                       broadcast_obs_matrix for ssm in component_ssms],
                      axis=-1))

      if observation_noise_scale is not None:
        observation_noise_scale = tf.convert_to_tensor(
            observation_noise_scale,
            name='observation_noise_scale',
            dtype=dtype)
        def observation_noise_fn(t):
          return tfd.MultivariateNormalDiag(
              loc=sum([ssm.get_observation_noise_for_timestep(t).mean()
                       for ssm in component_ssms]),
              scale_diag=observation_noise_scale[..., tf.newaxis])
      else:
        def observation_noise_fn(t):
          return sts_util.sum_mvns(
              [ssm.get_observation_noise_for_timestep(t)
               for ssm in component_ssms])

      super(AdditiveStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=transition_matrix_fn,
          transition_noise=transition_noise_fn,
          observation_matrix=observation_matrix_fn,
          observation_noise=observation_noise_fn,
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)