def _value(self, dtype=None, name=None, as_ref=False):
     y = self.transform_fn(self.pretransformed_input)  # pylint: disable=not-callable
     if dtype_util.base_dtype(y.dtype) != self.dtype:
         raise TypeError(
             'Actual dtype ({}) does not match deferred dtype ({}).'.format(
                 dtype_util.name(dtype_util.base_dtype(y.dtype)),
                 dtype_util.name(self.dtype)))
     if not tensorshape_util.is_compatible_with(y.shape, self.shape):
         raise TypeError(
             'Actual shape ({}) is incompatible with deferred shape ({}).'.
             format(y.shape, self.shape))
     return tf.convert_to_tensor(y, dtype=dtype, name=name)
    def __init__(self,
                 shift=None,
                 scale=None,
                 adjoint=False,
                 validate_args=False,
                 name="affine_linear_operator"):
        """Instantiates the `AffineLinearOperator` bijector.

    Args:
      shift: Floating-point `Tensor`.
      scale:  Subclass of `LinearOperator`. Represents the (batch) positive
        definite matrix `M` in `R^{k x k}`.
      adjoint: Python `bool` indicating whether to use the `scale` matrix as
        specified or its adjoint.
        Default value: `False`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      TypeError: if `scale` is not a `LinearOperator`.
      TypeError: if `shift.dtype` does not match `scale.dtype`.
      ValueError: if not `scale.is_non_singular`.
    """
        with tf.name_scope(name) as name:
            # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
            dtype = tf.float32
            if shift is not None:
                shift = tf.convert_to_tensor(value=shift, name="shift")
                dtype = dtype_util.base_dtype(shift.dtype)
            self._shift = shift
            if scale is not None:
                if (shift is not None and
                        not dtype_util.base_equal(shift.dtype, scale.dtype)):
                    raise TypeError(
                        "shift.dtype({}) is incompatible with scale.dtype({})."
                        .format(shift.dtype, scale.dtype))
                if not isinstance(scale, tf.linalg.LinearOperator):
                    raise TypeError(
                        "scale is not an instance of tf.LinearOperator")
                if validate_args and not scale.is_non_singular:
                    raise ValueError("Scale matrix must be non-singular.")
                if scale.dtype is not None:
                    dtype = dtype_util.base_dtype(scale.dtype)
            self._scale = scale
            self._adjoint = adjoint
            super(AffineLinearOperator,
                  self).__init__(forward_min_event_ndims=1,
                                 is_constant_jacobian=True,
                                 dtype=dtype,
                                 validate_args=validate_args,
                                 name=name)
示例#3
0
 def _fn(*args):
   p = tf.identity(proposal_log_prob_fn(*args), name='proposal_log_prob')
   t = tf.identity(target_log_prob_fn(*args), name='target_log_prob')
   dtype = dtype_util.base_dtype(p.dtype)
   beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype)
   return tf.identity(beta * t + (1. - beta) * p,
                      name='convex_combined_log_prob')
示例#4
0
def _prepare_args_with_initial_vertex(objective_function,
                                      initial_vertex,
                                      step_sizes,
                                      objective_at_initial_vertex,
                                      batch_evaluate_objective):
  """Constructs a standard axes aligned simplex."""
  dim = tf.size(initial_vertex)
  num_vertices = dim + 1
  unit_vectors_along_axes = tf.reshape(
      tf.eye(dim, dim, dtype=dtype_util.base_dtype(initial_vertex.dtype)),
      tf.concat([[dim], tf.shape(initial_vertex)], axis=0))

  # If step_sizes does not broadcast to initial_vertex, the multiplication
  # in the second term will fail.
  simplex_face = initial_vertex + step_sizes * unit_vectors_along_axes
  simplex = tf.concat([tf.expand_dims(initial_vertex, axis=0),
                       simplex_face], axis=0)
  # Evaluate the objective function at the simplex vertices.
  if objective_at_initial_vertex is None:
    objective_at_simplex, num_evaluations = _evaluate_objective_multiple(
        objective_function, simplex, batch_evaluate_objective)
  else:
    objective_at_simplex_face, num_evaluations = _evaluate_objective_multiple(
        objective_function, simplex_face, batch_evaluate_objective)
    objective_at_simplex = tf.concat(
        [
            tf.expand_dims(objective_at_initial_vertex, axis=0),
            objective_at_simplex_face
        ], axis=0)

  return (dim,
          num_vertices,
          simplex,
          objective_at_simplex,
          num_evaluations)
