def determine_batch_event_shapes(grid, endpoint_affine):
    """Helper to infer batch_shape and event_shape."""
    with tf.name_scope("determine_batch_event_shapes"):
        # grid  # shape: [B, k, q]
        # endpoint_affine     # len=k, shape: [B, d, d]
        batch_shape = grid.shape[:-2]
        batch_shape_tensor = tf.shape(grid)[:-2]
        event_shape = None
        event_shape_tensor = None

        def _set_event_shape(shape, shape_tensor):
            if event_shape is None:
                return shape, shape_tensor
            return (tf.broadcast_static_shape(event_shape, shape),
                    tf.broadcast_dynamic_shape(event_shape_tensor,
                                               shape_tensor))

        for aff in endpoint_affine:
            if aff.shift is not None:
                batch_shape = tf.broadcast_static_shape(
                    batch_shape, aff.shift.shape[:-1])
                batch_shape_tensor = tf.broadcast_dynamic_shape(
                    batch_shape_tensor,
                    tf.shape(aff.shift)[:-1])
                event_shape, event_shape_tensor = _set_event_shape(
                    aff.shift.shape[-1:],
                    tf.shape(aff.shift)[-1:])

            if aff.scale is not None:
                batch_shape = tf.broadcast_static_shape(
                    batch_shape, aff.scale.batch_shape)
                batch_shape_tensor = tf.broadcast_dynamic_shape(
                    batch_shape_tensor, aff.scale.batch_shape_tensor())
                event_shape, event_shape_tensor = _set_event_shape(
                    tf.TensorShape([aff.scale.range_dimension]),
                    aff.scale.range_dimension_tensor()[tf.newaxis])

        return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
 def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims,
                                **distribution_kwargs):
     """Finish computation of prob on one element of the inverse image."""
     x = self._maybe_rotate_dims(x, rotate_right=True)
     prob = self.distribution.prob(x, **distribution_kwargs)
     if self._is_maybe_event_override:
         prob = tf.reduce_prod(prob, axis=self._reduce_event_indices)
     prob = prob * tf.exp(tf.cast(ildj, prob.dtype))
     if self._is_maybe_event_override and isinstance(event_ndims, int):
         tensorshape_util.set_shape(
             prob,
             tf.broadcast_static_shape(
                 tensorshape_util.with_rank_at_least(y.shape,
                                                     1)[:-event_ndims],
                 self.batch_shape))
     return prob
Exemplo n.º 3
0
def broadcast_shape(x_shape, y_shape):
    """Computes the shape of a broadcast.

  When both arguments are statically-known, the broadcasted shape will be
  computed statically and returned as a `TensorShape`.  Otherwise, a rank-1
  `Tensor` will be returned.

  Arguments:
    x_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.
    y_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.

  Returns:
    shape: A `TensorShape` or rank-1 integer `Tensor` representing the
      broadcasted shape.
  """
    x_shape_static = tf.get_static_value(x_shape)
    y_shape_static = tf.get_static_value(y_shape)
    if (x_shape_static is None) or (y_shape_static is None):
        return tf.broadcast_dynamic_shape(x_shape, y_shape)

    return tf.broadcast_static_shape(tf.TensorShape(x_shape_static),
                                     tf.TensorShape(y_shape_static))
Exemplo n.º 4
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(self.concentration.shape,
                                      self.rate.shape)
Exemplo n.º 5
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(
         tf.broadcast_static_shape(self.df.shape, self.loc.shape),
         self.scale.shape)
 def _set_event_shape(shape, shape_tensor):
     if event_shape is None:
         return shape, shape_tensor
     return (tf.broadcast_static_shape(event_shape, shape),
             tf.broadcast_dynamic_shape(event_shape_tensor,
                                        shape_tensor))
Exemplo n.º 7
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(
         self.loc.shape,
         tf.broadcast_static_shape(self.atol.shape, self.rtol.shape))[:-1]
