def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Ejemplo n.º 2
0
 def _inverse(self, y):
     with tf.control_dependencies(self._maybe_assert_valid_y(y)):
         if self.power == 0.:
             return tf.math.log(y)
         # If large y accuracy is an issue, consider using:
         # (y**self.power - 1.) / self.power when y >> 1.
         return tf.math.expm1(tf.math.log(y) * self.power) / self.power
Ejemplo n.º 3
0
def vector_size_to_square_matrix_size(d, validate_args, name=None):
    """Convert a vector size to a matrix size."""
    if isinstance(d, (float, int, np.generic, np.ndarray)):
        n = (-1 + np.sqrt(1 + 8 * d)) / 2.
        if float(int(n)) != n:
            raise ValueError(
                'Vector length {} is not a triangular number.'.format(d))
        return int(n)
    else:
        with tf.name_scope(name
                           or 'vector_size_to_square_matrix_size') as name:
            n = (-1. + tf.sqrt(1 + 8. * tf.cast(d, dtype=tf.float32))) / 2.
            if validate_args:
                with tf.control_dependencies([
                        assert_util.assert_equal(
                            tf.cast(tf.cast(n, dtype=tf.int32),
                                    dtype=tf.float32),
                            n,
                            data=[
                                'Vector length is not a triangular number: ', d
                            ],
                            message='Vector length is not a triangular number')
                ]):
                    n = tf.identity(n)
            return tf.cast(n, d.dtype)
Ejemplo n.º 4
0
 def _forward(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_x(x)):
         if self.power == 0.:
             return tf.exp(x)
         # If large x accuracy is an issue, consider using:
         # (1. + x * self.power)**(1. / self.power) when x >> 1.
         return tf.exp(tf.math.log1p(x * self.power) / self.power)
 def _batch_shape_tensor(self):
     with tf.control_dependencies(self._runtime_assertions):
         return tf.broadcast_dynamic_shape(
             self._initial_distribution.batch_shape_tensor(),
             tf.broadcast_dynamic_shape(
                 self._transition_distribution.batch_shape_tensor()[:-1],
                 self._observation_distribution.batch_shape_tensor()[:-1]))
Ejemplo n.º 6
0
 def _inverse_log_det_jacobian(self, y):
     # If event_ndims = 2,
     # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1),
     # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0].
     with tf.control_dependencies(self._assertions(y)):
         zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype))
         return zero, zero
