def _batch_shape_tensor(self):
     with tf.control_dependencies(self._runtime_assertions):
         return tf.broadcast_dynamic_shape(
             self._initial_distribution.batch_shape_tensor(),
             tf.broadcast_dynamic_shape(
                 self._transition_distribution.batch_shape_tensor()[:-1],
                 self._observation_distribution.batch_shape_tensor()[:-1]))
 def _batch_shape_tensor(self, distributions=None):
     if distributions is None:
         distributions = self.poisson_and_mixture_distributions()
     dist, mixture_dist = distributions
     return tf.broadcast_dynamic_shape(
         dist.batch_shape_tensor(),
         prefer_static.shape(mixture_dist.logits))[:-1]
Esempio n. 3
0
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
Esempio n. 4
0
 def _batch_shape_tensor(self, concentration=None, total_count=None):
     if concentration is None:
         concentration = tf.convert_to_tensor(self._concentration)
     if total_count is None:
         total_count = tf.convert_to_tensor(self._total_count)
     return tf.broadcast_dynamic_shape(
         tf.shape(total_count[..., tf.newaxis]),
         tf.shape(concentration))[:-1]
Esempio n. 5
0
 def _cdf(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     broadcast_shape = tf.broadcast_dynamic_shape(
         tf.shape(x), self._batch_shape_tensor(low=low, high=high))
     zeros = tf.zeros(broadcast_shape, dtype=self.dtype)
     ones = tf.ones(broadcast_shape, dtype=self.dtype)
     result_if_not_big = tf.where(x < low, zeros, (x - low) /
                                  self._range(low=low, high=high))
     return tf.where(x >= high, ones, result_if_not_big)
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
def determine_batch_event_shapes(grid, endpoint_affine):
    """Helper to infer batch_shape and event_shape."""
    with tf.name_scope("determine_batch_event_shapes"):
        # grid  # shape: [B, k, q]
        # endpoint_affine     # len=k, shape: [B, d, d]
        batch_shape = grid.shape[:-2]
        batch_shape_tensor = tf.shape(grid)[:-2]
        event_shape = None
        event_shape_tensor = None

        def _set_event_shape(shape, shape_tensor):
            if event_shape is None:
                return shape, shape_tensor
            return (tf.broadcast_static_shape(event_shape, shape),
                    tf.broadcast_dynamic_shape(event_shape_tensor,
                                               shape_tensor))

        for aff in endpoint_affine:
            if aff.shift is not None:
                batch_shape = tf.broadcast_static_shape(
                    batch_shape, aff.shift.shape[:-1])
                batch_shape_tensor = tf.broadcast_dynamic_shape(
                    batch_shape_tensor,
                    tf.shape(aff.shift)[:-1])
                event_shape, event_shape_tensor = _set_event_shape(
                    aff.shift.shape[-1:],
                    tf.shape(aff.shift)[-1:])

            if aff.scale is not None:
                batch_shape = tf.broadcast_static_shape(
                    batch_shape, aff.scale.batch_shape)
                batch_shape_tensor = tf.broadcast_dynamic_shape(
                    batch_shape_tensor, aff.scale.batch_shape_tensor())
                event_shape, event_shape_tensor = _set_event_shape(
                    tf.TensorShape([aff.scale.range_dimension]),
                    aff.scale.range_dimension_tensor()[tf.newaxis])

        return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
def broadcast_shape(x_shape, y_shape):
    """Computes the shape of a broadcast.

  When both arguments are statically-known, the broadcasted shape will be
  computed statically and returned as a `TensorShape`.  Otherwise, a rank-1
  `Tensor` will be returned.

  Arguments:
    x_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.
    y_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.

  Returns:
    shape: A `TensorShape` or rank-1 integer `Tensor` representing the
      broadcasted shape.
  """
    x_shape_static = tf.get_static_value(x_shape)
    y_shape_static = tf.get_static_value(y_shape)
    if (x_shape_static is None) or (y_shape_static is None):
        return tf.broadcast_dynamic_shape(x_shape, y_shape)

    return tf.broadcast_static_shape(tf.TensorShape(x_shape_static),
                                     tf.TensorShape(y_shape_static))
Esempio n. 9
0
 def _batch_shape_tensor(self, low=None, high=None):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.low if low is None else low),
         tf.shape(self.high if high is None else high))
 def _set_event_shape(shape, shape_tensor):
     if event_shape is None:
         return shape, shape_tensor
     return (tf.broadcast_static_shape(event_shape, shape),
             tf.broadcast_dynamic_shape(event_shape_tensor,
                                        shape_tensor))
Esempio n. 11
0
 def _batch_shape_tensor(self, loc=None):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.loc if loc is None else loc),
         tf.broadcast_dynamic_shape(tf.shape(self.atol),
                                    tf.shape(self.rtol)))[:-1]
Esempio n. 12
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(
         tf.shape(self._probs if self._logits is None else self._logits)
         [:-1], tf.shape(self.total_count))
Esempio n. 13
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.df), self.scale_operator.batch_shape_tensor())
Esempio n. 14
0
 def _batch_shape_tensor(self, loc=None, scale=None):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.loc if loc is None else loc),
         tf.shape(self.scale if scale is None else scale))
    def _observation_log_probs(self, observations, mask):
        """Compute and shape tensor of log probs associated with observations.."""

        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            observation_log_probs = tf.where(
                mask[..., tf.newaxis], tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
Esempio n. 16
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(tf.shape(self.concentration),
                                       tf.shape(self.scale))
Esempio n. 17
0
 def _batch_shape_tensor(self, concentration=None, scale=None):
   return tf.broadcast_dynamic_shape(
       tf.shape(
           self.concentration if concentration is None else concentration),
       tf.shape(self.scale if scale is None else scale))
Esempio n. 18
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
Esempio n. 19
0
 def _batch_shape_tensor(self):
   x = self._probs if self._logits is None else self._logits
   return tf.broadcast_dynamic_shape(
       tf.shape(self._total_count), tf.shape(x))
 def _batch_shape_tensor(self):
   return tf.broadcast_dynamic_shape(
       tf.shape(self.mean_direction)[:-1], tf.shape(self.concentration))
 def _batch_shape_tensor(self, loc=None, concentration=None):
   return tf.broadcast_dynamic_shape(
       tf.shape(self.loc if loc is None else loc),
       tf.shape(
           self.concentration if concentration is None else concentration))