Ejemplo n.º 1
0
 def _sample_3d(self, n, seed=None):
     """Specialized inversion sampler for 3D."""
     seed = SeedStream(seed, salt='von_mises_fisher_3d')
     u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
     z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
     # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
     # be bisected for bounded sampling runtime (i.e. not rejection sampling).
     # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
     # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
     # We must protect against both kappa and z being zero.
     safe_conc = tf.where(self.concentration > 0, self.concentration,
                          tf.ones_like(self.concentration))
     safe_z = tf.where(z > 0, z, tf.ones_like(z))
     safe_u = 1 + tf.reduce_logsumexp(
         [tf.math.log(safe_z),
          tf.math.log1p(-safe_z) - 2 * safe_conc],
         axis=0) / safe_conc
     # Limit of the above expression as kappa->0 is 2*z-1
     u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u,
                  2 * z - 1)
     # Limit of the expression as z->0 is -1.
     u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
     if not self._allow_nan_stats:
         u = tf.debugging.check_numerics(u, 'u in _sample_3d')
     return u[..., tf.newaxis]
Ejemplo n.º 2
0
    def _survival_function(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # survival_function(y) := P[Y > y]
        #                       = 0, if y >= high,
        #                       = 1, if y < low,
        #                       = P[X > y], otherwise.

        # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
        # between.
        j = tf.math.ceil(y)

        # P[X > j], used when low < X < high.
        result_so_far = self.distribution.survival_function(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(j < low, tf.ones_like(result_so_far),
                                     result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far),
                                     result_so_far)

        return result_so_far
Ejemplo n.º 3
0
  def _variance(self):
    df = tf.convert_to_tensor(self.df)
    scale = tf.convert_to_tensor(self.scale)
    # We need to put the tf.where inside the outer tf.where to ensure we never
    # hit a NaN in the gradient.
    denom = tf.where(df > 2., df - 2., tf.ones_like(df))
    # Abs(scale) superfluous.
    var = (tf.ones(self._batch_shape_tensor(df=df, scale=scale),
                   dtype=self.dtype)
           * tf.square(scale) * df / denom)
    # When 1 < df <= 2, variance is infinite.
    result_where_defined = tf.where(
        df > 2.,
        var,
        dtype_util.as_numpy_dtype(self.dtype)(np.inf))

    if self.allow_nan_stats:
      return tf.where(
          df > 1.,
          result_where_defined,
          dtype_util.as_numpy_dtype(self.dtype)(np.nan))
    else:
      return distribution_util.with_dependencies([
          assert_util.assert_less(
              tf.ones([], dtype=self.dtype),
              df,
              message='variance not defined for components of df <= 1'),
      ], result_where_defined)
Ejemplo n.º 4
0
def assert_finite(x, data=None, summarize=None, message=None, name=None):
    """Assert all elements of `x` are finite.

  Args:
    x:  Numeric `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_finite".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
    with tf.name_scope(name or 'assert_finite'):
        x_ = tf.get_static_value(x)
        if x_ is not None:
            if ~np.all(np.isfinite(x_)):
                raise ValueError(message)
            return x
        assertion = tf1.assert_equal(tf.math.is_finite(x),
                                     tf.ones_like(x, tf.bool),
                                     data=data,
                                     summarize=summarize,
                                     message=message)
        with tf.control_dependencies([assertion]):
            return tf.identity(x)
Ejemplo n.º 5
0
 def _covariance(self):
     # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/
     event_dim = tf.compat.dimension_value(self.event_shape[0])
     if event_dim is None:
         raise ValueError(
             'event shape must be statically known for _bessel_ive')
     # TODO(bjp): Enable this; numerically unstable.
     if event_dim > 2:
         raise ValueError(
             'vMF covariance is numerically unstable for dim>2')
     concentration = self.concentration[..., tf.newaxis]
     safe_conc = tf.where(concentration > 0, concentration,
                          tf.ones_like(concentration))
     h = (_bessel_ive(event_dim / 2, safe_conc) /
          _bessel_ive(event_dim / 2 - 1, safe_conc))
     intermediate = (
         tf.matmul(self.mean_direction[..., :, tf.newaxis],
                   self.mean_direction[..., tf.newaxis, :]) *
         (1 - event_dim * h / safe_conc - h**2)[..., tf.newaxis])
     cov = tf.linalg.set_diag(
         intermediate,
         tf.linalg.diag_part(intermediate) + (h / safe_conc))
     return tf.where(
         concentration[..., tf.newaxis] > tf.zeros_like(cov), cov,
         tf.linalg.eye(event_dim, batch_shape=self.batch_shape_tensor()) /
         event_dim)
Ejemplo n.º 6
0
    def _cdf(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # cdf(y) := P[Y <= y]
        #         = 1, if y >= high,
        #         = 0, if y < low,
        #         = P[X <= y], otherwise.

        # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
        # between.
        j = tf.floor(y)

        # P[X <= j], used when low < X < high.
        result_so_far = self.distribution.cdf(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(j < low, tf.zeros_like(result_so_far),
                                     result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.ones_like(result_so_far),
                                     result_so_far)

        return result_so_far
def _extract_log_probs(num_states, dist):
    """Tabulate log probabilities from a batch of distributions."""

    states = tf.reshape(
        tf.range(num_states),
        tf.concat([[num_states],
                   tf.ones_like(dist.batch_shape_tensor())],
                  axis=0))
    return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
Ejemplo n.º 8
0
 def _prob(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     return tf.where(
         tf.math.is_nan(x),
         x,
         tf.where(
             # This > is only sound for continuous uniform
             (x < low) | (x > high),
             tf.zeros_like(x),
             tf.ones_like(x) / self._range(low=low, high=high)))
 def _expand_base_distribution_mean(self):
     """Ensures `self.distribution.mean()` has `[batch, event]` shape."""
     single_draw_shape = concat_vectors(self.batch_shape_tensor(),
                                        self.event_shape_tensor())
     m = tf.reshape(
         self.distribution.mean(),  # A scalar.
         shape=tf.ones_like(single_draw_shape, dtype=tf.int32))
     m = tf.tile(m, multiples=single_draw_shape)
     tensorshape_util.set_shape(
         m, tensorshape_util.concatenate(self.batch_shape,
                                         self.event_shape))
     return m
Ejemplo n.º 10
0
    def _log_prob(self, x):
        temperature = tf.convert_to_tensor(self.temperature)
        logits = self._logits_parameter_no_checks()

        x = self._assert_valid_sample(x)
        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            logits = tf.ones_like(x, dtype=logits.dtype) * logits
            x = tf.ones_like(logits, dtype=x.dtype) * x
        # compute the normalization constant
        k = tf.cast(self._event_size(logits), x.dtype)
        log_norm_const = (tf.math.lgamma(k) +
                          (k - 1.) * tf.math.log(temperature))
        # compute the unnormalized density
        log_softmax = tf.math.log_softmax(logits -
                                          x * temperature[..., tf.newaxis])
        log_unnorm_prob = tf.reduce_sum(log_softmax, axis=[-1], keepdims=False)
        # combine unnormalized density with normalization constant
        return log_norm_const + log_unnorm_prob
Ejemplo n.º 11
0
def _bessel_ive(v, z, cache=None):
    """Computes I_v(z)*exp(-abs(z)) using a recurrence relation, where z > 0."""
    # TODO(b/67497980): Switch to a more numerically faithful implementation.
    z = tf.convert_to_tensor(z)

    wrap = lambda result: tf.debugging.check_numerics(result, 'besseli{}'.
                                                      format(v))

    if float(v) >= 2:
        raise ValueError(
            'Evaluating bessel_i by recurrence becomes imprecise for large v')

    cache = cache or {}
    safe_z = tf.where(z > 0, z, tf.ones_like(z))
    if v in cache:
        return wrap(cache[v])
    if v == 0:
        cache[v] = tf.math.bessel_i0e(z)
    elif v == 1:
        cache[v] = tf.math.bessel_i1e(z)
    elif v == 0.5:
        # sinh(x)*exp(-abs(x)), sinh(x) = (e^x - e^{-x}) / 2
        sinhe = lambda x: (tf.exp(x - tf.abs(x)) - tf.exp(-x - tf.abs(x))) / 2
        cache[v] = (
            np.sqrt(2 / np.pi) * sinhe(z) *
            tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z)))
    elif v == -0.5:
        # cosh(x)*exp(-abs(x)), cosh(x) = (e^x + e^{-x}) / 2
        coshe = lambda x: (tf.exp(x - tf.abs(x)) + tf.exp(-x - tf.abs(x))) / 2
        cache[v] = (
            np.sqrt(2 / np.pi) * coshe(z) *
            tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z)))
    if v <= 1:
        return wrap(cache[v])
    # Recurrence relation:
    cache[v] = (_bessel_ive(v - 2, z, cache) -
                (2 * (v - 1)) * _bessel_ive(v - 1, z, cache) / z)
    return wrap(cache[v])
Ejemplo n.º 12
0
 def _mean(self):
     # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/
     event_dim = tf.compat.dimension_value(self.event_shape[0])
     if event_dim is None:
         raise ValueError(
             'event shape must be statically known for _bessel_ive')
     safe_conc = tf.where(self.concentration > 0, self.concentration,
                          tf.ones_like(self.concentration))
     safe_mean = self.mean_direction * (
         _bessel_ive(event_dim / 2, safe_conc) /
         _bessel_ive(event_dim / 2 - 1, safe_conc))[..., tf.newaxis]
     return tf.where(
         self.concentration[..., tf.newaxis] > tf.zeros_like(safe_mean),
         safe_mean, tf.zeros_like(safe_mean))
Ejemplo n.º 13
0
 def _log_normalization(self):
     """Computes the log-normalizer of the distribution."""
     event_dim = tf.compat.dimension_value(self.event_shape[0])
     if event_dim is None:
         raise ValueError('vMF _log_normalizer currently only supports '
                          'statically known event shape')
     safe_conc = tf.where(self.concentration > 0, self.concentration,
                          tf.ones_like(self.concentration))
     safe_lognorm = ((event_dim / 2 - 1) * tf.math.log(safe_conc) -
                     (event_dim / 2) * np.log(2 * np.pi) - tf.math.log(
                         _bessel_ive(event_dim / 2 - 1, safe_conc)) -
                     tf.abs(safe_conc))
     log_nsphere_surface_area = (
         np.log(2.) + (event_dim / 2) * np.log(np.pi) -
         tf.math.lgamma(tf.cast(event_dim / 2, self.dtype)))
     return tf.where(self.concentration > 0, -safe_lognorm,
                     log_nsphere_surface_area)
Ejemplo n.º 14
0
def _broadcast_event_and_samples(event, samples, event_ndims):
  """Broadcasts the event or samples."""
  # This is the shape of self.samples, without the samples axis, i.e. the shape
  # of the result of a call to dist.sample(). This way we can broadcast it with
  # event to get a properly-sized event, then add the singleton dim back at
  # -event_ndims - 1.
  samples_shape = tf.concat(
      [
          tf.shape(samples)[:-event_ndims - 1],
          tf.shape(samples)[tf.rank(samples) - event_ndims:]
      ],
      axis=0)
  event = event * tf.ones(samples_shape, dtype=event.dtype)
  event = tf.expand_dims(event, axis=-event_ndims - 1)
  samples = samples * tf.ones_like(event, dtype=samples.dtype)

  return event, samples
Ejemplo n.º 15
0
def _bdtr(k, n, p):
    """The binomial cumulative distribution function.

  Args:
    k: floating point `Tensor`.
    n: floating point `Tensor`.
    p: floating point `Tensor`.

  Returns:
    `sum_{j=0}^k p^j (1 - p)^(n - j)`.
  """
    # Trick for getting safe backprop/gradients into n, k when
    #   betainc(a = 0, ..) = nan
    # Write:
    #   where(unsafe, safe_output, betainc(where(unsafe, safe_input, input)))
    ones = tf.ones_like(n - k)
    k_eq_n = tf.equal(k, n)
    safe_dn = tf.where(k_eq_n, ones, n - k)
    dk = tf.math.betainc(a=safe_dn, b=k + 1, x=1 - p)
    return tf.where(k_eq_n, ones, dk)
Ejemplo n.º 16
0
    def _stddev(self):
        with tf.control_dependencies(self._assertions):
            distribution_means = [d.mean() for d in self.components]
            distribution_devs = [d.stddev() for d in self.components]
            cat_probs = self._cat_probs(log_probs=False)

            stacked_means = tf.stack(distribution_means, axis=-1)
            stacked_devs = tf.stack(distribution_devs, axis=-1)
            cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
            broadcasted_cat_probs = (tf.stack(cat_probs, axis=-1) *
                                     tf.ones_like(stacked_means))

            batched_dev = distribution_util.mixture_stddev(
                tf.reshape(broadcasted_cat_probs,
                           [-1, len(self.components)]),
                tf.reshape(stacked_means, [-1, len(self.components)]),
                tf.reshape(stacked_devs, [-1, len(self.components)]))

            # I.e. re-shape to list(batch_shape) + list(event_shape).
            return tf.reshape(batched_dev,
                              tf.shape(broadcasted_cat_probs)[:-1])
Ejemplo n.º 17
0
    def _cdf(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        interval_length = high - low
        # Due to the PDF being not smooth at the peak, we have to treat each side
        # somewhat differently. The PDF is two line segments, and thus we get
        # quadratics here for the CDF.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # (x - low) ** 2 / ((high - low) * (peak - low))
            tf.math.squared_difference(x, low) / (interval_length *
                                                  (peak - low)),
            # 1 - (high - x) ** 2 / ((high - low) * (high - peak))
            1. - tf.math.squared_difference(high, x) / (interval_length *
                                                        (high - peak)))

        # We now add that the left tail is 0 and the right tail is 1.
        result_if_not_big = tf.where(x < low, tf.zeros_like(x),
                                     result_inside_interval)

        return tf.where(x >= high, tf.ones_like(x), result_if_not_big)
Ejemplo n.º 18
0
def _broadcast_cat_event_and_params(event, params, base_dtype):
    """Broadcasts the event or distribution parameters."""
    if dtype_util.is_integer(event.dtype):
        pass
    elif dtype_util.is_floating(event.dtype):
        # When `validate_args=True` we've already ensured int/float casting
        # is closed.
        event = tf.cast(event, dtype=tf.int32)
    else:
        raise TypeError('`value` should have integer `dtype` or '
                        '`self.dtype` ({})'.format(base_dtype))
    shape_known_statically = (
        tensorshape_util.rank(params.shape) is not None
        and tensorshape_util.is_fully_defined(params.shape[:-1])
        and tensorshape_util.is_fully_defined(event.shape))
    if not shape_known_statically or params.shape[:-1] != event.shape:
        params = params * tf.ones_like(event[..., tf.newaxis],
                                       dtype=params.dtype)
        params_shape = tf.shape(params)[:-1]
        event = event * tf.ones(params_shape, dtype=event.dtype)
        if tensorshape_util.rank(params.shape) is not None:
            tensorshape_util.set_shape(event, params.shape[:-1])

    return event, params
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, 'dirichlet_multinomial')

    concentration = tf.convert_to_tensor(self._concentration)
    total_count = tf.convert_to_tensor(self._total_count)

    n_draws = tf.cast(total_count, dtype=tf.int32)
    k = self._event_shape_tensor(concentration)[0]
    alpha = tf.math.multiply(
        tf.ones_like(total_count[..., tf.newaxis]),
        concentration,
        name='alpha')

    unnormalized_logits = tf.math.log(
        tf.random.gamma(
            shape=[n],
            alpha=alpha,
            dtype=self.dtype,
            seed=seed()))
    x = multinomial.draw_sample(
        1, k, unnormalized_logits, n_draws, self.dtype, seed())
    final_shape = tf.concat(
        [[n], self._batch_shape_tensor(concentration, total_count), [k]], 0)
    return tf.reshape(x, final_shape)
Ejemplo n.º 20
0
 def _mode(self):
     return self.loc * tf.ones_like(self.scale)
Ejemplo n.º 21
0
 def _entropy(self):
     h = np.log(2 * np.pi) + tf.math.log(self.scale)
     return h * tf.ones_like(self.loc)
Ejemplo n.º 22
0
 def _entropy(self):
   log_normalization = 0.5 * np.log(2. * np.pi) + tf.math.log(self.scale)
   entropy = 0.5 + log_normalization
   return entropy * tf.ones_like(self.loc)
Ejemplo n.º 23
0
 def _stddev(self):
     return self.scale * tf.ones_like(self.loc) * np.pi / np.sqrt(6)
Ejemplo n.º 24
0
 def _stddev(self):
   return self.scale * tf.ones_like(self.loc)
Ejemplo n.º 25
0
def reduce_weighted_logsumexp(logx,
                              w=None,
                              axis=None,
                              keep_dims=False,
                              return_sign=False,
                              name=None):
  """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.

  If all weights `w` are known to be positive, it is more efficient to directly
  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more
  efficient than `du.reduce_weighted_logsumexp(logx, w)`.

  Reduces `input_tensor` along the dimensions given in `axis`.
  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
  entry in `axis`. If `keep_dims` is true, the reduced dimensions
  are retained with length 1.

  If `axis` has no entries, all dimensions are reduced, and a
  tensor with a single element is returned.

  This function is more numerically stable than log(sum(w * exp(input))). It
  avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  For example:

  ```python
  x = tf.constant([[0., 0, 0],
                   [0, 0, 0]])

  w = tf.constant([[-1., 1, 1],
                   [1, 1, 1]])

  du.reduce_weighted_logsumexp(x, w)
  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)

  du.reduce_weighted_logsumexp(x, w, axis=0)
  # ==> [log(-1+1), log(1+1), log(1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1)
  # ==> [log(-1+1+1), log(1+1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
  # ==> [[log(-1+1+1)], [log(1+1+1)]]

  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
  # ==> log(-1+5)
  ```

  Args:
    logx: The tensor to reduce. Should have numeric type.
    w: The weight tensor. Should have numeric type identical to `logx`.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keep_dims: If true, retains reduced dimensions with length 1.
    return_sign: If `True`, returns the sign of the result.
    name: A name for the operation (optional).

  Returns:
    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
    sign: (Optional) The sign of `sum(weight * exp(x))`.
  """
  with tf.name_scope(name or 'reduce_weighted_logsumexp'):
    logx = tf.convert_to_tensor(logx, name='logx')
    if w is None:
      lswe = tf.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
      if return_sign:
        sgn = tf.ones_like(lswe)
        return lswe, sgn
      return lswe
    w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w')
    log_absw_x = logx + tf.math.log(tf.abs(w))
    max_log_absw_x = tf.reduce_max(log_absw_x, axis=axis, keepdims=True)
    # If the largest element is `-inf` or `inf` then we don't bother subtracting
    # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
    # this is ok follows from the fact that we're actually free to subtract any
    # value we like, so long as we add it back after taking the `log(sum(...))`.
    max_log_absw_x = tf.where(
        tf.math.is_inf(max_log_absw_x),
        tf.zeros([], max_log_absw_x.dtype),
        max_log_absw_x)
    wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x))
    sum_wx_over_max_absw_x = tf.reduce_sum(
        wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
    if not keep_dims:
      max_log_absw_x = tf.squeeze(max_log_absw_x, axis)
    sgn = tf.sign(sum_wx_over_max_absw_x)
    lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x)
    if return_sign:
      return lswe, sgn
    return lswe
Ejemplo n.º 26
0
def draw_sample(num_samples, num_classes, logits, num_trials, dtype, seed):
    """Sample a multinomial.

  The batch shape is given by broadcasting num_trials with
  remove_last_dimension(logits).

  Args:
    num_samples: Python int or singleton integer Tensor: number of multinomial
      samples to draw.
    num_classes: Python int or singleton integer Tensor: number of classes.
    logits: Floating Tensor with last dimension k, of (unnormalized) logit
      probabilities per class.
    num_trials: Tensor of number of categorical trials each multinomial consists
      of.  num_trials[..., tf.newaxis] must broadcast with logits.
    dtype: dtype at which to emit samples.
    seed: Random seed.

  Returns:
    samples: Tensor of given dtype and shape [n] + batch_shape + [k].
  """
    with tf.name_scope('draw_sample'):
        # broadcast the num_trials and logits to same shape
        num_trials = tf.ones_like(logits[..., 0],
                                  dtype=num_trials.dtype) * num_trials
        logits = tf.ones_like(num_trials[..., tf.newaxis],
                              dtype=logits.dtype) * logits

        # flatten the total_count and logits
        # flat_logits has shape [B1B2...Bm, num_classes]
        flat_logits = tf.reshape(logits, [-1, num_classes])
        flat_num_trials = num_samples * tf.reshape(num_trials,
                                                   [-1])  # [B1B2...Bm]

        # Computes each logits and num_trials situation by map_fn.

        # Using just one batch tf.random.categorical call doesn't work because that
        # requires num_trials to be the same across all members of the batch of
        # logits.  This restriction makes sense for tf.random.categorical because
        # for it, num_trials is part of the returned shape.  However, the
        # multinomial sampler does not need that restriction, because it sums out
        # exactly that dimension.

        # One possibility would be to draw a batch categorical whose sample count is
        # max(num_trials) and mask out the excess ones.  However, if the elements of
        # num_trials vary widely, this can be wasteful of memory.

        # TODO(b/123763054, b/112152209): Revisit the possibility of writing this
        # with a batch categorical followed by batch unsorted_segment_sum, once both
        # of those work and are memory-efficient enough.
        def _sample_one_batch_member(args):
            logits, num_cat_samples = args[0], args[1]  # [K], []
            # x has shape [1, num_cat_samples = num_samples * num_trials]
            x = tf.random.categorical(logits[tf.newaxis, ...],
                                      num_cat_samples,
                                      seed=seed)
            x = tf.reshape(x, shape=[num_samples,
                                     -1])  # [num_samples, num_trials]
            x = tf.one_hot(
                x, depth=num_classes)  # [num_samples, num_trials, num_classes]
            x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
            return tf.cast(x, dtype=dtype)

        if seed is not None:
            # Force parallel_iterations to 1 to ensure reproducibility
            # b/139210489
            x = tf.map_fn(
                _sample_one_batch_member,
                [flat_logits, flat_num_trials],
                dtype=dtype,  # [B1B2...Bm, num_samples, num_classes]
                parallel_iterations=1)
        else:
            # Invoke default parallel_iterations behavior
            x = tf.map_fn(_sample_one_batch_member,
                          [flat_logits, flat_num_trials],
                          dtype=dtype)  # [B1B2...Bm, num_samples, num_classes]

        # reshape the results to proper shape
        x = tf.transpose(a=x, perm=[1, 0, 2])
        final_shape = tf.concat(
            [[num_samples], tf.shape(num_trials), [num_classes]], axis=0)
        x = tf.reshape(x, final_shape)

        return x
Ejemplo n.º 27
0
 def _entropy(self):
     # Use broadcasting rules to calculate the full broadcast sigma.
     scale = self.scale * tf.ones_like(self.loc)
     return 1. + tf.math.log(scale) + np.euler_gamma
Ejemplo n.º 28
0
 def _mode(self):
     return tf.ones_like(self.power, dtype=self.dtype)
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
Ejemplo n.º 30
0
    def _sample_n(self, num_samples, seed=None, name=None):
        """Returns a Tensor of samples from an LKJ distribution.

    Args:
      num_samples: Python `int`. The number of samples to draw.
      seed: Python integer seed for RNG
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      samples: A Tensor of correlation matrices with shape `[n, B, D, D]`,
        where `B` is the shape of the `concentration` parameter, and `D`
        is the `dimension`.

    Raises:
      ValueError: If `dimension` is negative.
    """
        if self.dimension < 0:
            raise ValueError(
                'Cannot sample negative-dimension correlation matrices.')
        # Notation below: B is the batch shape, i.e., tf.shape(concentration)
        seed = SeedStream(seed, 'sample_lkj')
        with tf.name_scope('sample_lkj' or name):
            concentration = tf.convert_to_tensor(self.concentration)
            if not dtype_util.is_floating(concentration.dtype):
                raise TypeError(
                    'The concentration argument should have floating type, not '
                    '{}'.format(dtype_util.name(concentration.dtype)))

            concentration = _replicate(num_samples, concentration)
            concentration_shape = tf.shape(concentration)
            if self.dimension <= 1:
                # For any dimension <= 1, there is only one possible correlation matrix.
                shape = tf.concat(
                    [concentration_shape, [self.dimension, self.dimension]],
                    axis=0)
                return tf.ones(shape=shape, dtype=concentration.dtype)
            beta_conc = concentration + (self.dimension - 2.) / 2.
            beta_dist = beta.Beta(concentration1=beta_conc,
                                  concentration0=beta_conc)

            # Note that the sampler below deviates from [1], by doing the sampling in
            # cholesky space. This does not change the fundamental logic of the
            # sampler, but does speed up the sampling.

            # This is the correlation coefficient between the first two dimensions.
            # This is also `r` in reference [1].
            corr12 = 2. * beta_dist.sample(seed=seed()) - 1.

            # Below we construct the Cholesky of the initial 2x2 correlation matrix,
            # which is of the form:
            # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
            # first two dimensions.
            # This is the top-left corner of the cholesky of the final sample.
            first_row = tf.concat([
                tf.ones_like(corr12)[..., tf.newaxis],
                tf.zeros_like(corr12)[..., tf.newaxis]
            ],
                                  axis=-1)
            second_row = tf.concat([
                corr12[..., tf.newaxis],
                tf.sqrt(1 - corr12**2)[..., tf.newaxis]
            ],
                                   axis=-1)

            chol_result = tf.concat([
                first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]
            ],
                                    axis=-2)

            for n in range(2, self.dimension):
                # Loop invariant: on entry, result has shape B + [n, n]
                beta_conc = beta_conc - 0.5
                # norm is y in reference [1].
                norm = beta.Beta(concentration1=n / 2.,
                                 concentration0=beta_conc).sample(seed=seed())
                # distance shape: B + [1] for broadcast
                distance = tf.sqrt(norm)[..., tf.newaxis]
                # direction is u in reference [1].
                # direction shape: B + [n]
                direction = _uniform_unit_norm(n, concentration_shape,
                                               concentration.dtype, seed)
                # raw_correlation is w in reference [1].
                raw_correlation = distance * direction  # shape: B + [n]

                # This is the next row in the cholesky of the result,
                # which differs from the construction in reference [1].
                # In the reference, the new row `z` = chol_result @ raw_correlation^T
                # = C @ raw_correlation^T (where as short hand we use C = chol_result).
                # We prove that the below equation is the right row to add to the
                # cholesky, by showing equality with reference [1].
                # Let S be the sample constructed so far, and let `z` be as in
                # reference [1]. Then at this iteration, the new sample S' will be
                # [[S z^T]
                #  [z 1]]
                # In our case we have the cholesky decomposition factor C, so
                # we want our new row x (same size as z) to satisfy:
                #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
                #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
                # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
                # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
                # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
                # distance**2).
                new_row = tf.concat(
                    [raw_correlation,
                     tf.sqrt(1. - norm[..., tf.newaxis])],
                    axis=-1)

                # Finally add this new row, by growing the cholesky of the result.
                chol_result = tf.concat([
                    chol_result,
                    tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
                ],
                                        axis=-1)

                chol_result = tf.concat(
                    [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

            if self.input_output_cholesky:
                return chol_result

            result = tf.matmul(chol_result, chol_result, transpose_b=True)
            # The diagonal for a correlation matrix should always be ones. Due to
            # numerical instability the matmul might not achieve that, so manually set
            # these to ones.
            result = tf.linalg.set_diag(
                result, tf.ones(shape=tf.shape(result)[:-1],
                                dtype=result.dtype))
            # This sampling algorithm can produce near-PSD matrices on which standard
            # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
            # fail. Specifically, as documented in b/116828694, around 2% of trials
            # of 900,000 5x5 matrices (distributed according to 9 different
            # concentration parameter values) contained at least one matrix on which
            # the Cholesky decomposition failed.
            return result