Example #1
0
def generator(z,
              progress,
              num_filters_fn,
              resolution_schedule,
              num_blocks=None,
              kernel_size=3,
              colors=3,
              to_rgb_activation=None,
              scope='progressive_gan_generator',
              reuse=None):
    """Generator network for the progressive GAN model.

  Args:
    z: A `Tensor` of latent vector. The first dimension must be batch size.
    progress: A scalar float `Tensor` of training progress.
    num_filters_fn: A function that maps `block_id` to # of filters for the
      block.
    resolution_schedule: An object of `ResolutionSchedule`.
    num_blocks: An integer of number of blocks. None means maximum number of
      blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
    kernel_size: An integer of convolution kernel size.
    colors: Number of output color channels. Defaults to 3.
    to_rgb_activation: Activation function applied when output rgb.
    scope: A string or variable scope.
    reuse: Whether to reuse `scope`. Defaults to None which means to inherit the
      reuse option of the parent scope.

  Returns:
    A `Tensor` of model output and a dictionary of model end points.
  """
    if num_blocks is None:
        num_blocks = resolution_schedule.num_resolutions

    start_h, start_w = resolution_schedule.start_resolutions
    final_h, final_w = resolution_schedule.final_resolutions

    def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
        return layers.custom_conv2d(
            x=x,
            filters=filters,
            kernel_size=kernel_size,
            padding=padding,
            activation=lambda x: layers.pixel_norm(tf.nn.leaky_relu(x)),
            he_initializer_slope=0.0,
            scope=scope)

    def _to_rgb(x):
        return layers.custom_conv2d(x=x,
                                    filters=colors,
                                    kernel_size=1,
                                    padding='SAME',
                                    activation=to_rgb_activation,
                                    scope='to_rgb')

    end_points = {}

    with tf.compat.v1.variable_scope(scope, reuse=reuse):
        with tf.compat.v1.name_scope('input'):
            x = tf.compat.v1.layers.flatten(z)
            end_points['latent_vector'] = x

        with tf.compat.v1.variable_scope(block_name(1)):
            x = tf.expand_dims(tf.expand_dims(x, 1), 1)
            x = layers.pixel_norm(x)
            # Pad the 1 x 1 image to 2 * (start_h - 1) x 2 * (start_w - 1)
            # with zeros for the next conv.
            x = tf.pad(tensor=x,
                       paddings=[[0] * 2, [start_h - 1] * 2, [start_w - 1] * 2,
                                 [0] * 2])
            # The output is start_h x start_w x num_filters_fn(1).
            x = _conv2d('conv0', x, (start_h, start_w), num_filters_fn(1),
                        'VALID')
            x = _conv2d('conv1', x, kernel_size, num_filters_fn(1))
            lods = [x]

        for block_id in range(2, num_blocks + 1):
            with tf.compat.v1.variable_scope(block_name(block_id)):
                x = layers.upscale(x, resolution_schedule.scale_base)
                x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
                x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id))
                lods.append(x)

        outputs = []
        for block_id in range(1, num_blocks + 1):
            with tf.compat.v1.variable_scope(block_name(block_id)):
                lod = _to_rgb(lods[block_id - 1])
                scale = resolution_schedule.scale_factor(block_id)
                lod = layers.upscale(lod, scale)
                end_points['upscaled_rgb_{}'.format(block_id)] = lod

                # alpha_i is used to replace lod_select. Note sum(alpha_i) is
                # garanteed to be 1.
                alpha = _generator_alpha(block_id, progress)
                end_points['alpha_{}'.format(block_id)] = alpha

                outputs.append(lod * alpha)

        predictions = tf.add_n(outputs)
        batch_size = dimension_value(z.shape[0])
        predictions.set_shape([batch_size, final_h, final_w, colors])
        end_points['predictions'] = predictions

    return predictions, end_points