Exemplo n.º 8
0
 def _batch_shape(self):
     return tensorshape_util.with_rank_at_least(
         tf.broadcast_static_shape(
             tf.TensorShape(self._total_count.shape).concatenate([1]),
             tf.TensorShape(self._concentration.shape)), 1)[:-1]
 def _batch_shape(self):
     dist, mixture_dist = self.poisson_and_mixture_distributions()
     return tf.broadcast_static_shape(dist.batch_shape,
                                      mixture_dist.logits.shape)[:-1]
Exemplo n.º 10
0
 def _batch_shape(self):
   return tf.broadcast_static_shape(
       tensorshape_util.with_rank_at_least(self.mean_direction.shape, 1)[:-1],
       self.concentration.shape)
Exemplo n.º 11
0
def _kl_brute_force(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for multivariate Normals.

  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
  covariance `C_a`, `C_b` respectively,

  ```
  KL(a || b) = 0.5 * ( L - k + T + Q ),
  L := Log[Det(C_b)] - Log[Det(C_a)]
  T := trace(C_b^{-1} C_a),
  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
  ```

  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
  methods for solving systems with `C_b` may be available, a dense version of
  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
  and `y`.

  Args:
    a: Instance of `MultivariateNormalLinearOperator`.
    b: Instance of `MultivariateNormalLinearOperator`.
    name: (optional) name to use for created ops. Default "kl_mvn".

  Returns:
    Batchwise `KL(a || b)`.
  """
    def squared_frobenius_norm(x):
        """Helper to make KL calculation slightly more readable."""
        # http://mathworld.wolfram.com/FrobeniusNorm.html
        # The gradient of KL[p,q] is not defined when p==q. The culprit is
        # tf.norm, i.e., we cannot use the commented out code.
        # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1]))
        return tf.reduce_sum(tf.square(x), axis=[-2, -1])

    # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
    # supports something like:
    #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
    def is_diagonal(x):
        """Helper to identify if `LinearOperator` has only a diagonal component."""
        return (isinstance(x, tf.linalg.LinearOperatorIdentity)
                or isinstance(x, tf.linalg.LinearOperatorScaledIdentity)
                or isinstance(x, tf.linalg.LinearOperatorDiag))

    with tf.name_scope(name or 'kl_mvn'):
        # Calculation is based on:
        # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
        # and,
        # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
        # i.e.,
        #   If Ca = AA', Cb = BB', then
        #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
        #                  = tr[inv(B) A A' inv(B)']
        #                  = tr[(inv(B) A) (inv(B) A)']
        #                  = sum_{ij} (inv(B) A)_{ij}**2
        #                  = ||inv(B) A||_F**2
        # where ||.||_F is the Frobenius norm and the second equality follows from
        # the cyclic permutation property.
        if is_diagonal(a.scale) and is_diagonal(b.scale):
            # Using `stddev` because it handles expansion of Identity cases.
            b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis]
        else:
            b_inv_a = b.scale.solve(a.scale.to_dense())
        kl_div = (b.scale.log_abs_determinant() -
                  a.scale.log_abs_determinant() + 0.5 *
                  (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) +
                   squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
                       b.scale.solve((b.mean() - a.mean())[..., tf.newaxis]))))
        tensorshape_util.set_shape(
            kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape))
        return kl_div
Exemplo n.º 12
0
 def _batch_shape(self):
   x = self._probs if self._logits is None else self._logits
   return tf.broadcast_static_shape(self.total_count.shape, x.shape)
Exemplo n.º 13
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(batch_shape) - 2
        sample_shape = tf.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) *
            tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat(
            [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Exemplo n.º 14
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(self.df.shape,
                                      self.scale_operator.batch_shape)
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Exemplo n.º 16
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(self.low.shape, self.high.shape)
Exemplo n.º 17
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(
         (self._probs if self._logits is None else self._logits).shape[:-1],
         self.total_count.shape)
Exemplo n.º 18
0
 def _batch_shape(self):
   param = self._logits if self._logits is not None else self._probs
   return tf.broadcast_static_shape(self.temperature.shape, param.shape[:-1])