Ejemplo n.º 1
0
 def _batch_shape_tensor(self, temperature=None, logits=None):
     param = logits
     if param is None:
         param = self._logits if self._logits is not None else self._probs
     if temperature is None:
         temperature = self.temperature
     return ps.broadcast_shape(ps.shape(temperature), ps.shape(param)[:-1])
Ejemplo n.º 2
0
    def _log_prob(self, x):
        scores = tf.convert_to_tensor(self.scores)
        event_size = self._event_size(scores)

        x = tf.cast(x, self.dtype)
        # Broadcast scores or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(scores.shape)
                or x.shape != scores.shape):
            broadcast_shape = ps.broadcast_shape(ps.shape(scores), ps.shape(x))
            scores = tf.broadcast_to(scores, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)
        scores_shape = ps.shape(scores)[:-1]
        scores_2d = tf.reshape(scores, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        # Ensure that these are indices that we can use in a gather.
        if dtype_util.is_floating(x_2d.dtype):
            x_2d = tf.cast(x_2d, tf.int32)

        rearranged_scores = tf.gather(scores_2d, x_2d, batch_dims=1)
        normalization_terms = tf.cumsum(rearranged_scores,
                                        axis=-1,
                                        reverse=True)
        ret = tf.math.reduce_sum(tf.math.log(rearranged_scores /
                                             normalization_terms),
                                 axis=-1)
        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, scores_shape)
        return ret
Ejemplo n.º 3
0
    def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None):
        bijector = bijector_fn()
        if x_event_ndims is None:
            x_event_ndims = bijector.forward_min_event_ndims
        batch_shape = bijector.experimental_batch_shape(
            x_event_ndims=x_event_ndims)
        param_batch_shapes = batch_shape_lib.batch_shape_parts(
            bijector, bijector_x_event_ndims=x_event_ndims)

        new_batch_shape = [4, 2, 1, 1, 1]
        broadcast_bijector = bijector._broadcast_parameters_with_batch_shape(
            new_batch_shape, x_event_ndims)
        broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor(
            x_event_ndims=x_event_ndims)
        self.assertAllEqual(broadcast_batch_shape,
                            ps.broadcast_shape(batch_shape, new_batch_shape))

        # Check that all params have the expected batch shape.
        broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts(
            broadcast_bijector, bijector_x_event_ndims=x_event_ndims)

        def _maybe_broadcast_param_batch_shape(p, s):
            if isinstance(p,
                          tfb.Invert) and not p.bijector._params_event_ndims():
                return s  # Can't broadcast a bijector that doesn't itself have params.
            return ps.broadcast_shape(s, new_batch_shape)

        expected_broadcast_param_batch_shapes = tf.nest.map_structure(
            _maybe_broadcast_param_batch_shape,
            {param: getattr(bijector, param)
             for param in param_batch_shapes}, param_batch_shapes)
        self.assertAllEqualNested(broadcast_param_batch_shapes,
                                  expected_broadcast_param_batch_shapes)
Ejemplo n.º 4
0
 def _log_prob(self, x):
   log_nsphere_surface_area = (
       np.log(2.) + (self.dimension / 2) * np.log(np.pi) -
       tf.math.lgamma(tf.cast(self.dimension / 2., x.dtype)))
   batch_shape = ps.broadcast_shape(
       ps.shape(x)[:-1], self.batch_shape)
   return tf.fill(batch_shape, -log_nsphere_surface_area)
Ejemplo n.º 5
0
 def _batch_shape_tensor(self, logits_or_probs=None, total_count=None):
     if logits_or_probs is None:
         logits_or_probs = self._logits if self._probs is None else self._logits
     total_count = self._total_count if total_count is None else total_count
     return prefer_static.broadcast_shape(
         prefer_static.shape(logits_or_probs),
         prefer_static.shape(total_count))
