Exemplo n.º 1
0
def sparse_or_dense_matmul(sparse_or_dense_a,
                           dense_b,
                           validate_args=False,
                           name=None,
                           **kwargs):
    """Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.

  Args:
    sparse_or_dense_a: `SparseTensor` or `Tensor` representing a (batch of)
      matrices.
    dense_b: `Tensor` representing a (batch of) matrices, with the same batch
      shape as `sparse_or_dense_a`. The shape must be compatible with the shape
      of `sparse_or_dense_a` and kwargs.
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'sparse_or_dense_matmul'.
    **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul` or
      `tf.matmul`.

  Returns:
    product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
    dtype as `sparse_or_dense_a` and `dense_b`. If `sparse_or_dense_a` or
    `dense_b` is adjointed through `kwargs` then the shape is adjusted
    accordingly.
  """
    with tf.name_scope(name or 'sparse_or_dense_matmul'):
        dense_b = tf.convert_to_tensor(dense_b,
                                       dtype_hint=tf.float32,
                                       name='dense_b')

        if validate_args:
            assert_a_rank_at_least_2 = assert_util.assert_rank_at_least(
                sparse_or_dense_a,
                rank=2,
                message=
                'Input `sparse_or_dense_a` must have at least 2 dimensions.')
            assert_b_rank_at_least_2 = assert_util.assert_rank_at_least(
                dense_b,
                rank=2,
                message='Input `dense_b` must have at least 2 dimensions.')
            with tf.control_dependencies(
                [assert_a_rank_at_least_2, assert_b_rank_at_least_2]):
                sparse_or_dense_a = tf.identity(sparse_or_dense_a)
                dense_b = tf.identity(dense_b)

        if isinstance(sparse_or_dense_a,
                      (tf.SparseTensor, tf1.SparseTensorValue)):
            return _sparse_tensor_dense_matmul(sparse_or_dense_a, dense_b,
                                               **kwargs)
        else:
            return tf.matmul(sparse_or_dense_a, dense_b, **kwargs)
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Exemplo n.º 3
0
def vector_size_to_square_matrix_size(d, validate_args, name=None):
    """Convert a vector size to a matrix size."""
    if isinstance(d, (float, int, np.generic, np.ndarray)):
        n = (-1 + np.sqrt(1 + 8 * d)) / 2.
        if float(int(n)) != n:
            raise ValueError(
                'Vector length {} is not a triangular number.'.format(d))
        return int(n)
    else:
        with tf.name_scope(name
                           or 'vector_size_to_square_matrix_size') as name:
            n = (-1. + tf.sqrt(1 + 8. * tf.cast(d, dtype=tf.float32))) / 2.
            if validate_args:
                with tf.control_dependencies([
                        assert_util.assert_equal(
                            tf.cast(tf.cast(n, dtype=tf.int32),
                                    dtype=tf.float32),
                            n,
                            data=[
                                'Vector length is not a triangular number: ', d
                            ],
                            message='Vector length is not a triangular number')
                ]):
                    n = tf.identity(n)
            return tf.cast(n, d.dtype)
Exemplo n.º 4
0
def matrix_rank(a, tol=None, validate_args=False, name=None):
    """Compute the matrix rank; the number of non-zero SVD singular values.

  Arguments:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    tol: Threshold below which the singular value is counted as 'zero'.
      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'matrix_rank'.

  Returns:
    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
      singular values.
  """
    with tf.name_scope(name or 'matrix_rank'):
        a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)
        s = tf.linalg.svd(a, compute_uv=False)
        if tol is None:
            if tensorshape_util.is_fully_defined(a.shape[-2:]):
                m = np.max(a.shape[-2:].as_list())
            else:
                m = tf.reduce_max(tf.shape(a)[-2:])
            eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps
            tol = (eps * tf.cast(m, a.dtype) *
                   tf.reduce_max(s, axis=-1, keepdims=True))
        return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
    def _std_var_helper(self, statistic, statistic_name, statistic_ndims,
                        df_factor_fn):
        """Helper to compute stddev, covariance and variance."""
        df = tf.reshape(
            self.df,
            tf.concat([
                tf.shape(self.df),
                tf.ones([statistic_ndims], dtype=tf.int32)
            ], -1))
        # We need to put the tf.where inside the outer tf1.where to ensure we never
        # hit a NaN in the gradient.
        denom = tf.where(df > 2., df - 2., tf.ones_like(df))
        statistic = statistic * df_factor_fn(df / denom)
        # When 1 < df <= 2, stddev/variance are infinite.
        result_where_defined = tf.where(
            df > 2., statistic,
            dtype_util.as_numpy_dtype(self.dtype)(np.inf))

        if self.allow_nan_stats:
            return tf.where(df > 1., result_where_defined,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        df,
                        message='{} not defined for components of df <= 1.'.
                        format(statistic_name.capitalize())),
            ]):
                return tf.identity(result_where_defined)