示例#5
0
def _prepare_args_with_initial_vertex(objective_function, initial_vertex,
                                      step_sizes, objective_at_initial_vertex,
                                      batch_evaluate_objective):
    """Constructs a standard axes aligned simplex."""
    dim = ps.size(initial_vertex)
    # tf.eye complains about np.array(.., np.int32) num_rows, only welcomes numpy
    # scalars. TODO(b/162529062): Remove the following line.
    dim = dim if tf.is_tensor(dim) else int(dim)
    num_vertices = dim + 1
    unit_vectors_along_axes = tf.reshape(
        tf.eye(dim, dim, dtype=dtype_util.base_dtype(initial_vertex.dtype)),
        ps.concat([[dim], ps.shape(initial_vertex)], axis=0))

    # If step_sizes does not broadcast to initial_vertex, the multiplication
    # in the second term will fail.
    simplex_face = initial_vertex + step_sizes * unit_vectors_along_axes
    simplex = tf.concat([tf.expand_dims(initial_vertex, axis=0), simplex_face],
                        axis=0)
    # Evaluate the objective function at the simplex vertices.
    if objective_at_initial_vertex is None:
        objective_at_simplex, num_evaluations = _evaluate_objective_multiple(
            objective_function, simplex, batch_evaluate_objective)
    else:
        objective_at_simplex_face, num_evaluations = _evaluate_objective_multiple(
            objective_function, simplex_face, batch_evaluate_objective)
        objective_at_simplex = tf.concat([
            tf.expand_dims(objective_at_initial_vertex, axis=0),
            objective_at_simplex_face
        ],
                                         axis=0)

    return (dim, num_vertices, simplex, objective_at_simplex, num_evaluations)
示例#6
0
    def _fn(state_parts, seed):
        """Adds a uniform perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.name_scope(name or 'random_walk_uniform_fn'):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')
            part_seeds = samplers.split_seed(seed, n=len(state_parts))
            next_state_parts = [
                samplers.uniform(  # pylint: disable=g-complex-comprehension
                    minval=state_part - scale_part,
                    maxval=state_part + scale_part,
                    shape=tf.shape(state_part),
                    dtype=dtype_util.base_dtype(state_part.dtype),
                    seed=seed_part)
                for scale_part, state_part, seed_part in zip(
                    scales, state_parts, part_seeds)
            ]
            return next_state_parts
示例#7
0
def _make_empty_queue_for(k, element):
    """Creates a `tf.Tensor` suitable to hold `k` element-shaped tensors.

  For example:

  ```python
    element = tf.constant([[0., 1., 2., 3., 4.],
                           [5., 6., 7., 8., 9.]])

    # A queue capable of holding 3 elements.
    _make_empty_queue_for(3, element)
    # => [[[ 0.,  0.,  0.,  0.,  0.],
    #      [ 0.,  0.,  0.,  0.,  0.]],
    #
    #     [[ 0.,  0.,  0.,  0.,  0.],
    #      [ 0.,  0.,  0.,  0.,  0.]],
    #
    #     [[ 0.,  0.,  0.,  0.,  0.],
    #      [ 0.,  0.,  0.,  0.,  0.]]]
  ```

  Args:
    k: A positive scalar integer, number of elements that each queue will hold.
    element: A `tf.Tensor`, only its shape and dtype information are relevant.

  Returns:
    A zero-filed `tf.Tensor` of shape `(k,) + tf.shape(element)` and same dtype
    as `element`.
  """
    queue_shape = tf.concat(
        [[k], distribution_util.prefer_static_shape(element)], axis=0)
    return tf.zeros(queue_shape, dtype=dtype_util.base_dtype(element.dtype))
 def _log_prob(self, x):
     # TODO(b/149334734): Consider using QuantizedDistribution for the log_prob
     # computation for better precision.
     num_categories = self._num_categories()
     x, augmented_log_survival = _broadcast_cat_event_and_params(
         event=x,
         params=tf.math.log_sigmoid(self.loc[..., tf.newaxis] -
                                    self._augmented_cutpoints()),
         base_dtype=dtype_util.base_dtype(self.dtype))
     x_flat = tf.reshape(x, [-1, 1])
     augmented_log_survival_flat = tf.reshape(augmented_log_survival,
                                              [-1, num_categories + 1])
     log_survival_flat_xm1 = tf.gather(params=augmented_log_survival_flat,
                                       indices=tf.clip_by_value(
                                           x_flat, 0, num_categories),
                                       batch_dims=1)
     log_survival_flat_x = tf.gather(params=augmented_log_survival_flat,
                                     indices=tf.clip_by_value(
                                         x_flat + 1, 0, num_categories),
                                     batch_dims=1)
     log_prob_flat = tfp_math.log_sub_exp(log_survival_flat_xm1,
                                          log_survival_flat_x)
     # Deal with case where both survival probabilities are -inf, which gives
     # `log_prob_flat = nan` when it should be -inf.
     minus_inf = tf.constant(-np.inf, dtype=log_prob_flat.dtype)
     log_prob_flat = tf.where(x_flat > num_categories - 1, minus_inf,
                              log_prob_flat)
     return tf.reshape(log_prob_flat, shape=ps.shape(x))
    def _fn(state_parts, seed):
        """Adds a normal perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.name_scope(name or 'random_walk_normal_fn'):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')
            seed_stream = SeedStream(seed, salt='RandomWalkNormalFn')
            next_state_parts = [
                tf.random.normal(  # pylint: disable=g-complex-comprehension
                    mean=state_part,
                    stddev=scale_part,
                    shape=tf.shape(state_part),
                    dtype=dtype_util.base_dtype(state_part.dtype),
                    seed=seed_stream())
                for scale_part, state_part in zip(scales, state_parts)
            ]

            return next_state_parts
