Beispiel #1
0
def DistanceBetweenCentroidsAndBBoxesFastAndFurious(centroids, bboxes, masks):
    """Computes the distance between centroids and bboxes.

  The distance/loss is loosely following the 'Fast and Furious' paper by Luo et
  al., CVPR'18.  This is just one way of calculating the distances. We will
  probably develop other ways.

  Args:
    centroids: [..., 4]. x/y/w/h for bboxes.
    bboxes: [..., 4]. ymin/xmin/ymax/xmax for bboxes.
    masks: [...]. masks[i] == 1 means i-th entry (centroids[i] and bboxes[i])
      should be considered in the distance/loss calculation.

  Returns:
    A [...] tensor. i-th value is the distance measure of centroids[i] and
    bboxes[i].
  """
    x, y, w, h = tf.unstack(centroids, axis=-1, num=4)
    # "gt" suffix means 'ground truth'.
    x_gt, y_gt, w_gt, h_gt = tf.unstack(BBoxesToXYWH(bboxes), axis=-1, num=4)

    def Pos(x):
        return tf.maximum(tf.constant(1e-8, x.dtype), x)

    # The following terms are zeros when masks[i] is 0.
    l_x = py_utils.CheckNumerics(masks * (x - x_gt) / Pos(w_gt))
    l_y = py_utils.CheckNumerics(masks * (y - y_gt) / Pos(h_gt))
    s_w = py_utils.CheckNumerics(masks * tf.math.log(Pos(w) / Pos(w_gt)))
    s_h = py_utils.CheckNumerics(masks * tf.math.log(Pos(h) / Pos(h_gt)))
    return (_SmoothL1Norm(l_x) + _SmoothL1Norm(l_y) + _SmoothL1Norm(s_w) +
            _SmoothL1Norm(s_h))