Example #2
0
def generator_fn(noise, mode):
  del mode
  return tf.compat.v1.layers.dense(noise, dimension_value(noise.shape[1]))
Example #3
0
def batch_norm(inputs,
               is_training,
               conditional_class_labels=None,
               axis=-1,
               variance_epsilon=1e-3,
               center=True,
               scale=True,
               beta_initializer=tf.compat.v1.initializers.zeros(),
               gamma_initializer=tf.compat.v1.initializers.ones(),
               batch_axis=0,
               name='batch_norm'):
    """Adds Batch Norm or Conditional Batch Norm.

  Args:
    inputs: Tensor of inputs (e.g. images).
    is_training: Whether or not the layer is in training mode. In training
      mode it would accumulate the statistics of the moments into the
      `moving_mean` and `moving_variance` using an exponential moving average
      with the given `decay`. When is_training=False, these variables are not
      updated, and the precomputed values are used verbatim.
    conditional_class_labels: If `None`, this layer is vanilla Batch
      Normalization. If not, it is a tensor of one-hot labels - same first
      dimension as inputs, and the layer is Conditional Batch Normalization
      with normalization constants determined by the class (see
      https://arxiv.org/pdf/1610.07629.pdf for more detail).
    axis: Integer, the axis that should be normalized (typically the features
        axis). For instance, after a `Convolution2D` layer with
        `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
    variance_epsilon: A small float number to avoid dividing by 0.
    center: If True, add offset of `beta` to normalized tensor. If False,
      `beta` is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can
      be disabled since the scaling can be done by the next layer.
    beta_initializer: Initializer for the beta weight.
    gamma_initializer: Initializer for the gamma weight.
    batch_axis: The axis of the batch dimension.
    name: name: String name to be used for scoping.
  Returns:
    Output tensor.
  """
    with tf.compat.v1.variable_scope(name,
                                     values=[inputs],
                                     reuse=tf.compat.v1.AUTO_REUSE):
        # Determine the variable shape.
        var_shape = [1] * inputs.shape.rank
        var_shape[axis] = contrib.dimension_value(inputs.shape[axis])

        # Allocate parameters for the trainable variables.
        if conditional_class_labels is not None:
            num_categories = contrib.dimension_value(
                conditional_class_labels.shape[-1])
            var_shape[batch_axis] = num_categories
            labels = tf.math.argmax(input=conditional_class_labels,
                                    axis=1)  # to integer
            if center:
                beta = tf.compat.v1.get_variable('beta',
                                                 var_shape,
                                                 initializer=beta_initializer)
                beta = tf.gather(beta, labels)
            if scale:
                gamma = tf.compat.v1.get_variable(
                    'gamma', var_shape, initializer=gamma_initializer)
                gamma = tf.gather(gamma, labels)
        else:
            if center:
                beta = tf.compat.v1.get_variable('beta',
                                                 var_shape,
                                                 initializer=beta_initializer)
            if scale:
                gamma = tf.compat.v1.get_variable(
                    'gamma', var_shape, initializer=gamma_initializer)
        outputs = standardize_batch(inputs,
                                    is_training=is_training,
                                    epsilon=variance_epsilon,
                                    offset=beta,
                                    scale=gamma)
        outputs.set_shape(inputs.shape)
        return outputs
Example #4
0
def standardize_batch(inputs,
                      is_training,
                      offset=None,
                      scale=None,
                      decay=0.999,
                      epsilon=1e-3,
                      data_format='NHWC',
                      use_moving_averages=True,
                      use_cross_replica_mean=None):
    """Adds TPU-enabled batch normalization layer.

  Details on Batch Normalization can be found in 'Batch Normalization:
  Accelerating Deep Network Training by Reducing Internal Covariate Shift',
  Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167].

  Note #1: This method computes the batch statistic across all TPU replicas,
  thus simulating the true batch norm in the distributed setting. If one wants
  to avoid the cross-replica communication set use_cross_replica_mean=False.

  Note #2: When is_training is True the moving_mean and moving_variance need
  to be updated in each training step. By default, the update_ops are placed
  in `tf.GraphKeys.UPDATE_OPS` and they need to be added as a dependency to
  the `train_op`. For example:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      updates = tf.group(*update_ops)
      total_loss = control_flow_ops.with_dependencies([updates], total_loss)

  Note #3: Reasonable values for `decay` are close to 1.0, typically in the
  multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying
  `decay`=0.9) if model experiences reasonably good training performance but
  poor validation and/or test performance.

  Args:
    inputs: A tensor with 2 or 4 dimensions, where the first dimension is
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC`, and the second dimension if `data_format` is
      `NCHW`.
    is_training: Whether or not the layer is in training mode. In training
      mode it would accumulate the statistics of the moments into the
      `moving_mean` and `moving_variance` using an exponential moving average
      with the given `decay`. When is_training=False, these variables are not
      updated, and the precomputed values are used verbatim.
    offset: An offset `Tensor`, often denoted `beta` in equations, or
      None. If present, will be added to the normalized tensor.
    scale: A scale `Tensor`, often denoted `gamma` in equations, or
      `None`. If present, the scale is applied to the normalized tensor.
    decay: Decay for the moving averages. See notes above for reasonable
      values.
    epsilon: Small float added to variance to avoid dividing by zero.
    data_format: Input data format. NHWC or NCHW.
    use_moving_averages: If True keep moving averages of mean and variance that
      are used during inference. Otherwise use accumlators.
    use_cross_replica_mean: If True add operations to do computes batch norm
      statistics across all TPU cores. These ops are not compatible with other
      platforms. The default (None) will only add the operations if running
      on TPU.

  Returns:
    The normalized tensor with the same type and shape as `inputs`.
  """
    if data_format not in {'NCHW', 'NHWC'}:
        raise ValueError(
            'Invalid data_format {}. Allowed: NCHW, NHWC.'.format(data_format))
    if use_cross_replica_mean is None:
        # Default to global batch norm only on TPUs.
        use_cross_replica_mean = (
            contrib.tpu_function.get_tpu_context().number_of_shards
            is not None)
        logging.debug('Automatically determined use_cross_replica_mean=%s.',
                      use_cross_replica_mean)

    inputs = tf.convert_to_tensor(value=inputs)
    inputs_dtype = inputs.dtype
    inputs_shape = inputs.get_shape()

    num_channels = contrib.dimension_value(inputs.shape[-1])
    if num_channels is None:
        raise ValueError('`C` dimension must be known but is None')

    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
        raise ValueError('Inputs %s has undefined rank' % inputs.name)
    elif inputs_rank not in [2, 4]:
        raise ValueError('Inputs %s has unsupported rank.'
                         ' Expected 2 or 4 but got %d' %
                         (inputs.name, inputs_rank))
    # Bring 2-D inputs into 4-D format.
    if inputs_rank == 2:
        new_shape = [-1, 1, 1, num_channels]
        if data_format == 'NCHW':
            new_shape = [-1, num_channels, 1, 1]
        inputs = tf.reshape(inputs, new_shape)

    # Execute a distributed batch normalization
    axis = 1 if data_format == 'NCHW' else 3
    inputs = tf.cast(inputs, tf.float32)
    reduction_axes = [i for i in range(4) if i != axis]
    if use_cross_replica_mean:
        mean, variance = cross_replica_moments(inputs, reduction_axes)
    else:
        counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
            inputs, reduction_axes, keepdims=False)
        mean, variance = tf.nn.normalize_moments(counts,
                                                 mean_ss,
                                                 variance_ss,
                                                 shift=None)

    if use_moving_averages:
        mean, variance = moving_moments_for_inference(mean=mean,
                                                      variance=variance,
                                                      is_training=is_training,
                                                      decay=decay)
    else:
        mean, variance = accumulated_moments_for_inference(
            mean=mean, variance=variance, is_training=is_training)

    outputs = tf.nn.batch_normalization(inputs,
                                        mean=mean,
                                        variance=variance,
                                        offset=offset,
                                        scale=scale,
                                        variance_epsilon=epsilon)
    outputs = tf.cast(outputs, inputs_dtype)

    # Bring 2-D inputs back into 2-D format.
    if inputs_rank == 2:
        outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list())
    outputs.set_shape(inputs_shape)
    return outputs
Example #5
0
def wasserstein_gradient_penalty(
        real_data,
        generated_data,
        generator_inputs,
        discriminator_fn,
        discriminator_scope,
        epsilon=1e-10,
        target=1.0,
        one_sided=False,
        weights=1.0,
        scope=None,
        loss_collection=tf.compat.v1.GraphKeys.LOSSES,
        reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
        add_summaries=False):
    """The gradient penalty for the Wasserstein discriminator loss.

  See `Improved Training of Wasserstein GANs`
  (https://arxiv.org/abs/1704.00028) for more details.

  Args:
    real_data: Real data.
    generated_data: Output of the generator.
    generator_inputs: Exact argument to pass to the generator, which is used
      as optional conditioning to the discriminator.
    discriminator_fn: A discriminator function that conforms to TF-GAN API.
    discriminator_scope: If not `None`, reuse discriminators from this scope.
    epsilon: A small positive number added for numerical stability when
      computing the gradient norm.
    target: Optional Python number or `Tensor` indicating the target value of
      gradient norm. Defaults to 1.0.
    one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894
      is used. Defaults to `False`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `real_data` and `generated_data`, and must be broadcastable to
      them (i.e., all dimensions must be either `1`, or the same as the
      corresponding dimension).
    scope: The scope for the operations performed in computing the loss.
    loss_collection: collection to which this loss will be added.
    reduction: A `tf.losses.Reduction` to apply to loss.
    add_summaries: Whether or not to add summaries for the loss.

  Returns:
    A loss Tensor. The shape depends on `reduction`.

  Raises:
    ValueError: If the rank of data Tensors is unknown.
    RuntimeError: If TensorFlow is executing eagerly.
  """
    if tf.executing_eagerly():
        raise RuntimeError('Can\'t use `tf.gradient` when executing eagerly.')
    with tf.compat.v1.name_scope(scope, 'wasserstein_gradient_penalty',
                                 (real_data, generated_data)) as scope:
        real_data = tf.convert_to_tensor(value=real_data)
        generated_data = tf.convert_to_tensor(value=generated_data)
        if real_data.shape.ndims is None:
            raise ValueError('`real_data` can\'t have unknown rank.')
        if generated_data.shape.ndims is None:
            raise ValueError('`generated_data` can\'t have unknown rank.')

        differences = generated_data - real_data
        batch_size = (contrib.dimension_value(differences.shape.dims[0])
                      or tf.shape(input=differences)[0])
        alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
        alpha = tf.random.uniform(shape=alpha_shape)
        interpolates = real_data + (alpha * differences)

        with tf.name_scope(
                ''):  # Clear scope so update ops are added properly.
            # Reuse variables if variables already exists.
            with tf.compat.v1.variable_scope(discriminator_scope,
                                             'gpenalty_dscope',
                                             reuse=tf.compat.v1.AUTO_REUSE):
                disc_interpolates = discriminator_fn(interpolates,
                                                     generator_inputs)

        if isinstance(disc_interpolates, tuple):
            # ACGAN case: disc outputs more than one tensor
            disc_interpolates = disc_interpolates[0]

        gradients = tf.gradients(ys=disc_interpolates, xs=interpolates)[0]
        gradient_squares = tf.reduce_sum(input_tensor=tf.square(gradients),
                                         axis=list(
                                             range(1, gradients.shape.ndims)))
        # Propagate shape information, if possible.
        if isinstance(batch_size, int):
            gradient_squares.set_shape([batch_size] +
                                       gradient_squares.shape.as_list()[1:])
        # For numerical stability, add epsilon to the sum before taking the square
        # root. Note tf.norm does not add epsilon.
        slopes = tf.sqrt(gradient_squares + epsilon)
        penalties = slopes / target - 1.0
        if one_sided:
            penalties = tf.maximum(0., penalties)
        penalties_squared = tf.square(penalties)
        penalty = tf.compat.v1.losses.compute_weighted_loss(
            penalties_squared,
            weights,
            scope=scope,
            loss_collection=loss_collection,
            reduction=reduction)

        if add_summaries:
            tf.compat.v1.summary.scalar('gradient_penalty_loss', penalty)

        return penalty
Example #6
0
def instance_norm(inputs,
                  center=True,
                  scale=True,
                  epsilon=1e-6,
                  activation_fn=None,
                  param_initializers=None,
                  reuse=None,
                  outputs_collections=None,
                  trainable=True,
                  data_format=DATA_FORMAT_NHWC,
                  scope=None):
    """Functional interface for the instance normalization layer.

  Reference: https://arxiv.org/abs/1607.08022.

    "Instance Normalization: The Missing Ingredient for Fast Stylization"
    Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky

  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `tf.nn.relu`), this can
      be disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    scope: Optional scope for `variable_scope`.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
    inputs = tf.convert_to_tensor(value=inputs)
    inputs_shape = inputs.shape
    inputs_rank = inputs.shape.ndims

    if inputs_rank is None:
        raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
        raise ValueError('data_format has to be either NCHW or NHWC.')

    with tf.compat.v1.variable_scope(scope,
                                     'InstanceNorm', [inputs],
                                     reuse=reuse):
        if data_format == DATA_FORMAT_NCHW:
            reduction_axis = 1
            # For NCHW format, rather than relying on implicit broadcasting, we
            # explicitly reshape the params to params_shape_broadcast when computing
            # the moments and the batch normalization.
            params_shape_broadcast = list(
                [1, contrib.dimension_value(inputs_shape[1])] +
                [1 for _ in range(2, inputs_rank)])
        else:
            reduction_axis = inputs_rank - 1
            params_shape_broadcast = None
        moments_axes = list(range(inputs_rank))
        del moments_axes[reduction_axis]
        del moments_axes[0]
        params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined channels dimension %s.' %
                             (inputs.name, params_shape))

        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        dtype = inputs.dtype.base_dtype
        if param_initializers is None:
            param_initializers = {}
        if center:
            beta_initializer = param_initializers.get(
                'beta', tf.compat.v1.initializers.zeros())
            beta = tf.compat.v1.get_variable(name='beta',
                                             shape=params_shape,
                                             dtype=dtype,
                                             initializer=beta_initializer,
                                             trainable=trainable)
            if params_shape_broadcast:
                beta = tf.reshape(beta, params_shape_broadcast)
        if scale:
            gamma_initializer = param_initializers.get(
                'gamma', tf.compat.v1.initializers.ones())
            gamma = tf.compat.v1.get_variable(name='gamma',
                                              shape=params_shape,
                                              dtype=dtype,
                                              initializer=gamma_initializer,
                                              trainable=trainable)
            if params_shape_broadcast:
                gamma = tf.reshape(gamma, params_shape_broadcast)

        # Calculate the moments (instance activations).
        mean, variance = tf.nn.moments(x=inputs,
                                       axes=moments_axes,
                                       keepdims=True)

        # Compute instance normalization.
        outputs = tf.nn.batch_normalization(inputs,
                                            mean,
                                            variance,
                                            beta,
                                            gamma,
                                            epsilon,
                                            name='instancenorm')
        if activation_fn is not None:
            outputs = activation_fn(outputs)

        if outputs_collections:
            tf.compat.v1.add_to_collections(outputs_collections, outputs)

        return outputs
Example #7
0
def group_norm(inputs,
               groups=32,
               channels_axis=-1,
               reduction_axes=(-3, -2),
               center=True,
               scale=True,
               epsilon=1e-6,
               activation_fn=None,
               param_initializers=None,
               reuse=None,
               outputs_collections=None,
               trainable=True,
               scope=None,
               mean_close_to_zero=False):
    """Functional interface for the group normalization layer.

  Reference: https://arxiv.org/abs/1803.08494.

    "Group Normalization", Yuxin Wu, Kaiming He

  Args:
    inputs: A Tensor with at least 2 dimensions one which is channels. All
     shape dimensions except for batch must be fully defined.
    groups: Integer. Divide the channels into this number of groups over which
      normalization statistics are computed. This number must be commensurate
      with the number of channels in `inputs`.
    channels_axis: An integer. Specifies index of channels axis which will be
      broken into `groups`, each of which whose statistics will be computed
      across. Must be mutually exclusive with `reduction_axes`. Preferred usage
      is to specify negative integers to be agnostic as to whether a batch
      dimension is included.
    reduction_axes: Tuple of integers. Specifies dimensions over which
       statistics will be accumulated. Must be mutually exclusive with
       `channels_axis`. Statistics will not be accumulated across axes not
       specified in `reduction_axes` nor `channel_axis`. Preferred usage is to
       specify negative integers to be agnostic to whether a batch dimension is
       included.

      Some sample usage cases:
        NHWC format: channels_axis=-1, reduction_axes=[-3, -2]
        NCHW format: channels_axis=-3, reduction_axes=[-2, -1]

    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `tf.nn.relu`), this can
      be disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.
    mean_close_to_zero: The mean of `input` before ReLU will be close to zero
      when batch size >= 4k for Resnet-50 on TPU. If `True`, use
      `tf.nn.sufficient_statistics` and `tf.nn.normalize_moments` to calculate
      the variance. This is the same behavior as `fused` equals `True` in batch
      normalization. If `False`, use `tf.nn.moments` to calculate the variance.
      When `mean` is close to zero, like 1e-4, use `mean` to calculate the
      variance may have poor result due to repeated roundoff error and
      denormalization in `mean`.  When `mean` is large, like 1e2,
      sum(`input`^2) is so large that only the high-order digits of the elements
      are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate
      the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
      when `mean` is large.


  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
    ValueError: If number of groups is not commensurate with number of channels.
    ValueError: If reduction_axes or channels_axis are out of bounds.
    ValueError: If reduction_axes are not mutually exclusive with channels_axis.
  """
    # TODO(shlens): Support partially defined shapes for the inputs.
    inputs = tf.convert_to_tensor(value=inputs)

    if inputs.shape.ndims is None:
        raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    if channels_axis > (inputs.shape.ndims - 1):
        raise ValueError('Axis is out of bounds.')

    # Use dynamic shape for not fully defined dimensions in the inputs.
    dyanmic_shape = tf.shape(input=inputs)
    input_shape_list = []
    for i, dim in enumerate(inputs.shape):
        if contrib.dimension_value(dim) is None:
            input_shape_list.append(dyanmic_shape[i])
        else:
            input_shape_list.append(dim)

    # Standardize the channels_axis to be positive and identify # of channels.
    if channels_axis < 0:
        channels_axis = inputs.shape.ndims + channels_axis
    channels = contrib.dimension_value(inputs.shape[channels_axis])

    if channels is None:
        raise ValueError('Inputs %s has undefined channel dimension: %d.' %
                         (inputs.name, channels_axis))

    # Standardize the reduction_axes to be positive.
    reduction_axes = list(reduction_axes)
    for i in range(len(reduction_axes)):
        if reduction_axes[i] < 0:
            reduction_axes[i] += inputs.shape.ndims

    for a in reduction_axes:
        if a > inputs.shape.ndims:
            raise ValueError('Axis is out of bounds.')
        if contrib.dimension_value(inputs.shape[a]) is None:
            raise ValueError('Inputs %s has undefined dimensions %d.' %
                             (inputs.name, a))
        if channels_axis == a:
            raise ValueError('reduction_axis must be mutually exclusive '
                             'with channels_axis')
    if groups > channels:
        raise ValueError('Invalid groups %d for %d channels.' %
                         (groups, channels))
    if channels % groups != 0:
        raise ValueError('%d channels is not commensurate with %d groups.' %
                         (channels, groups))

    # Determine axes before channels. Some examples of common image formats:
    #  'NCHW': before = [N], after = [HW]
    #  'NHWC': before = [NHW], after = []
    axes_before_channels = input_shape_list[:channels_axis]
    axes_after_channels = input_shape_list[channels_axis + 1:]

    # Manually broadcast the parameters to conform to the number of groups.
    params_shape_broadcast = ([1] * len(axes_before_channels) +
                              [groups, channels // groups] +
                              [1] * len(axes_after_channels))

    # Reshape the input by the group within the channel dimension.
    inputs_shape = (axes_before_channels + [groups, channels // groups] +
                    axes_after_channels)
    inputs = tf.reshape(inputs, inputs_shape)

    # Determine the dimensions across which moments are calculated.
    moments_axes = [channels_axis + 1]
    for a in reduction_axes:
        if a > channels_axis:
            moments_axes.append(a + 1)
        else:
            moments_axes.append(a)

    with tf.compat.v1.variable_scope(scope, 'GroupNorm', [inputs],
                                     reuse=reuse):
        # Note that the params_shape is the number of channels always.
        params_shape = [channels]

        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        dtype = inputs.dtype.base_dtype
        if param_initializers is None:
            param_initializers = {}
        if center:
            beta_initializer = param_initializers.get(
                'beta', tf.compat.v1.initializers.zeros())
            beta = tf.compat.v1.get_variable(name='beta',
                                             shape=params_shape,
                                             dtype=dtype,
                                             initializer=beta_initializer,
                                             trainable=trainable)
            beta = tf.reshape(beta, params_shape_broadcast)

        if scale:
            gamma_initializer = param_initializers.get(
                'gamma', tf.compat.v1.initializers.ones())
            gamma = tf.compat.v1.get_variable(name='gamma',
                                              shape=params_shape,
                                              dtype=dtype,
                                              initializer=gamma_initializer,
                                              trainable=trainable)
            gamma = tf.reshape(gamma, params_shape_broadcast)

        # Calculate the moments.
        if mean_close_to_zero:
            # One pass algorithm returns better result when mean is close to zero.
            counts, means_ss, variance_ss, _ = tf.nn.sufficient_statistics(
                inputs, moments_axes, keepdims=True)
            mean, variance = tf.nn.normalize_moments(counts,
                                                     means_ss,
                                                     variance_ss,
                                                     shift=None)
        else:
            mean, variance = tf.nn.moments(x=inputs,
                                           axes=moments_axes,
                                           keepdims=True)

        # Compute normalization.
        # TODO(shlens): Fix tf.nn.batch_normalization to handle the 5-D Tensor
        # appropriately so that this operation may be faster.
        gain = tf.math.rsqrt(variance + epsilon)
        offset = -mean * gain
        if gamma is not None:
            gain *= gamma
            offset *= gamma
        if beta is not None:
            offset += beta
        outputs = inputs * gain + offset

        # Collapse the groups into the channel dimension.
        outputs = tf.reshape(outputs, input_shape_list)

        if activation_fn is not None:
            outputs = activation_fn(outputs)

        if outputs_collections:
            tf.compat.v1.add_to_collections(outputs_collections, outputs)

        return outputs