Ejemplo n.º 7
0
 def _call_and_reshape_output(self,
                              fn,
                              event_shape_list=None,
                              static_event_shape_list=None,
                              extra_kwargs=None):
     """Calls `fn` and appropriately reshapes its output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn`, `event_shape_list`, `static_event_shape_list` and/or
     # `extra_kwargs` as keys.
     with tf.control_dependencies(self._runtime_assertions):
         if event_shape_list is None:
             event_shape_list = [self._event_shape_tensor()]
         if static_event_shape_list is None:
             static_event_shape_list = [self.event_shape]
         new_shape = tf.concat([self._batch_shape_unexpanded] +
                               event_shape_list,
                               axis=0)
         result = tf.reshape(
             fn(**extra_kwargs) if extra_kwargs else fn(), new_shape)
         if (tensorshape_util.rank(self.batch_shape) is not None
                 and tensorshape_util.rank(self.event_shape) is not None):
             event_shape = tf.TensorShape([])
             for rss in static_event_shape_list:
                 event_shape = tensorshape_util.concatenate(
                     event_shape, rss)
             static_shape = tensorshape_util.concatenate(
                 self.batch_shape, event_shape)
             tensorshape_util.set_shape(result, static_shape)
         return result
    def _std_var_helper(self, statistic, statistic_name, statistic_ndims,
                        df_factor_fn):
        """Helper to compute stddev, covariance and variance."""
        df = tf.reshape(
            self.df,
            tf.concat([
                tf.shape(self.df),
                tf.ones([statistic_ndims], dtype=tf.int32)
            ], -1))
        # We need to put the tf.where inside the outer tf1.where to ensure we never
        # hit a NaN in the gradient.
        denom = tf.where(df > 2., df - 2., tf.ones_like(df))
        statistic = statistic * df_factor_fn(df / denom)
        # When 1 < df <= 2, stddev/variance are infinite.
        result_where_defined = tf.where(
            df > 2., statistic,
            dtype_util.as_numpy_dtype(self.dtype)(np.inf))

        if self.allow_nan_stats:
            return tf.where(df > 1., result_where_defined,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        df,
                        message='{} not defined for components of df <= 1.'.
                        format(statistic_name.capitalize())),
            ]):
                return tf.identity(result_where_defined)
Ejemplo n.º 9
0
def matrix_rank(a, tol=None, validate_args=False, name=None):
    """Compute the matrix rank; the number of non-zero SVD singular values.

  Arguments:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    tol: Threshold below which the singular value is counted as 'zero'.
      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'matrix_rank'.

  Returns:
    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
      singular values.
  """
    with tf.name_scope(name or 'matrix_rank'):
        a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)
        s = tf.linalg.svd(a, compute_uv=False)
        if tol is None:
            if tensorshape_util.is_fully_defined(a.shape[-2:]):
                m = np.max(a.shape[-2:].as_list())
            else:
                m = tf.reduce_max(tf.shape(a)[-2:])
            eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps
            tol = (eps * tf.cast(m, a.dtype) *
                   tf.reduce_max(s, axis=-1, keepdims=True))
        return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
Ejemplo n.º 10
0
 def _entropy(self):
     logits, probs = self._logits_and_probs_no_checks()
     if not self.validate_args:
         assertions = []
     else:
         assertions = [
             assert_util.assert_less(
                 probs,
                 dtype_util.as_numpy_dtype(self.dtype)(1.),
                 message=
                 'Entropy is undefined when logits = inf or probs = 1.')
         ]
     with tf.control_dependencies(assertions):
         # Claim: entropy(p) = softplus(s)/p - s
         # where s=logits and p=probs.
         #
         # Proof:
         #
         # entropy(p)
         # := -[(1-p)log(1-p) + plog(p)]/p
         # = -[log(1-p) + plog(p/(1-p))]/p
         # = -[-softplus(s) + ps]/p
         # = softplus(s)/p - s
         #
         # since,
         # log[1-sigmoid(s)]
         # = log[1/(1+exp(s)]
         # = -log[1+exp(s)]
         # = -softplus(s)
         #
         # using the fact that,
         # 1-sigmoid(s) = sigmoid(-s) = 1/(1+exp(s))
         return tf.math.softplus(logits) / probs - logits
Ejemplo n.º 11
0
 def _inverse_log_det_jacobian(self, y):
     with tf.control_dependencies(self._maybe_assert_valid_y(y)):
         scale = tf.convert_to_tensor(self.scale)
         concentration = tf.convert_to_tensor(self.concentration)
         return (-tf.math.log1p(-y) +
                 tf.math.xlogy(1 / concentration - 1, -tf.math.log1p(-y)) +
                 tf.math.log(scale / concentration))
Ejemplo n.º 12
0
 def _log_prob(self, x):
   with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
     concentration = tf.convert_to_tensor(self.concentration)
     loc = tf.convert_to_tensor(self.loc)
     return (0.5 * (tf.math.log(concentration) - np.log(2. * np.pi) -
                    3. * tf.math.log(x)) + (-concentration * (x - loc)**2.) /
             (2. * loc**2. * x))
Ejemplo n.º 13
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
    """Helper to validate block sizes."""
    block_sizes_shape = block_sizes.shape
    if tensorshape_util.is_fully_defined(block_sizes_shape):
        if (tensorshape_util.rank(block_sizes_shape) != 1
                or (tensorshape_util.num_elements(block_sizes_shape) !=
                    len(bijectors))):
            raise ValueError(
                '`block_sizes` must be `None`, or a vector of the same length as '
                '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
                'length {}'.format(block_sizes_shape, len(bijectors)))
        return block_sizes
    elif validate_args:
        message = (
            '`block_sizes` must be `None`, or a vector of the same length '
            'as `bijectors`.')
        with tf.control_dependencies([
                assert_util.assert_equal(tf.size(block_sizes),
                                         len(bijectors),
                                         message=message),
                assert_util.assert_equal(tf.rank(block_sizes), 1)
        ]):
            return tf.identity(block_sizes)
    else:
        return block_sizes
Ejemplo n.º 14
0
def assert_finite(x, data=None, summarize=None, message=None, name=None):
  """Assert all elements of `x` are finite.

  Args:
    x:  Numeric `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_finite".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
  with tf.name_scope(name or 'assert_finite'):
    x_ = tf.get_static_value(x)
    if x_ is not None:
      if ~np.all(np.isfinite(x_)):
        raise ValueError(message)
      return x
    assertion = tf1.assert_equal(
        tf.math.is_finite(x), tf.ones_like(x, tf.bool),
        data=data, summarize=summarize, message=message)
    with tf.control_dependencies([assertion]):
      return tf.identity(x)
Ejemplo n.º 15
0
 def _log_prob(self, counts):
     with tf.control_dependencies(self._maybe_assert_valid_sample(counts)):
         log_p = (tf.math.log(self._probs) if self._logits is None else
                  tf.math.log_softmax(self._logits))
         k = tf.convert_to_tensor(self.total_count)
         return (tf.reduce_sum(counts * log_p, axis=-1) +  # log_unnorm_prob
                 tfp_math.log_combinations(k, counts))  # -log_normalization
Ejemplo n.º 16
0
 def _call_reshape_input_output(self, fn, x, extra_kwargs=None):
     """Calls `fn`, appropriately reshaping its input `x` and output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn` and/or `x` as a key.
     with tf.control_dependencies(self._runtime_assertions +
                                  self._validate_sample_arg(x)):
         sample_shape, static_sample_shape = self._sample_shape(x)
         old_shape = tf.concat([
             sample_shape,
             self.distribution.batch_shape_tensor(),
             self.event_shape_tensor(),
         ],
                               axis=0)
         x_reshape = tf.reshape(x, old_shape)
         result = fn(x_reshape, **
                     extra_kwargs) if extra_kwargs else fn(x_reshape)
         new_shape = tf.concat([
             sample_shape,
             self._batch_shape_unexpanded,
         ],
                               axis=0)
         result = tf.reshape(result, new_shape)
         if (tensorshape_util.rank(static_sample_shape) is not None
                 and tensorshape_util.rank(self.batch_shape) is not None):
             new_shape = tensorshape_util.concatenate(
                 static_sample_shape, self.batch_shape)
             tensorshape_util.set_shape(result, new_shape)
         return result
