Example #1
0
 def expand_dims_(x):
     """Implementation of `expand_dims`."""
     with tf.name_scope(name or 'expand_dims'):
         x = tf.convert_to_tensor(x, name='x')
         new_axis = tf.convert_to_tensor(axis,
                                         dtype_hint=tf.int32,
                                         name='axis')
         nx = prefer_static.rank(x)
         na = prefer_static.size(new_axis)
         is_neg_axis = new_axis < 0
         k = prefer_static.reduce_sum(
             prefer_static.cast(is_neg_axis, new_axis.dtype))
         new_axis = prefer_static.where(is_neg_axis, new_axis + nx,
                                        new_axis)
         new_axis = prefer_static.sort(new_axis)
         axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1])
         idx = prefer_static.argsort(prefer_static.concat([
             axis_pos,
             prefer_static.range(nx),
             axis_neg,
         ],
                                                          axis=0),
                                     stable=True)
         shape = prefer_static.pad(prefer_static.shape(x),
                                   paddings=[[na - k, k]],
                                   constant_values=1)
         shape = prefer_static.gather(shape, idx)
         return tf.reshape(x, shape)
Example #2
0
 def _forward_event_shape_tensor(self, input_shape, is_inverse=False):
   ndims = ps.size(input_shape)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   extra_sizes = ps.reduce_sum(self.paddings, axis=-1)
   update_fn = (ps.tensor_scatter_nd_sub if is_inverse else
                ps.tensor_scatter_nd_add)
   return update_fn(ps.identity(input_shape), indices, extra_sizes)
def moments_of_masked_time_series(time_series_tensor, broadcast_mask):
    """Compute mean and variance, accounting for a mask.

  Args:
    time_series_tensor: float `Tensor` time series of shape
      `concat([batch_shape, [num_timesteps]])`.
    broadcast_mask: bool `Tensor` of the same shape as `time_series`.
  Returns:
    mean: float `Tensor` of shape `batch_shape`.
    variance: float `Tensor` of shape `batch_shape`.
  """
    num_unmasked_entries = ps.cast(
        ps.reduce_sum(ps.cast(~broadcast_mask, np.int32), axis=-1),
        time_series_tensor.dtype)

    # Manually compute mean and variance, excluding masked entries.
    mean = (tf.reduce_sum(tf.where(
        broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype),
        time_series_tensor),
                          axis=-1) / num_unmasked_entries)
    variance = (tf.reduce_sum(tf.where(
        broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype),
        (time_series_tensor - mean[..., tf.newaxis])**2),
                              axis=-1) / num_unmasked_entries)
    return mean, variance
Example #4
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
Example #5
0
def _sample_bates(total_count, low, high, n, seed=None):
  """Vectorized production of `Bates` samples.

  Args:
    total_count: (Batches of) counts of `Uniform`s to take means of.  Should
      have integer dtype and already be broadcasted to the batch shape.
    low: (Batches of) lower bounds of the `Uniform` variables to sample.  Should
      be the same floating dtype as `high` and broadcastable to the batch shape.
    high: (Batches of) upper bounds of the `Uniform` variables to sample. Should
      be the same floating dtype as `low` and broadcastable to the batch shape.
    n: `int32` number of samples to generate.
    seed: Random seed to pass to `Uniform` sampler.

  Returns:
    samples: Samples of (batches of) the `Bates` variable.  Will have same dtype
      as `low` and `high`. If the batch shape is `[B1,..., Bn]`, `samples` has
      shape `[n, B1,..., Bn]`.
  """

  # 1. Sample Uniform(0, 1)s, flattening the batch dimension into axis 0.
  uniform_sample_shape = ps.concat([[ps.reduce_sum(total_count)], [n]], axis=0)
  uniform_samples = samplers.uniform(
      uniform_sample_shape, minval=0., maxval=1., dtype=low.dtype, seed=seed)
  # 2. Produce segment means.
  segment_lengths = tf.reshape(total_count, [-1])
  segment_ids = tf.repeat(tf.range(tf.size(segment_lengths)), segment_lengths)
  flatmeans = tf.math.segment_mean(uniform_samples, segment_ids)
  # 3. Reshape and transpose segment means back to the original shape.
  outshape = tf.concat([tf.shape(total_count), [n]], axis=0)
  tmeans = tf.reshape(flatmeans, outshape)
  axes = tf.range(tf.rank(tmeans))
  means = tf.transpose(tmeans, tf.roll(axes, shift=1, axis=0))
  # 4. Shift/scale from (0, 1) to (low, high).
  return low + (high - low) * means
