Beispiel #1
0
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
 def _batch_shape_tensor(self):
     shape_list = [
         self.scale.batch_shape_tensor(),
         tf.shape(self.df),
         tf.shape(self.loc)[:-1]
     ]
     return functools.reduce(tf.broadcast_dynamic_shape, shape_list)
 def _assertions(self, t):
   if self.validate_args:
     return []
   is_matrix = assert_util.assert_rank_at_least(t, 2)
   is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1])
   is_positive_definite = assert_util.assert_positive(
       tf.linalg.diag_part(t), message="Input must be positive definite.")
   return [is_matrix, is_square, is_positive_definite]
Beispiel #4
0
 def _batch_shape_tensor(self, concentration=None, total_count=None):
     if concentration is None:
         concentration = tf.convert_to_tensor(self._concentration)
     if total_count is None:
         total_count = tf.convert_to_tensor(self._total_count)
     return tf.broadcast_dynamic_shape(
         tf.shape(total_count[..., tf.newaxis]),
         tf.shape(concentration))[:-1]
 def validate_equal_last_dim(tensor_a, tensor_b, message):
   event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
   event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
   if event_size_a is not None and event_size_b is not None:
     if event_size_a != event_size_b:
       raise ValueError(message)
   elif validate_args:
     return assert_util.assert_equal(
         tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message)
Beispiel #6
0
def matrix_rank(a, tol=None, validate_args=False, name=None):
    """Compute the matrix rank; the number of non-zero SVD singular values.

  Arguments:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    tol: Threshold below which the singular value is counted as 'zero'.
      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'matrix_rank'.

  Returns:
    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
      singular values.
  """
    with tf.name_scope(name or 'matrix_rank'):
        a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)
        s = tf.linalg.svd(a, compute_uv=False)
        if tol is None:
            if tensorshape_util.is_fully_defined(a.shape[-2:]):
                m = np.max(a.shape[-2:].as_list())
            else:
                m = tf.reduce_max(tf.shape(a)[-2:])
            eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps
            tol = (eps * tf.cast(m, a.dtype) *
                   tf.reduce_max(s, axis=-1, keepdims=True))
        return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