Exemplo n.º 6
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
    """Helper to validate block sizes."""
    block_sizes_shape = block_sizes.shape
    if tensorshape_util.is_fully_defined(block_sizes_shape):
        if (tensorshape_util.rank(block_sizes_shape) != 1
                or (tensorshape_util.num_elements(block_sizes_shape) !=
                    len(bijectors))):
            raise ValueError(
                '`block_sizes` must be `None`, or a vector of the same length as '
                '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
                'length {}'.format(block_sizes_shape, len(bijectors)))
        return block_sizes
    elif validate_args:
        message = (
            '`block_sizes` must be `None`, or a vector of the same length '
            'as `bijectors`.')
        with tf.control_dependencies([
                assert_util.assert_equal(tf.size(block_sizes),
                                         len(bijectors),
                                         message=message),
                assert_util.assert_equal(tf.rank(block_sizes), 1)
        ]):
            return tf.identity(block_sizes)
    else:
        return block_sizes
Exemplo n.º 7
0
def assert_finite(x, data=None, summarize=None, message=None, name=None):
  """Assert all elements of `x` are finite.

  Args:
    x:  Numeric `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_finite".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
  with tf.name_scope(name or 'assert_finite'):
    x_ = tf.get_static_value(x)
    if x_ is not None:
      if ~np.all(np.isfinite(x_)):
        raise ValueError(message)
      return x
    assertion = tf1.assert_equal(
        tf.math.is_finite(x), tf.ones_like(x, tf.bool),
        data=data, summarize=summarize, message=message)
    with tf.control_dependencies([assertion]):
      return tf.identity(x)
Exemplo n.º 8
0
  def _mode(self):
    scale_x_zeros = self.bijector.scale.matvec(
        tf.zeros(self._mode_mean_shape(), self.dtype))

    if self.loc is None:
      return scale_x_zeros

    return tf.identity(self.loc) + scale_x_zeros
Exemplo n.º 9
0
 def _inverse_event_shape_tensor(self, output_shape_tensor):
     batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
     if self.validate_args:
         is_square_matrix = assert_util.assert_equal(
             n, output_shape_tensor[-2], message='Matrix must be square.')
         with tf.control_dependencies([is_square_matrix]):
             n = tf.identity(n)
     d = tf.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
     return tf.concat([batch_shape, [d]], axis=0)
Exemplo n.º 10
0
 def _inverse(self, y):
     x = tf.identity(y)
     if self.shift is not None:
         x = x - self.shift
     if self.scale is not None:
         x = x / self.scale
     if self.log_scale is not None:
         x = x * tf.exp(-self.log_scale)
     return x
Exemplo n.º 11
0
 def _forward(self, x):
     y = tf.identity(x)
     if self.scale is not None:
         y = y * self.scale
     if self.log_scale is not None:
         y = y * tf.exp(self.log_scale)
     if self.shift is not None:
         y = y + self.shift
     return y
Exemplo n.º 12
0
    def _mean(self):
        shape = tensorshape_util.concatenate(self.batch_shape,
                                             self.event_shape)
        has_static_shape = tensorshape_util.is_fully_defined(shape)
        if not has_static_shape:
            shape = tf.concat([
                self.batch_shape_tensor(),
                self.event_shape_tensor(),
            ], 0)

        if self.loc is None:
            return tf.zeros(shape, self.dtype)

        if has_static_shape and shape == self.loc.shape:
            return tf.identity(self.loc)

        # Add dummy tensor of zeros to broadcast.  This is only necessary if shape
        # != self.loc.shape, but we could not determine if this is the case.
        return tf.identity(self.loc) + tf.zeros(shape, self.dtype)
Exemplo n.º 13
0
 def _inverse_event_shape_tensor(self, output_shape):
     if self.validate_args:
         # It is not possible for a negative shape so we need only check <= 1.
         is_greater_one = assert_util.assert_greater(
             output_shape[-1],
             1,
             message="Need last dimension greater than 1.")
         with tf.control_dependencies([is_greater_one]):
             output_shape = tf.identity(output_shape)
     return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
Exemplo n.º 14
0
  def _mean(self):
    # Let
    #   W = (w1,...,wk), with wj ~ iid Exponential(0, 1).
    # Then this distribution is
    #   X = loc + LW,
    # and then E[X] = loc + L1, where 1 is the vector of ones.
    scale_x_ones = self.bijector.scale.matvec(
        tf.ones(self._mode_mean_shape(), self.dtype))

    if self.loc is None:
      return scale_x_ones

    return tf.identity(self.loc) + scale_x_ones
Exemplo n.º 15
0
 def _z(self, x, scale, concentration):
     loc = tf.convert_to_tensor(self.loc)
     if self.validate_args:
         valid = (x >= loc) & ((concentration >= 0) |
                               (x <= loc - scale / concentration))
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     valid,
                     True,
                     message='`x` outside distribution\'s support.')
         ]):
             x = tf.identity(x)
     return (x - loc) / scale
 def _forward(self, x):
     map_values = tf.convert_to_tensor(self.map_values)
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     (0 <= x) & (x < tf.size(map_values)),
                     True,
                     message='indices out of bound')
         ]):
             x = tf.identity(x)
     # If we want batch dims in self.map_values, we can (after broadcasting),
     # use:
     # tf.gather(self.map_values, x, batch_dims=-1, axis=-1)
     return tf.gather(map_values, indices=x)
    def _mean(self):
        mean = tf.broadcast_to(self.loc, self._sample_shape())

        if self.allow_nan_stats:
            return tf.where(self.df[..., tf.newaxis] > 1., mean,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        self.df,
                        message='Mean not defined for components of df <= 1.'),
            ]):
                return tf.identity(mean)
Exemplo n.º 18
0
 def _maybe_assert_valid_sample(self, samples):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return samples
   with tf.control_dependencies([
       assert_util.assert_near(
           1.,
           tf.linalg.norm(samples, axis=-1),
           message='samples must be unit length'),
       assert_util.assert_equal(
           tf.shape(samples)[-1:],
           self.event_shape_tensor(),
           message=('samples must have innermost dimension matching that of '
                    '`self.mean_direction`')),
   ]):
     return tf.identity(samples)
Exemplo n.º 19
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
Exemplo n.º 20
0
    def _mean(self):
        concentration = tf.convert_to_tensor(self.concentration)
        mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
        mixing_rate = tf.convert_to_tensor(self.mixing_rate)

        mean = concentration * mixing_rate / (mixing_concentration - 1.)
        if self.allow_nan_stats:
            return tf.where(mixing_concentration > 1., mean,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.ones([], self.dtype),
                        mixing_concentration,
                        message=
                        'mean undefined when `mixing_concentration` <= 1'),
            ]):
                return tf.identity(mean)
Exemplo n.º 21
0
def _prevent_2nd_derivative(x):
    """Disables computation of the second derivatives for a tensor.

  NB: you need to apply a non-identity function to the output tensor for the
  exception to be raised.

  Arguments:
    x: A tensor.

  Returns:
    A tensor with the same value and the same derivative as x, but that raises
    LookupError when trying to compute the second derivatives.
  """
    def grad(dy):
        return array_ops.prevent_gradient(
            dy, message="Second derivative is not implemented.")

    return tf.identity(x), grad
def maybe_check_wont_broadcast(flat_xs, validate_args):
  """Verifies that `parts` don't broadcast."""
  flat_xs = tuple(flat_xs)  # So we can receive generators.
  if not validate_args:
    # Note: we don't try static validation because it is theoretically
    # possible that a user wants to take advantage of broadcasting.
    # Only when `validate_args` is `True` do we enforce the validation.
    return flat_xs
  msg = 'Broadcasting probably indicates an error in model specification.'
  s = tuple(prefer_static.shape(x) for x in flat_xs)
  if all(prefer_static.is_numpy(s_) for s_ in s):
    if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
      raise ValueError(msg)
    return flat_xs
  assertions = [assert_util.assert_equal(a, b, message=msg)
                for a, b in zip(s[1:], s[:-1])]
  with tf.control_dependencies(assertions):
    return tuple(tf.identity(x) for x in flat_xs)
Exemplo n.º 23
0
 def _mode(self):
   concentration = tf.convert_to_tensor(self.concentration)
   k = tf.cast(tf.shape(concentration)[-1], self.dtype)
   total_concentration = tf.reduce_sum(concentration, axis=-1)
   mode = (concentration - 1.) / (total_concentration[..., tf.newaxis] - k)
   if self.allow_nan_stats:
     return tf.where(
         tf.reduce_all(concentration > 1., axis=-1, keepdims=True),
         mode,
         dtype_util.as_numpy_dtype(self.dtype)(np.nan))
   assertions = [
       assert_util.assert_less(
           tf.ones([], self.dtype),
           concentration,
           message='Mode undefined when any concentration <= 1')
   ]
   with tf.control_dependencies(assertions):
     return tf.identity(mode)
Exemplo n.º 24
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        seed = SeedStream('JointDistributionCoroutine', seed)
        gen = self._model()
        index = 0
        d = next(gen)
        if not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                if (value is not None and len(value) > index
                        and value[index] is not None):
                    seed()
                    next_value = value[index]
                else:
                    next_value = actual_distribution.sample(
                        sample_shape=sample_shape
                        if isinstance(d, self.Root) else (),
                        seed=seed())

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(tf.identity(next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out
Exemplo n.º 25
0
    def _variance(self):
        concentration = tf.convert_to_tensor(self.concentration)
        mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
        mixing_rate = tf.convert_to_tensor(self.mixing_rate)

        variance = (tf.square(concentration * mixing_rate /
                              (mixing_concentration - 1.)) /
                    (mixing_concentration - 2.))
        if self.allow_nan_stats:
            return tf.where(mixing_concentration > 2., variance,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.ones([], self.dtype) * 2.,
                        mixing_concentration,
                        message=
                        'variance undefined when `mixing_concentration` <= 2')
            ]):
                return tf.identity(variance)
Exemplo n.º 26
0
 def _validate_correlationness(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return x
     checks = [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
     with tf.control_dependencies(checks):
         return tf.identity(x)
Exemplo n.º 27
0
 def _validate_dimension(self, x):
     x = tf.convert_to_tensor(x, name='x')
     if tensorshape_util.is_fully_defined(x.shape[-2:]):
         if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(
                 x.shape)[-1] == self.dimension):
             pass
         else:
             raise ValueError(
                 'Input dimension mismatch: expected [..., {}, {}], got {}'.
                 format(self.dimension, self.dimension,
                        tensorshape_util.dims(x.shape)))
     elif self.validate_args:
         msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
             self.dimension, self.dimension, tf.shape(x))
         with tf.control_dependencies([
                 assert_util.assert_equal(tf.shape(x)[-2],
                                          self.dimension,
                                          message=msg),
                 assert_util.assert_equal(tf.shape(x)[-1],
                                          self.dimension,
                                          message=msg)
         ]):
             x = tf.identity(x)
     return x
Exemplo n.º 28
0
    def _prob(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        if self.validate_args:
            with tf.control_dependencies([
                    assert_util.assert_greater_equal(x, low),
                    assert_util.assert_less_equal(x, high)
            ]):
                x = tf.identity(x)

        interval_length = high - low
        # This is the pdf function when a low <= high <= x. This looks like
        # a triangle, so we have to treat each line segment separately.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # Line segment from (low, 0) to (peak, 2 / (high - low)).
            2. * (x - low) / (interval_length * (peak - low)),
            # Line segment from (peak, 2 / (high - low)) to (high, 0).
            2. * (high - x) / (interval_length * (high - peak)))

        return tf.where((x < low) | (x > high), tf.zeros_like(x),
                        result_inside_interval)
Exemplo n.º 29
0
def kl_divergence(distribution_a,
                  distribution_b,
                  allow_nan_stats=True,
                  name=None):
    """Get the KL-divergence KL(distribution_a || distribution_b).

  If there is no KL method registered specifically for `type(distribution_a)`
  and `type(distribution_b)`, then the class hierarchies of these types are
  searched.

  If one KL method is registered between any pairs of classes in these two
  parent hierarchies, it is used.

  If more than one such registered method exists, the method whose registered
  classes have the shortest sum MRO paths to the input types is used.

  If more than one such shortest path exists, the first method
  identified in the search is used (favoring a shorter MRO distance to
  `type(distribution_a)`).

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    allow_nan_stats: Python `bool`, default `True`. When `True`,
      statistics (e.g., mean, mode, variance) use the value "`NaN`" to
      indicate the result is undefined. When `False`, an exception is raised
      if one or more of the statistic's batch members are undefined.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    A Tensor with the batchwise KL-divergence between `distribution_a`
    and `distribution_b`.

  Raises:
    NotImplementedError: If no KL method is defined for distribution types
      of `distribution_a` and `distribution_b`.
  """
    kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
    if kl_fn is None:
        raise NotImplementedError(
            "No KL(distribution_a || distribution_b) registered for distribution_a "
            "type {} and distribution_b type {}".format(
                type(distribution_a).__name__,
                type(distribution_b).__name__))

    name = name or "KullbackLeibler"
    with tf.name_scope(name):
        # pylint: disable=protected-access
        with distribution_a._name_and_control_scope(name + "_a"):
            with distribution_b._name_and_control_scope(name + "_b"):
                kl_t = kl_fn(distribution_a, distribution_b, name=name)
                if allow_nan_stats:
                    return kl_t

        # Check KL for NaNs
        kl_t = tf.identity(kl_t, name="kl")

        with tf.control_dependencies([
                tf.Assert(
                    tf.logical_not(tf.reduce_any(tf.math.is_nan(kl_t))),
                    [("KL calculation between {} and {} returned NaN values "
                      "(and was called with allow_nan_stats=False). Values:".
                      format(distribution_a.name, distribution_b.name)), kl_t])
        ]):
            return tf.identity(kl_t, name="checked_kl")
Exemplo n.º 30
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions