コード例 #1
0
def make_momentum_distribution(state_parts,
                               batch_shape,
                               running_variance_parts=None,
                               shard_axis_names=None):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    state_parts: List of `Tensor`.
    batch_shape: Batch shape.
    running_variance_parts: Optional, list of `Tensor`
       outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults
       to ones with the same shape as state_parts.
    shard_axis_names: A structure of string names indicating how members of the
      state are sharded.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    if running_variance_parts is None:
        running_variance_parts = tf.nest.map_structure(tf.ones_like,
                                                       state_parts)
    distributions = []
    batch_ndims = ps.rank_from_shape(batch_shape)
    use_sharded_jd = True
    if shard_axis_names is None:
        use_sharded_jd = False
        shard_axis_names = [None] * len(state_parts)
    for variance_part, state_part, shard_axes in zip(running_variance_parts,
                                                     state_parts,
                                                     shard_axis_names):
        event_shape = state_part.shape[batch_ndims:]
        if not tensorshape_util.is_fully_defined(event_shape):
            event_shape = ps.shape(state_part,
                                   name='state_part_shp')[batch_ndims:]
        variance_tiled = tf.broadcast_to(
            variance_part, ps.concat([batch_shape, event_shape], axis=0))
        nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32)
        variance_flattened = tf.reshape(
            variance_tiled, ps.concat([batch_shape, [nevt]], axis=0))

        distribution = _CompositeTransformedDistribution(
            bijector=_CompositeReshape(event_shape_out=event_shape,
                                       name='reshape_mvnpfl'),
            distribution=(
                _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                    precision_factor=_CompositeLinearOperatorDiag(
                        tf.math.sqrt(variance_flattened)),
                    precision=_CompositeLinearOperatorDiag(variance_flattened),
                    name='momentum')))
        if shard_axes:
            distribution = sharded.Sharded(distribution,
                                           shard_axis_name=shard_axes)
        distributions.append(distribution)
    if use_sharded_jd:
        jd = _CompositeShardedJointDistributionSequential(distributions)
    else:
        jd = _CompositeJointDistributionSequential(distributions)
    return maybe_make_list_and_batch_broadcast(jd, batch_shape)