示例#10
0
  def __init__(self, transform_fn, pretransformed_input, dtype=None,
               shape=NONE_SPECIFIED, name=None):
    """Creates the `DeferredTensor` object.

    Args:
      transform_fn: Python `callable` taking `pretransformed_input` and
        returning a `Tensor` (representing by this object).
      pretransformed_input: object with `shape`, `dtype` properties (typically a
        `tf.Variable`) passed into `transform_fn` when this object is acted upon
        in a `Tensor` context, eg, `tf.convert_to_tensor`, `+`, `tf.math.exp`,
        etc.
      dtype: Equivalent to what would otherwise be
        `transform_fn(pretransformed_input).dtype`.
         Default value: `None` (i.e., `pretransformed_input.dtype`).
      shape: Equivalent to what would otherwise be
        `transform_fn(pretransformed_input).shape`.
         Default value: `'None'` (i.e., `pretransformed_input.shape`).
      name: Python `str` representing this object's `name`; used only in graph
        mode.
        Default value: `None` (i.e.,
        `transform_fn.__name__ + '_' + pretransformed_input.name`).

    Raises:
      TypeError: if `transform_fn` is not `callable`.
      TypeError: if `pretransformed_input` lacks `dtype` and/or `shape`
        properties (and `dtype` and/or `shape` arguments are unspecified).
    """
    if not callable(transform_fn):
      raise TypeError('Argument `transform_fn` must be a Python `callable`.')
    if ((dtype is None and not hasattr(pretransformed_input, 'dtype')) or
        (shape is None and not hasattr(pretransformed_input, 'shape'))):
      raise TypeError('Argument `pretransformed_input` must have `dtype` and '
                      '`shape` properties (unless `dtype`, `shape` arguments '
                      'are explicitly provided.')
    has_name = bool(name)
    if not has_name:
      name = '_'.join([
          transform_fn.__name__,
          getattr(pretransformed_input, 'name', '')])
      name = name_util.strip_invalid_chars(name)
      name = name_util.camel_to_lower_snake(name)
    name = name_util.get_name_scope_name(name)
    name = name_util.strip_invalid_chars(name)
    super(DeferredTensor, self).__init__(name=name)
    self._name = name

    self._transform_fn = transform_fn
    self._pretransformed_input = pretransformed_input
    self._dtype = dtype_util.base_dtype(dtype or pretransformed_input.dtype)
    self._shape = tf.TensorShape(
        pretransformed_input.shape if shape == 'None' else shape)

    # Secret handshake with tf.is_tensor to return True for DT.
    #
    # Works around an exception in LinearOperator (which in 2.0.0 checks only
    # `tf.is_tensor`, not also `linear_operator_util.is_ref`:
    # ValueError: Graph parent item 0 is not a Tensor;
    #   <DeferredTensor: dtype=float32, shape=[2], fn=exp>.
    # TODO(b/140157055): Remove this shim after LinOp is patched in 2.0.
    self.is_tensor_like = True