Ejemplo n.º 6
0
def _reduce_ldj_ratio(unreduced_ldj_ratio, p, q, input_shape, min_event_ndims,
                      event_ndims):
    """Reduces an LDJ ratio computed with event_ndims=min_event_ndims."""
    # pylint: disable=protected-access
    have_parameter_batch_shape = (p._parameter_batch_shape is not None
                                  and q._parameter_batch_shape is not None)
    if have_parameter_batch_shape:
        parameter_batch_shape = ps.broadcast_shape(p._parameter_batch_shape,
                                                   q._parameter_batch_shape)
    else:
        parameter_batch_shape = None

    reduce_shape, assertions = bijector_lib.ldj_reduction_shape(
        input_shape,
        event_ndims=event_ndims,
        min_event_ndims=min_event_ndims,
        parameter_batch_shape=parameter_batch_shape,
        allow_event_shape_broadcasting=not (p._parts_interact
                                            or q._parts_interact),
        validate_args=p.validate_args or q.validate_args)

    sum_fn = getattr(p, '_sum_fn', getattr(q, '_sum_fn', tf.reduce_sum))
    with tf.control_dependencies(assertions):
        return bijector_lib.reduce_jacobian_det_over_shape(
            unreduced_ldj_ratio, reduce_shape=reduce_shape, sum_fn=sum_fn)
Ejemplo n.º 7
0
 def _entropy(self):
     scale = tf.broadcast_to(
         self.scale,
         ps.broadcast_shape(ps.shape(self.scale), ps.shape(self.loc)))
     euler_gamma = tf.constant(np.euler_gamma, self.dtype)
     return 1. + tf.math.log(scale) + euler_gamma * (1. +
                                                     self.concentration)
Ejemplo n.º 8
0
  def _sample_n(self, n, seed=None):
    """Gamma sampler.

    Rather than use `tf.random.gamma` (which is as of February 2020 implemented
    in C++ for CPU only), we implement our own gamma sampler in Python, using
    `batched_las_vegas_algorithm` as a substrate. This has the advantage that
    our sampler is XLA compilable.

    If sampling becomes a bottleneck on CPU, one way to gain speed would be to
    consider switching back to the C++ sampler.

    Args:
      n: Number of samples to draw.
      seed: (optional) The random seed.

    Returns:
      n samples from the gamma distribution.
    """
    n = tf.convert_to_tensor(n, name='shape', dtype=tf.int32)
    alpha = tf.convert_to_tensor(self.concentration, name='alpha')
    beta = tf.convert_to_tensor(self.rate, name='beta')
    broadcast_shape = prefer_static.broadcast_shape(
        prefer_static.shape(alpha), prefer_static.shape(beta))
    result_shape = tf.concat([[n], broadcast_shape], axis=0)

    return random_gamma(result_shape, alpha, beta, seed=seed)
Ejemplo n.º 9
0
def _batch_gather_with_broadcast(params, indices, axis):
    """Like batch_gather, but broadcasts to the left of axis."""
    # batch_gather assumes...
    #   params.shape =  [A1,...,AN, B1,...,BM]
    #   indices.shape = [A1,...,AN, C]
    # which gives output of shape
    #                   [A1,...,AN, C, B1,...,BM]
    # Here we broadcast dims of each to the left of `axis` in params, and left of
    # the rightmost dim in indices, e.g. we can
    # have
    #   params.shape =  [A1,...,AN, B1,...,BM]
    #   indices.shape = [a1,...,aN, C],
    # where ai broadcasts with Ai.

    # leading_bcast_shape is the broadcast of [A1,...,AN] and [a1,...,aN].
    leading_bcast_shape = ps.broadcast_shape(
        ps.shape_slice(params, np.s_[:axis]),
        ps.shape_slice(indices, np.s_[:-1]))
    params = _broadcast_with(
        params,
        ps.concat((leading_bcast_shape, ps.shape_slice(params, np.s_[axis:])),
                  axis=0))
    indices = _broadcast_with(
        indices,
        ps.concat((leading_bcast_shape, ps.shape_slice(indices, np.s_[-1:])),
                  axis=0))
    return tf.gather(params,
                     indices,
                     batch_dims=tensorshape_util.rank(indices.shape) - 1)
