Esempio n. 1
0
 def _quantile(self, value):
   broadcast_shape = tf.broadcast_dynamic_shape(
       tf.shape(input=value), self.batch_shape_tensor())
   ones = tf.ones(broadcast_shape, dtype=self.dtype)
   broadcasted_value = value * ones
   return (1. - broadcasted_value) * self.low + broadcasted_value * self.high
Esempio n. 2
0
 def _batch_shape_tensor(self, low=None, high=None):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.low if low is None else low),
         tf.shape(self.high if high is None else high))
Esempio n. 3
0
 def _set_event_shape(shape, shape_tensor):
     if event_shape is None:
         return shape, shape_tensor
     return (tf.broadcast_static_shape(event_shape, shape),
             tf.broadcast_dynamic_shape(event_shape_tensor,
                                        shape_tensor))
Esempio n. 4
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(
         [] if self.amplitude is None else tf.shape(self.amplitude),
         [] if self.length_scale is None else tf.shape(self.length_scale))
Esempio n. 5
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(tf.shape(self.low),
                                       tf.shape(self.high))
Esempio n. 6
0
    def _chunked_matmul(x1, x2, x):
        """Chunk-at-a-time matrix multiplication and backprop."""
        fwd_ax_size = tf.shape(x2)[-kernel.feature_ndims - 1]
        fwd_part_size = fwd_ax_size // num_matmul_parts

        def cond(i, _):
            return i < num_matmul_parts

        def body(i, covx):
            return i + 1, covx + _forward_matmul_one_part(
                kernel, x1, x2, x, fwd_part_size, i)

        result_batch_shape = tf.broadcast_dynamic_shape(
            operator_shape[:-2],
            tf.shape(x)[:-2])
        result_shape = tf.concat(
            [result_batch_shape, [operator_shape[-2],
                                  tf.shape(x)[-1]]],
            axis=0)
        _, covx = tf.while_loop(
            cond,
            body, (tf.constant(0), tf.zeros(result_shape, dtype=x.dtype)),
            back_prop=False,
            parallel_iterations=1)
        covx = covx + _forward_matmul_one_part(
            kernel,
            x1,
            x2,
            x,
            fwd_part_size,
            num_matmul_parts,
            remainder_part_size=fwd_ax_size -
            (num_matmul_parts * fwd_part_size))
        del result_batch_shape, result_shape

        def grad_fn(dcovx):
            """Chunk-at-a-time backprop."""
            # Backward, we partition along the `x1`-defined axis.
            bwd_ax_size = tf.shape(x1)[-kernel.feature_ndims - 1]
            bwd_part_size = bwd_ax_size // num_matmul_parts

            def bw_cond(i, *_):
                return i < num_matmul_parts

            def bw_body(i, dx1, dx2, dx):
                """tf.while_loop body for backprop."""
                dx1part, dx2part, dxpart = _backward_matmul_one_part(
                    dcovx, kernel, x1, x2, x, bwd_part_size, i)
                return i + 1, dx1 + dx1part, dx2 + dx2part, dx + dxpart

            _, dx1, dx2, dx = tf.while_loop(
                bw_cond,
                bw_body,
                tuple(tf.zeros_like(t) for t in (0, x1, x2, x)),
                back_prop=False,
                parallel_iterations=1)
            dx1part, dx2part, dxpart = _backward_matmul_one_part(
                dcovx,
                kernel,
                x1,
                x2,
                x,
                bwd_part_size,
                num_matmul_parts,
                remainder_part_size=bwd_ax_size -
                (num_matmul_parts * bwd_part_size))
            return dx1 + dx1part, dx2 + dx2part, dx + dxpart

        return covx, grad_fn
 def _batch_shape_tensor(self, mean_direction=None, concentration=None):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.mean_direction
                  if mean_direction is None else mean_direction)[:-1],
         tf.shape(self.concentration
                  if concentration is None else concentration))
Esempio n. 8
0
    def _chunked_matmul(x1, x2, x):
        """Chunk-at-a-time matrix multiplication and backprop."""
        fwd_ax_size = tf.shape(x2)[-kernel.feature_ndims - 1]
        fwd_part_size = fwd_ax_size // num_matmul_parts

        def cond(i, _):
            niters = (fwd_ax_size + fwd_part_size - 1) // fwd_part_size
            return i < niters

        def body(i, covx):
            contraction_dim_slice = slice(fwd_part_size * i,
                                          fwd_part_size * (i + 1))
            slices = (Ellipsis, contraction_dim_slice)
            slices = slices + (slice(None), ) * kernel.feature_ndims
            return i + 1, covx + tf.matmul(kernel.matrix(x1, x2[slices]),
                                           x[..., contraction_dim_slice, :])

        result_batch_shape = tf.broadcast_dynamic_shape(
            operator_shape[:-2],
            tf.shape(x)[:-2])
        result_shape = tf.concat(
            [result_batch_shape, [operator_shape[-2],
                                  tf.shape(x)[-1]]],
            axis=0)
        _, covx = tf.while_loop(
            cond,
            body, (tf.constant(0), tf.zeros(result_shape, dtype=x.dtype)),
            back_prop=False,
            parallel_iterations=1)

        def grad_fn(dcovx):
            """Chunk-at-a-time backprop."""
            # `dcovx` matches result_shape.
            # `cov` shp (A,B), `x` shp (B,C), `covx`, `dcovx` shp (A,C).
            # Backward, we partition along the A axis.
            bwd_ax_size = tf.shape(x1)[-kernel.feature_ndims - 1]
            bwd_part_size = bwd_ax_size // num_matmul_parts

            def bw_cond(i, *_):
                niters = (bwd_ax_size + bwd_part_size - 1) // bwd_part_size
                return i < niters

            def bw_body(i, dx1, dx2, dx):
                """tf.while_loop body for backprop."""
                contraction_dim_slice = slice(bwd_part_size * i,
                                              bwd_part_size * (i + 1))
                dcovxpart = dcovx[..., contraction_dim_slice, :]  # PxC
                dcovpart = tf.matmul(dcovxpart, x,
                                     transpose_b=True)  # PxC @ CxB => PxB
                with tf.GradientTape() as tape:
                    tape.watch((x1, x2))
                    slices = (Ellipsis, contraction_dim_slice)
                    slices = slices + (slice(None), ) * kernel.feature_ndims
                    covpart = kernel.matrix(x1[slices], x2)  # PxB
                dx1part, dx2part = tape.gradient(covpart, (x1, x2),
                                                 output_gradients=dcovpart)
                dxpart = tf.matmul(covpart, dcovxpart,
                                   transpose_a=True)  # BxP @ PxC
                return i + 1, dx1 + dx1part, dx2 + dx2part, dx + dxpart

            return tf.while_loop(bw_cond,
                                 bw_body,
                                 tuple(
                                     tf.zeros_like(t) for t in (0, x1, x2, x)),
                                 back_prop=False,
                                 parallel_iterations=1)[1:]

        return covx, grad_fn