示例#11
0
 def _inverse_log_det_jacobian(self, y):
     # If event_ndims = 2,
     # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1),
     # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0].
     with tf.control_dependencies(self._assertions(y)):
         zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype))
         return zero, zero
示例#12
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
示例#13
0
def dense_to_sparse(x, ignore_value=None, name=None):
    """Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells.

  Args:
    x: A `Tensor`.
    ignore_value: Entries in `x` equal to this value will be
      absent from the return `SparseTensor`. If `None`, default value of
      `x` dtype will be used (e.g. '' for `str`, 0 for `int`).
    name: Python `str` prefix for ops created by this function.

  Returns:
    sparse_x: A `tf.SparseTensor` with the same shape as `x`.

  Raises:
    ValueError: when `x`'s rank is `None`.
  """
    # Copied (with modifications) from:
    # tensorflow/contrib/layers/python/ops/sparse_ops.py.
    with tf.name_scope(name or 'dense_to_sparse'):
        x = tf.convert_to_tensor(x, name='x')
        if ignore_value is None:
            if dtype_util.base_dtype(x.dtype) == tf.string:
                # Exception due to TF strings are converted to numpy objects by default.
                ignore_value = ''
            else:
                ignore_value = dtype_util.as_numpy_dtype(x.dtype)(0)
            ignore_value = tf.cast(ignore_value, x.dtype, name='ignore_value')
        indices = tf.where(tf.not_equal(x, ignore_value), name='indices')
        return tf.SparseTensor(indices=indices,
                               values=tf.gather_nd(x, indices, name='values'),
                               dense_shape=tf.shape(x,
                                                    out_type=tf.int64,
                                                    name='dense_shape'))
示例#14
0
    def _sample_n(self, n, seed):
        df = tf.convert_to_tensor(self.df)
        batch_shape = self._batch_shape_tensor(df)
        event_shape = self._event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        normal_seed, gamma_seed = samplers.split_seed(seed, salt='Wishart')

        # Complexity: O(nbk**2)
        x = samplers.normal(shape=shape,
                            mean=0.,
                            stddev=1.,
                            dtype=self.dtype,
                            seed=normal_seed)

        # Complexity: O(nbk)
        # This parameterization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = df * tf.ones(self._scale.batch_shape_tensor(),
                                   dtype=dtype_util.base_dtype(df.dtype))

        g = gamma_lib.random_gamma(shape=[n],
                                   concentration=self._multi_gamma_sequence(
                                       0.5 * expanded_df, self._dimension()),
                                   rate=0.5,
                                   seed=gamma_seed)

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self._scale.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
示例#15
0
    def _cdf(self, k):
        k = tf.convert_to_tensor(value=k, name="k")

        k, probs = _broadcast_cat_event_and_params(
            k, self.probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, self.num_categories - 1)

        batch_shape = tf.shape(input=k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [self.num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        return tf.where(should_be_zero, tf.zeros_like(cdf), cdf)
示例#16
0
    def bootstrap_results(self, init_state):
        """Creates initial `state`."""
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    "AdaptiveRandomWalkMetropolisHastings",
                                    "bootstrap_results")):
            if mcmc_util.is_list_like(init_state):
                initial_state_parts = list(init_state)
            else:
                initial_state_parts = [init_state]
            initial_state_parts = [
                tf.convert_to_tensor(s, name="init_state")
                for s in initial_state_parts
            ]

            shape = tf.stack(initial_state_parts).shape
            dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype)

            init_covariance_scaling = tf.cast(
                tf.repeat([self.initial_covariance_scaling],
                          repeats=[shape[0]],
                          axis=0),
                dtype=dtype,
            )

            inner_results = self._impl.bootstrap_results(init_state)
            return self.extra_setter_fn(
                inner_results,
                0,
                init_covariance_scaling / shape[-1],
                self.initial_covariance,
                self._accum_covar,
                self.initial_u,
            )