Ejemplo n.º 10
0
def random_gamma_with_runtime(shape,
                              concentration,
                              rate=None,
                              log_rate=None,
                              seed=None,
                              log_space=False):
    """Returns both a sample and the id of the implementation-selected runtime."""
    # This method exists chiefly for testing purposes.
    dtype = dtype_util.common_dtype([concentration, rate, log_rate],
                                    tf.float32)
    concentration = tf.convert_to_tensor(concentration, dtype=dtype)
    shape = ps.convert_to_shape_tensor(shape,
                                       dtype_hint=tf.int32,
                                       name='shape')

    if rate is not None and log_rate is not None:
        raise ValueError(
            'At most one of `rate` and `log_rate` may be specified.')
    if rate is not None:
        rate = tf.convert_to_tensor(rate, dtype=dtype)
    if log_rate is not None:
        log_rate = tf.convert_to_tensor(log_rate, dtype=dtype)
    total_shape = ps.concat([
        shape,
        ps.broadcast_shape(ps.shape(concentration),
                           _shape_or_scalar(rate, log_rate))
    ],
                            axis=0)
    seed = samplers.sanitize_seed(seed, salt='random_gamma')
    return _random_gamma_gradient(total_shape, concentration, rate, log_rate,
                                  seed, log_space)
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []

        sample_shape = tf.concat(
            [self._batch_shape_tensor(),
             self._event_shape_tensor()], axis=0)

        low = None if self._low is None else tf.convert_to_tensor(self._low)
        high = None if self._high is None else tf.convert_to_tensor(self._high)

        assertions = []
        if self._low is not None and is_init != tensor_util.is_ref(self._low):
            low_shape = ps.shape(low)
            broadcast_shape = ps.broadcast_shape(sample_shape, low_shape)
            assertions.extend([
                distribution_util.assert_integer_form(
                    low, message='`low` has non-integer components.'),
                assert_util.assert_equal(
                    tf.reduce_prod(broadcast_shape),
                    tf.reduce_prod(sample_shape),
                    message=('Shape of `low` adds extra batch dimensions to '
                             'sample shape.'))
            ])
        if self._high is not None and is_init != tensor_util.is_ref(
                self._high):
            high_shape = ps.shape(high)
            broadcast_shape = ps.broadcast_shape(sample_shape, high_shape)
            assertions.extend([
                distribution_util.assert_integer_form(
                    high, message='`high` has non-integer components.'),
                assert_util.assert_equal(
                    tf.reduce_prod(broadcast_shape),
                    tf.reduce_prod(sample_shape),
                    message=('Shape of `high` adds extra batch dimensions to '
                             'sample shape.'))
            ])
        if (self._low is not None and self._high is not None
                and (is_init != (tensor_util.is_ref(self._low)
                                 or tensor_util.is_ref(self._high)))):
            assertions.append(
                assert_util.assert_less(
                    low,
                    high,
                    message='`low` must be strictly less than `high`.'))

        return assertions
Ejemplo n.º 12
0
    def test_batching(self, input_batch_shape, kernel_batch_shape):
        input_shape = (12, 12, 2)
        filter_shape = (2, 2)
        channels_out = 3
        strides = (1, 1)
        dilations = (1, 1)
        padding = 'SAME'

        x, k = _make_input_and_kernel(self.make_input,
                                      input_batch_shape=input_batch_shape,
                                      input_shape=input_shape,
                                      kernel_batch_shape=kernel_batch_shape,
                                      filter_shape=filter_shape,
                                      channels_out=channels_out,
                                      dtype=self.dtype)

        conv_fn = tfn.util.make_convolution_fn(filter_shape,
                                               rank=2,
                                               strides=strides,
                                               padding=padding,
                                               dilations=dilations,
                                               validate_args=True)
        y_batched = conv_fn(x, k)

        broadcast_batch_shape = ps.broadcast_shape(input_batch_shape,
                                                   kernel_batch_shape)
        broadcasted_input = tf.broadcast_to(
            x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0))
        broadcasted_kernel = tf.broadcast_to(
            k,
            shape=ps.concat([broadcast_batch_shape,
                             ps.shape(k)[-2:]], axis=0))

        flat_y = tf.reshape(y_batched,
                            shape=ps.pad(ps.shape(y_batched)[-3:],
                                         paddings=[[1, 0]],
                                         constant_values=-1))
        flat_x = tf.reshape(broadcasted_input,
                            shape=ps.pad(input_shape,
                                         paddings=[[1, 0]],
                                         constant_values=-1))
        flat_tf_kernel = tf.reshape(broadcasted_kernel,
                                    shape=ps.concat(
                                        [(-1, ), filter_shape,
                                         (input_shape[-1], channels_out)],
                                        axis=0))

        y_expected = tf.vectorized_map(
            lambda args: tf.nn.conv2d(  # pylint: disable=g-long-lambda
                args[0][tf.newaxis],
                args[1],
                strides=strides,
                padding=padding),
            elems=(flat_x, flat_tf_kernel))

        [y_actual_,
         y_expected_] = self.evaluate([flat_y,
                                       tf.squeeze(y_expected, axis=1)])
        self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