Example #6
0
 def _calculate_batch_shape(self):
   """Computes fully defined batch shape for the new distribution."""
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shape = ps.stack(all_batch_shapes, axis=0)
   index_mask = ps.cast(
       ps.one_hot(self._axis, ps.shape(original_shape)[1]),
       dtype=tf.bool)
   new_concat_dim = ps.cast(
       ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32)
   return ps.where(index_mask, new_concat_dim,
                   ps.reduce_max(original_shape, axis=0))
Example #7
0
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                             depth=ndims,
                                             on_value=True,
                                             off_value=False,
                                             dtype=tf.bool),
                                  axis=0)
        in_shape = args_shape
        if not tensorshape_util.is_fully_defined(in_shape):
            in_shape = tf.shape(operands[0])
        unsqueezed_shape = ps.where(axis_nhot, 1, in_shape)

        result = _variadic_reduce_custom_grad(operands, inits, axis, reducer,
                                              unsqueezed_shape)

        if keepdims:
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, unsqueezed_shape), result)
        return result
    def preprocess_state(init_state):
      """Initial preprocessing at Stage 0."""
      dimension = ps.reduce_sum([
          ps.reduce_prod(ps.shape(x)[1:]) for x in init_state])
      likelihood_log_prob = likelihood_log_prob_fn(*init_state)

      # Default to the optimal for normal distributed targets.
      # TODO(b/152412213): Revisit this default parameter.
      scale_start = (
          tf.constant(2.38 ** 2, dtype=likelihood_log_prob.dtype) /
          tf.constant(dimension, dtype=likelihood_log_prob.dtype))
      # TODO(b/152412213): Enable batch of batches style by using non-scalar
      # inverse_temperature
      inverse_temperature = tf.zeros([], dtype=likelihood_log_prob.dtype)
      scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(scale_start, 1.)
      kernel = make_kernel_fn(
          _make_tempered_target_log_prob_fn(
              prior_log_prob_fn,
              likelihood_log_prob_fn,
              inverse_temperature),
          init_state,
          scalings,
          seed=seed_stream())
      pkr = kernel.bootstrap_results(current_state)
      _, kernel_target_log_prob = gather_mh_like_result(pkr)

      particle_info = ParticleInfo(
          log_accept_prob=ps.zeros_like(likelihood_log_prob),
          log_scalings=tf.math.log(scalings),
          tempered_log_prob=kernel_target_log_prob,
          likelihood_log_prob=likelihood_log_prob,
      )

      return SMCResults(
          num_steps=tf.convert_to_tensor(
              max_num_steps, dtype=tf.int32, name='num_steps'),
          inverse_temperature=inverse_temperature,
          log_marginal_likelihood=tf.constant(
              0., dtype=likelihood_log_prob.dtype),
          particle_info=particle_info
      )