コード例 #2
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        if not isinstance(matrix, tf.linalg.LinearOperator):
            matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')
        if isinstance(matrix, tf.linalg.LinearOperator):
            matrix_shape = tf.cast(matrix.shape_tensor(), tf.int64)
        else:
            matrix_shape = ps.shape(matrix, out_type=tf.int64)

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(max_rank, matrix_shape[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            if callable(getattr(matrix, 'row', None)):
                row = matrix.row(perm[..., m])[..., tf.newaxis, :]
            else:
                row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([ps.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros(matrix_shape, dtype=matrix.dtype)[..., :max_rank, :]
        perm = tf.broadcast_to(ps.range(matrix_shape[-1]), matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
コード例 #3
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        tensorshape_util.set_shape(x, lower_upper.shape)
        return x
コード例 #4
0
 def _entropy(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(2. + tf.math.log(scale),
                          self._batch_shape_tensor(scale=scale))
コード例 #5
0
 def _stddev(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(
       scale * tf.constant(np.pi / np.sqrt(3), dtype=scale.dtype),
       self._batch_shape_tensor(scale=scale))
コード例 #6
0
 def _bcast_x(self, x):
     shp = self.bijector.inverse_event_shape(self._single_event_shape())
     if not tensorshape_util.is_fully_defined(shp):
         shp = self.bijector.inverse_event_shape_tensor(
             self._single_event_shape_tensor())
     return tf.broadcast_to(x, ps.broadcast_shape(ps.shape(x), shp))
コード例 #7
0
 def _mode(self):
     return tf.broadcast_to(tf.convert_to_tensor(self.peak),
                            self._batch_shape_tensor())
コード例 #8
0
def _random_gamma_rejection(shape,
                            concentration,
                            rate=None,
                            log_rate=None,
                            seed=None,
                            log_space=False,
                            internal_dtype=None):
    """Samples from the gamma distribution.

  The sampling algorithm is rejection sampling [1], and pathwise gradients with
  respect to concentration are computed via implicit differentiation [2].

  Args:
    shape: The output sample shape. Trailing dims must match broadcast of
      `concentration` with `rate` or `log_rate`.
    concentration: Floating point tensor, the concentration params of the
      distribution(s). Must contain only positive values. Must broadcast with
      `rate` or `log_rate`, if given.
    rate: Floating point tensor, the inverse scale params of the
      distribution(s). Must contain only positive values. Must broadcast with
      `concentration`. If `None`, handled as if 1 (but possibly more
      efficiently). Mutually exclusive with `log_rate`.
    log_rate: Floating point tensor, log of the inverse scale params of the
      distribution(s). Must broadcast with `concentration`. If `None`, handled
      as if 0 (but possibly more efficiently). Mutually exclusive with `rate`.
    seed: (optional) The random seed.
    log_space: Optionally sample log(gamma) variates.
    internal_dtype: dtype to use for internal computations. If unspecified, we
      use the same dtype as the output (i.e. the dtype of `concentration`,
      `rate`, or `log_rate`) when `log_space==True`, and `tf.float64` otherwise.

  Returns:
    Differentiable samples from the gamma distribution.

  #### References

  [1] George Marsaglia and Wai Wan Tsang. A simple method for generating Gamma
      variables. ACM Transactions on Mathematical Software, 2000.

  [2] Michael Figurnov, Shakir Mohamed, and Andriy Mnih. Implicit
      Reparameterization Gradients. Neural Information Processing Systems, 2018.
  """
    generate_and_test_samples_seed, concentration_fix_seed = samplers.split_seed(
        seed, salt='random_gamma')
    output_dtype = dtype_util.common_dtype([concentration, rate, log_rate],
                                           dtype_hint=tf.float32)
    if internal_dtype is None:
        internal_dtype = output_dtype if log_space else tf.float64

    def rejection_sample(concentration):
        """Gamma rejection sampler."""
        # Note, concentration here already has a shape that is broadcast with rate.
        cast_concentration = tf.cast(concentration, internal_dtype)

        good_params_mask = (concentration > 0.)
        # When replacing NaN values, use 100. for concentration, since that leads to
        # a high-likelihood of the rejection sampler accepting on the first pass.
        safe_concentration = tf.where(good_params_mask, cast_concentration,
                                      100.)

        modified_safe_concentration = tf.where(safe_concentration < 1.,
                                               safe_concentration + 1.,
                                               safe_concentration)

        one_third = tf.constant(1. / 3, dtype=internal_dtype)
        d = modified_safe_concentration - one_third
        c = one_third * tf.math.rsqrt(d)

        def generate_and_test_samples(seed):
            """Generate and test samples."""
            v_seed, u_seed = samplers.split_seed(seed)

            x = samplers.normal(shape, dtype=internal_dtype, seed=v_seed)
            # This implicitly broadcasts concentration up to sample shape.
            v = 1 + c * x
            # In [1], there is an 'inner' rejection sampling loop which checks that
            # v > 0 and generates a new normal sample if it's not, saving the rest of
            # the computations below. We found that merging the check for  v > 0 with
            # the `good_sample_mask` not only simplifies the code, but leads to a
            # ~2x speedup for small concentrations on GPU, at the cost of deviating
            # slightly from the implementation given in Ref. [1].
            accept_v = v > 0.
            logv = tf.math.log1p(c * x)
            x2 = x * x
            v3 = v * v * v
            logv3 = logv * 3

            u = samplers.uniform(shape, dtype=internal_dtype, seed=u_seed)

            # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to
            # run the check below only if it fails, in order to avoid the relatively
            # expensive logarithm calls. Our algorithm operates in batch mode: we will
            # have to compute or not compute the logarithms for the entire batch, and
            # as the batch gets larger, the odds we compute it grow. Therefore we
            # don't bother with the "cheap" check.
            good_sample_mask = tf.logical_and(
                tf.math.log(u) < (x2 / 2. + d * (1 - v3 + logv3)), accept_v)

            return logv3 if log_space else v3, good_sample_mask

        samples = brs.batched_las_vegas_algorithm(
            generate_and_test_samples, seed=generate_and_test_samples_seed)[0]

        concentration_fix_unif = samplers.uniform(  # in [0, 1)
            shape,
            dtype=internal_dtype,
            seed=concentration_fix_seed)

        if log_space:
            concentration_lt_one_fix = tf.where(
                safe_concentration < 1.,
                # Why do we use log1p(-x)? x is in [0, 1) and log(0) = -inf, is bad.
                # x ~ U(0,1) => 1-x ~ U(0,1)
                # But at the boundary, 1-x in (0, 1]. Good.
                # So we can take log(unif(0,1)) safely as log(1-unif(0,1)).
                # log1p(-0) = 0, and log1p(-almost_one) = -not_quite_inf. Good.
                tf.math.log1p(-concentration_fix_unif) / safe_concentration,
                tf.zeros((), dtype=internal_dtype))
            samples = samples + tf.math.log(d) + concentration_lt_one_fix
        else:
            concentration_lt_one_fix = tf.where(
                safe_concentration < 1.,
                tf.math.pow(concentration_fix_unif,
                            tf.math.reciprocal(safe_concentration)),
                tf.ones((), dtype=internal_dtype))
            samples = samples * d * concentration_lt_one_fix

        samples = tf.where(good_params_mask, samples, np.nan)
        output_type_samples = tf.cast(samples, output_dtype)

        return output_type_samples

    broadcast_conc_shape = ps.broadcast_shape(ps.shape(concentration),
                                              _shape_or_scalar(rate, log_rate))
    broadcast_concentration = tf.broadcast_to(concentration,
                                              broadcast_conc_shape)
    concentration_samples = rejection_sample(broadcast_concentration)

    if rate is not None and log_rate is not None:
        raise ValueError('`rate` and `log_rate` are mutually exclusive.')

    if rate is None and log_rate is None:
        if not log_space:
            concentration_samples = _fix_zero_samples(concentration_samples)
        return concentration_samples

    if log_space:
        if log_rate is None:
            log_rate = tf.math.log(tf.where(rate >= 0., rate, np.nan))
        return concentration_samples - log_rate
    else:
        if rate is None:
            rate = tf.math.exp(log_rate)
        corrected_rate = tf.where(rate >= 0., rate, np.nan)
        return _fix_zero_samples(concentration_samples / corrected_rate)
コード例 #9
0
def broadcast_batch_shape(x, batch_shape):
    """Broadcasts batch shape of `x`."""
    return tf.broadcast_to(x, tf.TensorShape(batch_shape) + x.shape[-1])
コード例 #10
0
 def scalar_broadcast_to(self, x, shape):
   return tf.broadcast_to(x, shape)
コード例 #11
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                out_height = _deconv_output_length(xh,
                                                   filter_size=fh,
                                                   padding=padding,
                                                   output_padding=None,
                                                   stride=strides,
                                                   dilation=dh)
                out_width = _deconv_output_length(xw,
                                                  filter_size=fw,
                                                  padding=padding,
                                                  output_padding=None,
                                                  stride=strides,
                                                  dilation=dw)

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
コード例 #12
0
    def _compute_shared(self, x=None, y=None):
        """Captures shared computations across forward/inverse/logdet.

    Only one of `x` or `y` should be specified.

    Args:
      x: The `x` values we will search for.
      y: The `y` values we will search for.

    Returns:
      data: A namedtuple with named fields containing shared computations.
    """
        assert (x is None) != (y is None)
        is_x = x is not None

        range_min = tf.convert_to_tensor(self.range_min, name='range_min')
        kx = _knot_positions(self.bin_widths, range_min)
        ky = _knot_positions(self.bin_heights, range_min)
        kd = _padded(_ensure_at_least_1d(self.knot_slopes), lhs=1, rhs=1)
        kx_or_ky = kx if is_x else ky
        kx_or_ky_min = kx_or_ky[..., 0]
        kx_or_ky_max = kx_or_ky[..., -1]
        x_or_y = x if is_x else y
        out_of_bounds = (x_or_y <= kx_or_ky_min) | (x_or_y >= kx_or_ky_max)
        x_or_y = tf.where(out_of_bounds, kx_or_ky_min, x_or_y)

        shape = functools.reduce(
            tf.broadcast_dynamic_shape,
            (
                tf.shape(x_or_y[..., tf.newaxis]),  # Add a n_knots dim.
                tf.shape(kx),
                tf.shape(ky),
                tf.shape(kd)))

        bc_x_or_y = tf.broadcast_to(x_or_y, shape[:-1])
        bc_kx = tf.broadcast_to(kx, shape)
        bc_ky = tf.broadcast_to(ky, shape)
        bc_kd = tf.broadcast_to(kd, shape)
        bc_kx_or_ky = bc_kx if is_x else bc_ky
        indices = tf.maximum(
            tf.zeros([], dtype=tf.int64),
            tf.searchsorted(bc_kx_or_ky[..., :-1],
                            bc_x_or_y[..., tf.newaxis],
                            side='right',
                            out_type=tf.int64) - 1)

        def gather_squeeze(params, indices):
            rank = tensorshape_util.rank(indices.shape)
            if rank is None:
                raise ValueError('`indices` must have statically known rank.')
            return tf.gather(params, indices, axis=-1,
                             batch_dims=rank - 1)[..., 0]

        x_k = gather_squeeze(bc_kx, indices)
        x_kp1 = gather_squeeze(bc_kx, indices + 1)
        y_k = gather_squeeze(bc_ky, indices)
        y_kp1 = gather_squeeze(bc_ky, indices + 1)
        d_k = gather_squeeze(bc_kd, indices)
        d_kp1 = gather_squeeze(bc_kd, indices + 1)
        h_k = y_kp1 - y_k
        w_k = x_kp1 - x_k
        s_k = h_k / w_k

        return _SplineShared(out_of_bounds=out_of_bounds,
                             x_k=x_k,
                             y_k=y_k,
                             d_k=d_k,
                             d_kp1=d_kp1,
                             h_k=h_k,
                             w_k=w_k,
                             s_k=s_k)
コード例 #13
0
    def forward_model(self, sample):
        """The forward model.

    Args:
      sample: A sample from the model.

    Returns:
      mesh_emissivity: Float `Tensor` with shape [num_wavelengths, num_sensors,
        num_integration_points]. The spatial spectral emissivity. The last two
        dimensions describe locations perpendicular and parallel to the
        spectrometer lines of sight, respectively. Those dimensions span
        [-sensor_span, sensor_span] and [-outer_shell_radius,
        outer_shell_radius] respectively.
      mean_measurement: Float `Tensor` with shape [num_wavelengths,
        num_sensors]. The measurement means.
    """
        wavelengths = tf.convert_to_tensor(self.wavelengths, tf.float32)
        center_wavelength = tf.convert_to_tensor(self.center_wavelength,
                                                 tf.float32)

        shell_radii = tf.linspace(0., self.outer_shell_radius, self.num_shells)

        kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(
            length_scale=self.prior_length_scale)
        prior_cov = kernel.matrix(shell_radii[..., tf.newaxis],
                                  shell_radii[..., tf.newaxis])
        prior_cov = (prior_cov + self.prior_diag_noise_variance * tf.eye(
            int(prior_cov.shape[-1]))) / (1 + self.prior_diag_noise_variance)
        prior_scale = tf.linalg.cholesky(prior_cov)

        amplitude = tf.linalg.matvec(prior_scale, sample.amplitude)
        temperature = tf.linalg.matvec(prior_scale, sample.temperature)
        velocity = tf.linalg.matvec(prior_scale, sample.velocity)

        # [1, num_shells]
        amplitude = (self.amplitude_scale *
                     tf.nn.softplus(amplitude))[..., tf.newaxis, :]
        # [1, num_shells]
        temperature = (self.temperature_scale *
                       tf.nn.softplus(temperature))[..., tf.newaxis, :]
        # [1, num_shells]
        velocity = (self.velocity_scale * velocity)[..., tf.newaxis, :]
        shift = sample.shift[..., tf.newaxis]

        doppler_shifted_center_wavelength = center_wavelength * (1 - velocity)
        bandwidth = center_wavelength * tf.sqrt(temperature)

        # [num_wavelengths, num_shells]
        emissivity = amplitude / (tf.constant(np.sqrt(
            2 * np.pi), bandwidth.dtype) * bandwidth) * tf.exp(
                -(wavelengths[:, tf.newaxis] -
                  doppler_shifted_center_wavelength)**2 / (2 * bandwidth**2))

        if self.use_bump_function:
            emissivity *= tfp.math.round_exponential_bump_function(
                tf.linspace(-1., 1., self.num_shells))

        x = tf.linspace(-self.outer_shell_radius, self.outer_shell_radius,
                        self.num_integration_points)
        y = tf.linspace(-self.sensor_span, self.sensor_span, self.num_sensors)

        mesh_x, mesh_y = tf.meshgrid(x, y)
        # [num_sensors, num_integration_points]
        mesh_y = -shift[..., tf.newaxis] + mesh_y
        mesh_x = tf.broadcast_to(mesh_x, mesh_y.shape)
        mesh_r = tf.linalg.norm(tf.stack([mesh_x, mesh_y], -1), axis=-1)

        # [num_wavelengths, num_sensors, num_integration_points]
        mesh_emissivity = tfp.math.batch_interp_regular_1d_grid(
            mesh_r[..., tf.newaxis, :, :],
            0.,
            self.outer_shell_radius,
            emissivity[..., :, tf.newaxis, :],
            fill_value=0.)

        # [num_wavelengths, num_sensors]
        mean_measurement = tfp.math.trapz(
            mesh_emissivity,
            tf.broadcast_to(mesh_x[..., tf.newaxis, :, :],
                            mesh_emissivity.shape))
        return mesh_emissivity, mean_measurement
コード例 #14
0
    def _setup(self, coupon_spec):
        """Setup tensors for efficient computations."""

        if isinstance(coupon_spec, list):
            cpn_frequency = dates.periods.PeriodTensor.stack(
                [x.coupon_frequency for x in coupon_spec], axis=0)
            businessday_rule = coupon_spec[-1].businessday_rule
            ref_term = dates.periods.PeriodTensor.stack(
                [x.reference_rate_term for x in coupon_spec], axis=0)
            daycount_convention = coupon_spec[-1].daycount_convention
            notional = tf.convert_to_tensor([x.notional for x in coupon_spec],
                                            dtype=self._dtype)
            coupon_basis = tf.convert_to_tensor(
                [x.coupon_basis for x in coupon_spec], dtype=self._dtype)
            coupon_multiplier = tf.convert_to_tensor(
                [x.coupon_multiplier for x in coupon_spec], dtype=self._dtype)
        else:
            cpn_frequency = coupon_spec.coupon_frequency
            businessday_rule = coupon_spec.businessday_rule
            ref_term = coupon_spec.reference_rate_term
            daycount_convention = coupon_spec.daycount_convention
            notional = tf.broadcast_to(
                tf.convert_to_tensor(coupon_spec.notional, dtype=self._dtype),
                self._start_date.shape)
            coupon_basis = tf.broadcast_to(
                tf.convert_to_tensor(coupon_spec.coupon_basis,
                                     dtype=self._dtype),
                self._start_date.shape)
            coupon_multiplier = tf.broadcast_to(
                tf.convert_to_tensor(coupon_spec.coupon_multiplier,
                                     dtype=self._dtype),
                self._start_date.shape)

        cpn_dates = self._generate_schedule(cpn_frequency, businessday_rule)
        accrual_start_dates = cpn_dates[:, :-1]

        accrual_end_dates = cpn_dates[:, :
                                      -1] + dates.periods.PeriodTensor.expand_dims(
                                          ref_term, axis=-1).broadcast_to(
                                              accrual_start_dates.shape)
        coupon_start_dates = cpn_dates[:, :-1]
        coupon_end_dates = cpn_dates[:, 1:]
        payment_dates = cpn_dates[:, 1:]

        daycount_fractions = rc.get_daycount_fraction(cpn_dates[:, :-1],
                                                      cpn_dates[:, 1:],
                                                      daycount_convention,
                                                      dtype=self._dtype)

        notional = tf.repeat(notional, payment_dates.shape.as_list()[-1])
        coupon_basis = tf.repeat(coupon_basis,
                                 payment_dates.shape.as_list()[-1])
        coupon_multiplier = tf.repeat(coupon_multiplier,
                                      payment_dates.shape.as_list()[-1])

        contract_index = tf.repeat(tf.range(0, self._batch_size),
                                   payment_dates.shape.as_list()[-1])

        self._num_cashflows = daycount_fractions.shape.as_list()[-1]
        self._coupon_start_dates = coupon_start_dates.reshape([-1])
        self._coupon_end_dates = coupon_end_dates.reshape([-1])
        self._payment_dates = payment_dates.reshape([-1])
        self._accrual_start_date = accrual_start_dates.reshape([-1])
        self._accrual_end_date = accrual_end_dates.reshape([-1])
        self._notional = notional
        self._daycount_fractions = tf.reshape(daycount_fractions, [-1])
        self._coupon_basis = coupon_basis
        self._coupon_multiplier = coupon_multiplier
        self._contract_index = contract_index
コード例 #15
0
def left_justified_broadcast_to(x, shape, name=None):
    """Broadcasts `x` to shape, in a left-justified manner."""
    with tf.name_scope(name or 'left_justified_broadcast_to'):
        return tf.broadcast_to(
            left_justified_expand_dims_to(x, prefer_static.size(shape)), shape)
コード例 #16
0
    def testAutoVectorization(self, bijector_name, data):

        # TODO(b/150161911): reconcile numeric behavior of eager and graph mode.
        if tf.executing_eagerly():
            return

        bijector, event_dim = self._draw_bijector(
            bijector_name,
            data,
            batch_shape=[],  # Avoid conflict with vmap sample dimension.
            validate_args=False,  # Work around lack of `If` support in vmap.
            allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
                               set(AUTOVECTORIZATION_IS_BROKEN)))
        atol = AUTOVECTORIZATION_ATOL[bijector_name]
        rtol = AUTOVECTORIZATION_RTOL[bijector_name]

        # Forward
        n = 3
        xs = self._draw_domain_tensor(bijector,
                                      data,
                                      event_dim,
                                      sample_shape=[n])
        ys = bijector.forward(xs)
        vectorized_ys = tf.vectorized_map(bijector.forward,
                                          xs,
                                          fallback_to_while_loop=False)
        self.assertAllClose(*self.evaluate((ys, vectorized_ys)),
                            atol=atol,
                            rtol=rtol)

        # FLDJ
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=ps.rank_from_shape(xs.shape) - 1))
        fldj_fn = functools.partial(bijector.forward_log_det_jacobian,
                                    event_ndims=event_ndims)
        vectorized_fldj = tf.vectorized_map(fldj_fn,
                                            xs,
                                            fallback_to_while_loop=False)
        fldj = tf.broadcast_to(fldj_fn(xs), tf.shape(vectorized_fldj))
        self.assertAllClose(*self.evaluate((fldj, vectorized_fldj)),
                            atol=atol,
                            rtol=rtol)

        # Inverse
        ys = self._draw_codomain_tensor(bijector,
                                        data,
                                        event_dim,
                                        sample_shape=[n])
        xs = bijector.inverse(ys)
        vectorized_xs = tf.vectorized_map(bijector.inverse,
                                          ys,
                                          fallback_to_while_loop=False)
        self.assertAllClose(*self.evaluate((xs, vectorized_xs)),
                            atol=atol,
                            rtol=rtol)

        # ILDJ
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ps.rank_from_shape(ys.shape) - 1))
        ildj_fn = functools.partial(bijector.inverse_log_det_jacobian,
                                    event_ndims=event_ndims)
        vectorized_ildj = tf.vectorized_map(ildj_fn,
                                            ys,
                                            fallback_to_while_loop=False)
        ildj = tf.broadcast_to(ildj_fn(ys), tf.shape(vectorized_ildj))
        self.assertAllClose(*self.evaluate((ildj, vectorized_ildj)),
                            atol=atol,
                            rtol=rtol)
コード例 #17
0
    def state_y(self,
                t: types.RealTensor,
                name: str = None) -> types.RealTensor:
        """Computes the state variable `y(t)` for tha Gaussian HJM Model.

    For Gaussian HJM model, the state parameter y(t), can be analytically
    computed as follows:

    y_ij(t) = exp(-k_i * t) * exp(-k_j * t) * (
              int_0^t rho_ij * sigma_i(u) * sigma_j(u) * du)

    Args:
      t: A rank 1 real `Tensor` of shape `[num_times]` specifying the time `t`.
      name: Python string. The name to give to the ops created by this function.
        Default value: `None` which maps to the default name `state_y`.

    Returns:
      A real `Tensor` of shape [self._factors, self._factors, num_times]
      containing the computed y_ij(t).
    """
        name = name or 'state_y'
        with tf.name_scope(name):
            t = tf.convert_to_tensor(t, dtype=self._dtype)
            t_shape = tf.shape(t)
            t = tf.broadcast_to(t, tf.concat([[self._dim], t_shape], axis=0))
            time_index = tf.searchsorted(self._jump_locations, t)
            # create a matrix k2(i,j) = k(i) + k(j)
            mr2 = tf.expand_dims(self._mean_reversion, axis=-1)
            # Add a dimension corresponding to `num_times`
            mr2 = tf.expand_dims(mr2 + tf.transpose(mr2), axis=-1)

            def _integrate_volatility_squared(vol, l_limit, u_limit):
                # create sigma2_ij = sigma_i * sigma_j
                vol = tf.expand_dims(vol, axis=-2)
                vol_squared = tf.expand_dims(self._rho, axis=-1) * (
                    vol * tf.transpose(vol, perm=[1, 0, 2]))
                return vol_squared / mr2 * (tf.math.exp(mr2 * u_limit) -
                                            tf.math.exp(mr2 * l_limit))

            is_constant_vol = tf.math.equal(
                tf.shape(self._jump_values_vol)[-1], 0)
            v_squared_between_vol_knots = tf.cond(
                is_constant_vol,
                lambda: tf.zeros(shape=(self._dim, self._dim, 0),
                                 dtype=self._dtype),
                lambda: _integrate_volatility_squared(  # pylint: disable=g-long-lambda
                    self._jump_values_vol, self._padded_knots, self.
                    _jump_locations))
            v_squared_at_vol_knots = tf.concat([
                tf.zeros((self._dim, self._dim, 1), dtype=self._dtype),
                utils.cumsum_using_matvec(v_squared_between_vol_knots)
            ],
                                               axis=-1)

            vn = tf.concat([self._zero_padding, self._jump_locations], axis=1)

            v_squared_t = _integrate_volatility_squared(
                self._volatility(t), tf.gather(vn, time_index, batch_dims=1),
                t)
            v_squared_t += tf.gather(v_squared_at_vol_knots,
                                     time_index,
                                     batch_dims=-1)

            return tf.math.exp(-mr2 * t) * v_squared_t
コード例 #18
0
    def posterior_marginals(self, observations, name=None):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0], ..., x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0], ..., x[num_steps - 1])`
    for all `i` from `0` to `num_steps - 1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
        made on the hidden Markov model.  The rightmost dimension of this tensor
        gives the steps in a sequence of observations from a single sample from
        the hidden Markov model. The size of this dimension should match the
        `num_steps` parameter of the hidden Markov model object. The other
        dimensions are the dimensions of the batch and these are broadcast with
        the hidden Markov model's parameters.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_marginal: A `Categorical` distribution object representing the
        marginal probability of the hidden Markov model being in each state at
        each step. The rightmost dimension of the `Categorical` distributions
        batch will equal the `num_steps` parameter providing one marginal
        distribution for each step. The other dimensions are the dimensions
        corresponding to the batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope(name or "posterior_marginals"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(input=observations)

                with self._observation_shape_preconditions(
                        observation_tensor_shape):
                    observation_batch_shape = observation_tensor_shape[:-1 -
                                                                       self.
                                                                       _underlying_event_rank]
                    observation_event_shape = observation_tensor_shape[
                        -1 - self._underlying_event_rank:]

                    batch_shape = tf.broadcast_dynamic_shape(
                        observation_batch_shape, self.batch_shape_tensor())
                    log_init = tf.broadcast_to(
                        self._log_init,
                        tf.concat([batch_shape, [self._num_states]], axis=0))
                    log_transition = self._log_trans

                    observations = tf.broadcast_to(
                        observations,
                        tf.concat([batch_shape, observation_event_shape],
                                  axis=0))
                    observation_rank = tf.rank(observations)
                    underlying_event_rank = self._underlying_event_rank
                    observations = distribution_util.move_dimension(
                        observations,
                        observation_rank - underlying_event_rank - 1, 0)
                    observations = tf.expand_dims(
                        observations, observation_rank - underlying_event_rank)
                    observation_log_probs = self._observation_distribution.log_prob(
                        observations)

                    log_adjoint_prob = tf.zeros_like(log_init)

                    def forward_step(log_previous_step, log_prob_observation):
                        return _log_vector_matrix(
                            log_previous_step,
                            log_transition) + log_prob_observation

                    log_prob = log_init + observation_log_probs[0]

                    forward_log_probs = tf.scan(forward_step,
                                                observation_log_probs[1:],
                                                initializer=log_prob,
                                                name="forward_log_probs")

                    forward_log_probs = tf.concat(
                        [[log_prob], forward_log_probs], axis=0)

                    def backward_step(log_previous_step, log_prob_observation):
                        return _log_matrix_vector(
                            log_transition,
                            log_prob_observation + log_previous_step)

                    backward_log_adjoint_probs = tf.scan(
                        backward_step,
                        observation_log_probs[1:],
                        initializer=log_adjoint_prob,
                        reverse=True,
                        name="backward_log_adjoint_probs")

                    total_log_prob = tf.reduce_logsumexp(
                        input_tensor=forward_log_probs[-1], axis=-1)

                    backward_log_adjoint_probs = tf.concat(
                        [backward_log_adjoint_probs, [log_adjoint_prob]],
                        axis=0)

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = distribution_util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
コード例 #19
0
 def _bcast_y(self, y):
     return tf.broadcast_to(
         y,
         ps.broadcast_shape(ps.shape(y), self._single_event_shape_tensor()))
コード例 #20
0
    def posterior_mode(self, observations, name=None):
        """Compute maximum likelihood sequence of hidden states.

    When this function is provided with a sequence of observations
    `x[0], ..., x[num_steps - 1]`, it returns the sequence of hidden
    states `z[0], ..., z[num_steps - 1]`, drawn from the underlying
    Markov chain, that is most likely to yield those observations.

    It uses the [Viterbi algorithm](
    https://en.wikipedia.org/wiki/Viterbi_algorithm).

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Note: if there isn't a unique most likely sequence then one
    of the equally most likely sequences is chosen.

    Args:
      observations: A tensor representing a batch of observations made on the
        hidden Markov model.  The rightmost dimensions of this tensor correspond
        to the dimensions of the observation distributions of the underlying
        Markov chain.  The next dimension from the right indexes the steps in a
        sequence of observations from a single sample from the hidden Markov
        model.  The size of this dimension should match the `num_steps`
        parameter of the hidden Markov model object.  The other dimensions are
        the dimensions of the batch and these are broadcast with the hidden
        Markov model's parameters.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_mode: A `Tensor` representing the most likely sequence of hidden
        states. The rightmost dimension of this tensor will equal the
        `num_steps` parameter providing one hidden state for each step. The
        other dimensions are those of the batch.

    Raises:
      ValueError: if the `observations` tensor does not consist of
      sequences of `num_steps` observations.

    #### Examples

    ```python
    tfd = tfp.distributions

    # A simple weather model.

    # Represent a cold day with 0 and a hot day with 1.
    # Suppose the first day of a sequence has a 0.8 chance of being cold.

    initial_distribution = tfd.Categorical(probs=[0.8, 0.2])

    # Suppose a cold day has a 30% chance of being followed by a hot day
    # and a hot day has a 20% chance of being followed by a cold day.

    transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                     [0.2, 0.8]])

    # Suppose additionally that on each day the temperature is
    # normally distributed with mean and standard deviation 0 and 5 on
    # a cold day and mean and standard deviation 15 and 10 on a hot day.

    observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

    # This gives the hidden Markov model:

    model = tfd.HiddenMarkovModel(
        initial_distribution=initial_distribution,
        transition_distribution=transition_distribution,
        observation_distribution=observation_distribution,
        num_steps=7)

    # Suppose we observe gradually rising temperatures over a week:
    temps = [-2., 0., 2., 4., 6., 8., 10.]

    # We can now compute the most probable sequence of hidden states:

    model.posterior_mode(temps)

    # The result is [0 0 0 0 0 1 1] telling us that the transition
    # from "cold" to "hot" most likely happened between the
    # 5th and 6th days.
    ```
    """

        with tf.name_scope(name or "posterior_mode"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(input=observations)

                with self._observation_shape_preconditions(
                        observation_tensor_shape):
                    observation_batch_shape = observation_tensor_shape[:-1 -
                                                                       self.
                                                                       _underlying_event_rank]
                    observation_event_shape = observation_tensor_shape[
                        -1 - self._underlying_event_rank:]

                    batch_shape = tf.broadcast_dynamic_shape(
                        observation_batch_shape, self.batch_shape_tensor())
                    log_init = tf.broadcast_to(
                        self._log_init,
                        tf.concat([batch_shape, [self._num_states]], axis=0))

                    observations = tf.broadcast_to(
                        observations,
                        tf.concat([batch_shape, observation_event_shape],
                                  axis=0))
                    observation_rank = tf.rank(observations)
                    underlying_event_rank = self._underlying_event_rank
                    observations = distribution_util.move_dimension(
                        observations,
                        observation_rank - underlying_event_rank - 1, 0)

                    # We need to compute the probability of each observation for
                    # each possible state.
                    # This requires inserting an extra index just before the
                    # observation event indices that will be broadcast with the
                    # last batch index in `observation_distribution`.
                    observations = tf.expand_dims(
                        observations, observation_rank - underlying_event_rank)
                    observation_log_probs = self._observation_distribution.log_prob(
                        observations)

                    log_prob = log_init + observation_log_probs[0]

                    if self._num_steps == 1:
                        most_likely_end = tf.argmax(input=log_prob, axis=-1)
                        return most_likely_end[..., tf.newaxis]

                    def forward_step(previous_step_pair, log_prob_observation):
                        log_prob_previous = previous_step_pair[0]
                        log_prob = (log_prob_previous[..., tf.newaxis] +
                                    self._log_trans +
                                    log_prob_observation[..., tf.newaxis, :])
                        most_likely_given_successor = tf.argmax(input=log_prob,
                                                                axis=-2)
                        max_log_p_given_successor = tf.reduce_max(
                            input_tensor=log_prob, axis=-2)
                        return (max_log_p_given_successor,
                                most_likely_given_successor)

                    forward_log_probs, all_most_likely_given_successor = tf.scan(
                        forward_step,
                        observation_log_probs[1:],
                        initializer=(log_prob,
                                     tf.zeros(tf.shape(input=log_init),
                                              dtype=tf.int64)),
                        name="forward_log_probs")

                    most_likely_end = tf.argmax(input=forward_log_probs[-1],
                                                axis=-1)

                    # We require the operation that gives C from A and B where
                    # C[i...j] = A[i...j, B[i...j]]
                    # and A = most_likely_given_successor
                    #     B = most_likely_successor.
                    # tf.gather requires indices of known shape so instead we use
                    # reduction with tf.one_hot(B) to pick out elements from B
                    def backward_step(most_likely_successor,
                                      most_likely_given_successor):
                        return tf.reduce_sum(
                            input_tensor=(most_likely_given_successor *
                                          tf.one_hot(most_likely_successor,
                                                     self._num_states,
                                                     dtype=tf.int64)),
                            axis=-1)

                    backward_scan = tf.scan(backward_step,
                                            all_most_likely_given_successor,
                                            most_likely_end,
                                            reverse=True)
                    most_likely_sequences = tf.concat(
                        [backward_scan, [most_likely_end]], axis=0)
                    return distribution_util.move_dimension(
                        most_likely_sequences, 0, -1)
コード例 #21
0
def _queue_push(queue, should_update, new_vecs):
    """Conditionally push new vectors into a batch of first-in-first-out queues.

  The `queue` of shape `[k, ..., n]` can be thought of as a batch of queues,
  each holding `k` n-D vectors; while `new_vecs` of shape `[..., n]` is a
  fresh new batch of n-D vectors. The `should_update` batch of Boolean scalars,
  i.e. shape `[...]`, indicates batch members whose corresponding n-D vector in
  `new_vecs` should be added at the back of its queue, pushing out the
  corresponding n-D vector from the front. Batch members in `new_vecs` for
  which `should_update` is False are ignored.

  Note: the choice of placing `k` at the dimension 0 of the queue is
  constrained by the L-BFGS two-loop algorithm above. The algorithm uses
  tf.scan to iterate over the `k` correction pairs simulatneously across all
  batches, and tf.scan itself can only iterate over dimension 0.

  For example:

  ```python
    k, b, n = (3, 2, 5)
    queue = tf.reshape(tf.range(30), (k, b, n))
    # => [[[ 0,  1,  2,  3,  4],
    #      [ 5,  6,  7,  8,  9]],
    #
    #     [[10, 11, 12, 13, 14],
    #      [15, 16, 17, 18, 19]],
    #
    #     [[20, 21, 22, 23, 24],
    #      [25, 26, 27, 28, 29]]]

    element = tf.reshape(tf.range(30, 40), (b, n))
    # => [[30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39]]

    should_update = tf.constant([True, False])  # Shape: (b,)

    _queue_add(should_update, queue, element)
    # => [[[10, 11, 12, 13, 14],
    #      [ 5,  6,  7,  8,  9]],
    #
    #     [[20, 21, 22, 23, 24],
    #      [15, 16, 17, 18, 19]],
    #
    #     [[30, 31, 32, 33, 34],
    #      [25, 26, 27, 28, 29]]]
  ```

  Args:
    queue: A `tf.Tensor` of shape `[k, ..., n]`; a batch of queues each with
      `k` n-D vectors.
    should_update: A Boolean `tf.Tensor` of shape `[...]` indicating batch
      members where new vectors should be added to their queues.
    new_vecs: A `tf.Tensor` of shape `[..., n]`; a batch of n-D vectors to add
      at the end of their respective queues, pushing out the first element from
      each.

  Returns:
    A new `tf.Tensor` of shape `[k, ..., n]`.
  """
    new_queue = tf.concat([queue[1:], [new_vecs]], axis=0)
    update_pattern = tf.broadcast_to(
        should_update[tf.newaxis, ..., tf.newaxis],
        distribution_util.prefer_static_shape(queue))
    return tf1.where(update_pattern, new_queue, queue)
コード例 #22
0
def log_concave_rejection_sampler(mode,
                                  prob_fn,
                                  dtype,
                                  sample_shape=(),
                                  distribution_minimum=None,
                                  distribution_maximum=None,
                                  seed=None):
    """Utility for rejection sampling from log-concave discrete distributions.

  This utility constructs an easy-to-sample-from upper bound for a discrete
  univariate log-concave distribution (for discrete univariate distributions, a
  necessary and sufficient condition is p_k^2 >= p_{k-1} p_{k+1} for all k).
  The method requires that the mode of the distribution is known. While a better
  method can likely be derived for any given distribution, this method is
  general and easy to implement. The expected number of iterations is bounded by
  4+m, where m is the probability of the mode. For details, see [(Devroye,
  1979)][1].

  Args:
    mode: Tensor, the mode[s] of the [batch of] distribution[s].
    prob_fn: Python callable, counts -> prob(counts).
    dtype: DType of the generated samples.
    sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
    distribution_minimum: Tensor of type `dtype`. The minimum value
      taken by the distribution. The `prob` method will only be called on values
      greater than equal to the specified minimum. The shape must broadcast with
      the batch shape of the distribution. If unspecified, the domain is treated
      as unbounded below.
    distribution_maximum: Tensor of type `dtype`. The maximum value
      taken by the distribution. See `distribution_minimum` for details.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    samples: a `Tensor` with prepended dimensions `sample_shape`.

  #### References

  [1] Luc Devroye. A Simple Generator for Discrete Log-Concave
      Distributions. Computing, 1987.
  """
    mode = tf.broadcast_to(mode,
                           ps.concat(
                               [sample_shape, ps.shape(mode)], axis=0))

    mode_height = prob_fn(mode)
    mode_shape = ps.shape(mode)

    top_width = 1. + mode_height / 2.  # w in ref [1].
    top_fraction = top_width / (1 + top_width)
    exponential_distribution = exponential.Exponential(rate=tf.ones(
        [], dtype=dtype))  # E in ref [1].

    if distribution_minimum is None:
        distribution_minimum = tf.constant(-np.inf, dtype)
    if distribution_maximum is None:
        distribution_maximum = tf.constant(np.inf, dtype)

    def proposal(seed):
        """Proposal for log-concave rejection sampler."""
        (top_lobe_fractions_seed, exponential_samples_seed, top_selector_seed,
         rademacher_seed) = samplers.split_seed(seed, n=4)

        top_lobe_fractions = samplers.uniform(mode_shape,
                                              seed=top_lobe_fractions_seed,
                                              dtype=dtype)  # V in ref [1].
        top_offsets = top_lobe_fractions * top_width / mode_height

        exponential_samples = exponential_distribution.sample(
            mode_shape, seed=exponential_samples_seed)  # E in ref [1].
        exponential_height = (
            exponential_distribution.prob(exponential_samples) * mode_height)
        exponential_offsets = (top_width + exponential_samples) / mode_height

        top_selector = samplers.uniform(mode_shape,
                                        seed=top_selector_seed,
                                        dtype=dtype)  # U in ref [1].
        on_top_mask = (top_selector <= top_fraction)

        unsigned_offsets = tf.where(on_top_mask, top_offsets,
                                    exponential_offsets)
        offsets = tf.round(
            tfp_random.rademacher(
                mode_shape, seed=rademacher_seed, dtype=dtype) *
            unsigned_offsets)

        potential_samples = mode + offsets
        envelope_height = tf.where(on_top_mask, mode_height,
                                   exponential_height)

        return potential_samples, envelope_height

    def target(values):
        # Check for out of bounds rather than in bounds to avoid accidentally
        # masking a `nan` value.
        out_of_bounds_mask = ((values < distribution_minimum) |
                              (values > distribution_maximum))
        in_bounds_values = tf.where(out_of_bounds_mask,
                                    tf.constant(0., dtype=values.dtype),
                                    values)
        probs = prob_fn(in_bounds_values)
        return tf.where(out_of_bounds_mask, tf.zeros([], probs.dtype), probs)

    return tf.stop_gradient(
        brs.batched_rejection_sampler(proposal, target, seed,
                                      dtype=dtype)[0])  # Discard `num_iters`.
コード例 #23
0
 def _mean(self):
   loc = tf.convert_to_tensor(self.loc)
   return tf.broadcast_to(loc, self._batch_shape_tensor(loc=loc))
コード例 #24
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(state_n_replica, it_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = ps.convert_to_shape_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
            else:
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        x,
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state
                ]

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second '
                    '(`seed`) argument. `TransitionKernel` instances now receive seeds '
                    'via `one_step`.')

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = bu.left_justified_broadcast_to(tf.range(num_replica),
                                                   replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=_set_swapped_fields_to_nan(
                    replica_results),
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
                step_count=tf.zeros(shape=(), dtype=tf.int32),
                seed=samplers.zeros_seed(),
                potential_energy=tf.zeros_like(
                    pre_swap_replica_target_log_prob),
            )
コード例 #25
0
def cholesky_update(chol, update_vector, multiplier=1., name=None):
    """Returns cholesky of chol @ chol.T + multiplier * u @ u.T.

  Given a (batch of) lower triangular cholesky factor(s) `chol`, along with a
  (batch of) vector(s) `update_vector`, compute the lower triangular cholesky
  factor of the rank-1 update `chol @ chol.T + multiplier * u @ u.T`, where
  `multiplier` is a (batch of) scalar(s).

  If `chol` has shape `[L, L]`, this has complexity `O(L^2)` compared to the
  naive algorithm which has complexity `O(L^3)`.

  Args:
    chol: Floating-point `Tensor` with shape `[B1, ..., Bn, L, L]`.
      Cholesky decomposition of `mat = chol @ chol.T`. Batch dimensions
      must be broadcastable with `update_vector` and `multiplier`.
    update_vector: Floating-point `Tensor` with shape `[B1, ... Bn, L]`. Vector
      defining rank-one update. Batch dimensions must be broadcastable with
      `chol` and `multiplier`.
    multiplier: Floating-point `Tensor` with shape `[B1, ..., Bn]. Scalar
      multiplier to rank-one update. Batch dimensions must be broadcastable
      with `chol` and `update_vector`. Note that updates where `multiplier` is
      positive are numerically stable, while when `multiplier` is negative
      (downdating), the update will only work if the new resulting matrix is
      still positive definite.
    name: Optional name for this op.

  #### References
  [1] Oswin Krause. Christian Igel. A More Efficient Rank-one Covariance
      Matrix Update for Evolution Strategies. 2015 ACM Conference.
      https://www.researchgate.net/publication/300581419_A_More_Efficient_Rank-one_Covariance_Matrix_Update_for_Evolution_Strategies
  """
    # TODO(b/154638092): Move this functionality in to TensorFlow.
    with tf.name_scope(name or 'cholesky_update'):
        dtype = dtype_util.common_dtype([chol, update_vector, multiplier],
                                        dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        update_vector = tf.convert_to_tensor(update_vector,
                                             name='update_vector',
                                             dtype=dtype)
        multiplier = tf.convert_to_tensor(multiplier,
                                          name='multiplier',
                                          dtype=dtype)

        batch_shape = ps.broadcast_shape(
            ps.broadcast_shape(
                ps.shape(chol)[:-2],
                ps.shape(update_vector)[:-1]), ps.shape(multiplier))
        chol = tf.broadcast_to(
            chol, ps.concat([batch_shape, ps.shape(chol)[-2:]], axis=0))
        update_vector = tf.broadcast_to(
            update_vector,
            ps.concat([batch_shape, ps.shape(update_vector)[-1:]], axis=0))
        multiplier = tf.broadcast_to(multiplier, batch_shape)

        chol_diag = tf.linalg.diag_part(chol)

        # The algorithm in [1] is implemented as a double for loop. We can treat
        # the inner loop in Algorithm 3.1 as a vector operation, and thus the
        # whole algorithm as a single for loop, and hence can use a `tf.scan`
        # on it.

        # We use for accumulation omega and b as defined in Algorithm 3.1, since
        # these are updated per iteration.

        def compute_new_column(accumulated_quantities, state):
            """Computes the next column of the updated cholesky."""
            _, _, omega, b = accumulated_quantities
            index, diagonal_member, col, col_mask = state
            omega_at_index = omega[..., index]

            # Line 4
            new_diagonal_member = tf.math.sqrt(
                tf.math.square(diagonal_member) +
                multiplier / b * tf.math.square(omega_at_index))
            # `scaling_factor` is the same as `gamma` on Line 5.
            scaling_factor = (tf.math.square(diagonal_member) * b +
                              multiplier * tf.math.square(omega_at_index))

            # The following updates are the same as the for loop in lines 6-8.
            omega = omega - (omega_at_index /
                             diagonal_member)[..., tf.newaxis] * col
            new_col = new_diagonal_member[..., tf.newaxis] * (
                col / diagonal_member[..., tf.newaxis] +
                (multiplier * omega_at_index / scaling_factor)[..., tf.newaxis]
                * omega * col_mask)
            b = b + multiplier * tf.math.square(
                omega_at_index / diagonal_member)
            return new_diagonal_member, new_col, omega, b

        # We will scan over the columns.
        cols_mask = distribution_util.move_dimension(tf.linalg.band_part(
            tf.ones_like(chol), -1, 0),
                                                     source_idx=-1,
                                                     dest_idx=0)
        chol = distribution_util.move_dimension(chol,
                                                source_idx=-1,
                                                dest_idx=0)
        chol_diag = distribution_util.move_dimension(chol_diag,
                                                     source_idx=-1,
                                                     dest_idx=0)

        new_diag, new_chol, _, _ = tf.scan(
            fn=compute_new_column,
            elems=(tf.range(0,
                            ps.shape(chol)[0]), chol_diag, chol, cols_mask),
            initializer=(tf.zeros_like(multiplier), tf.zeros_like(chol[0,
                                                                       ...]),
                         update_vector, tf.ones_like(multiplier)))
        new_chol = distribution_util.move_dimension(new_chol,
                                                    source_idx=0,
                                                    dest_idx=-1)
        new_diag = distribution_util.move_dimension(new_diag,
                                                    source_idx=0,
                                                    dest_idx=-1)
        new_chol = tf.linalg.set_diag(new_chol, new_diag)
        return new_chol
コード例 #26
0
 def _mode(self):
     scale = tf.convert_to_tensor(self.scale)
     return tf.broadcast_to(scale, self._batch_shape_tensor(scale=scale))
コード例 #27
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return tf.linalg.triangular_solve(
            lower_upper,  # Only upper is accessed.
            tf.linalg.triangular_solve(lower, permuted_rhs),
            lower=False)
コード例 #28
0
    def _observation_log_probs(self, observations, mask):
        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(input=observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            mask = tf.expand_dims(mask, -1)
            mask = tf.broadcast_to(mask, tf.shape(observation_log_probs))

            observation_log_probs = tf1.where(
                mask, tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
コード例 #29
0
def prepare_args(model_matrix,
                 response,
                 model_coefficients,
                 predicted_linear_response,
                 offset,
                 name=None):
    """Helper to `fit` which sanitizes input args.

  Args:
    model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row
      represents a sample's features.
    response: (Batch of) vector-shaped `Tensor` where each element represents a
      sample's observed response (to the corresponding row of features). Must
      have same `dtype` as `model_matrix`.
    model_coefficients: Optional (batch of) vector-shaped `Tensor` representing
      the model coefficients, one for each column in `model_matrix`. Must have
      same `dtype` as `model_matrix`.
      Default value: `tf.zeros(tf.shape(model_matrix)[-1], model_matrix.dtype)`.
    predicted_linear_response: Optional `Tensor` with `shape`, `dtype` matching
      `response`; represents `offset` shifted initial linear predictions based
      on current `model_coefficients`.
      Default value: `offset` if `model_coefficients is None`, and
      `tf.linalg.matvec(model_matrix, model_coefficients_start) + offset`
      otherwise.
    offset: Optional `Tensor` with `shape`, `dtype` matching `response`;
      represents constant shift applied to `predicted_linear_response`.
      Default value: `None` (i.e., `tf.zeros_like(response)`).
    name: Python `str` used as name prefix to ops created by this function.
      Default value: `"prepare_args"`.

  Returns:
    model_matrix: A `Tensor` with `shape`, `dtype` and values of the
      `model_matrix` argument.
    response: A `Tensor` with `shape`, `dtype` and values of the
      `response` argument.
    model_coefficients_start: A `Tensor` with `shape`, `dtype` and
      values of the `model_coefficients_start` argument if specified.
      A (batch of) vector-shaped `Tensors` with `dtype` matching `model_matrix`
      containing the default starting point otherwise.
    predicted_linear_response:  A `Tensor` with `shape`, `dtype` and
      values of the `predicted_linear_response` argument if specified.
      A `Tensor` with `shape`, `dtype` matching `response` containing the
      default value otherwise.
    offset: A `Tensor` with `shape`, `dtype` and values of the `offset` argument
      if specified or `None` otherwise.
  """
    graph_deps = [
        model_matrix, response, model_coefficients, predicted_linear_response,
        offset
    ]
    with tf1.name_scope(name, 'prepare_args', graph_deps):
        dtype = dtype_util.common_dtype(graph_deps, np.float32)

        model_matrix = tf.convert_to_tensor(value=model_matrix,
                                            dtype=dtype,
                                            name='model_matrix')

        if offset is not None:
            offset = tf.convert_to_tensor(value=offset,
                                          dtype=dtype,
                                          name='offset')

        response = tf.convert_to_tensor(value=response,
                                        dtype=dtype,
                                        name='response')

        use_default_model_coefficients = model_coefficients is None
        if use_default_model_coefficients:
            # User did not supply model coefficients; assume they're all zero.
            batch_shape = tf.shape(input=model_matrix)[:-2]
            num_columns = tf.shape(input=model_matrix)[-1]
            model_coefficients = tf.zeros(shape=tf.concat(
                [batch_shape, [num_columns]], axis=0),
                                          dtype=dtype,
                                          name='model_coefficients')
        else:
            # User did supply model coefficients; convert to Tensor in case it's
            # numpy or literal.
            model_coefficients = tf.convert_to_tensor(
                value=model_coefficients,
                dtype=dtype,
                name='model_coefficients')

        if predicted_linear_response is None:
            if use_default_model_coefficients:
                # Since we're using zeros for model_coefficients, we know the predicted
                # linear response will also be all zeros.
                if offset is None:
                    predicted_linear_response = tf.zeros_like(
                        response, dtype, name='predicted_linear_response')
                else:
                    predicted_linear_response = tf.broadcast_to(
                        offset,
                        tf.shape(input=response),
                        name='predicted_linear_response')
            else:
                # We were given model_coefficients but not the predicted linear
                # response.
                predicted_linear_response = calculate_linear_predictor(
                    model_matrix, model_coefficients, offset)
        else:
            predicted_linear_response = tf.convert_to_tensor(
                value=predicted_linear_response,
                dtype=dtype,
                name='predicted_linear_response')

    return [
        model_matrix,
        response,
        model_coefficients,
        predicted_linear_response,
        offset,
    ]
コード例 #30
0
  def vectorized_fn(*args):
    """Vectorized version of `fn` that accepts arguments of any rank."""
    with tf.name_scope(name or 'make_rank_polymorphic'):
      assertions = []

      # If we got a single value for core_ndims, tile it across all args.
      core_ndims_structure = (
          core_ndims
          if tf.nest.is_nested(core_ndims)
          else tf.nest.map_structure(lambda _: core_ndims, args))

      # Build flat lists of all argument parts and their corresponding core
      # ndims.
      flat_core_ndims = tf.nest.flatten(core_ndims_structure)
      flat_args = nest.flatten_up_to(
          core_ndims_structure, args, check_types=False)

      # Filter to only the `Tensor`-valued args (taken to be those with `None`
      # values for `core_ndims`). Other args will be passed through to `fn`
      # unmodified.
      (vectorized_arg_core_ndims,
       vectorized_args,
       fn_of_vectorized_args) = _lock_in_non_vectorized_args(
           fn,
           arg_structure=core_ndims_structure,
           flat_core_ndims=flat_core_ndims,
           flat_args=flat_args)

      # `vectorized_map` requires all inputs to have a single, common batch
      # dimension `[n]`. So we broadcast all input parts to a common
      # batch shape, then flatten it down to a single dimension.

      # First, compute how many 'extra' (batch) ndims each part has. This must
      # be nonnegative.
      vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args]
      batch_ndims = [
          ps.rank_from_shape(arg_shape) - nd
          for (arg_shape, nd) in zip(
              vectorized_arg_shapes, vectorized_arg_core_ndims)]
      static_ndims = [tf.get_static_value(nd) for nd in batch_ndims]
      if any([nd and nd < 0 for nd in static_ndims]):
        raise ValueError('Cannot broadcast a Tensor having lower rank than the '
                         'specified `core_ndims`! (saw input ranks {}, '
                         '`core_ndims` {}).'.format(
                             tf.nest.map_structure(
                                 ps.rank_from_shape,
                                 vectorized_arg_shapes),
                             vectorized_arg_core_ndims))
      if validate_args:
        for nd, part, core_nd in zip(
            batch_ndims, vectorized_args, vectorized_arg_core_ndims):
          assertions.append(tf.debugging.assert_non_negative(
              nd, message='Cannot broadcast a Tensor having lower rank than '
              'the specified `core_ndims`! (saw {} vs minimum rank {}).'.format(
                  part, core_nd)))

      # Next, split each part's shape into batch and core shapes, and
      # broadcast the batch shapes.
      with tf.control_dependencies(assertions):
        empty_shape = np.zeros([0], dtype=np.int32)
        batch_shapes, core_shapes = empty_shape, empty_shape
        if vectorized_arg_shapes:
          batch_shapes, core_shapes = zip(*[
              (arg_shape[:nd], arg_shape[nd:])
              for (arg_shape, nd) in zip(vectorized_arg_shapes, batch_ndims)])
        broadcast_batch_shape = (
            functools.reduce(ps.broadcast_shape, batch_shapes, []))

      # Flatten all of the batch dimensions into one.
      n = tf.cast(ps.reduce_prod(broadcast_batch_shape), tf.int32)
      static_n = tf.get_static_value(n)
      if static_n == 1:
        result = fn(*args)
      else:
        # Pad all input parts to the common shape, then flatten
        # into the single leading dimension `[n]`.
        # TODO(b/145227909): If/when vmap supports broadcasting, use nested vmap
        # when batch rank is static so that we can exploit broadcasting.
        broadcast_vectorized_args = [
            tf.broadcast_to(part, ps.concat(
                [broadcast_batch_shape, core_shape], axis=0))
            for (part, core_shape) in zip(vectorized_args, core_shapes)]
        vectorized_args_with_flattened_batch_dim = [
            tf.reshape(part, ps.concat([[n], core_shape], axis=0))
            for (part, core_shape) in zip(
                broadcast_vectorized_args, core_shapes)]
        batched_result = tf.vectorized_map(
            fn_of_vectorized_args, vectorized_args_with_flattened_batch_dim)

        # Unflatten any `Tensor`s in the result.
        unflatten = lambda x: tf.reshape(x, ps.concat([  # pylint: disable=g-long-lambda
            broadcast_batch_shape, ps.shape(x)[1:]], axis=0))
        result = tf.nest.map_structure(
            lambda x: unflatten(x) if tf.is_tensor(x) else x, batched_result,
            expand_composites=True)
    return result