Beispiel #7
0
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
Beispiel #8
0
def _get_shape(x, out_type=tf.int32):
    # Return the shape of a Tensor or a SparseTensor as an np.array if its shape
    # is known statically. Otherwise return a Tensor representing the shape.
    if tensorshape_util.is_fully_defined(x.shape):
        return np.array(tensorshape_util.as_list(x.shape),
                        dtype=dtype_util.as_numpy_dtype(out_type))
    return tf.shape(x, out_type=out_type)
  def _make_columnar(self, x):
    """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
    if tensorshape_util.rank(x.shape) is not None:
      if tensorshape_util.rank(x.shape) == 1:
        x = x[tf.newaxis, :]
      return x
    shape = tf.shape(x)
    maybe_expanded_shape = tf.concat([
        shape[:-1],
        distribution_util.pick_vector(
            tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)),
        shape[-1:],
    ], 0)
    return tf.reshape(x, maybe_expanded_shape)
    def _call_sample_n(self, sample_shape, seed, name, **kwargs):
        # We override `_call_sample_n` rather than `_sample_n` so we can ensure that
        # the result of `self.bijector.forward` is not modified (and thus caching
        # works).
        with self._name_and_control_scope(name):
            sample_shape = tf.convert_to_tensor(sample_shape,
                                                dtype=tf.int32,
                                                name="sample_shape")
            sample_shape, n = self._expand_sample_shape_to_vector(
                sample_shape, "sample_shape")

            distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(
                kwargs)

            # First, generate samples. We will possibly generate extra samples in the
            # event that we need to reinterpret the samples as part of the
            # event_shape.
            x = self._sample_n(n, seed, **distribution_kwargs)

            # Next, we reshape `x` into its final form. We do this prior to the call
            # to the bijector to ensure that the bijector caching works.
            batch_event_shape = tf.shape(x)[1:]
            final_shape = tf.concat([sample_shape, batch_event_shape], 0)
            x = tf.reshape(x, final_shape)

            # Finally, we apply the bijector's forward transformation. For caching to
            # work, it is imperative that this is the last modification to the
            # returned result.
            y = self.bijector.forward(x, **bijector_kwargs)
            y = self._set_sample_static_shape(y, sample_shape)

            return y
Beispiel #11
0
def _broadcast_event_and_samples(event, samples, event_ndims):
    """Broadcasts the event or samples."""
    # This is the shape of self.samples, without the samples axis, i.e. the shape
    # of the result of a call to dist.sample(). This way we can broadcast it with
    # event to get a properly-sized event, then add the singleton dim back at
    # -event_ndims - 1.
    samples_shape = tf.concat([
        tf.shape(samples)[:-event_ndims - 1],
        tf.shape(samples)[tf.rank(samples) - event_ndims:]
    ],
                              axis=0)
    event = event * tf.ones(samples_shape, dtype=event.dtype)
    event = tf.expand_dims(event, axis=-event_ndims - 1)
    samples = samples * tf.ones_like(event, dtype=samples.dtype)

    return event, samples
 def _sample_n(self, n, seed=None):
     n_draws = tf.cast(self.total_count, dtype=tf.int32)
     logits = self._logits_parameter_no_checks()
     k = tf.compat.dimension_value(logits.shape[-1])
     if k is None:
         k = tf.shape(logits)[-1]
     return draw_sample(n, k, logits, n_draws, self.dtype, seed)
Beispiel #13
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=tf.pad(
                        tf.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
def maybe_check_quadrature_param(param, name, validate_args):
    """Helper which checks validity of `loc` and `scale` init args."""
    with tf.name_scope("check_" + name):
        assertions = []
        if tensorshape_util.rank(param.shape) is not None:
            if tensorshape_util.rank(param.shape) == 0:
                raise ValueError("Mixing params must be a (batch of) vector; "
                                 "{}.rank={} is not at least one.".format(
                                     name, tensorshape_util.rank(param.shape)))
        elif validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(
                    param,
                    1,
                    message=("Mixing params must be a (batch of) vector; "
                             "{}.rank is not at least one.".format(name))))

        # TODO(jvdillon): Remove once we support k-mixtures.
        if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None:
            if tf.compat.dimension_value(param.shape[-1]) != 1:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "{}.shape[-1]={} is not 1.".format(
                        name, tf.compat.dimension_value(param.shape[-1])))
        elif validate_args:
            assertions.append(
                assert_util.assert_equal(
                    tf.shape(param)[-1],
                    1,
                    message=("Currently only bimixtures are supported; "
                             "{}.shape[-1] is not 1.".format(name))))

        if assertions:
            return distribution_util.with_dependencies(assertions, param)
        return param
Beispiel #15
0
 def _forward(self, x):
     output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
         tf.shape(x), self._event_shape_in, self._event_shape_out,
         self.validate_args)
     y = tf.reshape(x, output_shape)
     tensorshape_util.set_shape(y, output_tensorshape)
     return y
    def _std_var_helper(self, statistic, statistic_name, statistic_ndims,
                        df_factor_fn):
        """Helper to compute stddev, covariance and variance."""
        df = tf.reshape(
            self.df,
            tf.concat([
                tf.shape(self.df),
                tf.ones([statistic_ndims], dtype=tf.int32)
            ], -1))
        # We need to put the tf.where inside the outer tf1.where to ensure we never
        # hit a NaN in the gradient.
        denom = tf.where(df > 2., df - 2., tf.ones_like(df))
        statistic = statistic * df_factor_fn(df / denom)
        # When 1 < df <= 2, stddev/variance are infinite.
        result_where_defined = tf.where(
            df > 2., statistic,
            dtype_util.as_numpy_dtype(self.dtype)(np.inf))

        if self.allow_nan_stats:
            return tf.where(df > 1., result_where_defined,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        df,
                        message='{} not defined for components of df <= 1.'.
                        format(statistic_name.capitalize())),
            ]):
                return tf.identity(result_where_defined)
Beispiel #17
0
 def _inverse(self, y):
     output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
         tf.shape(y), self._event_shape_out, self._event_shape_in,
         self.validate_args)
     x = tf.reshape(y, output_shape)
     tensorshape_util.set_shape(x, output_tensorshape)
     return x
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
 def _expand_mix_distribution_probs(self):
     p = self.mixture_distribution.probs_parameter()  # [B, deg]
     deg = tf.compat.dimension_value(
         tensorshape_util.with_rank_at_least(p.shape, 1)[-1])
     if deg is None:
         deg = tf.shape(p)[-1]
     event_ndims = tensorshape_util.rank(self.event_shape)
     if event_ndims is None:
         event_ndims = tf.shape(self.event_shape_tensor())[0]
     expand_shape = tf.concat([
         self.mixture_distribution.batch_shape_tensor(),
         tf.ones([event_ndims], dtype=tf.int32),
         [deg],
     ],
                              axis=0)
     return tf.reshape(p, shape=expand_shape)
Beispiel #20
0
    def _mode(self, samples=None):
        # Samples count can vary by batch member. Use map_fn to compute mode for
        # each batch separately.
        def _get_mode(samples):
            # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
            count = gen_array_ops.unique_with_counts_v2(samples,
                                                        axis=[0]).count
            return tf.argmax(count)

        if samples is None:
            samples = tf.convert_to_tensor(self._samples)
        num_samples = self._compute_num_samples(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            flattened_samples = tf.reshape(samples, [-1, num_samples])
            mode_shape = self._batch_shape_tensor(samples)
        else:
            event_size = tf.reduce_prod(self._event_shape_tensor(samples))
            mode_shape = tf.concat([
                self._batch_shape_tensor(samples),
                self._event_shape_tensor(samples)
            ],
                                   axis=0)
            flattened_samples = tf.reshape(samples,
                                           [-1, num_samples, event_size])

        indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
        full_indices = tf.stack(
            [tf.range(tf.shape(indices)[0]),
             tf.cast(indices, tf.int32)],
            axis=1)

        mode = tf.gather_nd(flattened_samples, full_indices)
        return tf.reshape(mode, mode_shape)
Beispiel #21
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
  """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
  # Extend param shape with ones on the left to match dist_batch_shape.
  param_shape = tf.shape(input=param)
  insert_ones = tf.ones(
      [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)],
      dtype=param_shape.dtype)
  new_param_shape = tf.concat([insert_ones, param_shape], axis=0)
  full_batch_param = tf.reshape(param, new_param_shape)
  param_slices = []
  # We separately track the batch axis from the parameter axis because we want
  # them to align for positive indexing, and be offset by param_event_ndims for
  # negative indexing.
  param_dim_idx = 0
  batch_dim_idx = 0
  for slc in slices:
    if slc is tf.newaxis:
      param_slices.append(slc)
      continue
    if slc is Ellipsis:
      if batch_dim_idx < 0:
        raise ValueError('Found multiple `...` in slices {}'.format(slices))
      param_slices.append(slc)
      # Switch over to negative indexing for the broadcast check.
      num_remaining_non_newaxis_slices = sum(
          [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]])
      batch_dim_idx = -num_remaining_non_newaxis_slices
      param_dim_idx = batch_dim_idx - param_event_ndims
      continue
    # Find the batch dimension sizes for both parameter and distribution.
    param_dim_size = new_param_shape[param_dim_idx]
    batch_dim_size = dist_batch_shape[batch_dim_idx]
    is_broadcast = batch_dim_size > param_dim_size
    # Slices are denoted by start:stop:step.
    if isinstance(slc, slice):
      start, stop, step = slc.start, slc.stop, slc.step
      if start is not None:
        start = tf.where(is_broadcast, 0, start)
      if stop is not None:
        stop = tf.where(is_broadcast, 1, stop)
      if step is not None:
        step = tf.where(is_broadcast, 1, step)
      param_slices.append(slice(start, stop, step))
    else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
      param_slices.append(tf.where(is_broadcast, 0, slc))
    param_dim_idx += 1
    batch_dim_idx += 1
  param_slices.extend([ALL_SLICE] * param_event_ndims)
  return full_batch_param.__getitem__(param_slices)
 def _forward(self, x):
     with tf.control_dependencies(self._assertions(x)):
         shape = tf.shape(x)
         return tf.linalg.triangular_solve(x,
                                           tf.eye(shape[-1],
                                                  batch_shape=shape[:-2],
                                                  dtype=x.dtype),
                                           lower=True)