def sample_sequential_monte_carlo(
        prior_log_prob_fn,
        likelihood_log_prob_fn,
        current_state,
        max_num_steps=25,
        max_stage=100,
        make_kernel_fn=make_rwmh_kernel_fn,
        tuning_fn=simple_heuristic_tuning,
        make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
        ess_threshold_ratio=0.5,
        parallel_iterations=10,
        seed=None,
        name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on  Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_sequential_monte_carlo` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        seed_stream = SeedStream(seed, salt='smc_seed')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        # Initial preprocessing at Stage 0
        likelihood_log_prob = likelihood_log_prob_fn(*current_state)

        likelihood_rank = ps.rank(likelihood_log_prob)
        dimension = ps.reduce_sum([
            ps.reduce_prod(ps.shape(x)[likelihood_rank:])
            for x in current_state
        ])

        # We infer the particle shapes from the resulting likelihood:
        # [num_particles, b1, ..., bN]
        particle_shape = ps.shape(likelihood_log_prob)
        num_particles, batch_shape = particle_shape[0], particle_shape[1:]
        effective_sample_size_threshold = tf.cast(
            num_particles * ess_threshold_ratio, tf.int32)

        # TODO(b/152412213): Revisit this default parameter.
        # Default to the optimal scaling of a random walk kernel for a d-dimensional
        # normal distributed targets: 2.38 ** 2 / d.
        # For more detail see:
        # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
        # random walk Metropolis algorithms. _The annals of applied probability_.
        # 1997;7(1):110-20.
        scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                       tf.constant(dimension, dtype=likelihood_log_prob.dtype))

        inverse_temperature = tf.zeros(batch_shape,
                                       dtype=likelihood_log_prob.dtype)
        scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
            scale_start, 1.)
        kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
            prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature),
                                current_state,
                                scalings,
                                seed=seed_stream)
        pkr = kernel.bootstrap_results(current_state)
        _, kernel_target_log_prob = gather_mh_like_result(pkr)

        particle_info = ParticleInfo(
            log_accept_prob=ps.zeros_like(likelihood_log_prob),
            log_scalings=tf.math.log(scalings),
            tempered_log_prob=kernel_target_log_prob,
            likelihood_log_prob=likelihood_log_prob,
        )

        current_pkr = SMCResults(
            num_steps=tf.convert_to_tensor(max_num_steps,
                                           dtype=tf.int32,
                                           name='num_steps'),
            inverse_temperature=inverse_temperature,
            log_marginal_likelihood=tf.zeros_like(inverse_temperature),
            particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights

        def mutate(current_state, log_scalings, num_steps,
                   inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                scalings = tf.exp(log_scalings)
                kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
                    prior_log_prob_fn, likelihood_log_prob_fn,
                    inverse_temperature),
                                        current_state,
                                        scalings,
                                        seed=seed_stream)
                pkr = kernel.bootstrap_results(current_state)
                kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

                def mutate_onestep(i, state, pkr, log_accept_prob_sum):
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return i + 1, next_state, next_kernel_results, log_accept_prob_sum

                (
                    _, next_state, next_kernel_results, log_accept_prob_sum
                ) = tf.while_loop(
                    cond=lambda i, *args: i < num_steps,
                    body=mutate_onestep,
                    loop_vars=(
                        tf.zeros([], dtype=tf.int32),
                        current_state,
                        pkr,
                        # we accumulate the acceptance probability in log space.
                        tf.fill(
                            ps.shape(kernel_log_accept_ratio),
                            tf.constant(-np.inf,
                                        kernel_log_accept_ratio.dtype))),
                    parallel_iterations=parallel_iterations)
                _, kernel_target_log_prob = gather_mh_like_result(
                    next_kernel_results)
                avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
                    tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
                return (next_state, avg_log_accept_prob_per_particle,
                        kernel_target_log_prob)

        # One SMC steps.
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_accept_prob,
                            axis=0)), 'scaling:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_scalings,
                            axis=0)))
            (resampled_state,
             resampled_particle_info), _ = resample_particle_and_info(
                 (state, smc_kernel_result.particle_info),
                 log_weights,
                 seed=seed_stream)
            next_num_steps, next_log_scalings = tuning_fn(
                smc_kernel_result.num_steps,
                resampled_particle_info.log_scalings,
                resampled_particle_info.log_accept_prob)
            # Skip tuning at stage 0.
            next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps,
                                      next_num_steps)
            next_log_scalings = tf.where(stage == 0,
                                         resampled_particle_info.log_scalings,
                                         next_log_scalings)
            next_num_steps = tf.clip_by_value(next_num_steps, 2, max_num_steps)

            next_state, log_accept_prob, tempered_log_prob = mutate(
                resampled_state, next_log_scalings, next_num_steps,
                new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=next_num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    log_accept_prob=log_accept_prob,
                    log_scalings=next_log_scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
                (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)),
            body=smc_body_fn,
            loop_vars=(tf.zeros([],
                                dtype=tf.int32), current_state, current_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
Example #10
0
 def _inverse_event_shape_tensor(self, output_shape):
     input_size = ps.reduce_sum(self.block_sizes)
     return ps.concat([output_shape[:-1], input_size[tf.newaxis]], -1)
Example #11
0
 def _forward_event_shape_tensor(self, input_shape):
     output_size = ps.reduce_sum(self._output_block_sizes())
     return ps.concat([input_shape[:-1], output_size[tf.newaxis]], -1)
Example #12
0
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        if JAX_MODE:
            from jax import lax  # pylint: disable=g-import-not-at-top
            result = lax.reduce(operands,
                                init_values=inits,
                                dimensions=axis,
                                computation=reducer)
        elif (tf.executing_eagerly()
              or not control_flow_util.GraphOrParentsInXlaContext(
                  tf1.get_default_graph())):
            result = _variadic_reduce(operands,
                                      init=inits,
                                      axis=axis,
                                      reducer=reducer)
        else:
            result = _xla_reduce(operands, inits, axis)

        if keepdims:
            axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                                 depth=ndims,
                                                 on_value=True,
                                                 off_value=False,
                                                 dtype=tf.bool),
                                      axis=0)
            in_shape = args_shape
            if not tensorshape_util.is_fully_defined(in_shape):
                in_shape = tf.shape(operands[0])
            final_shape = ps.where(axis_nhot, 1, in_shape)
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, final_shape), result)
        return result