示例#17
0
    def __init__(self,
                 transform_fn,
                 pretransformed_input,
                 dtype=None,
                 shape=NONE_SPECIFIED,
                 name=None):
        """Creates the `DeferredTensor` object.

    Args:
      transform_fn: Python `callable` taking `pretransformed_input` and
        returning a `Tensor` (representing by this object).
      pretransformed_input: object with `shape`, `dtype` properties (typically a
        `tf.Variable`) passed into `transform_fn` when this object is acted upon
        in a `Tensor` context, eg, `tf.convert_to_tensor`, `+`, `tf.math.exp`,
        etc.
      dtype: Equivalent to what would otherwise be
        `transform_fn(pretransformed_input).dtype`.
         Default value: `None` (i.e., `pretransformed_input.dtype`).
      shape: Equivalent to what would otherwise be
        `transform_fn(pretransformed_input).shape`.
         Default value: `'None'` (i.e., `pretransformed_input.shape`).
      name: Python `str` representing this object's `name`; used only in graph
        mode.
        Default value: `None` (i.e.,
        `transform_fn.__name__ + '_' + pretransformed_input.name`).

    Raises:
      TypeError: if `transform_fn` is not `callable`.
      TypeError: if `pretransformed_input` lacks `dtype` and/or `shape`
        properties (and `dtype` and/or `shape` arguments are unspecified).
    """
        if not callable(transform_fn):
            raise TypeError(
                'Argument `transform_fn` must be a Python `callable`.')
        if ((dtype is None and not hasattr(pretransformed_input, 'dtype')) or
            (shape is None and not hasattr(pretransformed_input, 'shape'))):
            raise TypeError(
                'Argument `pretransformed_input` must have `dtype` and '
                '`shape` properties (unless `dtype`, `shape` arguments '
                'are explicitly provided.')
        has_name = bool(name)
        if not has_name:
            name = '_'.join([
                transform_fn.__name__,
                getattr(pretransformed_input, 'name', '')
            ])
            name = name_util.strip_invalid_chars(name)
            name = name_util.camel_to_lower_snake(name)
        name = name_util.get_name_scope_name(name)
        name = name_util.strip_invalid_chars(name)
        super(DeferredTensor, self).__init__(name=name)
        self._name = name

        self._transform_fn = transform_fn
        self._pretransformed_input = pretransformed_input
        self._dtype = dtype_util.base_dtype(dtype
                                            or pretransformed_input.dtype)
        self._shape = tf.TensorShape(pretransformed_input.shape if shape ==
                                     'None' else shape)
示例#18
0
    def _forward_log_det_jacobian(self, x):
        # is_constant_jacobian = True for this bijector, hence the
        # `log_det_jacobian` need only be specified for a single input, as this will
        # be tiled to match `event_ndims`.
        if self.scale is None:
            return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))

        return tf.math.log(tf.abs(self.scale))
 def _forward_log_det_jacobian(self, x):
   # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
   # sections leading up to it, for context) in
   # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
   with tf.control_dependencies(self._assertions(x)):
     matrix_dim = tf.cast(tf.shape(x)[-1],
                          dtype_util.base_dtype(x.dtype))
     return -(matrix_dim + 1) * tf.reduce_sum(
         tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
示例#20
0
 def _log_prob(self, k):
     logits = self.logits_parameter()
     if self.validate_args:
         k = distribution_util.embed_check_integer_casting_closed(
             k, target_dtype=tf.int32)
     k, logits = _broadcast_cat_event_and_params(
         k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
     return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                            logits=logits)
示例#21
0
    def histogram(self, x, value_range=None, nbins=None, name=None):
        """Return histogram of values.

    Given the tensor `values`, this operation returns a rank 1 histogram
    counting the number of entries in `values` that fell into every bin. The
    bins are equal width and determined by the arguments `value_range` and
    `nbins`.

    Args:
      x: 1D numeric `Tensor` of items to count.
      value_range:  Shape [2] `Tensor`. `new_values <= value_range[0]` will be
        mapped to `hist[0]`, `values >= value_range[1]` will be mapped to
        `hist[-1]`. Must be same dtype as `x`.
      nbins:  Scalar `int32 Tensor`.  Number of histogram bins.
      name: Python `str` name prefixed to Ops created by this class.

    Returns:
      counts: 1D `Tensor` of counts, i.e.,
        `counts[i] = sum{ edges[i-1] <= values[j] < edges[i] : j }`.
      edges: 1D `Tensor` characterizing intervals used for counting.
    """
        with tf.compat.v2.name_scope(name or 'histogram'):
            x = tf.convert_to_tensor(value=x, name='x')
            if value_range is None:
                value_range = [
                    tf.reduce_min(input_tensor=x),
                    1 + tf.reduce_max(input_tensor=x)
                ]
            value_range = tf.convert_to_tensor(value=value_range,
                                               name='value_range')
            lo = value_range[0]
            hi = value_range[1]
            if nbins is None:
                nbins = tf.cast(hi - lo, dtype=tf.int32)
            delta = (hi - lo) / tf.cast(
                nbins, dtype=dtype_util.base_dtype(value_range.dtype))
            edges = tf.range(start=lo,
                             limit=hi,
                             delta=delta,
                             dtype=dtype_util.base_dtype(x.dtype))
            counts = tf.histogram_fixed_width(x,
                                              value_range=value_range,
                                              nbins=nbins)
            return counts, edges
示例#22
0
 def _forward(self, x):
     with tf.control_dependencies(self._assertions(x)):
         x_shape = tf.shape(input=x)
         identity_matrix = tf.eye(x_shape[-1],
                                  batch_shape=x_shape[:-2],
                                  dtype=dtype_util.base_dtype(x.dtype))
         # Note `matrix_triangular_solve` implicitly zeros upper triangular of `x`.
         y = tf.linalg.triangular_solve(x, identity_matrix)
         y = tf.matmul(y, y, adjoint_a=True)
         return tf.linalg.cholesky(y)
示例#23
0
 def _log_prob(self, k):
     with tf.name_scope("Cat2log_prob"):
         logits = self.logits_parameter()
         if self.validate_args:
             k = distribution_util.embed_check_integer_casting_closed(
                 k, target_dtype=self.dtype)
         k, logits = _broadcast_cat_event_and_params(
             k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
         logits_normalised = tf.math.log(tf.math.softmax(logits))
         return tf.gather(logits_normalised, k, batch_dims=1)
示例#24
0
    def _forward_log_det_jacobian(self, x):
        # is_constant_jacobian = True for this bijector, hence the
        # `log_det_jacobian` need only be specified for a single input, as this will
        # be tiled to match `event_ndims`.
        if self.scale is None:
            return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))

        with tf.control_dependencies(self._maybe_collect_assertions() if self.
                                     validate_args else []):
            return self.scale.log_abs_determinant()
