예제 #1
0
 def _call_reshape_input_output(self, fn, x, extra_kwargs=None):
     """Calls `fn`, appropriately reshaping its input `x` and output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn` and/or `x` as a key.
     with tf.control_dependencies(self._runtime_assertions +
                                  self._validate_sample_arg(x)):
         sample_shape, static_sample_shape = self._sample_shape(x)
         old_shape = tf.concat([
             sample_shape,
             self.distribution.batch_shape_tensor(),
             self.event_shape_tensor(),
         ],
                               axis=0)
         x_reshape = tf.reshape(x, old_shape)
         result = fn(x_reshape, **
                     extra_kwargs) if extra_kwargs else fn(x_reshape)
         new_shape = tf.concat([
             sample_shape,
             self._batch_shape_unexpanded,
         ],
                               axis=0)
         result = tf.reshape(result, new_shape)
         if (tensorshape_util.rank(static_sample_shape) is not None
                 and tensorshape_util.rank(self.batch_shape) is not None):
             new_shape = tensorshape_util.concatenate(
                 static_sample_shape, self.batch_shape)
             tensorshape_util.set_shape(result, new_shape)
         return result
예제 #2
0
def _sparse_tensor_dense_matmul(sp_a, b, **kwargs):
    """Returns (batched) matmul of a SparseTensor with a Tensor.

  Args:
    sp_a: `SparseTensor` representing a (batch of) matrices.
    b: `Tensor` representing a (batch of) matrices, with the same batch shape of
      `sp_a`. The shape must be compatible with the shape of `sp_a` and kwargs.
    **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul`.

  Returns:
    product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
    dtype as `sp_a` and `b`. If `sp_a` or `b` is adjointed through `kwargs` then
    the shape is adjusted accordingly.
  """
    batch_shape = _get_shape(sp_a)[:-2]

    # Reshape the SparseTensor into a rank 3 SparseTensors, with the
    # batch shape flattened to a single dimension. If the batch rank is 0, then
    # we add a batch dimension of rank 1.
    sp_a = tf.sparse.reshape(sp_a,
                             tf.concat([[-1], _get_shape(sp_a)[-2:]], axis=0))
    # Reshape b to stack the batch dimension along the rows.
    b = tf.reshape(b, tf.concat([[-1], _get_shape(b)[-1:]], axis=0))

    # Convert the SparseTensor to a matrix in block diagonal form with blocks of
    # matrices [M, N]. This allow us to use tf.sparse_tensor_dense_matmul which
    # only accepts rank 2 (Sparse)Tensors.
    out = tf.sparse.sparse_dense_matmul(_sparse_block_diag(sp_a), b, **kwargs)

    # Finally retrieve the original batch shape from the resulting rank 2 Tensor.
    # Note that we avoid inferring the final shape from `sp_a` or `b` because we
    # might have transposed one or both of them.
    return tf.reshape(
        out,
        tf.concat([batch_shape, [-1], _get_shape(out)[-1:]], axis=0))
예제 #3
0
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  ```python
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  ```
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

  Args:
    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

  Returns:
    chol_concat: The Cholesky decomposition of:
      ```
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
      ```
  """
    with tf.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            chol,
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
        ],
                         axis=-2)
    def _variance(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, 1, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size 1 num_states observation_event_size
            flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps 1 observation_event_size

            variances = self._observation_distribution.variance()
            variances = tf.broadcast_to(variances, means_shape)
            # variances :: batch_shape num_states observation_event_shape
            flat_variances = tf.reshape(variances, flat_means_shape)
            # flat_variances :: batch_size 1 num_states observation_event_size

            # For a mixture of n distributions with mixture probabilities
            # p[i], and where the individual distributions have means and
            # variances given by mean[i] and var[i], the variance of
            # the mixture is given by:
            #
            # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

            flat_variance = tf.einsum("ijk,jikl->jil", flat_probs,
                                      (flat_means - flat_mean)**2 +
                                      flat_variances)
            # flat_variance :: batch_size num_steps observation_event_size

            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)

            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_variance, unflat_mean_shape)
예제 #5
0
        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
예제 #7
0
 def _compute_quantiles():
   """Helper to build quantiles."""
   # Omit {0, 1} since they might lead to Inf/NaN.
   zero = tf.zeros([], dtype=dist.dtype)
   edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
   # Expand edges so its broadcast across batch dims.
   edges = tf.reshape(
       edges,
       shape=tf.concat(
           [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
   quantiles = dist.quantile(edges)
   # Cyclically permute left by one.
   perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
   quantiles = tf.transpose(a=quantiles, perm=perm)
   return quantiles
 def _sample_3d(self, n, seed=None):
     """Specialized inversion sampler for 3D."""
     seed = SeedStream(seed, salt='von_mises_fisher_3d')
     u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
     z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
     # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
     # be bisected for bounded sampling runtime (i.e. not rejection sampling).
     # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
     # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
     # We must protect against both kappa and z being zero.
     safe_conc = tf.where(self.concentration > 0, self.concentration,
                          tf.ones_like(self.concentration))
     safe_z = tf.where(z > 0, z, tf.ones_like(z))
     safe_u = 1 + tf.reduce_logsumexp(
         [tf.math.log(safe_z),
          tf.math.log1p(-safe_z) - 2 * safe_conc],
         axis=0) / safe_conc
     # Limit of the above expression as kappa->0 is 2*z-1
     u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u,
                  2 * z - 1)
     # Limit of the expression as z->0 is -1.
     u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
     if not self._allow_nan_stats:
         u = tf.debugging.check_numerics(u, 'u in _sample_3d')
     return u[..., tf.newaxis]
예제 #9
0
 def _sample_n(self, n, seed=None):
     # See https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution or
     # https://www.jstor.org/stable/2683801
     concentration = tf.convert_to_tensor(self.concentration)
     loc = tf.convert_to_tensor(self.loc)
     seed = SeedStream(seed, 'inverse_gaussian')
     shape = tf.concat(
         [[n],
          self._batch_shape_tensor(loc=loc, concentration=concentration)],
         axis=0)
     sampled_chi2 = (tf.random.normal(shape,
                                      mean=0.,
                                      stddev=1.,
                                      seed=seed(),
                                      dtype=self.dtype))**2.
     sampled_uniform = tf.random.uniform(shape,
                                         minval=0.,
                                         maxval=1.,
                                         seed=seed(),
                                         dtype=self.dtype)
     sampled = (loc + loc**2. * sampled_chi2 / (2. * concentration) - loc /
                (2. * concentration) *
                (4. * loc * concentration * sampled_chi2 +
                 (loc * sampled_chi2)**2)**0.5)
     return tf.where(sampled_uniform <= loc / (loc + sampled), sampled,
                     loc**2 / sampled)
예제 #10
0
    def _call_sample_n(self, sample_shape, seed, name, **kwargs):
        # We override `_call_sample_n` rather than `_sample_n` so we can ensure that
        # the result of `self.bijector.forward` is not modified (and thus caching
        # works).
        with self._name_and_control_scope(name):
            sample_shape = tf.convert_to_tensor(sample_shape,
                                                dtype=tf.int32,
                                                name="sample_shape")
            sample_shape, n = self._expand_sample_shape_to_vector(
                sample_shape, "sample_shape")

            distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(
                kwargs)

            # First, generate samples. We will possibly generate extra samples in the
            # event that we need to reinterpret the samples as part of the
            # event_shape.
            x = self._sample_n(n, seed, **distribution_kwargs)

            # Next, we reshape `x` into its final form. We do this prior to the call
            # to the bijector to ensure that the bijector caching works.
            batch_event_shape = tf.shape(x)[1:]
            final_shape = tf.concat([sample_shape, batch_event_shape], 0)
            x = tf.reshape(x, final_shape)

            # Finally, we apply the bijector's forward transformation. For caching to
            # work, it is imperative that this is the last modification to the
            # returned result.
            y = self.bijector.forward(x, **bijector_kwargs)
            y = self._set_sample_static_shape(y, sample_shape)

            return y
예제 #11
0
  def _mode(self, samples=None):
    # Samples count can vary by batch member. Use map_fn to compute mode for
    # each batch separately.
    def _get_mode(samples):
      # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      return tf.argmax(count)

    if samples is None:
      samples = tf.convert_to_tensor(self._samples)
    num_samples = self._compute_num_samples(samples)

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      flattened_samples = tf.reshape(samples, [-1, num_samples])
      mode_shape = self._batch_shape_tensor(samples)
    else:
      event_size = tf.reduce_prod(self._event_shape_tensor(samples))
      mode_shape = tf.concat(
          [self._batch_shape_tensor(samples),
           self._event_shape_tensor(samples)],
          axis=0)
      flattened_samples = tf.reshape(samples, [-1, num_samples, event_size])

    indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
    full_indices = tf.stack(
        [tf.range(tf.shape(indices)[0]),
         tf.cast(indices, tf.int32)], axis=1)

    mode = tf.gather_nd(flattened_samples, full_indices)
    return tf.reshape(mode, mode_shape)
예제 #12
0
 def _call_and_reshape_output(self,
                              fn,
                              event_shape_list=None,
                              static_event_shape_list=None,
                              extra_kwargs=None):
     """Calls `fn` and appropriately reshapes its output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn`, `event_shape_list`, `static_event_shape_list` and/or
     # `extra_kwargs` as keys.
     with tf.control_dependencies(self._runtime_assertions):
         if event_shape_list is None:
             event_shape_list = [self._event_shape_tensor()]
         if static_event_shape_list is None:
             static_event_shape_list = [self.event_shape]
         new_shape = tf.concat([self._batch_shape_unexpanded] +
                               event_shape_list,
                               axis=0)
         result = tf.reshape(
             fn(**extra_kwargs) if extra_kwargs else fn(), new_shape)
         if (tensorshape_util.rank(self.batch_shape) is not None
                 and tensorshape_util.rank(self.event_shape) is not None):
             event_shape = tf.TensorShape([])
             for rss in static_event_shape_list:
                 event_shape = tensorshape_util.concatenate(
                     event_shape, rss)
             static_shape = tensorshape_util.concatenate(
                 self.batch_shape, event_shape)
             tensorshape_util.set_shape(result, static_shape)
         return result