Ejemplo n.º 17
0
 def _mean(self):
     with tf.control_dependencies(self._runtime_assertions):
         probs = distribution_utils.pad_mixture_dimensions(
             self.mixture_distribution.probs_parameter(), self,
             self.mixture_distribution, self._event_ndims)  # [B, k, [1]*e]
         return tf.reduce_sum(probs * self.components_distribution.mean(),
                              axis=-1 - self._event_ndims)  # [B, E]
Ejemplo n.º 18
0
 def _forward_log_det_jacobian(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_x(x)):
         scale = tf.convert_to_tensor(self.scale)
         concentration = tf.convert_to_tensor(self.concentration)
         return (-(x / scale)**concentration +
                 tf.math.xlogy(concentration - 1, x) +
                 tf.math.log(concentration) -
                 concentration * tf.math.log(scale))
Ejemplo n.º 19
0
 def _cdf(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         concentration1 = tf.convert_to_tensor(self.concentration1)
         concentration0 = tf.convert_to_tensor(self.concentration0)
         shape = self._batch_shape_tensor(concentration1, concentration0)
         concentration1 = tf.broadcast_to(concentration1, shape)
         concentration0 = tf.broadcast_to(concentration0, shape)
         return tf.math.betainc(concentration1, concentration0, x)
Ejemplo n.º 20
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
 def _forward(self, x):
     with tf.control_dependencies(self._assertions(x)):
         shape = tf.shape(x)
         return tf.linalg.triangular_solve(x,
                                           tf.eye(shape[-1],
                                                  batch_shape=shape[:-2],
                                                  dtype=x.dtype),
                                           lower=True)
    def _variance(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, 1, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size 1 num_states observation_event_size
            flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps 1 observation_event_size

            variances = self._observation_distribution.variance()
            variances = tf.broadcast_to(variances, means_shape)
            # variances :: batch_shape num_states observation_event_shape
            flat_variances = tf.reshape(variances, flat_means_shape)
            # flat_variances :: batch_size 1 num_states observation_event_size

            # For a mixture of n distributions with mixture probabilities
            # p[i], and where the individual distributions have means and
            # variances given by mean[i] and var[i], the variance of
            # the mixture is given by:
            #
            # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

            flat_variance = tf.einsum("ijk,jikl->jil", flat_probs,
                                      (flat_means - flat_mean)**2 +
                                      flat_variances)
            # flat_variance :: batch_size num_steps observation_event_size

            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)

            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_variance, unflat_mean_shape)
 def _forward_log_det_jacobian(self, x):
     # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
     # sections leading up to it, for context) in
     # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
     with tf.control_dependencies(self._assertions(x)):
         matrix_dim = tf.cast(
             tf.shape(x)[-1], dtype_util.base_dtype(x.dtype))
         return -(matrix_dim + 1) * tf.reduce_sum(
             tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
Ejemplo n.º 24
0
 def _forward(self, x):
     y = x
     if self.scale is not None:
         with tf.control_dependencies(self._maybe_collect_assertions(
         ) if self.validate_args else []):
             y = self.scale.matvec(y, adjoint=self.adjoint)
     if self.shift is not None:
         y = y + self.shift
     return y
Ejemplo n.º 25
0
 def _log_prob(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         probs = self._probs_parameter_no_checks()
         if not self.validate_args:
             # For consistency with cdf, we take the floor.
             x = tf.floor(x)
         safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs),
                                probs)
         return x * tf.math.log1p(-safe_domain) + tf.math.log(probs)
Ejemplo n.º 26
0
 def _cdf(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         probs = self._probs_parameter_no_checks()
         if not self.validate_args:
             # Whether or not x is integer-form, the following is well-defined.
             # However, scipy takes the floor, so we do too.
             x = tf.floor(x)
         return tf.where(x < 0., tf.zeros_like(x), -tf.math.expm1(
             (1. + x) * tf.math.log1p(-probs)))
Ejemplo n.º 27
0
 def _log_prob(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         concentration = tf.convert_to_tensor(self.concentration)
         scale = tf.convert_to_tensor(self.scale)
         unnormalized_prob = -(1. +
                               concentration) * tf.math.log(x) - scale / x
         normalization = (tf.math.lgamma(concentration) -
                          concentration * tf.math.log(scale))
         return unnormalized_prob - normalization
Ejemplo n.º 28
0
 def _log_prob(self, x):
     with tf.control_dependencies(self._runtime_assertions):
         x = self._pad_sample_dims(x)
         log_prob_x = self.components_distribution.log_prob(x)  # [S, B, k]
         log_mix_prob = tf.math.log_softmax(
             self.mixture_distribution.logits_parameter(),
             axis=-1)  # [B, k]
         return tf.reduce_logsumexp(log_prob_x + log_mix_prob,
                                    axis=-1)  # [S, B]
Ejemplo n.º 29
0
 def _inverse_event_shape_tensor(self, output_shape_tensor):
     batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
     if self.validate_args:
         is_square_matrix = assert_util.assert_equal(
             n, output_shape_tensor[-2], message='Matrix must be square.')
         with tf.control_dependencies([is_square_matrix]):
             n = tf.identity(n)
     d = tf.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
     return tf.concat([batch_shape, [d]], axis=0)
Ejemplo n.º 30
0
 def _log_prob(self, x):
     concentration = 0.5 * self.df
     rate = tf.convert_to_tensor(0.5, dtype=self.dtype)
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         log_unnormalized_prob = tf.math.xlogy(concentration - 1.,
                                               x) - rate * x
         log_normalization = (tf.math.lgamma(concentration) -
                              concentration * tf.math.log(rate))
         return log_unnormalized_prob - log_normalization