Esempio n. 1
0
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Esempio n. 2
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
Esempio n. 3
0
def _sparse_tensor_dense_matmul(sp_a, b, **kwargs):
    """Returns (batched) matmul of a SparseTensor with a Tensor.

  Args:
    sp_a: `SparseTensor` representing a (batch of) matrices.
    b: `Tensor` representing a (batch of) matrices, with the same batch shape of
      `sp_a`. The shape must be compatible with the shape of `sp_a` and kwargs.
    **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul`.

  Returns:
    product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
    dtype as `sp_a` and `b`. If `sp_a` or `b` is adjointed through `kwargs` then
    the shape is adjusted accordingly.
  """
    batch_shape = _get_shape(sp_a)[:-2]

    # Reshape the SparseTensor into a rank 3 SparseTensors, with the
    # batch shape flattened to a single dimension. If the batch rank is 0, then
    # we add a batch dimension of rank 1.
    sp_a = tf.sparse.reshape(sp_a,
                             tf.concat([[-1], _get_shape(sp_a)[-2:]], axis=0))
    # Reshape b to stack the batch dimension along the rows.
    b = tf.reshape(b, tf.concat([[-1], _get_shape(b)[-1:]], axis=0))

    # Convert the SparseTensor to a matrix in block diagonal form with blocks of
    # matrices [M, N]. This allow us to use tf.sparse_tensor_dense_matmul which
    # only accepts rank 2 (Sparse)Tensors.
    out = tf.sparse.sparse_dense_matmul(_sparse_block_diag(sp_a), b, **kwargs)

    # Finally retrieve the original batch shape from the resulting rank 2 Tensor.
    # Note that we avoid inferring the final shape from `sp_a` or `b` because we
    # might have transposed one or both of them.
    return tf.reshape(
        out,
        tf.concat([batch_shape, [-1], _get_shape(out)[-1:]], axis=0))
Esempio n. 4
0
  def _mode(self, samples=None):
    # Samples count can vary by batch member. Use map_fn to compute mode for
    # each batch separately.
    def _get_mode(samples):
      # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      return tf.argmax(count)

    if samples is None:
      samples = tf.convert_to_tensor(self._samples)
    num_samples = self._compute_num_samples(samples)

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      flattened_samples = tf.reshape(samples, [-1, num_samples])
      mode_shape = self._batch_shape_tensor(samples)
    else:
      event_size = tf.reduce_prod(self._event_shape_tensor(samples))
      mode_shape = tf.concat(
          [self._batch_shape_tensor(samples),
           self._event_shape_tensor(samples)],
          axis=0)
      flattened_samples = tf.reshape(samples, [-1, num_samples, event_size])

    indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
    full_indices = tf.stack(
        [tf.range(tf.shape(indices)[0]),
         tf.cast(indices, tf.int32)], axis=1)

    mode = tf.gather_nd(flattened_samples, full_indices)
    return tf.reshape(mode, mode_shape)
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
 def _call_reshape_input_output(self, fn, x, extra_kwargs=None):
     """Calls `fn`, appropriately reshaping its input `x` and output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn` and/or `x` as a key.
     with tf.control_dependencies(self._runtime_assertions +
                                  self._validate_sample_arg(x)):
         sample_shape, static_sample_shape = self._sample_shape(x)
         old_shape = tf.concat([
             sample_shape,
             self.distribution.batch_shape_tensor(),
             self.event_shape_tensor(),
         ],
                               axis=0)
         x_reshape = tf.reshape(x, old_shape)
         result = fn(x_reshape, **
                     extra_kwargs) if extra_kwargs else fn(x_reshape)
         new_shape = tf.concat([
             sample_shape,
             self._batch_shape_unexpanded,
         ],
                               axis=0)
         result = tf.reshape(result, new_shape)
         if (tensorshape_util.rank(static_sample_shape) is not None
                 and tensorshape_util.rank(self.batch_shape) is not None):
             new_shape = tensorshape_util.concatenate(
                 static_sample_shape, self.batch_shape)
             tensorshape_util.set_shape(result, new_shape)
         return result
    def _variance(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, 1, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size 1 num_states observation_event_size
            flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps 1 observation_event_size

            variances = self._observation_distribution.variance()
            variances = tf.broadcast_to(variances, means_shape)
            # variances :: batch_shape num_states observation_event_shape
            flat_variances = tf.reshape(variances, flat_means_shape)
            # flat_variances :: batch_size 1 num_states observation_event_size

            # For a mixture of n distributions with mixture probabilities
            # p[i], and where the individual distributions have means and
            # variances given by mean[i] and var[i], the variance of
            # the mixture is given by:
            #
            # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

            flat_variance = tf.einsum("ijk,jikl->jil", flat_probs,
                                      (flat_means - flat_mean)**2 +
                                      flat_variances)
            # flat_variance :: batch_size num_steps observation_event_size

            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)

            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_variance, unflat_mean_shape)
Esempio n. 8
0
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
     sample_dtype = tf.int64 if dtype_util.size(
         self.dtype) > 4 else tf.int32
     draws = tf.random.categorical(logits_2d,
                                   n,
                                   dtype=sample_dtype,
                                   seed=seed)
     draws = tf.cast(draws, self.dtype)
     return tf.reshape(tf.transpose(draws),
                       shape=tf.concat(
                           [[n], self._batch_shape_tensor(logits)], axis=0))
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     sample_shape = prefer_static.concat(
         [[n], prefer_static.shape(logits)], 0)
     event_size = self._event_size(logits)
     if tensorshape_util.rank(logits.shape) == 2:
         logits_2d = logits
     else:
         logits_2d = tf.reshape(logits, [-1, event_size])
     samples = tf.random.categorical(logits_2d, n, seed=seed)
     samples = tf.transpose(a=samples)
     samples = tf.one_hot(samples, event_size, dtype=self.dtype)
     ret = tf.reshape(samples, sample_shape)
     return ret
Esempio n. 10
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=tf.pad(
                        tf.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Esempio n. 11
0
 def _call_and_reshape_output(self,
                              fn,
                              event_shape_list=None,
                              static_event_shape_list=None,
                              extra_kwargs=None):
     """Calls `fn` and appropriately reshapes its output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn`, `event_shape_list`, `static_event_shape_list` and/or
     # `extra_kwargs` as keys.
     with tf.control_dependencies(self._runtime_assertions):
         if event_shape_list is None:
             event_shape_list = [self._event_shape_tensor()]
         if static_event_shape_list is None:
             static_event_shape_list = [self.event_shape]
         new_shape = tf.concat([self._batch_shape_unexpanded] +
                               event_shape_list,
                               axis=0)
         result = tf.reshape(
             fn(**extra_kwargs) if extra_kwargs else fn(), new_shape)
         if (tensorshape_util.rank(self.batch_shape) is not None
                 and tensorshape_util.rank(self.event_shape) is not None):
             event_shape = tf.TensorShape([])
             for rss in static_event_shape_list:
                 event_shape = tensorshape_util.concatenate(
                     event_shape, rss)
             static_shape = tensorshape_util.concatenate(
                 self.batch_shape, event_shape)
             tensorshape_util.set_shape(result, static_shape)
         return result
Esempio n. 12
0
    def _mean(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("mean is not implemented for non-affine "
                                      "bijectors")

        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        x = self.distribution.mean(**distribution_kwargs)

        if self._is_maybe_batch_override or self._is_maybe_event_override:
            # A batch (respectively event) shape override is only allowed if the batch
            # (event) shape of the base distribution is [], so concatenating all the
            # shapes does the right thing.
            new_shape = prefer_static.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor(),
                prefer_static.ones_like(self._override_event_shape),
                self.distribution.event_shape_tensor(),
            ], 0)
            x = tf.reshape(x, new_shape)
            new_shape = prefer_static.concat(
                [self.batch_shape_tensor(),
                 self.event_shape_tensor()], 0)
            x = tf.broadcast_to(x, new_shape)

        y = self.bijector.forward(x, **bijector_kwargs)

        sample_shape = tf.convert_to_tensor([],
                                            dtype=tf.int32,
                                            name="sample_shape")
        y = self._set_sample_static_shape(y, sample_shape)
        return y
Esempio n. 13
0
 def _inverse(self, y):
     output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
         tf.shape(y), self._event_shape_out, self._event_shape_in,
         self.validate_args)
     x = tf.reshape(y, output_shape)
     tensorshape_util.set_shape(x, output_tensorshape)
     return x
Esempio n. 14
0
 def _cdf(self, x):
     x = tf.convert_to_tensor(x, name='x')
     flat_x = tf.reshape(x, shape=[-1])
     upper_bound = tf.searchsorted(self.outcomes,
                                   values=flat_x,
                                   side='right')
     values_at_ub = tf.gather(
         self.outcomes,
         indices=tf.minimum(
             upper_bound,
             dist_util.prefer_static_shape(self.outcomes)[-1] - 1))
     should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub)
     indices = tf.where(should_use_upper_bound, upper_bound,
                        upper_bound - 1)
     return self._categorical.cdf(
         tf.reshape(indices, shape=dist_util.prefer_static_shape(x)))
Esempio n. 15
0
    def _make_columnar(self, x):
        """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 1:
                x = x[tf.newaxis, :]
            return x
        shape = tf.shape(x)
        maybe_expanded_shape = tf.concat([
            shape[:-1],
            distribution_util.pick_vector(tf.equal(tf.rank(x), 1), [1],
                                          np.array([], dtype=np.int32)),
            shape[-1:],
        ], 0)
        return tf.reshape(x, maybe_expanded_shape)
Esempio n. 16
0
    def _call_sample_n(self, sample_shape, seed, name, **kwargs):
        # We override `_call_sample_n` rather than `_sample_n` so we can ensure that
        # the result of `self.bijector.forward` is not modified (and thus caching
        # works).
        with self._name_and_control_scope(name):
            sample_shape = tf.convert_to_tensor(sample_shape,
                                                dtype=tf.int32,
                                                name="sample_shape")
            sample_shape, n = self._expand_sample_shape_to_vector(
                sample_shape, "sample_shape")

            distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(
                kwargs)

            # First, generate samples. We will possibly generate extra samples in the
            # event that we need to reinterpret the samples as part of the
            # event_shape.
            x = self._sample_n(n, seed, **distribution_kwargs)

            # Next, we reshape `x` into its final form. We do this prior to the call
            # to the bijector to ensure that the bijector caching works.
            batch_event_shape = tf.shape(x)[1:]
            final_shape = tf.concat([sample_shape, batch_event_shape], 0)
            x = tf.reshape(x, final_shape)

            # Finally, we apply the bijector's forward transformation. For caching to
            # work, it is imperative that this is the last modification to the
            # returned result.
            y = self.bijector.forward(x, **bijector_kwargs)
            y = self._set_sample_static_shape(y, sample_shape)

            return y
Esempio n. 17
0
 def _forward(self, x):
     output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
         tf.shape(x), self._event_shape_in, self._event_shape_out,
         self.validate_args)
     y = tf.reshape(x, output_shape)
     tensorshape_util.set_shape(y, output_tensorshape)
     return y
Esempio n. 18
0
 def _reshape_part(part, dtype, event_shape):
     part = tf.cast(part, dtype)
     static_rank = tf.get_static_value(
         ps.rank_from_shape(event_shape))
     if static_rank == 1:
         return part
     new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                           axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _extract_log_probs(num_states, dist):
    """Tabulate log probabilities from a batch of distributions."""

    states = tf.reshape(
        tf.range(num_states),
        tf.concat([[num_states],
                   tf.ones_like(dist.batch_shape_tensor())],
                  axis=0))
    return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
Esempio n. 20
0
 def _log_prob(self, x):
     x = tf.convert_to_tensor(x, name='x')
     right_indices = tf.minimum(
         tf.size(self.outcomes) - 1,
         tf.reshape(
             tf.searchsorted(self.outcomes,
                             values=tf.reshape(x, shape=[-1]),
                             side='right'),
             dist_util.prefer_static_shape(x)))
     use_right_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=right_indices))
     left_indices = tf.maximum(0, right_indices - 1)
     use_left_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=left_indices))
     log_probs = self._categorical.log_prob(
         tf.where(use_left_indices, left_indices, right_indices))
     return tf.where(tf.logical_not(use_left_indices | use_right_indices),
                     dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf),
                     log_probs)
Esempio n. 21
0
 def _sample_n(self, n, seed=None, **kwargs):
     with tf.control_dependencies(self._runtime_assertions):
         x = self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
         new_shape = tf.concat([
             [n],
             self._batch_shape_unexpanded,
             self.event_shape_tensor(),
         ],
                               axis=0)
         return tf.reshape(x, new_shape)
 def _pad_sample_dims(self, x):
     with tf.name_scope("pad_sample_dims"):
         ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank(
             x.shape) is not None else tf.rank(x)
         shape = tf.shape(x)
         d = ndims - self._event_ndims
         x = tf.reshape(x,
                        shape=tf.concat([shape[:d], [1], shape[d:]],
                                        axis=0))
         return x
Esempio n. 23
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=stream())
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
Esempio n. 24
0
 def _sample_one_batch_member(args):
     logits, num_cat_samples = args[0], args[1]  # [K], []
     # x has shape [1, num_cat_samples = num_samples * num_trials]
     x = tf.random.categorical(logits[tf.newaxis, ...],
                               num_cat_samples,
                               seed=seed)
     x = tf.reshape(x, shape=[num_samples,
                              -1])  # [num_samples, num_trials]
     x = tf.one_hot(
         x, depth=num_classes)  # [num_samples, num_trials, num_classes]
     x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
     return tf.cast(x, dtype=dtype)
 def _expand_base_distribution_mean(self):
     """Ensures `self.distribution.mean()` has `[batch, event]` shape."""
     single_draw_shape = concat_vectors(self.batch_shape_tensor(),
                                        self.event_shape_tensor())
     m = tf.reshape(
         self.distribution.mean(),  # A scalar.
         shape=tf.ones_like(single_draw_shape, dtype=tf.int32))
     m = tf.tile(m, multiples=single_draw_shape)
     tensorshape_util.set_shape(
         m, tensorshape_util.concatenate(self.batch_shape,
                                         self.event_shape))
     return m
Esempio n. 26
0
    def _stddev(self):
        with tf.control_dependencies(self._assertions):
            distribution_means = [d.mean() for d in self.components]
            distribution_devs = [d.stddev() for d in self.components]
            cat_probs = self._cat_probs(log_probs=False)

            stacked_means = tf.stack(distribution_means, axis=-1)
            stacked_devs = tf.stack(distribution_devs, axis=-1)
            cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
            broadcasted_cat_probs = (tf.stack(cat_probs, axis=-1) *
                                     tf.ones_like(stacked_means))

            batched_dev = distribution_util.mixture_stddev(
                tf.reshape(broadcasted_cat_probs,
                           [-1, len(self.components)]),
                tf.reshape(stacked_means, [-1, len(self.components)]),
                tf.reshape(stacked_devs, [-1, len(self.components)]))

            # I.e. re-shape to list(batch_shape) + list(event_shape).
            return tf.reshape(batched_dev,
                              tf.shape(broadcasted_cat_probs)[:-1])
Esempio n. 27
0
  def _entropy(self):
    samples = tf.convert_to_tensor(self.samples)
    num_samples = self._compute_num_samples(samples)
    entropy_shape = self._batch_shape_tensor(samples)

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      samples = tf.reshape(samples, [-1, num_samples])
    else:
      event_size = tf.reduce_prod(self.event_shape_tensor())
      samples = tf.reshape(samples, [-1, num_samples, event_size])

    # Use map_fn to compute entropy for each batch separately.
    def _get_entropy(samples):
      # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      prob = tf.cast(count / num_samples, dtype=self.dtype)
      entropy = tf.reduce_sum(-prob * tf.math.log(prob))
      return entropy

    entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype)
    return tf.reshape(entropy, entropy_shape)
    def _mean(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size num_states observation_event_size
            flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps observation_event_size
            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)
            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_mean, unflat_mean_shape)
Esempio n. 29
0
 def _compute_quantiles():
   """Helper to build quantiles."""
   # Omit {0, 1} since they might lead to Inf/NaN.
   zero = tf.zeros([], dtype=dist.dtype)
   edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
   # Expand edges so its broadcast across batch dims.
   edges = tf.reshape(
       edges,
       shape=tf.concat(
           [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
   quantiles = dist.quantile(edges)
   # Cyclically permute left by one.
   perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
   quantiles = tf.transpose(a=quantiles, perm=perm)
   return quantiles
 def _expand_mix_distribution_probs(self):
     p = self.mixture_distribution.probs_parameter()  # [B, deg]
     deg = tf.compat.dimension_value(
         tensorshape_util.with_rank_at_least(p.shape, 1)[-1])
     if deg is None:
         deg = tf.shape(p)[-1]
     event_ndims = tensorshape_util.rank(self.event_shape)
     if event_ndims is None:
         event_ndims = tf.shape(self.event_shape_tensor())[0]
     expand_shape = tf.concat([
         self.mixture_distribution.batch_shape_tensor(),
         tf.ones([event_ndims], dtype=tf.int32),
         [deg],
     ],
                              axis=0)
     return tf.reshape(p, shape=expand_shape)