예제 #13
0
 def _sample_n(self, n, seed=None):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     shape = tf.concat(
         [[n], self._batch_shape_tensor(low=low, high=high)], 0)
     samples = tf.random.uniform(shape=shape, dtype=self.dtype, seed=seed)
     return low + self._range(low=low, high=high) * samples
예제 #14
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
예제 #15
0
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
예제 #16
0
    def _make_columnar(self, x):
        """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 1:
                x = x[tf.newaxis, :]
            return x
        shape = tf.shape(x)
        maybe_expanded_shape = tf.concat([
            shape[:-1],
            distribution_util.pick_vector(tf.equal(tf.rank(x), 1), [1],
                                          np.array([], dtype=np.int32)),
            shape[-1:],
        ], 0)
        return tf.reshape(x, maybe_expanded_shape)
    def _marginal_hidden_probs(self):
        """Compute marginal pdf for each individual observable."""

        initial_log_probs = tf.broadcast_to(
            self._log_init,
            tf.concat([self.batch_shape_tensor(), [self._num_states]], axis=0))

        # initial_log_probs :: batch_shape num_states

        def _scan_multiple_steps():
            """Perform `scan` operation when `num_steps` > 1."""

            transition_log_probs = self._log_trans

            def forward_step(log_probs, _):
                return _log_vector_matrix(log_probs, transition_log_probs)

            dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)

            forward_log_probs = tf.scan(forward_step,
                                        dummy_index,
                                        initializer=initial_log_probs,
                                        name="forward_log_probs")

            return tf.concat([[initial_log_probs], forward_log_probs], axis=0)

        forward_log_probs = prefer_static.cond(
            self._num_steps > 1, _scan_multiple_steps,
            lambda: initial_log_probs[tf.newaxis, ...])

        return tf.exp(forward_log_probs)
예제 #18
0
    def _sample_n(self, n, seed=None):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        stream = SeedStream(seed, salt='triangular')
        shape = tf.concat(
            [[n], self._batch_shape_tensor(low=low, high=high, peak=peak)],
            axis=0)
        samples = tf.random.uniform(shape=shape,
                                    dtype=self.dtype,
                                    seed=stream())
        # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
        # we must use sqrts here.
        interval_length = high - low
        return tf.where(
            # Note the CDF on the left side of the peak is
            # (x - low) ** 2 / ((high - low) * (peak - low)).
            # If we plug in peak for x, we get that the CDF at the peak
            # is (peak - low) / (high - low). Because of this we decide
            # which part of the piecewise CDF we should use based on the cdf samples
            # we drew.
            samples < (peak - low) / interval_length,
            # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
            low + tf.sqrt(samples * interval_length * (peak - low)),
            # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
            high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
예제 #19
0
 def _sample_n(self, n, seed=None):
   loc = tf.convert_to_tensor(self.loc)
   scale = tf.convert_to_tensor(self.scale)
   shape = tf.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale)],
                     axis=0)
   sampled = tf.random.normal(
       shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)
   return sampled * scale + loc
예제 #20
0
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("entropy is not implemented")
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError("entropy is not implemented when "
                                      "bijector is not injective.")
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(
                tf.reduce_prod(self._override_event_shape),
                dtype=dtype_util.base_dtype(entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor()
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                self._override_batch_shape,
                prefer_static.ones_like(self.distribution.batch_shape_tensor())
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat(
            [self.batch_shape_tensor(),
             self.event_shape_tensor()], 0),
                                    dtype=self.dtype)
        event_ndims = (tensorshape_util.rank(self.event_shape)
                       if tensorshape_util.rank(self.event_shape) is not None
                       else tf.size(self.event_shape_tensor()))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
예제 #21
0
 def _rotate(self, samples):
     """Applies a Householder rotation to `samples`."""
     event_dim = (tf.compat.dimension_value(self.event_shape[0])
                  or self._event_shape_tensor()[0])
     basis = tf.concat(
         [[1.], tf.zeros([event_dim - 1], dtype=self.dtype)], axis=0),
     u = tf.math.l2_normalize(basis - self.mean_direction, axis=-1)
     return samples - 2 * tf.reduce_sum(samples * u, axis=-1,
                                        keepdims=True) * u
예제 #22
0
 def _inverse_event_shape_tensor(self, output_shape):
   if self.validate_args:
     # It is not possible for a negative shape so we need only check <= 1.
     dependencies = [assert_util.assert_greater(
         output_shape[-1], 1, message="Need last dimension greater than 1.")]
   else:
     dependencies = []
   with tf.control_dependencies(dependencies):
     return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
예제 #23
0
 def _sample_n(self, n, seed=None):
     del seed  # unused
     loc = tf.convert_to_tensor(self.loc)
     return tf.broadcast_to(
         loc,
         tf.concat([[n],
                    self._batch_shape_tensor(loc=loc),
                    self._event_shape_tensor(loc=loc)],
                   axis=0))
def _extract_log_probs(num_states, dist):
    """Tabulate log probabilities from a batch of distributions."""

    states = tf.reshape(
        tf.range(num_states),
        tf.concat([[num_states],
                   tf.ones_like(dist.batch_shape_tensor())],
                  axis=0))
    return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
예제 #25
0
 def _sample_n(self, n, seed=None):
     scale = tf.convert_to_tensor(self.scale)
     shape = tf.concat([[n], tf.shape(scale)], 0)
     sampled = tf.random.normal(shape=shape,
                                mean=0.,
                                stddev=1.,
                                dtype=self.dtype,
                                seed=seed)
     return tf.abs(sampled * scale)
예제 #26
0
def _uniform_unit_norm(dimension, shape, dtype, seed):
    """Returns a batch of points chosen uniformly from the unit hypersphere."""
    # This works because the Gaussian distribution is spherically symmetric.
    # raw shape: shape + [dimension]
    raw = normal.Normal(loc=dtype_util.as_numpy_dtype(dtype)(0),
                        scale=dtype_util.as_numpy_dtype(dtype)(1)).sample(
                            tf.concat([shape, [dimension]], axis=0),
                            seed=seed())
    unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis]
    return unit_norm
예제 #27
0
 def _sample_n(self, n, seed=None, **kwargs):
     with tf.control_dependencies(self._runtime_assertions):
         x = self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
         new_shape = tf.concat([
             [n],
             self._batch_shape_unexpanded,
             self.event_shape_tensor(),
         ],
                               axis=0)
         return tf.reshape(x, new_shape)
 def _pad_sample_dims(self, x):
     with tf.name_scope("pad_sample_dims"):
         ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank(
             x.shape) is not None else tf.rank(x)
         shape = tf.shape(x)
         d = ndims - self._event_ndims
         x = tf.reshape(x,
                        shape=tf.concat([shape[:d], [1], shape[d:]],
                                        axis=0))
         return x
예제 #29
0
 def _sample_n(self, n, seed=None):
     loc = tf.convert_to_tensor(self.loc)
     scale = tf.convert_to_tensor(self.scale)
     batch_shape = self._batch_shape_tensor(loc=loc, scale=scale)
     shape = tf.concat([[n], batch_shape], 0)
     probs = tf.random.uniform(shape=shape,
                               minval=0.,
                               maxval=1.,
                               dtype=self.dtype,
                               seed=seed)
     return self._quantile(probs, loc=loc, scale=scale)
예제 #30
0
 def _mode_mean_shape(self):
     """Shape for the mode/mean Tensors."""
     shape = tensorshape_util.concatenate(self.batch_shape,
                                          self.event_shape)
     has_static_shape = tensorshape_util.is_fully_defined(shape)
     if not has_static_shape:
         shape = tf.concat([
             self.batch_shape_tensor(),
             self.event_shape_tensor(),
         ], 0)
     return shape