Beispiel #23
0
 def _entropy(self):
   concentration = tf.convert_to_tensor(self.concentration)
   k = tf.cast(tf.shape(concentration)[-1], self.dtype)
   total_concentration = tf.reduce_sum(concentration, axis=-1)
   return (tf.math.lbeta(concentration) +
           ((total_concentration - k) * tf.math.digamma(total_concentration)) -
           tf.reduce_sum((concentration - 1.) * tf.math.digamma(concentration),
                         axis=-1))
Beispiel #24
0
 def _event_size(self, param=None):
     if param is None:
         param = self._logits if self._logits is not None else self._probs
     if param.shape is not None:
         event_size = tf.compat.dimension_value(param.shape[-1])
         if event_size is not None:
             return event_size
     return tf.shape(param)[-1]
 def _forward_log_det_jacobian(self, x):
     # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
     # sections leading up to it, for context) in
     # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
     with tf.control_dependencies(self._assertions(x)):
         matrix_dim = tf.cast(
             tf.shape(x)[-1], dtype_util.base_dtype(x.dtype))
         return -(matrix_dim + 1) * tf.reduce_sum(
             tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
 def _num_categories(self, x=None):
     """Scalar `int32` tensor: the number of categories."""
     with tf.name_scope('num_categories'):
         if x is None:
             x = self._probs if self._logits is None else self._logits
         num_categories = tf.compat.dimension_value(x.shape[-1])
         if num_categories is not None:
             return num_categories
         return tf.shape(x)[-1]
 def _sample_n(self, n, seed=None):
     scale = tf.convert_to_tensor(self.scale)
     shape = tf.concat([[n], tf.shape(scale)], 0)
     sampled = tf.random.normal(shape=shape,
                                mean=0.,
                                stddev=1.,
                                dtype=self.dtype,
                                seed=seed)
     return tf.abs(sampled * scale)
def _shape(input, out_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
    if not hasattr(input, 'shape'):
        x = np.array(input)
        input = tf.convert_to_tensor(input) if x.dtype is np.object else x
    input_shape = tf.TensorShape(input.shape)
    if tensorshape_util.is_fully_defined(input.shape):
        return np.array(tensorshape_util.as_list(input_shape)).astype(
            _numpy_dtype(out_type))
    return tf.shape(input, out_type=out_type, name=name)
Beispiel #29
0
 def _cdf(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     broadcast_shape = tf.broadcast_dynamic_shape(
         tf.shape(x), self._batch_shape_tensor(low=low, high=high))
     zeros = tf.zeros(broadcast_shape, dtype=self.dtype)
     ones = tf.ones(broadcast_shape, dtype=self.dtype)
     result_if_not_big = tf.where(x < low, zeros, (x - low) /
                                  self._range(low=low, high=high))
     return tf.where(x >= high, ones, result_if_not_big)
Beispiel #30
0
 def _inverse(self, y):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = y.shape[-1]
     # z_k = y_k / (1 - sum_{i=1 to k-1} y_i)
     # x_k = logit(z_k) - log(1 / (N - k))
     offset = tf.math.log(
         tf.cast(tf.range(tf.shape(y)[-1] - 1, 0, delta=-1),
                 dtype=dtype_util.base_dtype(y.dtype)))
     z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
     return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset