def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Пример #2
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
Пример #3
0
    def _mode(self, samples=None):
        # Samples count can vary by batch member. Use map_fn to compute mode for
        # each batch separately.
        def _get_mode(samples):
            # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
            count = gen_array_ops.unique_with_counts_v2(samples,
                                                        axis=[0]).count
            return tf.argmax(count)

        if samples is None:
            samples = tf.convert_to_tensor(self._samples)
        num_samples = self._compute_num_samples(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            flattened_samples = tf.reshape(samples, [-1, num_samples])
            mode_shape = self._batch_shape_tensor(samples)
        else:
            event_size = tf.reduce_prod(self._event_shape_tensor(samples))
            mode_shape = tf.concat([
                self._batch_shape_tensor(samples),
                self._event_shape_tensor(samples)
            ],
                                   axis=0)
            flattened_samples = tf.reshape(samples,
                                           [-1, num_samples, event_size])

        indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
        full_indices = tf.stack(
            [tf.range(tf.shape(indices)[0]),
             tf.cast(indices, tf.int32)],
            axis=1)

        mode = tf.gather_nd(flattened_samples, full_indices)
        return tf.reshape(mode, mode_shape)
Пример #4
0
 def _call_reshape_input_output(self, fn, x, extra_kwargs=None):
     """Calls `fn`, appropriately reshaping its input `x` and output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn` and/or `x` as a key.
     with tf.control_dependencies(self._runtime_assertions +
                                  self._validate_sample_arg(x)):
         sample_shape, static_sample_shape = self._sample_shape(x)
         old_shape = tf.concat([
             sample_shape,
             self.distribution.batch_shape_tensor(),
             self.event_shape_tensor(),
         ],
                               axis=0)
         x_reshape = tf.reshape(x, old_shape)
         result = fn(x_reshape, **
                     extra_kwargs) if extra_kwargs else fn(x_reshape)
         new_shape = tf.concat([
             sample_shape,
             self._batch_shape_unexpanded,
         ],
                               axis=0)
         result = tf.reshape(result, new_shape)
         if (tensorshape_util.rank(static_sample_shape) is not None
                 and tensorshape_util.rank(self.batch_shape) is not None):
             new_shape = tensorshape_util.concatenate(
                 static_sample_shape, self.batch_shape)
             tensorshape_util.set_shape(result, new_shape)
         return result
Пример #5
0
def _sparse_tensor_dense_matmul(sp_a, b, **kwargs):
    """Returns (batched) matmul of a SparseTensor with a Tensor.

  Args:
    sp_a: `SparseTensor` representing a (batch of) matrices.
    b: `Tensor` representing a (batch of) matrices, with the same batch shape of
      `sp_a`. The shape must be compatible with the shape of `sp_a` and kwargs.
    **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul`.

  Returns:
    product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
    dtype as `sp_a` and `b`. If `sp_a` or `b` is adjointed through `kwargs` then
    the shape is adjusted accordingly.
  """
    batch_shape = _get_shape(sp_a)[:-2]

    # Reshape the SparseTensor into a rank 3 SparseTensors, with the
    # batch shape flattened to a single dimension. If the batch rank is 0, then
    # we add a batch dimension of rank 1.
    sp_a = tf.sparse.reshape(sp_a,
                             tf.concat([[-1], _get_shape(sp_a)[-2:]], axis=0))
    # Reshape b to stack the batch dimension along the rows.
    b = tf.reshape(b, tf.concat([[-1], _get_shape(b)[-1:]], axis=0))

    # Convert the SparseTensor to a matrix in block diagonal form with blocks of
    # matrices [M, N]. This allow us to use tf.sparse_tensor_dense_matmul which
    # only accepts rank 2 (Sparse)Tensors.
    out = tf.sparse.sparse_dense_matmul(_sparse_block_diag(sp_a), b, **kwargs)

    # Finally retrieve the original batch shape from the resulting rank 2 Tensor.
    # Note that we avoid inferring the final shape from `sp_a` or `b` because we
    # might have transposed one or both of them.
    return tf.reshape(
        out,
        tf.concat([batch_shape, [-1], _get_shape(out)[-1:]], axis=0))
Пример #6
0
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
    def _variance(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, 1, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size 1 num_states observation_event_size
            flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps 1 observation_event_size

            variances = self._observation_distribution.variance()
            variances = tf.broadcast_to(variances, means_shape)
            # variances :: batch_shape num_states observation_event_shape
            flat_variances = tf.reshape(variances, flat_means_shape)
            # flat_variances :: batch_size 1 num_states observation_event_size

            # For a mixture of n distributions with mixture probabilities
            # p[i], and where the individual distributions have means and
            # variances given by mean[i] and var[i], the variance of
            # the mixture is given by:
            #
            # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

            flat_variance = tf.einsum("ijk,jikl->jil", flat_probs,
                                      (flat_means - flat_mean)**2 +
                                      flat_variances)
            # flat_variance :: batch_size num_steps observation_event_size

            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)

            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_variance, unflat_mean_shape)
Пример #8
0
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
     sample_dtype = tf.int64 if dtype_util.size(
         self.dtype) > 4 else tf.int32
     draws = tf.random.categorical(logits_2d,
                                   n,
                                   dtype=sample_dtype,
                                   seed=seed)
     draws = tf.cast(draws, self.dtype)
     return tf.reshape(tf.transpose(draws),
                       shape=tf.concat(
                           [[n], self._batch_shape_tensor(logits)], axis=0))
Пример #9
0
 def _cdf(self, x):
   x = tf.convert_to_tensor(x, name='x')
   flat_x = tf.reshape(x, shape=[-1])
   upper_bound = tf.searchsorted(self.outcomes, values=flat_x, side='right')
   values_at_ub = tf.gather(
       self.outcomes,
       indices=tf.minimum(upper_bound,
                          dist_util.prefer_static_shape(self.outcomes)[-1] -
                          1))
   should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub)
   indices = tf.where(should_use_upper_bound, upper_bound, upper_bound - 1)
   return self._categorical.cdf(
       tf.reshape(indices, shape=dist_util.prefer_static_shape(x)))
Пример #10
0
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     sample_shape = prefer_static.concat(
         [[n], prefer_static.shape(logits)], 0)
     event_size = self._event_size(logits)
     if tensorshape_util.rank(logits.shape) == 2:
         logits_2d = logits
     else:
         logits_2d = tf.reshape(logits, [-1, event_size])
     samples = tf.random.categorical(logits_2d, n, seed=seed)
     samples = tf.transpose(a=samples)
     samples = tf.one_hot(samples, event_size, dtype=self.dtype)
     ret = tf.reshape(samples, sample_shape)
     return ret
Пример #11
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=tf.pad(
                        tf.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Пример #12
0
 def _inverse(self, y):
     output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
         tf.shape(y), self._event_shape_out, self._event_shape_in,
         self.validate_args)
     x = tf.reshape(y, output_shape)
     tensorshape_util.set_shape(x, output_tensorshape)
     return x
Пример #13
0
 def _forward(self, x):
     output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
         tf.shape(x), self._event_shape_in, self._event_shape_out,
         self.validate_args)
     y = tf.reshape(x, output_shape)
     tensorshape_util.set_shape(y, output_tensorshape)
     return y
    def _std_var_helper(self, statistic, statistic_name, statistic_ndims,
                        df_factor_fn):
        """Helper to compute stddev, covariance and variance."""
        df = tf.reshape(
            self.df,
            tf.concat([
                tf.shape(self.df),
                tf.ones([statistic_ndims], dtype=tf.int32)
            ], -1))
        # We need to put the tf.where inside the outer tf1.where to ensure we never
        # hit a NaN in the gradient.
        denom = tf.where(df > 2., df - 2., tf.ones_like(df))
        statistic = statistic * df_factor_fn(df / denom)
        # When 1 < df <= 2, stddev/variance are infinite.
        result_where_defined = tf.where(
            df > 2., statistic,
            dtype_util.as_numpy_dtype(self.dtype)(np.inf))

        if self.allow_nan_stats:
            return tf.where(df > 1., result_where_defined,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        df,
                        message='{} not defined for components of df <= 1.'.
                        format(statistic_name.capitalize())),
            ]):
                return tf.identity(result_where_defined)
    def _call_sample_n(self, sample_shape, seed, name, **kwargs):
        # We override `_call_sample_n` rather than `_sample_n` so we can ensure that
        # the result of `self.bijector.forward` is not modified (and thus caching
        # works).
        with self._name_and_control_scope(name):
            sample_shape = tf.convert_to_tensor(sample_shape,
                                                dtype=tf.int32,
                                                name="sample_shape")
            sample_shape, n = self._expand_sample_shape_to_vector(
                sample_shape, "sample_shape")

            distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(
                kwargs)

            # First, generate samples. We will possibly generate extra samples in the
            # event that we need to reinterpret the samples as part of the
            # event_shape.
            x = self._sample_n(n, seed, **distribution_kwargs)

            # Next, we reshape `x` into its final form. We do this prior to the call
            # to the bijector to ensure that the bijector caching works.
            batch_event_shape = tf.shape(x)[1:]
            final_shape = tf.concat([sample_shape, batch_event_shape], 0)
            x = tf.reshape(x, final_shape)

            # Finally, we apply the bijector's forward transformation. For caching to
            # work, it is imperative that this is the last modification to the
            # returned result.
            y = self.bijector.forward(x, **bijector_kwargs)
            y = self._set_sample_static_shape(y, sample_shape)

            return y
    def _mean(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("mean is not implemented for non-affine "
                                      "bijectors")

        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        x = self.distribution.mean(**distribution_kwargs)

        if self._is_maybe_batch_override or self._is_maybe_event_override:
            # A batch (respectively event) shape override is only allowed if the batch
            # (event) shape of the base distribution is [], so concatenating all the
            # shapes does the right thing.
            new_shape = prefer_static.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor(),
                prefer_static.ones_like(self._override_event_shape),
                self.distribution.event_shape_tensor(),
            ], 0)
            x = tf.reshape(x, new_shape)
            new_shape = prefer_static.concat(
                [self.batch_shape_tensor(),
                 self.event_shape_tensor()], 0)
            x = tf.broadcast_to(x, new_shape)

        y = self.bijector.forward(x, **bijector_kwargs)

        sample_shape = tf.convert_to_tensor([],
                                            dtype=tf.int32,
                                            name="sample_shape")
        y = self._set_sample_static_shape(y, sample_shape)
        return y
  def _make_columnar(self, x):
    """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
    if tensorshape_util.rank(x.shape) is not None:
      if tensorshape_util.rank(x.shape) == 1:
        x = x[tf.newaxis, :]
      return x
    shape = tf.shape(x)
    maybe_expanded_shape = tf.concat([
        shape[:-1],
        distribution_util.pick_vector(
            tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)),
        shape[-1:],
    ], 0)
    return tf.reshape(x, maybe_expanded_shape)
Пример #18
0
 def _call_and_reshape_output(self,
                              fn,
                              event_shape_list=None,
                              static_event_shape_list=None,
                              extra_kwargs=None):
     """Calls `fn` and appropriately reshapes its output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn`, `event_shape_list`, `static_event_shape_list` and/or
     # `extra_kwargs` as keys.
     with tf.control_dependencies(self._runtime_assertions):
         if event_shape_list is None:
             event_shape_list = [self._event_shape_tensor()]
         if static_event_shape_list is None:
             static_event_shape_list = [self.event_shape]
         new_shape = tf.concat([self._batch_shape_unexpanded] +
                               event_shape_list,
                               axis=0)
         result = tf.reshape(
             fn(**extra_kwargs) if extra_kwargs else fn(), new_shape)
         if (tensorshape_util.rank(self.batch_shape) is not None
                 and tensorshape_util.rank(self.event_shape) is not None):
             event_shape = tf.TensorShape([])
             for rss in static_event_shape_list:
                 event_shape = tensorshape_util.concatenate(
                     event_shape, rss)
             static_shape = tensorshape_util.concatenate(
                 self.batch_shape, event_shape)
             tensorshape_util.set_shape(result, static_shape)
         return result
Пример #19
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
  """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
  # Extend param shape with ones on the left to match dist_batch_shape.
  param_shape = tf.shape(input=param)
  insert_ones = tf.ones(
      [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)],
      dtype=param_shape.dtype)
  new_param_shape = tf.concat([insert_ones, param_shape], axis=0)
  full_batch_param = tf.reshape(param, new_param_shape)
  param_slices = []
  # We separately track the batch axis from the parameter axis because we want
  # them to align for positive indexing, and be offset by param_event_ndims for
  # negative indexing.
  param_dim_idx = 0
  batch_dim_idx = 0
  for slc in slices:
    if slc is tf.newaxis:
      param_slices.append(slc)
      continue
    if slc is Ellipsis:
      if batch_dim_idx < 0:
        raise ValueError('Found multiple `...` in slices {}'.format(slices))
      param_slices.append(slc)
      # Switch over to negative indexing for the broadcast check.
      num_remaining_non_newaxis_slices = sum(
          [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]])
      batch_dim_idx = -num_remaining_non_newaxis_slices
      param_dim_idx = batch_dim_idx - param_event_ndims
      continue
    # Find the batch dimension sizes for both parameter and distribution.
    param_dim_size = new_param_shape[param_dim_idx]
    batch_dim_size = dist_batch_shape[batch_dim_idx]
    is_broadcast = batch_dim_size > param_dim_size
    # Slices are denoted by start:stop:step.
    if isinstance(slc, slice):
      start, stop, step = slc.start, slc.stop, slc.step
      if start is not None:
        start = tf.where(is_broadcast, 0, start)
      if stop is not None:
        stop = tf.where(is_broadcast, 1, stop)
      if step is not None:
        step = tf.where(is_broadcast, 1, step)
      param_slices.append(slice(start, stop, step))
    else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
      param_slices.append(tf.where(is_broadcast, 0, slc))
    param_dim_idx += 1
    batch_dim_idx += 1
  param_slices.extend([ALL_SLICE] * param_event_ndims)
  return full_batch_param.__getitem__(param_slices)
Пример #20
0
 def _reshape_part(part, dtype, event_shape):
     part = tf.cast(part, dtype)
     static_rank = tf.get_static_value(
         ps.rank_from_shape(event_shape))
     if static_rank == 1:
         return part
     new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                           axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _extract_log_probs(num_states, dist):
    """Tabulate log probabilities from a batch of distributions."""

    states = tf.reshape(
        tf.range(num_states),
        tf.concat([[num_states],
                   tf.ones_like(dist.batch_shape_tensor())],
                  axis=0))
    return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
    def _sample_n(self, n, seed=None):
        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        distributions = self.poisson_and_mixture_distributions()
        dist, mixture_dist = distributions
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(
                self._batch_shape_tensor(distributions=distributions))
        # We need to 'sample extra' from the mixture distribution if it doesn't
        # already specify a probs vector for each batch coordinate.
        # We only support this kind of reduced broadcasting, i.e., there is exactly
        # one probs vector for all batch dims or one for each.
        stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound')
        ids = mixture_dist.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(mixture_dist.is_scalar_batch(),
                                          [batch_size], np.int32([]))),
                                  seed=stream())
        # We need to flatten batch dims in case mixture_dist has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = tf.range(start=0,
                          limit=batch_size * self._quadrature_size,
                          delta=self._quadrature_size,
                          dtype=ids.dtype)
        ids = ids + offset
        rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
        rate = tf.reshape(
            rate,
            shape=concat_vectors(
                [n], self._batch_shape_tensor(distributions=distributions)))
        return tf.random.poisson(lam=rate,
                                 shape=[],
                                 dtype=self.dtype,
                                 seed=seed)
Пример #23
0
  def _stddev(self):
    with tf.control_dependencies(self._assertions):
      distribution_means = [d.mean() for d in self.components]
      distribution_devs = [d.stddev() for d in self.components]
      cat_probs = self._cat_probs(log_probs=False)

      stacked_means = tf.stack(distribution_means, axis=-1)
      stacked_devs = tf.stack(distribution_devs, axis=-1)
      cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
      broadcasted_cat_probs = (
          tf.stack(cat_probs, axis=-1) * tf.ones_like(stacked_means))

      batched_dev = distribution_util.mixture_stddev(
          tf.reshape(broadcasted_cat_probs, [-1, len(self.components)]),
          tf.reshape(stacked_means, [-1, len(self.components)]),
          tf.reshape(stacked_devs, [-1, len(self.components)]))

      # I.e. re-shape to list(batch_shape) + list(event_shape).
      return tf.reshape(batched_dev, tf.shape(broadcasted_cat_probs)[:-1])
Пример #24
0
 def _sample_n(self, n, seed=None, **kwargs):
     with tf.control_dependencies(self._runtime_assertions):
         x = self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
         new_shape = tf.concat([
             [n],
             self._batch_shape_unexpanded,
             self.event_shape_tensor(),
         ],
                               axis=0)
         return tf.reshape(x, new_shape)
Пример #25
0
 def _log_prob(self, x):
   x = tf.convert_to_tensor(x, name='x')
   right_indices = tf.minimum(
       tf.size(self.outcomes) - 1,
       tf.reshape(
           tf.searchsorted(
               self.outcomes, values=tf.reshape(x, shape=[-1]), side='right'),
           dist_util.prefer_static_shape(x)))
   use_right_indices = self._is_equal_or_close(
       x, tf.gather(self.outcomes, indices=right_indices))
   left_indices = tf.maximum(0, right_indices - 1)
   use_left_indices = self._is_equal_or_close(
       x, tf.gather(self.outcomes, indices=left_indices))
   log_probs = self._categorical.log_prob(
       tf.where(use_left_indices, left_indices, right_indices))
   return tf.where(
       tf.logical_not(use_left_indices | use_right_indices),
       dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf),
       log_probs)
Пример #26
0
 def _pad_sample_dims(self, x):
     with tf.name_scope("pad_sample_dims"):
         ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank(
             x.shape) is not None else tf.rank(x)
         shape = tf.shape(x)
         d = ndims - self._event_ndims
         x = tf.reshape(x,
                        shape=tf.concat([shape[:d], [1], shape[d:]],
                                        axis=0))
         return x
 def _expand_base_distribution_mean(self):
     """Ensures `self.distribution.mean()` has `[batch, event]` shape."""
     single_draw_shape = concat_vectors(self.batch_shape_tensor(),
                                        self.event_shape_tensor())
     m = tf.reshape(
         self.distribution.mean(),  # A scalar.
         shape=tf.ones_like(single_draw_shape, dtype=tf.int32))
     m = tf.tile(m, multiples=single_draw_shape)
     tensorshape_util.set_shape(
         m, tensorshape_util.concatenate(self.batch_shape,
                                         self.event_shape))
     return m
Пример #28
0
 def _sample_one_batch_member(args):
     logits, num_cat_samples = args[0], args[1]  # [K], []
     # x has shape [1, num_cat_samples = num_samples * num_trials]
     x = tf.random.categorical(logits[tf.newaxis, ...],
                               num_cat_samples,
                               seed=seed)
     x = tf.reshape(x, shape=[num_samples,
                              -1])  # [num_samples, num_trials]
     x = tf.one_hot(
         x, depth=num_classes)  # [num_samples, num_trials, num_classes]
     x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
     return tf.cast(x, dtype=dtype)
Пример #29
0
    def _entropy(self):
        samples = tf.convert_to_tensor(self.samples)
        num_samples = self._compute_num_samples(samples)
        entropy_shape = self._batch_shape_tensor(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            samples = tf.reshape(samples, [-1, num_samples])
        else:
            event_size = tf.reduce_prod(self.event_shape_tensor())
            samples = tf.reshape(samples, [-1, num_samples, event_size])

        # Use map_fn to compute entropy for each batch separately.
        def _get_entropy(samples):
            # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
            count = gen_array_ops.unique_with_counts_v2(samples,
                                                        axis=[0]).count
            prob = tf.cast(count / num_samples, dtype=self.dtype)
            entropy = tf.reduce_sum(-prob * tf.math.log(prob))
            return entropy

        entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype)
        return tf.reshape(entropy, entropy_shape)
    def _mean(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size num_states observation_event_size
            flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps observation_event_size
            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)
            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_mean, unflat_mean_shape)