Ejemplo n.º 13
0
 def _broadcast_params(self):
     lower_upper = tf.convert_to_tensor(self.lower_upper)
     perm = tf.convert_to_tensor(self.permutation)
     shape = ps.broadcast_shape(ps.shape(lower_upper)[:-1], ps.shape(perm))
     lower_upper = tf.broadcast_to(lower_upper,
                                   ps.concat([shape, shape[-1:]], 0))
     perm = tf.broadcast_to(perm, shape)
     return lower_upper, perm
Ejemplo n.º 14
0
 def _sample_n(self, n, seed=None):
     broadcast_shape = prefer_static.broadcast_shape(
         prefer_static.shape(self.concentration),
         prefer_static.shape(self.scale))
     return 1. / gamma.random_gamma(sample_shape=tf.concat(
         [[n], broadcast_shape], axis=0),
                                    alpha=self.concentration,
                                    beta=self.scale,
                                    seed=seed)
Ejemplo n.º 15
0
 def fn(self, *args, **kwargs):
   val = getattr(self.distribution, fn_name)(*args, **kwargs)
   single_val_shape = self.batch_shape_tensor()
   if n_event_shapes:
     single_val_shape = ps.concat(
         [single_val_shape] + [self.event_shape_tensor()] * n_event_shapes,
         axis=0)
   return tf.broadcast_to(
       val, ps.broadcast_shape(ps.shape(val), single_val_shape))
Ejemplo n.º 16
0
 def _cdf(self, x):
     loc = tf.convert_to_tensor(self.loc)
     concentration = tf.convert_to_tensor(self.concentration)
     batch_shape = ps.broadcast_shape(
         self._batch_shape_tensor(loc=loc, concentration=concentration),
         ps.shape(x))
     z = tf.broadcast_to(self._z(x, loc=loc), batch_shape)
     concentration = tf.broadcast_to(concentration, batch_shape)
     return von_mises_cdf(z, concentration)
Ejemplo n.º 17
0
 def expand_right_dims(x, broadcast=False):
   """Expand x so it can bcast w/ tensors of output shape."""
   expanded_shape_left = ps.broadcast_shape(
       ps.shape(x)[:-1],
       ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32))
   expanded_shape = ps.concat(
       (expanded_shape_left, ps.shape(x)[-1:],
        ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)),
       axis=0)
   x_expanded = tf.reshape(x, expanded_shape)
   if broadcast:
     broadcast_shape_left = ps.broadcast_shape(
         ps.shape(x)[:-1], y_ref_shape_left)
     broadcast_shape = ps.concat(
         (broadcast_shape_left, ps.shape(x)[-1:], y_ref_shape_right),
         axis=0)
     x_expanded = _broadcast_with(x_expanded, broadcast_shape)
   return x_expanded
def _cumulative_broadcast_dynamic(event_shape):
  broadcast_shapes = [
      ps.slice(s, begin=[0], size=[ps.size(s)-1]) for s in event_shape]
  cumulative_shapes = [broadcast_shapes[0]]
  for shape in broadcast_shapes[1:]:
    out_shape = ps.broadcast_shape(shape, cumulative_shapes[-1])
    cumulative_shapes.append(out_shape)
  return [
      ps.concat([b, ps.slice(s, begin=[ps.size(s)-1], size=[1])], axis=0)
      for b, s in zip(cumulative_shapes, event_shape)]
Ejemplo n.º 19
0
def _broadcast_with(tensor, shape):
    """Like broadcast_to, but allows singletons in the destination shape."""
    res = tf.broadcast_to(tensor, ps.broadcast_shape(ps.shape(tensor), shape))
    # We need this done explicitly because ps.broadcast_shape cannot deal with
    # partially specified shapes.
    tensorshape_util.set_shape(
        res,
        tf.broadcast_static_shape(tensor.shape,
                                  tf.TensorShape(tf.get_static_value(shape))))
    return res
Ejemplo n.º 20
0
  def test_dynamic(self):
    if tf.executing_eagerly():
      return

    shape = prefer_static.broadcast_shape(
        tf.convert_to_tensor([3, 2, 1]),
        tf.shape(tf1.placeholder_with_default(np.zeros((1, 5)),
                                              shape=(None, 5))))
    self.assertIsNone(tf.get_static_value(shape))
    self.assertAllEqual([3, 2, 5], self.evaluate(shape))
Ejemplo n.º 21
0
 def _variance(self):
     if self._precision is None:
         precision = self._precision_factor.matmul(self._precision_factor,
                                                   adjoint_arg=True)
     else:
         precision = self._precision
     variance = precision.inverse().diag_part()
     return tf.broadcast_to(
         variance, ps.broadcast_shape(ps.shape(variance),
                                      ps.shape(self.loc)))
Ejemplo n.º 22
0
 def test_works_correctly(self, input_size, output_size, kernel_batch_shape,
                          input_batch_shape):
     affine = tfn.Affine(input_size,
                         output_size=output_size,
                         batch_shape=kernel_batch_shape)
     x = tf.ones((input_batch_shape + (input_size, )), dtype=tf.float32)
     y = affine(x)
     self.assertAllEqual(
         y.shape,
         ps.broadcast_shape(kernel_batch_shape,
                            input_batch_shape).concatenate(output_size))
Ejemplo n.º 23
0
 def testAffineBatching(self, layer_batch, input_batch):
     dist = jdlayers.Affine(4, 3, dtype=self.dtype)
     layer = dist.sample(layer_batch,
                         seed=test_util.test_seed(sampler_type='stateless'))
     # Validate that we can map the layer.
     layer = tf.nest.map_structure(lambda x: x + 0., layer)
     x = tf.ones(input_batch + [3], dtype=self.dtype)
     y = layer(x)
     self.assertAllEqual(
         list(ps.broadcast_shape(layer_batch, input_batch)) + [4], y.shape)
     self.assertEqual(self.dtype, y.dtype)
Ejemplo n.º 24
0
def _random_gamma_noncpu(shape, concentration, rate, seed=None):
    """Sample using XLA-friendly python-based rejection sampler."""
    shape = tf.concat([
        shape,
        prefer_static.broadcast_shape(tf.shape(concentration), tf.shape(rate))
    ],
                      axis=0)
    return random_gamma_rejection(sample_shape=shape,
                                  alpha=concentration,
                                  beta=rate,
                                  seed=seed)
Ejemplo n.º 25
0
def random_gamma(shape, concentration, rate, seed=None):
    shape = ps.convert_to_shape_tensor(shape,
                                       dtype_hint=tf.int32,
                                       name='shape')

    total_shape = ps.concat(
        [shape,
         ps.broadcast_shape(ps.shape(concentration), ps.shape(rate))],
        axis=0)
    seed = samplers.sanitize_seed(seed, salt='random_gamma')
    return _random_gamma_gradient(total_shape, concentration, rate, seed)