Beispiel #2
0
  def FProp(self, theta, inputs):
    """Applies batch normalization.

    Using the implementation in github.com/
    tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550

    Args:
      theta: A nested map object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
    p = self.params
    inputs_dtype = inputs.dtype
    inputs = tf.cast(inputs, p.dtype)
    inputs = py_utils.with_dependencies(
        [py_utils.assert_shape_match([tf.shape(inputs)[-1]], [p.dim])], inputs)
    with tf.name_scope(p.name) as scope:
      if p.is_eval:
        outputs = tf.nn.batch_normalization(inputs, theta.moving_mean,
                                            theta.moving_variance,
                                            theta.beta, theta.gamma, p.epsilon)
      else:
        mean, variance = self._Moments(inputs, p.bn_group_size)
        mean = py_utils.CheckNumerics(
            mean, 'mean of {} failed numeric check'.format(scope))
        variance = py_utils.CheckNumerics(
            variance, 'variance of {} failed numeric check'.format(scope))
        outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta,
                                            theta.gamma, p.epsilon)
      outputs.set_shape(inputs.get_shape())
      return tf.cast(outputs, inputs_dtype)
Beispiel #3
0
    def _Normalize(self, theta, grouped_inputs, group_mean, group_variance):
        p = self.params
        group_mean = py_utils.CheckNumerics(
            group_mean, f'mean of {p.name} failed numeric check.')
        group_variance = py_utils.CheckNumerics(
            group_variance, f'variance of {p.name} failed numeric check.')

        input_shape = py_utils.GetShape(grouped_inputs)
        moment_shape = list(input_shape)
        if p.input_rank == 4:
            moment_shape[2] = 1
            moment_shape[-1] = 1
        else:
            moment_shape[-1] = 1
        if not p.cumulative:
            # If not cumulative, the seqlen dimension is also reduced.
            moment_shape[1] = 1

        group_mean = py_utils.HasShape(group_mean, moment_shape)
        group_variance = py_utils.HasShape(group_variance, moment_shape)
        group_variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(group_variance,
                                          tf.cast(0, group_variance.dtype))
        ], group_variance)

        grouped_inputs = (grouped_inputs - group_mean
                          ) * tf.math.rsqrt(group_variance + self._epsilon)
        # Merges the last two dims.
        grouped_inputs = tf.reshape(grouped_inputs, input_shape[:-2] + [-1])

        # Note, The real gamma to use is 1 + gamma.
        outputs = grouped_inputs * (theta.gamma + 1) + theta.beta
        return outputs
Beispiel #4
0
 def PostTrainingStepUpdate(self, global_step):
     p = self.params
     counts = self.accumulators.counts.GetValue()
     mean_ss = self.accumulators.mean_ss.GetValue()
     variance_ss = self.accumulators.variance_ss.GetValue()
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
     with tf.name_scope(p.name) as scope:
         with tf.colocate_with(self.vars.moving_mean):
             mean_update = tf.assign_sub(
                 self.vars.moving_mean,
                 (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay,
                 name='moving_mean_update')
         with tf.colocate_with(self.vars.moving_variance):
             var_update = tf.assign_sub(
                 self.vars.moving_variance,
                 (self.vars.moving_variance - tf.cast(variance, p.dtype)) *
                 decay,
                 name='moving_variance_update')
         py_utils.CheckNumerics(
             self.vars.moving_mean,
             'moving mean of {} failed numeric check'.format(scope))
         py_utils.CheckNumerics(
             self.vars.moving_variance,
             'moving variance of {} failed numeric check'.format(scope))
     self.accumulators.counts.Reset()
     self.accumulators.mean_ss.Reset()
     self.accumulators.variance_ss.Reset()
     return tf.group(mean_update, var_update)
Beispiel #5
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        p = self.params
        n, h, w, c = tf.unstack(tf.shape(inputs), axis=0, num=4)
        group_size = p.dim // p.num_groups
        num_groups = p.num_groups
        min_group_size = p.min_group_size if p.dim > p.min_group_size else p.dim
        if group_size <= min_group_size:
            group_size = min_group_size
            num_groups = p.dim // group_size

        with tf.name_scope(p.name):
            x = tf.reshape(inputs, [n, h, w, num_groups, group_size])
            if paddings is None:
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=[1, 2, 4], keepdims=True)
                norm_mean, norm_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(paddings, [n, h, 1, 1, 1])
                norm_mean, norm_variance = ComputeMomentsWithPadding(
                    x, expanded_paddings, [1, 2, 4], keepdims=True)

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta = theta.beta
            gamma = theta.gamma

            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.cast(0., norm_variance.dtype)),
                    py_utils.assert_shape_match([n, 1, 1, num_groups, 1],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([n, 1, 1, num_groups, 1],
                                                tf.shape(norm_variance)),
            ]):
                x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon)
                x = tf.reshape(x, [n, h, w, c])
                gn_output = x * gamma + beta
                gn_output = tf.reshape(gn_output, [n, h, w, c])
                if paddings is None:
                    return gn_output
                else:
                    return gn_output, paddings
Beispiel #6
0
 def PostTrainingStepUpdate(self, global_step):
     """Updates moving_mean, moving_variance after each training step."""
     p = self.params
     # Get sufficient stats that accumulates over microbatches.
     counts = self.accumulators.counts.GetValue()
     mean_ss = self.accumulators.mean_ss.GetValue()
     variance_ss = self.accumulators.variance_ss.GetValue()
     # Compute batch mean and batch variance from sufficient stats
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
     # Update moving_mean, moving_variance from  batch mean and batch variance.
     with tf.name_scope(p.name) as scope:
         with tf.colocate_with(self.vars.moving_mean):
             mean_update = tf.assign_sub(
                 self.vars.moving_mean,
                 (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay,
                 name='moving_mean_update')
         with tf.colocate_with(self.vars.moving_variance):
             var_update = tf.assign_sub(
                 self.vars.moving_variance,
                 (self.vars.moving_variance - tf.cast(variance, p.dtype)) *
                 decay,
                 name='moving_variance_update')
         py_utils.CheckNumerics(
             self.vars.moving_mean,
             'moving mean of {} failed numeric check'.format(scope))
         py_utils.CheckNumerics(
             self.vars.moving_variance,
             'moving variance of {} failed numeric check'.format(scope))
     self.accumulators.counts.Reset()
     self.accumulators.mean_ss.Reset()
     self.accumulators.variance_ss.Reset()
     return tf.group(mean_update, var_update)
    def testCheckNumerics(self):
        checked = py_utils.CheckNumerics(
            tf.convert_to_tensor([2.0, 3.0], tf.float32))
        self.assertListEqual([2.0, 3.0], checked.numpy().tolist())

        with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'NaN'):
            py_utils.CheckNumerics(
                tf.reduce_mean(tf.convert_to_tensor([], tf.float32)))
Beispiel #8
0
  def testCheckNumerics(self):
    xv = [[1, 2], [3, 4]]
    yv = [10] * 4
    with self.session() as sess:
      x = tf.constant(xv, tf.float32)
      y = tf.constant(yv)
      z = tf.reduce_mean(tf.constant([], tf.float32))
      self.assertAllClose(xv, sess.run(py_utils.CheckNumerics(x)))
      self.assertAllClose(yv, sess.run(py_utils.CheckNumerics(y)))
      actual_xv, actual_yv = sess.run(py_utils.CheckNumerics([x, y]))
      self.assertAllClose(xv, actual_xv)
      self.assertAllClose(yv, actual_yv)
      actual_xv, actual_yv = sess.run(py_utils.CheckNumerics((x, y)))
      self.assertAllClose(xv, actual_xv)
      self.assertAllClose(yv, actual_yv)

      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'NaN'):
        sess.run(py_utils.CheckNumerics(z))
Beispiel #9
0
 def _FPropMetrics(self, metrics):
   # Adds stats about the input batch.
   metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
       self.input_generator.InputBatchSize()), tf.constant(1.0))
   # Generates summaries.
   for name, (value, weight) in six.iteritems(metrics):
     self.AddEvalMetric(name, value, weight)
   # Loss.
   self._loss, self._num_predicts = metrics['loss']
   self._loss = py_utils.CheckNumerics(self._loss)
Beispiel #10
0
 def _FPropResult(self, metrics, per_example):
   # Adds stats about the input batch.
   metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
       self.input_generator.InputBatchSize()), tf.constant(1.0))
   # Generates summaries.
   for name, (value, weight) in six.iteritems(metrics):
     self.AddEvalMetric(name, value, weight)
   per_example = self.FilterPerExampleTensors(per_example)
   for name, value in six.iteritems(per_example):
     self.AddPerExampleTensor(name, value)
   # Loss.
   self._loss, self._num_predictions = metrics['loss']
   self._loss = py_utils.CheckNumerics(self._loss)
   summary_utils.scalar('num_predictions', self._num_predictions)
Beispiel #11
0
        def _PaddedMaxFn(inp):
            """Apply padded max using reduce_max with paddings replaced by neginf."""
            # Replace all padded features with -inf.
            neginf_padding = tf.where(inp.padding > 0, -np.inf * inp.padding,
                                      inp.padding)
            features = inp.features + neginf_padding[..., tf.newaxis]
            features = tf.reduce_max(features, axis=-2)

            # Replace features of all padded points by zeros. If a batch of points are
            # all padded, then reduce_min over the padding will be 1. We set the
            # features to be zero, so that we don't get any downstream issue with
            # NaNs. Note that inf * 0 = NaN.
            all_padded = tf.cast(tf.reduce_min(inp.padding, axis=-1), tf.bool)
            all_padded = tf.broadcast_to(all_padded[..., tf.newaxis],
                                         py_utils.GetShape(features))
            features = tf.where(all_padded, tf.zeros_like(features), features)
            return py_utils.CheckNumerics(features)
Beispiel #12
0
    def _PaddedMeanFn(inp):
      """Apply padded mean using reduce_sum and dividing by # real points."""
      # Replace all padded features with 0 by masking the padded features out.
      mask = 1 - inp.padding
      features = inp.features * mask[..., tf.newaxis]
      features = tf.reduce_sum(features, axis=-2)
      num_real_points = tf.reduce_sum(mask, axis=-1, keep_dims=True)
      # Prevent the divisor of our padded mean from ever being 0, so that
      # the gradient flowing back through this op doesn't give us NaNs.
      num_real_points = tf.maximum(num_real_points, 1)
      features = features / num_real_points

      # Replace features of all padded points by zeros. If a batch of points are
      # all padded, then num_real_points will be zero. We set the features to be
      # zero, so that we don't get any downstream issue with NaNs.
      # Note that inf * 0 = NaN.
      all_padded = tf.equal(num_real_points, 0.)
      all_padded = tf.broadcast_to(all_padded, py_utils.GetShape(features))
      features = tf.where(all_padded, tf.zeros_like(features), features)
      return py_utils.CheckNumerics(features)
Beispiel #13
0
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = self._moving_mean, self._moving_variance
            else:
                mean, variance = self._Moments(
                    inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self._moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self._moving_variance, variance,
                                             self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram('%s_moving_mean' % p.name,
                                        tf.cast(self._moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self._moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(mean - self._moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(variance - self._moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies([mean],
                                                           self._moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self._moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            if p.use_moving_avg_in_training:
                beta = 0.0
                gamma = 1.0
            else:
                beta = theta.beta
                gamma = theta.gamma
            return norm_mean, norm_variance, beta, gamma
Beispiel #14
0
    def LocalizationResiduals(self, anchor_bboxes, assigned_gt_bboxes):
        """Computes the anchor residuals for every bbox.

    For a given bbox, compute residuals in the following way:

      Let ``anchor_bbox = (x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a)``
      and ``assigned_gt_bbox = (x_gt, y_gt, z_gt, dx_gt, dy_gt, dz_gt, phi_gt)``

      Define ``diagonal_xy = sqrt(dx_a^2 + dy_a^2)``

      Then the corresponding residuals are given by::

        x_residual = (x_gt - x_a) / (diagonal_xy)
        y_residual = (y_gt - y_a) / (diagonal_xy)
        z_residual = (z_gt - z_a) / (dz_a)

        dx_residual = log(dx_gt / dx_a)
        dy_residual = log(dy_gt / dy_a)
        dz_residual = log(dz_gt / dz_a)

        phi_residual = phi_gt - phi_a

      The normalization for x and y residuals by the diagonal was first
      proposed by [1]. Intuitively, this reflects that objects can usually
      move freely in the x-y plane, including diagonally. On the other hand,
      moving in the z-axis (up and down) can be considered orthogonal to x-y.

      For phi_residual, one way to frame the loss is with
      SmoothL1(sine(phi_residual - phi_predicted)).
      The use of sine to wrap the phi residual was proposed by [2]. This
      stems from the observation that bboxes at phi and phi + pi are the same
      bbox, fully overlapping in 3D space, except that the direction is
      different. Note that the use of sine makes this residual invariant to
      direction when a symmetric loss like SmoothL1 is used. In
      ResidualsToBBoxes, we ensure that the phi predicted is between [0, pi).

    The Huber (SmoothL1) loss can then be applied to the delta between these
    target residuals and the model predicted residuals.

    [1] VoxelNet: End-to-End Learning for Point Cloud Based 3D Object Detection
        https://arxiv.org/abs/1711.06396

    [2] SECOND: Sparsely Embedded Convolutional Detection
        https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf

    Args:
      anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz,
        phi), corresponding to each anchor bbox parameters.
      assigned_gt_bboxes: tf.float32 of the same shape as anchor_bboxes
        containing the corresponding assigned ground-truth bboxes.

    Returns:
      A tf.float32 tensor of the same shape as anchor_bboxes with target
      residuals for every corresponding bbox.
    """
        anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes)
        anchor_bboxes = py_utils.with_dependencies(
            [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes)
        assigned_gt_bboxes = py_utils.HasShape(assigned_gt_bboxes,
                                               anchor_bboxes_shape)

        x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack(anchor_bboxes,
                                                            num=7,
                                                            axis=-1)
        x_gt, y_gt, z_gt, dx_gt, dy_gt, dz_gt, phi_gt = tf.unstack(
            assigned_gt_bboxes, num=7, axis=-1)

        diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a))

        # The anchor dimensions is usually a hard-coded param given to the input
        # generator and should not be 0. We use CheckNumerics to ensure that is the
        # case.
        x_residual = py_utils.CheckNumerics((x_gt - x_a) / diagonal_xy)
        y_residual = py_utils.CheckNumerics((y_gt - y_a) / diagonal_xy)
        z_residual = py_utils.CheckNumerics((z_gt - z_a) / dz_a)

        dx_residual = py_utils.CheckNumerics(tf.log(dx_gt / dx_a))
        dy_residual = py_utils.CheckNumerics(tf.log(dy_gt / dy_a))
        dz_residual = py_utils.CheckNumerics(tf.log(dz_gt / dz_a))

        phi_residual = phi_gt - phi_a

        return tf.stack([
            x_residual,
            y_residual,
            z_residual,
            dx_residual,
            dy_residual,
            dz_residual,
            phi_residual,
        ],
                        axis=-1)  # pyformat: disable
Beispiel #15
0
    def FProp(self, theta):
        """Forward propagation.

    This default `FProp` implementation here supports batch splitting in
    synchronous and asynchronous training when sub-classes implement
    `FPropTower`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.

    Returns:
      A dict containing metrics pairs. One of the keys should be 'loss' and its
      value should be a (loss, num_predictions) pair.
    """
        p = self.params
        cluster = cluster_factory.Current()

        with tf.name_scope('fprop'), tf.name_scope(p.name):
            all_fprop_metrics = []

            if py_utils.use_tpu():
                batch = self.input_generator.CreateTpuFeeds()
                with tf.name_scope('tower_0_0'):
                    dec_metrics = self.FPropTower(theta, batch)
                all_fprop_metrics.append(dec_metrics)
            else:
                # Splits the input batch on the input device.
                num_splits = cluster.num_splits_per_client
                with tf.device(cluster.input_device):
                    batches = self.input_generator.SplitInputBatch(num_splits)
                    assert num_splits == len(batches)

                # dev_list_per_replica[i][j] is the i-th worker's j-th device.
                dev_list_per_replica = cluster.available_devices.tolist()

                # Asserts invariant of the total number of splits w.r.t.,
                # splits per worker.
                splits_per_replica = cluster.num_splits_per_replica
                assert num_splits == splits_per_replica * len(
                    dev_list_per_replica)

                for w_id, w_devs in enumerate(dev_list_per_replica):
                    # Make local copy of the vars, shard on devices for this worker.
                    theta_local = py_utils.CreateLocalTheta(theta,
                                                            w_devs,
                                                            label='worker %d' %
                                                            w_id)

                    for s_id in range(splits_per_replica):
                        # s_id-th split for the w_id-th worker.
                        split_id = splits_per_replica * w_id + s_id
                        with py_utils.ModelSplit(split_id):
                            with tf.device(
                                    cluster.WorkerDeviceInModelSplit(0)):
                                with tf.name_scope('tower_%d_%d' %
                                                   (w_id, s_id)):
                                    batch = self.input_generator.PreprocessInputBatch(
                                        batches[split_id])
                                    dec_metrics = self.FPropTower(
                                        theta_local, batch)
                        all_fprop_metrics.append(dec_metrics)

            metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics)

        # Adds stats about the input batch.
        metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
            self.input_generator.InputBatchSize()), tf.constant(1.0))
        # Generates summaries.
        for name, (value, weight) in six.iteritems(metrics):
            self.AddEvalMetric(name, value, weight)

        # Loss.
        self._loss, self._num_predicts = metrics['loss']
        self._loss = py_utils.CheckNumerics(self._loss)

        return metrics
Beispiel #16
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        p = self.params
        inputs = py_utils.with_dependencies([
            py_utils.assert_greater_equal(py_utils.GetRank(inputs),
                                          p.input_rank)
        ], inputs)

        min_group_size = min(p.min_group_size, p.dim)
        group_size = max(p.dim // p.num_groups, min_group_size)
        num_groups = p.dim // group_size

        input_shape = py_utils.GetShape(inputs)
        with tf.name_scope(p.name):
            x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size])
            expanded_rank = p.input_rank + 1
            all_dims = list(range(expanded_rank))
            if paddings is None:
                # Skip d0, d[-2]
                axes = all_dims[1:-2] + all_dims[-1:]
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=axes, keepdims=True)
                norm_mean, norm_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(
                    paddings, input_shape[:2] + [1] * (expanded_rank - 2))
                # skip the batching and group dim
                if p.cumulative:
                    # Skip d0, d1 and d[-2]
                    reduce_over_dims = all_dims[2:-2] + all_dims[-1:]
                    norm_mean, norm_variance = ComputeMomentsWithPadding(
                        x,
                        expanded_paddings,
                        reduce_over_dims=reduce_over_dims,
                        cumulative_axis=1,
                        keepdims=True)
                else:
                    # Skip d0, d[-2]
                    reduce_over_dims = all_dims[1:-2] + all_dims[-1:]
                    norm_mean, norm_variance = ComputeMomentsWithPadding(
                        x, expanded_paddings, reduce_over_dims, keepdims=True)

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta = theta.beta
            gamma = theta.gamma
            n = input_shape[0]
            t = input_shape[1] if p.cumulative else 1
            norm_shape = [n, t, 1, num_groups, 1
                          ] if p.input_rank == 4 else [n, t, num_groups, 1]
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.cast(0., norm_variance.dtype)),
                    py_utils.assert_shape_match(norm_shape,
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match(norm_shape,
                                                tf.shape(norm_variance)),
            ]):
                x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon)
                x = tf.reshape(x, input_shape)
                gn_output = x * gamma + beta
                gn_output = tf.reshape(gn_output, input_shape)
                if paddings is None:
                    return gn_output
                else:
                    return gn_output, paddings