Пример #1
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=tf.pad(
                        tf.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
 def _maybe_rotate_dims(self, x, rotate_right=False):
     """Helper which rolls left event_dims left or right event_dims right."""
     needs_rotation_const = tf.get_static_value(self._needs_rotation)
     if needs_rotation_const is not None and not needs_rotation_const:
         return x
     ndims = prefer_static.rank(x)
     n = (ndims -
          self._rotate_ndims) if rotate_right else self._rotate_ndims
     perm = prefer_static.concat(
         [prefer_static.range(n, ndims),
          prefer_static.range(0, n)], axis=0)
     return tf.transpose(a=x, perm=perm)
Пример #3
0
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
     sample_dtype = tf.int64 if dtype_util.size(
         self.dtype) > 4 else tf.int32
     draws = tf.random.categorical(logits_2d,
                                   n,
                                   dtype=sample_dtype,
                                   seed=seed)
     draws = tf.cast(draws, self.dtype)
     return tf.reshape(tf.transpose(draws),
                       shape=tf.concat(
                           [[n], self._batch_shape_tensor(logits)], axis=0))
Пример #4
0
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     sample_shape = prefer_static.concat(
         [[n], prefer_static.shape(logits)], 0)
     event_size = self._event_size(logits)
     if tensorshape_util.rank(logits.shape) == 2:
         logits_2d = logits
     else:
         logits_2d = tf.reshape(logits, [-1, event_size])
     samples = tf.random.categorical(logits_2d, n, seed=seed)
     samples = tf.transpose(a=samples)
     samples = tf.one_hot(samples, event_size, dtype=self.dtype)
     ret = tf.reshape(samples, sample_shape)
     return ret
Пример #5
0
 def _sample_n(self, n, seed=None):
     samples = tf.convert_to_tensor(self._samples)
     indices = tf.random.uniform([n],
                                 maxval=self._compute_num_samples(samples),
                                 dtype=tf.int32,
                                 seed=seed)
     draws = tf.gather(samples, indices, axis=self._samples_axis)
     axes = tf.concat([[self._samples_axis],
                       tf.range(self._samples_axis, dtype=tf.int32),
                       tf.range(self._event_ndims, dtype=tf.int32) +
                       self._samples_axis + 1],
                      axis=0)
     draws = tf.transpose(a=draws, perm=axes)
     return draws
 def _compute_quantiles():
     """Helper to build quantiles."""
     # Omit {0, 1} since they might lead to Inf/NaN.
     zero = tf.zeros([], dtype=dist.dtype)
     edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
     # Expand edges so its broadcast across batch dims.
     edges = tf.reshape(
         edges,
         shape=tf.concat(
             [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
     quantiles = dist.quantile(edges)
     # Cyclically permute left by one.
     perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
     quantiles = tf.transpose(a=quantiles, perm=perm)
     return quantiles
Пример #7
0
 def _sample_n(self, n, seed, **kwargs):
     fake_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     perm = prefer_static.concat([
         [0],
         prefer_static.range(1 + fake_sample_ndims,
                             1 + fake_sample_ndims + batch_ndims),
         prefer_static.range(1, 1 + fake_sample_ndims),
         prefer_static.range(
             1 + fake_sample_ndims + batch_ndims,
             1 + fake_sample_ndims + batch_ndims + event_ndims),
     ],
                                 axis=0)
     x = self.distribution.sample(prefer_static.concat(
         [[n], self.sample_shape], axis=0),
                                  seed=seed,
                                  **kwargs)
     return tf.transpose(a=x, perm=perm)
Пример #8
0
def draw_sample(num_samples, num_classes, logits, num_trials, dtype, seed):
    """Sample a multinomial.

  The batch shape is given by broadcasting num_trials with
  remove_last_dimension(logits).

  Args:
    num_samples: Python int or singleton integer Tensor: number of multinomial
      samples to draw.
    num_classes: Python int or singleton integer Tensor: number of classes.
    logits: Floating Tensor with last dimension k, of (unnormalized) logit
      probabilities per class.
    num_trials: Tensor of number of categorical trials each multinomial consists
      of.  num_trials[..., tf.newaxis] must broadcast with logits.
    dtype: dtype at which to emit samples.
    seed: Random seed.

  Returns:
    samples: Tensor of given dtype and shape [n] + batch_shape + [k].
  """
    with tf.name_scope('draw_sample'):
        # broadcast the num_trials and logits to same shape
        num_trials = tf.ones_like(logits[..., 0],
                                  dtype=num_trials.dtype) * num_trials
        logits = tf.ones_like(num_trials[..., tf.newaxis],
                              dtype=logits.dtype) * logits

        # flatten the total_count and logits
        # flat_logits has shape [B1B2...Bm, num_classes]
        flat_logits = tf.reshape(logits, [-1, num_classes])
        flat_num_trials = num_samples * tf.reshape(num_trials,
                                                   [-1])  # [B1B2...Bm]

        # Computes each logits and num_trials situation by map_fn.

        # Using just one batch tf.random.categorical call doesn't work because that
        # requires num_trials to be the same across all members of the batch of
        # logits.  This restriction makes sense for tf.random.categorical because
        # for it, num_trials is part of the returned shape.  However, the
        # multinomial sampler does not need that restriction, because it sums out
        # exactly that dimension.

        # One possibility would be to draw a batch categorical whose sample count is
        # max(num_trials) and mask out the excess ones.  However, if the elements of
        # num_trials vary widely, this can be wasteful of memory.

        # TODO(b/123763054, b/112152209): Revisit the possibility of writing this
        # with a batch categorical followed by batch unsorted_segment_sum, once both
        # of those work and are memory-efficient enough.
        def _sample_one_batch_member(args):
            logits, num_cat_samples = args[0], args[1]  # [K], []
            # x has shape [1, num_cat_samples = num_samples * num_trials]
            x = tf.random.categorical(logits[tf.newaxis, ...],
                                      num_cat_samples,
                                      seed=seed)
            x = tf.reshape(x, shape=[num_samples,
                                     -1])  # [num_samples, num_trials]
            x = tf.one_hot(
                x, depth=num_classes)  # [num_samples, num_trials, num_classes]
            x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
            return tf.cast(x, dtype=dtype)

        if seed is not None:
            # Force parallel_iterations to 1 to ensure reproducibility
            # b/139210489
            x = tf.map_fn(
                _sample_one_batch_member,
                [flat_logits, flat_num_trials],
                dtype=dtype,  # [B1B2...Bm, num_samples, num_classes]
                parallel_iterations=1)
        else:
            # Invoke default parallel_iterations behavior
            x = tf.map_fn(_sample_one_batch_member,
                          [flat_logits, flat_num_trials],
                          dtype=dtype)  # [B1B2...Bm, num_samples, num_classes]

        # reshape the results to proper shape
        x = tf.transpose(a=x, perm=[1, 0, 2])
        final_shape = tf.concat(
            [[num_samples], tf.shape(num_trials), [num_classes]], axis=0)
        x = tf.reshape(x, final_shape)

        return x
Пример #9
0
 def _transpose(self, x, perm):
     perm = self._make_perm(tf.rank(x), perm)
     return tf.transpose(a=x, perm=perm)
Пример #10
0
    def _sample_n(self, n, seed=None):
        sample_and_batch_shape = tf.concat([[n], self.batch_shape_tensor()], 0)
        flat_batch_and_sample_shape = tf.stack(
            [tf.reduce_prod(self.batch_shape_tensor()), n])

        # In order to be reparameterizable we sample on the truncated_normal of
        # unit variance and mean and scale (but with the standardized
        # truncation bounds).

        @tf.custom_gradient
        def _std_samples_with_gradients(lower, upper):
            """Standard truncated Normal with gradient support for low, high."""
            # Note: Unlike the convention in tf_probability,
            # parameterized_truncated_normal returns a tensor with the final dimension
            # being the sample dimension.
            std_samples = random_ops.parameterized_truncated_normal(
                shape=flat_batch_and_sample_shape,
                means=0.0,
                stddevs=1.0,
                minvals=lower,
                maxvals=upper,
                dtype=self.dtype,
                seed=seed)

            def grad(dy):
                """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
                # std_samples has an extra dimension (the sample dimension), expand
                # lower and upper so they broadcast along this dimension.
                # See note above regarding parameterized_truncated_normal, the sample
                # dimension is the final dimension.
                lower_broadcast = lower[..., tf.newaxis]
                upper_broadcast = upper[..., tf.newaxis]

                cdf_samples = ((special_math.ndtr(std_samples) -
                                special_math.ndtr(lower_broadcast)) /
                               (special_math.ndtr(upper_broadcast) -
                                special_math.ndtr(lower_broadcast)))

                # tiny, eps are tolerance parameters to ensure we stay away from giving
                # a zero arg to the log CDF expression.

                tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
                eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
                cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

                du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                            tf.math.log(cdf_samples))
                dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                            tf.math.log1p(-cdf_samples))

                # Reduce the gradient across the samples
                grad_u = tf.reduce_sum(dy * du, axis=-1)
                grad_l = tf.reduce_sum(dy * dl, axis=-1)
                return [grad_l, grad_u]

            return std_samples, grad

        std_samples = _std_samples_with_gradients(
            tf.reshape(self._standardized_low, [-1]),
            tf.reshape(self._standardized_high, [-1]))

        # The returned shape is [flat_batch x n]
        std_samples = tf.transpose(a=std_samples, perm=[1, 0])

        std_samples = tf.reshape(std_samples, sample_and_batch_shape)
        samples = (std_samples * tf.expand_dims(self._scale, axis=0) +
                   tf.expand_dims(self._loc, axis=0))

        return samples
Пример #11
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(batch_shape) - 2
        sample_shape = tf.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) *
            tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat(
            [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Пример #12
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=stream())

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=dtype_util.base_dtype(self.df.dtype))

        g = tf.random.gamma(shape=[n],
                            alpha=self._multi_gamma_sequence(
                                0.5 * expanded_df, self.dimension),
                            beta=0.5,
                            dtype=self.dtype,
                            seed=stream())

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x