Esempio n. 9
0
def sample_annealed_importance_chain(
    num_steps,
    proposal_log_prob_fn,
    target_log_prob_fn,
    current_state,
    make_kernel_fn,
    parallel_iterations=10,
    name=None):
  """Runs annealed importance sampling (AIS) to estimate normalizing constants.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'proposal' distribution:

  `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`

  and the target distribution:

  `exp(target_log_prob_fn(x) - target_log_normalizer)`,

  accumulating importance weights along the way. The product of these
  importance weights gives an unbiased estimate of the ratio of the
  normalizing constants of the initial distribution and the target
  distribution:

  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.

  Note: When running in graph mode, `proposal_log_prob_fn` and
  `target_log_prob_fn` are called exactly three times (although this may be
  reduced to two times in the future).

  Args:
    num_steps: Integer number of Markov chain updates to run. More
      iterations means more expense, but smoother annealing between q
      and p, which in turn means exponentially lower variance for the
      normalizing constant estimator.
    proposal_log_prob_fn: Python callable that returns the log density of the
      initial distribution.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_annealed_importance_chain` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_annealed_importance_chain').

  Returns:
    next_state: `Tensor` or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at the final iteration. Has same shape as
      input `current_state`.
    ais_weights: Tensor with the estimated weight(s). Has shape matching
      `target_log_prob_fn(current_state)`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  #### Examples

  ##### Estimate the normalizing constant of a log-gamma distribution.

  ```python
  tfd = tfp.distributions

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 20
  dtype = np.float32

  proposal = tfd.MultivariateNormalDiag(
     loc=tf.zeros([dims], dtype=dtype))

  target = tfd.TransformedDistribution(
    distribution=tfd.Gamma(concentration=dtype(2),
                           rate=dtype(3)),
    bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()),
    event_shape=[dims])

  chains_state, ais_weights, kernels_results = (
      tfp.mcmc.sample_annealed_importance_chain(
          num_steps=1000,
          proposal_log_prob_fn=proposal.log_prob,
          target_log_prob_fn=target.log_prob,
          current_state=proposal.sample(num_chains),
          make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=tlp_fn,
            step_size=0.2,
            num_leapfrog_steps=2)))

  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                              - np.log(num_chains))
  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
  ```

  ##### Estimate marginal likelihood of a Bayesian regression model.

  ```python
  tfd = tfp.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, x):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 10
  dtype = np.float32

  # Make training data.
  x = np.random.randn(num_chains, dims).astype(dtype)
  true_weights = np.random.randn(dims).astype(dtype)
  y = np.dot(x, true_weights) + np.random.randn(num_chains)

  # Setup model.
  prior = make_prior(dims, dtype)
  def target_log_prob_fn(weights):
    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

  proposal = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype))

  weight_samples, ais_weights, kernel_results = (
      tfp.mcmc.sample_annealed_importance_chain(
        num_steps=1000,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=target_log_prob_fn
        current_state=tf.zeros([num_chains, dims], dtype),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.1,
          num_leapfrog_steps=2)))
  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                             - np.log(num_chains))
  ```

  """
  with tf.name_scope(name or 'sample_annealed_importance_chain'):
    num_steps = tf.convert_to_tensor(
        value=num_steps, dtype=tf.int32, name='num_steps')
    if mcmc_util.is_list_like(current_state):
      current_state = [
          tf.convert_to_tensor(s, name='current_state')
          for s in current_state
      ]
    else:
      current_state = tf.convert_to_tensor(
          value=current_state, name='current_state')

    def _make_convex_combined_log_prob_fn(iter_):
      def _fn(*args):
        p = tf.identity(proposal_log_prob_fn(*args), name='proposal_log_prob')
        t = tf.identity(target_log_prob_fn(*args), name='target_log_prob')
        dtype = dtype_util.base_dtype(p.dtype)
        beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype)
        return tf.identity(beta * t + (1. - beta) * p,
                           name='convex_combined_log_prob')
      return _fn

    def _loop_body(iter_, ais_weights, current_state, kernel_results):
      """Closure which implements `tf.while_loop` body."""
      x = (current_state if mcmc_util.is_list_like(current_state)
           else [current_state])
      proposal_log_prob = proposal_log_prob_fn(*x)
      target_log_prob = target_log_prob_fn(*x)
      ais_weights += ((target_log_prob - proposal_log_prob) /
                      tf.cast(num_steps, ais_weights.dtype))
      kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
      next_state, inner_results = kernel.one_step(
          current_state, kernel_results.inner_results)
      kernel_results = AISResults(
          proposal_log_prob=proposal_log_prob,
          target_log_prob=target_log_prob,
          inner_results=inner_results,
      )
      return [iter_ + 1, ais_weights, next_state, kernel_results]

    def _bootstrap_results(init_state):
      """Creates first version of `previous_kernel_results`."""
      kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0))
      inner_results = kernel.bootstrap_results(init_state)
      mh_results = _find_inner_mh_results(inner_results)

      convex_combined_log_prob = mh_results.accepted_results.target_log_prob
      dtype = dtype_util.as_numpy_dtype(convex_combined_log_prob.dtype)
      shape = tf.shape(convex_combined_log_prob)
      proposal_log_prob = tf.fill(shape, dtype(np.nan),
                                  name='bootstrap_proposal_log_prob')
      target_log_prob = tf.fill(shape, dtype(np.nan),
                                name='target_target_log_prob')

      return AISResults(
          proposal_log_prob=proposal_log_prob,
          target_log_prob=target_log_prob,
          inner_results=inner_results,
      )

    previous_kernel_results = _bootstrap_results(current_state)
    inner_results = previous_kernel_results.inner_results
    mh_results = _find_inner_mh_results(inner_results)

    ais_weights = tf.zeros(
        shape=tf.broadcast_dynamic_shape(
            tf.shape(mh_results.proposed_results.target_log_prob),
            tf.shape(mh_results.accepted_results.target_log_prob)),
        dtype=mh_results.proposed_results.target_log_prob.dtype)

    [_, ais_weights, current_state, kernel_results] = tf.while_loop(
        cond=lambda iter_, *args: iter_ < num_steps,
        body=_loop_body,
        loop_vars=[
            np.int32(0),  # iter_
            ais_weights,
            current_state,
            previous_kernel_results,
        ],
        parallel_iterations=parallel_iterations)

    return [current_state, ais_weights, kernel_results]
  def _log_prob(self, value):
    # The argument `value` is a tensor of sequences of observations.
    # `observation_batch_shape` is the shape of that tensor with the
    # sequence part removed.
    # `observation_batch_shape` is then broadcast to the full batch shape
    # to give the `batch_shape` that defines the shape of the result.
    observation_tensor_shape = ps.shape(value)
    observation_distribution = self.observation_distribution
    underlying_event_rank = ps.size(
        observation_distribution.event_shape_tensor())
    observation_batch_shape = observation_tensor_shape[
        :-1 - underlying_event_rank]
    # value :: observation_batch_shape num_steps observation_event_shape
    batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                             self.batch_shape_tensor())
    num_states = self.transition_distribution.batch_shape_tensor()[-1]
    log_init = _extract_log_probs(num_states,
                                  self.initial_distribution)
    # log_init :: batch_shape num_states
    log_init = tf.broadcast_to(log_init,
                               ps.concat([batch_shape,
                                          [num_states]], axis=0))
    log_transition = _extract_log_probs(num_states,
                                        self.transition_distribution)

    # `observation_event_shape` is the shape of each sequence of observations
    # emitted by the model.
    observation_event_shape = observation_tensor_shape[
        -1 - underlying_event_rank:]
    working_obs = tf.broadcast_to(value,
                                  ps.concat([batch_shape,
                                             observation_event_shape],
                                            axis=0))
    # working_obs :: batch_shape observation_event_shape
    r = underlying_event_rank

    # Move index into sequence of observations to front so we can apply
    # tf.foldl
    if self._time_varying_observation_distribution:
      working_obs = tf.expand_dims(working_obs, -1 - r)
      # working_obs :: batch_shape num_steps 1 underlying_event_shape
      observation_probs = observation_distribution.log_prob(working_obs)
      # observation_probs :: batch_shape num_steps num_states
      observation_probs = distribution_util.move_dimension(
          observation_probs, -2, 0)
      # observation_probs :: num_steps batch_shape num_states
    else:
      working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0)
      # working_obs :: num_steps batch_shape underlying_event_shape
      working_obs = tf.expand_dims(working_obs, -1 - r)
      # working_obs :: num_steps batch_shape 1 underlying_event_shape

      observation_probs = observation_distribution.log_prob(working_obs)
      # observation_probs :: num_steps batch_shape num_states

    def forward_step(log_prev_step, log_prob_observation):
      return _log_vector_matrix(log_prev_step,
                                log_transition) + log_prob_observation

    # TODO(davmre): Delete this warning after Dec 31, 2020.
    warnings.warn(
        'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
        'in which the transition model was applied prior to the initial step. '
        'This bug has been fixed. You may observe a slight change in behavior.')
    fwd_prob = tf.foldl(forward_step, observation_probs[1:],
                        initializer=log_init + observation_probs[0])
    # fwd_prob :: batch_shape num_states

    log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
    # log_prob :: batch_shape

    return log_prob
  def _observation_log_probs(self, observations, mask):
    """Compute and shape tensor of log probs associated with observations."""

    # Let E be the underlying event shape
    #     M the number of steps in the HMM
    #     N the number of states of the HMM
    #
    # Then the incoming observations have shape
    #
    # observations : batch_o [M] E
    #
    # and the mask (if present) has shape
    #
    # mask : batch_m [M]
    #
    # Let this HMM distribution have batch shape batch_d
    # We need to broadcast all three of these batch shapes together
    # into the shape batch.
    #
    # We need to move the step dimension to the first dimension to make
    # them suitable for folding or scanning over.
    #
    # When we call `log_prob` for our observations we need to
    # do this for each state the observation could correspond to.
    # We do this by expanding the dimensions by 1 so we end up with:
    #
    # observations : [M] batch [1] [E]
    #
    # After calling `log_prob` we get
    #
    # observation_log_probs : [M] batch [N]
    #
    # We wish to use `mask` to select from this so we also
    # reshape and broadcast it up to shape
    #
    # mask : [M] batch [N]

    observation_distribution = self.observation_distribution
    underlying_event_rank = ps.size(
        observation_distribution.event_shape_tensor())
    observation_tensor_shape = ps.shape(observations)
    observation_batch_shape = observation_tensor_shape[
        :-1 - underlying_event_rank]
    observation_event_shape = observation_tensor_shape[
        -1 - underlying_event_rank:]

    if mask is not None:
      mask_tensor_shape = ps.shape(mask)
      mask_batch_shape = mask_tensor_shape[:-1]

    batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                             self.batch_shape_tensor())

    if mask is not None:
      batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                               mask_batch_shape)
    observations = tf.broadcast_to(observations,
                                   ps.concat([batch_shape,
                                              observation_event_shape],
                                             axis=0))
    observation_rank = ps.rank(observations)
    if self._time_varying_observation_distribution:
      observations = tf.expand_dims(
          observations,
          observation_rank - underlying_event_rank)
      observation_log_probs = observation_distribution.log_prob(
          observations)
      observation_log_probs = distribution_util.move_dimension(
          observation_log_probs, -2, 0)
    else:
      observations = distribution_util.move_dimension(
          observations, observation_rank - underlying_event_rank - 1, 0)
      observations = tf.expand_dims(
          observations,
          observation_rank - underlying_event_rank)
      observation_log_probs = observation_distribution.log_prob(
          observations)

    if mask is not None:
      mask = tf.broadcast_to(mask,
                             ps.concat([batch_shape, [self._num_steps]],
                                       axis=0))
      mask = distribution_util.move_dimension(mask, -1, 0)
      observation_log_probs = tf.where(mask[..., tf.newaxis],
                                       tf.zeros_like(observation_log_probs),
                                       observation_log_probs)

    return observation_log_probs
Esempio n. 12
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
Esempio n. 13
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(
         tf.shape(self._probs if self._logits is None else self._logits)
         [:-1], tf.shape(self.total_count))
Esempio n. 14
0
 def _broadcast_inputs(self, inputs):
     shape = tf.broadcast_dynamic_shape(tf.shape(inputs),
                                        self.batch_shape_tensor())
     return tf.broadcast_to(inputs, shape)
Esempio n. 15
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(tf.shape(input=self.total_count),
                                       tf.shape(input=self.probs))
Esempio n. 16
0
 def _batch_shape_tensor(self, loc=None, concentration=None):
   return tf.broadcast_dynamic_shape(
       tf.shape(self.loc if loc is None else loc),
       tf.shape(
           self.concentration if concentration is None else concentration))
Esempio n. 17
0
 def _batch_shape_tensor(self):
     x = self._probs if self._logits is None else self._logits
     return tf.broadcast_dynamic_shape(tf.shape(self._total_count),
                                       tf.shape(x))
Esempio n. 18
0
 def _batch_shape_tensor(self):
   return tf.broadcast_dynamic_shape(
       tf.shape(self.mean_direction)[:-1], tf.shape(self.concentration))
Esempio n. 19
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(
         self.distribution.batch_shape_tensor(),
         tf.shape(input=self.mixture_distribution.logits))[:-1]
