def group_norm(inputs, groups=32, channels_axis=-1, reduction_axes=(-3, -2), center=True, scale=True, epsilon=1e-6, activation_fn=None, param_initializers=None, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, scope=None, mean_close_to_zero=False): """Functional interface for the group normalization layer. Reference: https://arxiv.org/abs/1803.08494. "Group Normalization", Yuxin Wu, Kaiming He Args: inputs: A Tensor with at least 2 dimensions one which is channels. All shape dimensions must be fully defined. groups: Integer. Divide the channels into this number of groups over which normalization statistics are computed. This number must be commensurate with the number of channels in `inputs`. channels_axis: An integer. Specifies index of channels axis which will be broken into `groups`, each of which whose statistics will be computed across. Must be mutually exclusive with `reduction_axes`. Preferred usage is to specify negative integers to be agnostic as to whether a batch dimension is included. reduction_axes: Tuple of integers. Specifies dimensions over which statistics will be accumulated. Must be mutually exclusive with `channels_axis`. Statistics will not be accumulated across axes not specified in `reduction_axes` nor `channel_axis`. Preferred usage is to specify negative integers to be agnostic to whether a batch dimension is included. Some sample usage cases: NHWC format: channels_axis=-1, reduction_axes=[-3, -2] NCHW format: channels_axis=-3, reduction_axes=[-2, -1] center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). scope: Optional scope for `variable_scope`. mean_close_to_zero: The mean of `input` before ReLU will be close to zero when batch size >= 4k for Resnet-50 on TPU. If `True`, use `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the variance. This is the same behavior as `fused` equals `True` in batch normalization. If `False`, use `nn.moments` to calculate the variance. When `mean` is close to zero, like 1e-4, use `mean` to calculate the variance may have poor result due to repeated roundoff error and denormalization in `mean`. When `mean` is large, like 1e2, sum(`input`^2) is so large that only the high-order digits of the elements are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2) when `mean` is large. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. ValueError: If number of groups is not commensurate with number of channels. ValueError: If reduction_axes or channels_axis are out of bounds. ValueError: If reduction_axes are not mutually exclusive with channels_axis. """ # TODO(shlens): Support partially defined shapes for the inputs. inputs = ops.convert_to_tensor(inputs) original_shape = inputs.shape if inputs.shape.ndims is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if channels_axis > (inputs.shape.ndims - 1): raise ValueError('Axis is out of bounds.') # Standardize the channels_axis to be positive and identify # of channels. if channels_axis < 0: channels_axis = inputs.shape.ndims + channels_axis channels = inputs.shape[channels_axis].value if channels is None: raise ValueError('Inputs %s has undefined channel dimension: %d.' % (inputs.name, channels_axis)) # Standardize the reduction_axes to be positive. reduction_axes = list(reduction_axes) for i in range(len(reduction_axes)): if reduction_axes[i] < 0: reduction_axes[i] += inputs.shape.ndims for a in reduction_axes: if a > inputs.shape.ndims: raise ValueError('Axis is out of bounds.') if inputs.shape[a].value is None: raise ValueError('Inputs %s has undefined dimensions %d.' % (inputs.name, a)) if channels_axis == a: raise ValueError('reduction_axis must be mutually exclusive ' 'with channels_axis') if groups > channels: raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) if channels % groups != 0: raise ValueError('%d channels is not commensurate with %d groups.' % (channels, groups)) # Determine axes before channels. Some examples of common image formats: # 'NCHW': before = [N], after = [HW] # 'NHWC': before = [NHW], after = [] axes_before_channels = inputs.shape.as_list()[:channels_axis] axes_after_channels = inputs.shape.as_list()[channels_axis + 1:] # Manually broadcast the parameters to conform to the number of groups. params_shape_broadcast = ([1] * len(axes_before_channels) + [groups, channels // groups] + [1] * len(axes_after_channels)) # Reshape the input by the group within the channel dimension. inputs_shape = (axes_before_channels + [groups, channels // groups] + axes_after_channels) inputs = array_ops.reshape(inputs, inputs_shape) # Determine the dimensions across which moments are calculated. moments_axes = [channels_axis + 1] for a in reduction_axes: if a > channels_axis: moments_axes.append(a + 1) else: moments_axes.append(a) with variable_scope.variable_scope(scope, 'GroupNorm', [inputs], reuse=reuse) as sc: # Note that the params_shape is the number of channels always. params_shape = [channels] # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None dtype = inputs.dtype.base_dtype if param_initializers is None: param_initializers = {} if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta_initializer = param_initializers.get( 'beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) beta = array_ops.reshape(beta, params_shape_broadcast) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get( 'gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) gamma = array_ops.reshape(gamma, params_shape_broadcast) # Calculate the moments. if mean_close_to_zero: # One pass algorithm returns better result when mean is close to zero. counts, means_ss, variance_ss, _ = nn.sufficient_statistics( inputs, moments_axes, keep_dims=True) mean, variance = nn.normalize_moments(counts, means_ss, variance_ss, shift=None) else: mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) # Compute normalization. # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor # appropriately so that this operation may be faster. gain = math_ops.rsqrt(variance + epsilon) offset = -mean * gain if gamma is not None: gain *= gamma offset *= gamma if beta is not None: offset += beta outputs = inputs * gain + offset # Collapse the groups into the channel dimension. outputs = array_ops.reshape(outputs, original_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def group_norm(inputs, groups=32, channels_axis=-1, reduction_axes=(-3, -2), center=True, scale=True, epsilon=1e-6, activation_fn=None, param_initializers=None, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, scope=None, mean_close_to_zero=False): """Functional interface for the group normalization layer. Reference: https://arxiv.org/abs/1803.08494. "Group Normalization", Yuxin Wu, Kaiming He Args: inputs: A Tensor with at least 2 dimensions one which is channels. All shape dimensions must be fully defined. groups: Integer. Divide the channels into this number of groups over which normalization statistics are computed. This number must be commensurate with the number of channels in `inputs`. channels_axis: An integer. Specifies index of channels axis which will be broken into `groups`, each of which whose statistics will be computed across. Must be mutually exclusive with `reduction_axes`. Preferred usage is to specify negative integers to be agnostic as to whether a batch dimension is included. reduction_axes: Tuple of integers. Specifies dimensions over which statistics will be accumulated. Must be mutually exclusive with `channels_axis`. Statistics will not be accumulated across axes not specified in `reduction_axes` nor `channel_axis`. Preferred usage is to specify negative integers to be agnostic to whether a batch dimension is included. Some sample usage cases: NHWC format: channels_axis=-1, reduction_axes=[-3, -2] NCHW format: channels_axis=-3, reduction_axes=[-2, -1] center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). scope: Optional scope for `variable_scope`. mean_close_to_zero: The mean of `input` before ReLU will be close to zero when batch size >= 4k for Resnet-50 on TPU. If `True`, use `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the variance. This is the same behavior as `fused` equals `True` in batch normalization. If `False`, use `nn.moments` to calculate the variance. When `mean` is close to zero, like 1e-4, use `mean` to calculate the variance may have poor result due to repeated roundoff error and denormalization in `mean`. When `mean` is large, like 1e2, sum(`input`^2) is so large that only the high-order digits of the elements are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2) when `mean` is large. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. ValueError: If number of groups is not commensurate with number of channels. ValueError: If reduction_axes or channels_axis are out of bounds. ValueError: If reduction_axes are not mutually exclusive with channels_axis. """ # TODO(shlens): Support partially defined shapes for the inputs. inputs = ops.convert_to_tensor(inputs) original_shape = inputs.shape if inputs.shape.ndims is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if channels_axis > (inputs.shape.ndims - 1): raise ValueError('Axis is out of bounds.') # Standardize the channels_axis to be positive and identify # of channels. if channels_axis < 0: channels_axis = inputs.shape.ndims + channels_axis channels = inputs.shape[channels_axis].value if channels is None: raise ValueError('Inputs %s has undefined channel dimension: %d.' % ( inputs.name, channels_axis)) # Standardize the reduction_axes to be positive. reduction_axes = list(reduction_axes) for i in range(len(reduction_axes)): if reduction_axes[i] < 0: reduction_axes[i] += inputs.shape.ndims for a in reduction_axes: if a > inputs.shape.ndims: raise ValueError('Axis is out of bounds.') if inputs.shape[a].value is None: raise ValueError('Inputs %s has undefined dimensions %d.' % ( inputs.name, a)) if channels_axis == a: raise ValueError('reduction_axis must be mutually exclusive ' 'with channels_axis') if groups > channels: raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) if channels % groups != 0: raise ValueError('%d channels is not commensurate with %d groups.' % (channels, groups)) # Determine axes before channels. Some examples of common image formats: # 'NCHW': before = [N], after = [HW] # 'NHWC': before = [NHW], after = [] axes_before_channels = inputs.shape.as_list()[:channels_axis] axes_after_channels = inputs.shape.as_list()[channels_axis+1:] # Manually broadcast the parameters to conform to the number of groups. params_shape_broadcast = ([1] * len(axes_before_channels) + [groups, channels // groups] + [1] * len(axes_after_channels)) # Reshape the input by the group within the channel dimension. inputs_shape = (axes_before_channels + [groups, channels // groups] + axes_after_channels) inputs = array_ops.reshape(inputs, inputs_shape) # Determine the dimensions across which moments are calculated. moments_axes = [channels_axis + 1] for a in reduction_axes: if a > channels_axis: moments_axes.append(a + 1) else: moments_axes.append(a) with variable_scope.variable_scope( scope, 'GroupNorm', [inputs], reuse=reuse) as sc: # Note that the params_shape is the number of channels always. params_shape = [channels] # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None dtype = inputs.dtype.base_dtype if param_initializers is None: param_initializers = {} if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta_initializer = param_initializers.get( 'beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) beta = array_ops.reshape(beta, params_shape_broadcast) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get( 'gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) gamma = array_ops.reshape(gamma, params_shape_broadcast) # Calculate the moments. if mean_close_to_zero: # One pass algorithm returns better result when mean is close to zero. counts, means_ss, variance_ss, _ = nn.sufficient_statistics( inputs, moments_axes, keep_dims=True) mean, variance = nn.normalize_moments( counts, means_ss, variance_ss, shift=None) else: mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) # Compute normalization. # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor # appropriately so that this operation may be faster. gain = math_ops.rsqrt(variance + epsilon) offset = -mean * gain if gamma is not None: gain *= gamma offset *= gamma if beta is not None: offset += beta outputs = inputs * gain + offset # Collapse the groups into the channel dimension. outputs = array_ops.reshape(outputs, original_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)