示例#25
0
 def _inverse(self, y):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = y.shape[-1]
     # z_k = y_k / (1 - sum_{i=1 to k-1} y_i)
     # x_k = logit(z_k) - log(1 / (N - k))
     offset = tf.math.log(
         tf.cast(tf.range(ps.shape(y)[-1] - 1, 0, delta=-1),
                 dtype=dtype_util.base_dtype(y.dtype)))
     z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
     return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError('`entropy` is not implemented.')
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError('`entropy` is not implemented when '
                                      '`bijector` is not injective.')
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        override_event_shape = tf.convert_to_tensor(self._override_event_shape)
        override_batch_shape = tf.convert_to_tensor(self._override_batch_shape)
        base_batch_shape_tensor = self.distribution.batch_shape_tensor()
        base_event_shape_tensor = self.distribution.event_shape_tensor()
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(tf.reduce_prod(override_event_shape),
                                        dtype=dtype_util.base_dtype(
                                            entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(override_batch_shape),
                base_batch_shape_tensor
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                override_batch_shape,
                prefer_static.ones_like(base_batch_shape_tensor)
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat([
            self._batch_shape_tensor(override_batch_shape,
                                     base_batch_shape_tensor),
            self._event_shape_tensor(override_event_shape,
                                     base_event_shape_tensor)
        ], 0),
                                    dtype=self.dtype)
        event_ndims = (
            tensorshape_util.rank(self.event_shape)  # pylint: disable=g-long-ternary
            if tensorshape_util.rank(self.event_shape) is not None else
            tf.size(
                self._event_shape_tensor(override_event_shape,
                                         base_event_shape_tensor)))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
示例#27
0
def _maybe_validate_distributions(distributions, dtype_override,
                                  validate_args):
    """Checks that `distributions` satisfies all assumptions."""
    assertions = []

    if not _is_iterable(distributions) or not distributions:
        raise ValueError('`distributions` must be a list of one or more '
                         'distributions.')

    if dtype_override is None:
        dts = [
            dtype_util.base_dtype(d.dtype) for d in distributions
            if d.dtype is not None
        ]
        if dts[1:] != dts[:-1]:
            raise TypeError(
                'Distributions must have same dtype; found: {}.'.format(
                    set(dtype_util.name(dt) for dt in dts)))

    # Validate event_ndims.
    for d in distributions:
        if tensorshape_util.rank(d.event_shape) is not None:
            if tensorshape_util.rank(d.event_shape) != 1:
                raise ValueError('`Distribution` must be vector variate, '
                                 'found event nimds: {}.'.format(
                                     tensorshape_util.rank(d.event_shape)))
        elif validate_args:
            assertions.append(
                assert_util.assert_equal(
                    1,
                    tf.size(d.event_shape_tensor()),
                    message='`Distribution` must be vector variate.'))

    batch_shapes = [d.batch_shape for d in distributions]
    if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes):
        if batch_shapes[1:] != batch_shapes[:-1]:
            raise ValueError('Distributions must have the same `batch_shape`; '
                             'found: {}.'.format(batch_shapes))
    elif validate_args:
        batch_shapes = [
            tensorshape_util.as_list(d.batch_shape)  # pylint: disable=g-complex-comprehension
            if tensorshape_util.is_fully_defined(d.batch_shape) else
            d.batch_shape_tensor() for d in distributions
        ]
        assertions.extend(
            assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                b1,
                b2,
                message='Distribution `batch_shape`s must be identical.')
            for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))

    return assertions
 def _log_survival_function(self, x):
     num_categories = self._num_categories()
     x, augmented_log_survival = _broadcast_cat_event_and_params(
         event=x,
         params=tf.math.log_sigmoid(self.loc[..., tf.newaxis] -
                                    self._augmented_cutpoints()),
         base_dtype=dtype_util.base_dtype(self.dtype))
     x_flat = tf.reshape(x, [-1, 1])
     augmented_log_survival_flat = tf.reshape(augmented_log_survival,
                                              [-1, num_categories + 1])
     log_survival_flat = tf.gather(params=augmented_log_survival_flat,
                                   indices=tf.clip_by_value(
                                       x_flat + 1, 0, num_categories),
                                   batch_dims=1)
     return tf.reshape(log_survival_flat, shape=ps.shape(x))
示例#29
0
 def _forward(self, x):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = x.shape[-1] + 1
     # z_k = sigmoid(x + log(1 / (N - k)))
     # y_1 = z_1
     # y_k = (1 - sum_{i=1 to k-1} y_i) * z_k
     # y_N = 1 - sum_{i=1 to N-1} y_i
     # TODO(b/128857065): The numerics can possibly be improved here with a
     # log-space computation.
     offset = -tf.math.log(
         tf.cast(tf.range(ps.shape(x)[-1], 0, delta=-1),
                 dtype=dtype_util.base_dtype(x.dtype)))
     z = tf.math.sigmoid(x + offset)
     y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True)
     return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)],
                      axis=-1)
示例#30
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []
        sample_shape = None  # Memoize concretization.

        # Check valid shape.
        ndims_ = tensorshape_util.rank(self.sample_shape.shape)
        if is_init != (ndims_ is None):
            msg = 'Argument `sample_shape` must be either a scalar or a vector.'
            if ndims_ is not None:
                if ndims_ > 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_less(tf.rank(sample_shape),
                                            2,
                                            message=msg))

        # Check valid dtype.
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = self.sample_shape.dtype
            if dtype_ is None:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                dtype_ = sample_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError(
                    'Argument `sample_shape` must be integer type; '
                    'saw {}.'.format(dtype_util.name(dtype_)))

        # Check valid "value".
        if is_init != tensor_util.is_ref(self.sample_shape):
            sample_shape_ = tf.get_static_value(self.sample_shape)
            msg = 'Argument `sample_shape` must have non-negative values.'
            if sample_shape_ is not None:
                if np.any(np.array(sample_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, sample_shape_))
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_greater(sample_shape, -1, message=msg))

        return assertions