Ejemplo n.º 26
0
 def _cdf(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     batch_shape = self.batch_shape
     if not tensorshape_util.is_fully_defined(batch_shape):
         batch_shape = self._batch_shape_tensor(low=low, high=high)
     broadcast_shape = ps.broadcast_shape(ps.shape(x), batch_shape)
     zeros = tf.zeros(broadcast_shape, dtype=self.dtype)
     ones = tf.ones(broadcast_shape, dtype=self.dtype)
     result_if_not_big = tf.where(x < low, zeros, (x - low) /
                                  self._range(low=low, high=high))
     return tf.where(x >= high, ones, result_if_not_big)
Ejemplo n.º 27
0
 def _finish_log_prob(self, lp, aux):
   (sample_ndims, extra_sample_ndims, batch_ndims) = aux
   # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
   #     full sample shape in the sample axes, before we reduce.
   bcast_lp_shape = ps.broadcast_shape(
       ps.shape(lp),
       ps.concat([ps.ones([sample_ndims], tf.int32),
                  ps.reshape(self.sample_shape, shape=[-1]),
                  ps.ones([batch_ndims], tf.int32)], axis=0))
   lp = tf.broadcast_to(lp, bcast_lp_shape)
   # (2) Make the final reduction.
   axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
   return self._sum_fn()(lp, axis=axis)
Ejemplo n.º 28
0
def _log_gamma_difference_jvp(primals, tangents):
    """Computes JVP for log-gamma-difference (supports JAX custom derivative)."""
    x, y = primals
    dx, dy = tangents
    # TODO(https://github.com/google/jax/issues/3768): eliminate broadcast_to?
    bc_shp = prefer_static.broadcast_shape(prefer_static.shape(dx),
                                           prefer_static.shape(dy))
    dx = tf.broadcast_to(dx, bc_shp)
    dy = tf.broadcast_to(dy, bc_shp)
    # See note above in _log_gamma_difference_bwd.
    px = -tf.math.digamma(x + y)
    py = tf.math.digamma(y) + px
    return _log_gamma_difference_naive_gradient(x, y), px * dx + py * dy
Ejemplo n.º 29
0
def _lbeta_jvp(primals, tangents):
    """Computes JVP for log-beta (supports JAX custom derivative)."""
    x, y = primals
    dx, dy = tangents
    # TODO(https://github.com/google/jax/issues/3768): eliminate broadcast_to?
    bc_shp = prefer_static.broadcast_shape(prefer_static.shape(dx),
                                           prefer_static.shape(dy))
    dx = tf.broadcast_to(dx, bc_shp)
    dy = tf.broadcast_to(dy, bc_shp)
    total_digamma = tf.math.digamma(x + y)
    px = tf.math.digamma(x) - total_digamma
    py = tf.math.digamma(y) - total_digamma
    return _lbeta_naive_gradient(x, y), px * dx + py * dy
Ejemplo n.º 30
0
    def _inner_apply(x1, x2):
      order = ps.shape(self.amplitudes)[-1]

      def scan_fn(esp, i):
        s = self.kernel[..., i].apply(
            x1[..., i][..., tf.newaxis],
            x2[..., i][..., tf.newaxis],
            example_ndims=example_ndims)
        next_esp = esp[..., 1:] + s[..., tf.newaxis] * esp[..., :-1]
        # Add the zero-th polynomial.
        next_esp = tf.concat(
            [tf.ones_like(esp[..., 0][..., tf.newaxis]), next_esp], axis=-1)
        return next_esp

      batch_shape = ps.broadcast_shape(
          ps.shape(x1)[:-self.kernel.feature_ndims],
          ps.shape(x2)[:-self.kernel.feature_ndims])

      batch_shape = ps.broadcast_shape(
          batch_shape,
          ps.concat([
              self.batch_shape_tensor(),
              [1] * example_ndims], axis=0))

      initializer = tf.concat(
          [tf.ones(ps.concat([batch_shape, [1]], axis=0),
                   dtype=self.dtype),
           tf.zeros(ps.concat([batch_shape, [order]], axis=0),
                    dtype=self.dtype)], axis=-1)

      esps = tf.scan(
          scan_fn,
          elems=ps.range(0, ps.shape(x1)[-1], dtype=tf.int32),
          parallel_iterations=32,
          initializer=initializer)[-1, ..., 1:]
      amplitudes = util.pad_shape_with_ones(
          self.amplitudes, ndims=example_ndims, start=-2)
      return tf.reduce_sum(esps * tf.math.square(amplitudes), axis=-1)