Esempio n. 20
0
 def _batch_shape_tensor(self, loc=None):
   return tf.broadcast_dynamic_shape(
       tf.shape(self.loc if loc is None else loc),
       tf.broadcast_dynamic_shape(tf.shape(self.atol),
                                  tf.shape(self.rtol)))[:-1]
Esempio n. 21
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(
         super(GeneralizedMatern, self)._batch_shape_tensor(),
         [] if self.df is None else tf.shape(self.df))
Esempio n. 22
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(tf.shape(self.concentration),
                                       tf.shape(self.scale))
Esempio n. 23
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(tf.shape(input=self.loc),
                                       tf.shape(input=self.scale))
Esempio n. 24
0
 def _batch_shape_tensor(self):
   return tf.broadcast_dynamic_shape(
       self.kernel.batch_shape_tensor(),
       tf.shape(self.scale_diag)[:-self.kernel.feature_ndims])
Esempio n. 25
0
 def _batch_shape_tensor(self, loc=None, scale=None):
     return tf.broadcast_dynamic_shape(
         tf.shape(self.loc if loc is None else loc),
         tf.shape(self.scale if scale is None else scale))
Esempio n. 26
0
 def _batch_shape_tensor(self, df=None):
     df = tf.convert_to_tensor(self.df) if df is None else df
     return tf.broadcast_dynamic_shape(tf.shape(df),
                                       self._scale.batch_shape_tensor())
  def _sample_n(self, n, seed=None):
    loc, scale, low, high = self._loc_scale_low_high()
    batch_shape = self._batch_shape_tensor(
        loc=loc, scale=scale, low=low, high=high)
    sample_and_batch_shape = ps.concat([[n], batch_shape], 0)
    # TODO(b/162522020): Use this behavior unconditionally.
    if (tf.executing_eagerly() or
        not control_flow_util.GraphOrParentsInXlaContext(
            tf1.get_default_graph())):
      return tf.random.stateless_parameterized_truncated_normal(
          shape=sample_and_batch_shape,
          means=loc,
          stddevs=scale,
          minvals=low,
          maxvals=high,
          seed=samplers.sanitize_seed(seed))

    flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n])
    # In order to be reparameterizable we sample on the truncated_normal of
    # unit variance and mean and scale (but with the standardized
    # truncation bounds).

    @tf.custom_gradient
    def _std_samples_with_gradients(lower, upper):
      """Standard truncated Normal with gradient support for low, high."""
      # Note: Unlike the convention in TFP, parameterized_truncated_normal
      # returns a tensor with the final dimension being the sample dimension.
      std_samples = random_ops.parameterized_truncated_normal(
          shape=flat_batch_and_sample_shape,
          means=0.0,
          stddevs=1.0,
          minvals=lower,
          maxvals=upper,
          dtype=self.dtype,
          seed=seed)

      def grad(dy):
        """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
        # std_samples has an extra dimension (the sample dimension), expand
        # lower and upper so they broadcast along this dimension.
        # See note above regarding parameterized_truncated_normal, the sample
        # dimension is the final dimension.
        lower_broadcast = lower[..., tf.newaxis]
        upper_broadcast = upper[..., tf.newaxis]

        cdf_samples = ((special_math.ndtr(std_samples) -
                        special_math.ndtr(lower_broadcast)) /
                       (special_math.ndtr(upper_broadcast) -
                        special_math.ndtr(lower_broadcast)))

        # tiny, eps are tolerance parameters to ensure we stay away from giving
        # a zero arg to the log CDF expression.

        tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
        eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
        cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

        du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                    tf.math.log(cdf_samples))
        dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                    tf.math.log1p(-cdf_samples))

        # Reduce the gradient across the samples
        grad_u = tf.reduce_sum(dy * du, axis=-1)
        grad_l = tf.reduce_sum(dy * dl, axis=-1)
        return [grad_l, grad_u]

      return std_samples, grad

    std_low, std_high = self._standardized_low_and_high(
        low=low, high=high, loc=loc, scale=scale)
    low_high_shp = tf.broadcast_dynamic_shape(
        tf.shape(std_low), tf.shape(std_high))
    std_low = tf.broadcast_to(std_low, low_high_shp)
    std_high = tf.broadcast_to(std_high, low_high_shp)

    std_samples = _std_samples_with_gradients(
        tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1]))

    # The returned shape is [flat_batch x n]
    std_samples = tf.transpose(std_samples, perm=[1, 0])

    std_samples = tf.reshape(std_samples, sample_and_batch_shape)
    return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
  def _parameter_control_dependencies(self, is_init):
    """Validate parameters."""
    bw, bh, kd = None, None, None
    try:
      shape = tf.broadcast_static_shape(self.bin_widths.shape,
                                        self.bin_heights.shape)
    except ValueError as e:
      raise ValueError('`bin_widths`, `bin_heights` must broadcast: {}'.format(
          str(e)))
    bin_sizes_shape = shape
    try:
      shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1])
    except ValueError as e:
      raise ValueError(
          '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
          'batch axes: {}'.format(str(e)))

    assertions = []
    if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and
        tensorshape_util.is_fully_defined(self.knot_slopes.shape[-1:])):
      if tensorshape_util.rank(self.knot_slopes.shape) > 0:
        num_interior_knots = tensorshape_util.dims(bin_sizes_shape)[-1] - 1
        if tensorshape_util.dims(
            self.knot_slopes.shape)[-1] not in (1, num_interior_knots):
          raise ValueError(
              'Innermost axis of non-scalar `knot_slopes` must broadcast with '
              '{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape))
    elif self.validate_args:
      if is_init != any(
          tensor_util.is_ref(t)
          for t in (self.bin_widths, self.bin_heights, self.knot_slopes)):
        bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
        shape = tf.broadcast_dynamic_shape(
            tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
        assertions.append(
            assert_util.assert_greater(
                tf.shape(shape)[0],
                tf.zeros([], dtype=shape.dtype),
                message='`(bin_widths + bin_heights)[..., :-1]` must broadcast '
                'with `knot_slopes` to at least 1-D.'))

    if not self.validate_args:
      assert not assertions
      return assertions

    if (is_init != tensor_util.is_ref(self.bin_widths) or
        is_init != tensor_util.is_ref(self.bin_heights)):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_near(
              tf.reduce_sum(bw, axis=-1),
              tf.reduce_sum(bh, axis=-1),
              message='`sum(bin_widths, axis=-1)` must equal '
              '`sum(bin_heights, axis=-1)`.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_widths):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      assertions += [
          assert_util.assert_positive(
              bw, message='`bin_widths` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_heights):
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_positive(
              bh, message='`bin_heights` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.knot_slopes):
      kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
      assertions += [
          assert_util.assert_positive(
              kd, message='`knot_slopes` must be positive.'),
      ]
    return assertions
Esempio n. 29
0
 def _batch_shape_tensor(self):
     return tf.broadcast_dynamic_shape(tf.shape(input=self.loc),
                                       tf.shape(input=self.concentration))
Esempio n. 30
0
    def _chunked_matmul_cgrad(x1, x2, x, *kernel_args):
        """Chunk-at-a-time matrix multiplication and backprop."""
        kernel_args = tf.nest.pack_sequence_as(kernel_args_structure,
                                               kernel_args)
        kernel = kernel_fn(*kernel_args)
        fwd_ax_size = tf.shape(x2)[-kernel.feature_ndims - 1]
        fwd_part_size = fwd_ax_size // num_matmul_parts

        def cond(i, _):
            return i < num_matmul_parts

        def body(i, covx):
            return i + 1, covx + _forward_matmul_one_part(
                kernel, x1, x2, x, fwd_part_size, i)

        result_batch_shape = tf.broadcast_dynamic_shape(
            operator_shape[:-2],
            tf.shape(x)[:-2])
        result_shape = tf.concat(
            [result_batch_shape, [operator_shape[-2],
                                  tf.shape(x)[-1]]],
            axis=0)
        _, covx = tf.while_loop(
            cond,
            body, (tf.constant(0), tf.zeros(result_shape, dtype=x.dtype)),
            back_prop=False,
            parallel_iterations=1)
        covx = covx + _forward_matmul_one_part(
            kernel,
            x1,
            x2,
            x,
            fwd_part_size,
            num_matmul_parts,
            remainder_part_size=fwd_ax_size -
            (num_matmul_parts * fwd_part_size))
        del result_batch_shape, result_shape

        def grad_fn(dcovx):
            """Chunk-at-a-time backprop."""
            # Backward, we partition along the `x1`-defined axis.
            bwd_ax_size = tf.shape(x1)[-kernel.feature_ndims - 1]
            bwd_part_size = bwd_ax_size // num_matmul_parts

            def bw_cond(i, *_):
                return i < num_matmul_parts

            def bw_body(i, dx1, dx2, dx, dkernel_args):
                """tf.while_loop body for backprop."""
                dx1part, dx2part, dxpart, dkernel_argspart = _backward_matmul_one_part(
                    dcovx, kernel_fn, kernel_args, x1, x2, x, bwd_part_size, i)
                dx1, dx2, dx, dkernel_args = tf.nest.pack_sequence_as(
                    (dx1, dx2, dx, dkernel_args),
                    [
                        a + b for a, b in zip(  # pylint: disable=g-complex-comprehension
                            tf.nest.flatten((dx1, dx2, dx, dkernel_args)),
                            tf.nest.flatten((dx1part, dx2part, dxpart,
                                             dkernel_argspart)))
                    ])
                return i + 1, dx1, dx2, dx, dkernel_args

            _, dx1, dx2, dx, dkernel_args = tf.while_loop(
                bw_cond,
                bw_body,
                tf.nest.map_structure(tf.zeros_like,
                                      (0, x1, x2, x, kernel_args)),
                back_prop=False,
                parallel_iterations=1)
            dx1rem, dx2rem, dxrem, dkernel_argsrem = _backward_matmul_one_part(
                dcovx,
                kernel_fn,
                kernel_args,
                x1,
                x2,
                x,
                bwd_part_size,
                num_matmul_parts,
                remainder_part_size=bwd_ax_size -
                (num_matmul_parts * bwd_part_size))
            return tuple(a + b for a, b in zip(
                tf.nest.flatten((dx1, dx2, dx, dkernel_args)),
                tf.nest.flatten((dx1rem, dx2rem, dxrem, dkernel_argsrem))))

        return covx, grad_fn