def normalize_in_training(): if needs_broadcasting: return nn.batch_normalization(inputs, broadcast_mean, broadcast_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: return nn.batch_normalization( inputs, mean, variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon)
def normalize_in_test(): if needs_broadcasting: broadcast_moving_mean = array_ops.reshape( self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape( self.moving_variance, broadcast_shape) return nn.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: return nn.batch_normalization( inputs, self.moving_mean, self.moving_variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon)
def normalize_in_training(): if needs_broadcasting: return nn.batch_normalization(inputs, broadcast_mean, broadcast_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: return nn.batch_normalization(inputs, mean, variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon)
def normalize(self, inputs): """Apply normalization to input. The shape must match the declared shape in the constructor. [This is copied from tf.contrib.rnn.LayerNormBasicLSTMCell.] Args: inputs: Input tensor Returns: Normalized version of input tensor. Raises: ValueError: if inputs has undefined rank. """ inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) axis = range(1, inputs_rank) beta = self._component.get_variable('beta_%s' % self._name) gamma = self._component.get_variable('gamma_%s' % self._name) with tf.variable_scope('layer_norm_%s' % self._name): # Calculate the moments on the last axis (layer activations). mean, variance = nn.moments(inputs, axis, keep_dims=True) # Compute layer normalization using the batch_normalization function. variance_epsilon = 1E-12 outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, variance_epsilon) outputs.set_shape(inputs_shape) return outputs
def __call__(self, inputs): """Run virtual batch normalization on inputs. Args: inputs: Tensor input. Returns: A virtual batch normalized version of `inputs`. Raises: ValueError: If `inputs` shape isn't compatible with the reference batch. """ _validate_call_input([inputs, self._reference_batch], self._batch_axis) with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]): # Calculate the statistics on the current input on a per-example basis. vb_mean, vb_mean_sq = self._virtual_statistics( inputs, self._example_reduction_axes) vb_variance = vb_mean_sq - math_ops.square(vb_mean) # The exact broadcast shape of the input statistic Tensors depends on the # current batch, not the reference batch. The parameter broadcast shape # is independent of the shape of the input statistic Tensor dimensions. b_shape = self._broadcast_shape[:] # deep copy b_shape[self._batch_axis] = _static_or_dynamic_batch_size( inputs, self._batch_axis) return nn.batch_normalization( inputs, self._broadcast(vb_mean, b_shape), self._broadcast(vb_variance, b_shape), self._broadcast(self._beta, self._broadcast_shape), self._broadcast(self._gamma, self._broadcast_shape), self._epsilon)
def call(self, inputs): # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) # Calculate the moments on the last axis (layer activations). mean, variance = nn.moments(inputs, self.norm_axis, keep_dims=True) # Broadcasting only necessary for norm where the params axes aren't just # the last dimension broadcast_shape = [1] * ndims for dim in self.params_axis: broadcast_shape[dim] = input_shape.dims[dim].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and self.params_axis != [ndims - 1]): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # Compute layer normalization using the batch_normalization function. outputs = nn.batch_normalization(inputs, mean, variance, offset=offset, scale=scale, variance_epsilon=self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def batch_norm(x, deterministic, alpha=0.9, shift=True, scope='bn'): with vs.variable_scope(scope): dtype = x.dtype input_shape = x.get_shape().as_list() feat_dim = input_shape[-1] axes = range(len(input_shape) - 1) if shift: beta = vs.get_variable(scope + "_beta", shape=[feat_dim], initializer=init_ops.zeros_initializer, dtype=dtype) else: beta = vs.get_variable(scope + "_beta", shape=[feat_dim], initializer=init_ops.zeros_initializer, dtype=dtype, trainable=False) gamma = vs.get_variable(scope + "_gamma", shape=[feat_dim], initializer=init_ops.constant_initializer(0.1), dtype=dtype) mean = vs.get_variable(scope + "_mean", shape=[feat_dim], initializer=init_ops.zeros_initializer, dtype=dtype, trainable=False) var = vs.get_variable(scope + "_var", shape=[feat_dim], initializer=init_ops.ones_initializer, dtype=dtype, trainable=False) counter = vs.get_variable(scope + "_counter", shape=[], initializer=init_ops.constant_initializer(0), dtype=tf.int64, trainable=False) zero_cnt = vs.get_variable( scope + "_zero_cnt", shape=[], initializer=init_ops.constant_initializer(0), dtype=tf.int64, trainable=False) batch_mean, batch_var = moments(x, axes, name=scope + '_moments') mean, var = cond(math_ops.equal(counter, zero_cnt), lambda: (batch_mean, batch_var), lambda: (mean, var)) mean, var, counter = cond( deterministic, lambda: (mean, var, counter), lambda: ((1 - alpha) * batch_mean + alpha * mean, (1 - alpha) * batch_var + alpha * var, counter + 1)) normed = batch_normalization(x, mean, var, beta, gamma, 1e-8) return normed
def call(self, inputs): # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) # Calculate the moments on the last axis (layer activations). mean, variance = nn.moments(inputs, self.axis, keep_dims=True) # Broadcasting only necessary for norm where the axis is not just # the last dimension broadcast_shape = [1] * ndims for dim in self.axis: broadcast_shape[dim] = input_shape.dims[dim].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # Compute layer normalization using the batch_normalization function. outputs = nn.batch_normalization( inputs, mean, variance, offset=offset, scale=scale, variance_epsilon=self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def reference_batch_normalization(self): """Return the reference batch, but batch normalized.""" with ops.name_scope(self._vs.name): return nn.batch_normalization(self._reference_batch, self._broadcast(self._ref_mean), self._broadcast(self._ref_variance), self._broadcast(self._beta), self._broadcast(self._gamma), self._epsilon)
def my_batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): """Applies batch normalization on x given mean, var, beta and gamma. I.e. returns: `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta` Arguments: x: Input tensor or variable. mean: Mean of batch. var: Variance of batch. beta: Tensor with which to center the input. gamma: Tensor by which to scale the input. axis: Integer, the axis that should be normalized. (typically the features axis). epsilon: Fuzz factor. Returns: A tensor. """ if K.ndim(x) == 4: print("hey") # The CPU implementation of `fused_batch_norm` only supports NHWC if axis == 1 or axis == -3: tf_data_format = 'NCHW' elif axis == 3 or axis == -1: tf_data_format = 'NHWC' else: tf_data_format = None if (tf_data_format == 'NHWC' or tf_data_format == 'NCHW' and _has_nchw_support()): # The mean / var / beta / gamma tensors may be broadcasted # so they may have extra axes of size 1, which should be squeezed. if K.ndim(mean) > 1: mean = array_ops.reshape(mean, [-1]) if K.ndim(var) > 1: var = array_ops.reshape(var, [-1]) if beta is None: beta = zeros_like(mean) elif K.ndim(beta) > 1: beta = array_ops.reshape(beta, [-1]) if gamma is None: gamma = ones_like(mean) elif K.ndim(gamma) > 1: gamma = array_ops.reshape(gamma, [-1]) y, _, _ = nn.fused_batch_norm(x, gamma, beta, epsilon=epsilon, mean=mean, variance=var, data_format=tf_data_format, is_training=False) return y return tf.map_fn( lambda xx: nn.batch_normalization(xx, mean, var, beta, gamma, epsilon), x)
def normalize_in_test(): if needs_broadcasting: broadcast_moving_mean = array_ops.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape(self.moving_variance, broadcast_shape) return nn.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: return nn.batch_normalization(inputs, self.moving_mean, self.moving_variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon)
def instance_norm(self, inputs, inputs_latent, name): inputs_rank = inputs.shape.ndims n_outputs = np.int(inputs.shape[-1]) n_batch = np.int(inputs.shape[0]) inputs_latent_flatten = tf.layers.flatten(inputs_latent) gamma = self.MLP(inputs_latent_flatten, n_outputs, name+"_gamma") beta = self.MLP(inputs_latent_flatten, n_outputs, name+"_beta") gamma = tf.reshape(gamma, [n_batch, 1, 1, n_outputs]) beta = tf.reshape(beta, [n_batch, 1, 1, n_outputs]) moments_axes = list(range(inputs_rank)) mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) outputs = nn.batch_normalization( inputs, mean, variance, beta, gamma, 1e-6, name=name) return outputs
def batch_norm(x, deterministic, alpha=0.9, shift=True, scope='bn'): with vs.variable_scope(scope): dtype = x.dtype input_shape = x.get_shape().as_list() feat_dim = input_shape[-1] axes = range(len(input_shape)-1) if shift: beta = vs.get_variable( scope+"_beta", shape=[feat_dim], initializer=init_ops.zeros_initializer, dtype=dtype) else: beta = vs.get_variable( scope+"_beta", shape=[feat_dim], initializer=init_ops.zeros_initializer, dtype=dtype, trainable=False) gamma = vs.get_variable( scope+"_gamma", shape=[feat_dim], initializer=init_ops.constant_initializer(0.1), dtype=dtype) mean = vs.get_variable(scope+"_mean", shape=[feat_dim], initializer=init_ops.zeros_initializer, dtype=dtype, trainable=False) var = vs.get_variable(scope+"_var", shape=[feat_dim], initializer=init_ops.ones_initializer, dtype=dtype, trainable=False) counter = vs.get_variable(scope+"_counter", shape=[], initializer=init_ops.constant_initializer(0), dtype=tf.int64, trainable=False) zero_cnt = vs.get_variable(scope+"_zero_cnt", shape=[], initializer=init_ops.constant_initializer(0), dtype=tf.int64, trainable=False) batch_mean, batch_var = moments(x, axes, name=scope+'_moments') mean, var = cond(math_ops.equal(counter, zero_cnt), lambda: (batch_mean, batch_var), lambda: (mean, var)) mean, var, counter = cond(deterministic, lambda: (mean, var, counter), lambda: ((1-alpha) * batch_mean + alpha * mean, (1-alpha) * batch_var + alpha * var, counter + 1)) normed = batch_normalization(x, mean, var, beta, gamma, 1e-8) return normed
def normalize_in_training(): arg_mean = broadcast_mean if needs_broadcasting else mean arg_variance = broadcast_variance if needs_broadcasting else variance arg_beta = broadcast_beta if needs_broadcasting else ( self.beta if self.center else None) arg_gamma = broadcast_gamma if needs_broadcasting else ( self.gamma if self.scale else None) if self.quantizer is None: return nn.batch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon) else: return qbatch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon, self.quantizer)
def normalize_in_test(): if needs_broadcasting: broadcast_moving_mean = array_ops.reshape( self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape( self.moving_variance, broadcast_shape) arg_mean = broadcast_moving_mean if needs_broadcasting else self.moving_mean arg_variance = broadcast_moving_variance if needs_broadcasting else self.moving_variance arg_beta = broadcast_beta if needs_broadcasting else ( self.beta if self.center else None) arg_gamma = broadcast_gamma if needs_broadcasting else ( self.gamma if self.scale else None) if self.quantizer is None: return nn.batch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon) else: return qbatch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon, self.quantizer)
def call(self, inputs, training=True): layer_inputs = inputs[0] mix_weights = inputs[1] self.assign_mixture_value(name="template_beta", mixture_weights=mix_weights) self.assign_mixture_value(name="template_gamma", mixture_weights=mix_weights) norm_input = BatchNormalization.call(self, layer_inputs, training) input_shape = layer_inputs.shape reduction_axes = [ i for i in range(len(input_shape)) if i not in self.axis ] mean, var = nn.moments(norm_input, reduction_axes, keep_dims=True) scale = self._broadcast(self.template_gamma, input_shape) offset = self._broadcast(self.template_beta, input_shape) output = nn.batch_normalization(norm_input, mean, var, offset, scale, self.epsilon) self.reset_all_values() return output
def my_graph(a): with ops.device("/device:IPU:0"): with variable_scope.variable_scope("", use_resource=True): beta = variable_scope.get_variable( "x", dtype=np.float16, shape=[4], initializer=init_ops.constant_initializer(0.0)) gamma = variable_scope.get_variable( "y", dtype=np.float16, shape=[4], initializer=init_ops.constant_initializer(1.0)) b_mean, b_var = nn.moments(a, [0, 1, 2], name='moments') normed = nn.batch_normalization( a, b_mean, b_var, beta, gamma, 1e-3) return normed
def testBatchNormalizeFp16(self): x = array_ops.placeholder(np.float16, [4, 64, 64, 4], name="a") with ops.device("/device:IPU:0"): with variable_scope.variable_scope("", use_resource=True): beta = variable_scope.get_variable( "x", dtype=np.float16, shape=[4], initializer=init_ops.constant_initializer(0.0)) gamma = variable_scope.get_variable( "y", dtype=np.float16, shape=[4], initializer=init_ops.constant_initializer(1.0)) b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments') normed = nn.batch_normalization(x, b_mean, b_var, beta, gamma, 1e-3) with ops.device('cpu'): report = gen_ipu_ops.ipu_event_trace() tu.configure_ipu_system() with tu.ipu_session() as sess: sess.run(report) sess.run(variables.global_variables_initializer()) result = sess.run(normed, {x: np.zeros([4, 64, 64, 4])}) self.assertAllClose(result, np.zeros([4, 64, 64, 4])) rep = sess.run(report) s = tu.extract_all_strings_from_event_trace(rep) cs = tu.get_compute_sets_from_report(s) bl = ['*convert*/Cast*'] self.assertTrue(tu.check_compute_sets_not_in_blacklist(cs, bl))
def instance_norm(inputs, center=True, scale=True, epsilon=1e-6, activation_fn=None, param_initializers=None, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, data_format=DATA_FORMAT_NHWC, scope=None): """Functional interface for the instance normalization layer. Reference: https://arxiv.org/abs/1607.08022. "Instance Normalization: The Missing Ingredient for Fast Stylization" Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). data_format: A string. `NHWC` (default) and `NCHW` are supported. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs.shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with variable_scope.variable_scope( scope, 'InstanceNorm', [inputs], reuse=reuse) as sc: if data_format == DATA_FORMAT_NCHW: reduction_axis = 1 # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: reduction_axis = inputs_rank - 1 params_shape_broadcast = None moments_axes = list(range(inputs_rank)) del moments_axes[reduction_axis] del moments_axes[0] params_shape = inputs_shape[reduction_axis:reduction_axis + 1] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % ( inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None dtype = inputs.dtype.base_dtype if param_initializers is None: param_initializers = {} if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta_initializer = param_initializers.get( 'beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if params_shape_broadcast: beta = array_ops.reshape(beta, params_shape_broadcast) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get( 'gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) if params_shape_broadcast: gamma = array_ops.reshape(gamma, params_shape_broadcast) # Calculate the moments (instance activations). mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) # Compute instance normalization. outputs = nn.batch_normalization( inputs, mean, variance, beta, gamma, epsilon, name='instancenorm') if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def call(self, inputs): # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) # Broadcasting only necessary for norm where the axis is not just # the last dimension broadcast_shape = [1] * ndims for dim in self.axis: broadcast_shape[dim] = input_shape.dims[dim].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): return array_ops.reshape(v, broadcast_shape) return v if not self._fused: input_dtype = inputs.dtype if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32': # If mixed precision is used, cast inputs to float32 so that this is at # least as numerically stable as the fused version. inputs = math_ops.cast(inputs, 'float32') # Calculate the moments on the last axis (layer activations). mean, variance = nn.moments(inputs, self.axis, keep_dims=True) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # Compute layer normalization using the batch_normalization function. outputs = nn.batch_normalization( inputs, mean, variance, offset=offset, scale=scale, variance_epsilon=self.epsilon) outputs = math_ops.cast(outputs, input_dtype) else: # Collapse dims before self.axis, and dims in self.axis pre_dim, in_dim = (1, 1) axis = sorted(self.axis) tensor_shape = array_ops.shape(inputs) for dim in range(0, ndims): dim_tensor = tensor_shape[dim] if dim < axis[0]: pre_dim = pre_dim * dim_tensor else: assert dim in axis in_dim = in_dim * dim_tensor squeezed_shape = [1, pre_dim, in_dim, 1] # This fused operation requires reshaped inputs to be NCHW. data_format = 'NCHW' inputs = array_ops.reshape(inputs, squeezed_shape) def _set_const_tensor(val, dtype, shape): return array_ops.fill(shape, constant_op.constant(val, dtype=dtype)) # self.gamma and self.beta have the wrong shape for fused_batch_norm, so # we cannot pass them as the scale and offset parameters. Therefore, we # create two constant tensors in correct shapes for fused_batch_norm and # later construct a separate calculation on the scale and offset. scale = _set_const_tensor(1.0, self.dtype, [pre_dim]) offset = _set_const_tensor(0.0, self.dtype, [pre_dim]) # Compute layer normalization using the fused_batch_norm function. outputs, _, _ = nn.fused_batch_norm( inputs, scale=scale, offset=offset, epsilon=self.epsilon, data_format=data_format) outputs = array_ops.reshape(outputs, tensor_shape) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if scale is not None: outputs = outputs * math_ops.cast(scale, outputs.dtype) if offset is not None: outputs = outputs + math_ops.cast(offset, outputs.dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, scope=None): """Code modification of tensorflow/contrib/layers/python/layers/layers.py """ with variable_scope.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined last dimension %s.' % ( inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer, trainable=False, collections=moving_variance_collections) # Calculate the moments based on the individual batch. mean, variance = nn.moments(inputs, axis, shift=moving_mean) # Update the moving_mean and moving_variance moments. update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) if updates_collections is None: # Make sure the updates are computed here. with ops.control_dependencies([update_moving_mean, update_moving_variance]): outputs = nn.batch_normalization( inputs, mean, variance, beta, gamma, epsilon) else: # Collect the updates to be computed later. ops.add_to_collections(updates_collections, update_moving_mean) ops.add_to_collections(updates_collections, update_moving_variance) outputs = nn.batch_normalization( inputs, mean, variance, beta, gamma, epsilon) test_outputs = nn.batch_normalization( inputs, moving_mean, moving_variance, beta, gamma, epsilon) outputs = tf.cond(is_training, lambda: outputs, lambda: test_outputs) outputs.set_shape(inputs_shape) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def call(self, inputs, training=False): if self.num_virtual_batches > 1: # Virtual batches (aka ghost batches) can be simulated by using some # reshape/transpose tricks on top of base batch normalization. original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [-1, self.num_virtual_batches] + original_shape[1:] # Will cause errors if num_virtual_batches does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) ndims = len(expanded_shape) if self.axis < 0: axis = ndims + self.axis else: axis = self.axis + 1 # Account for the added dimension # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis # TODO(b/66257056): when multi-axis batch normalization is implemented, # this permutation trick and the combined_dim reshape are no longer # necessary and can be reworked to simply use broadcasting. permutation = ([0] + list(range(2, axis)) + [1, axis] + list(range(axis + 1, ndims))) inverse_permutation = [x[1] for x in sorted(zip(permutation, range(ndims)))] inputs = array_ops.transpose(inputs, perm=permutation) # Combine the axis and num_virtual_batch dimension in order to take # advantage of fused batch normalization combined_dim = expanded_shape[1] * expanded_shape[axis] perm_shape = [-1] + inputs.shape.as_list()[1:] combined_shape = (perm_shape[:axis - 1] + [combined_dim] + perm_shape[axis + 1:]) inputs = array_ops.reshape(inputs, combined_shape) # After the above reshape, the batch norm axis is the original self.axis # Undoes the reshaping and transposing tricks done above def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, perm_shape) outputs = array_ops.transpose(outputs, perm=inverse_permutation) outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.num_virtual_batches > 1: return undo_virtual_batching(outputs) return outputs # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) scale, offset = self.gamma, self.beta # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. mean, variance = nn.moments(inputs, reduction_axes) mean = _smart_select(training, lambda: mean, lambda: self.moving_mean) variance = _smart_select(training, lambda: variance, lambda: self.moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. scale = array_ops.stop_gradient(r, name='renorm_r') offset = array_ops.stop_gradient(d, name='renorm_d') if self.gamma is not None: scale *= self.gamma offset *= self.gamma if self.beta is not None: offset += self.beta else: new_mean, new_variance = mean, variance # Update moving averages when training, and prevent updates otherwise. decay = _smart_select(training, lambda: self.momentum, lambda: 1.) mean_update = moving_averages.assign_moving_average( self.moving_mean, new_mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, new_variance, decay, zero_debias=False) if context.in_graph_mode(): self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) else: mean, variance = self.moving_mean, self.moving_variance def _broadcast(v): if needs_broadcasting and v is not None: # In this case we must explicitly broadcast all parameters. return array_ops.reshape(v, broadcast_shape) return v outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), _broadcast(offset), _broadcast(scale), self.epsilon) if self.num_virtual_batches > 1: return undo_virtual_batching(outputs) return outputs
def batch_norm_mine_old(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=False, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99): """ This earlier version of my modification to batch norm uses current_mean and current_variance if is_training is True and moving_mean and moving_variance otherwise. This was leading a large divergence between the results depending upon whether the is_training set to True or not. I think ideally it should always use moving_mean and moving_variance. batch_norm_mine does this. Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. copy of tensorflow.contrib.layers Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). batch_weights: An optional tensor of shape `[batch_size]`, containing a frequency weight for each batch item. If present, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. scope: Optional scope for `variable_scope`. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to scalar `Tensors` used to clip the renorm correction. The correction `(r, d)` is used as `corrected_value = normalized_value * r + d`, with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_decay: Momentum used to update the moving means and standard deviations with renorm. Unlike `momentum`, this affects training and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `batch_weights` is not None and `fused` is True. ValueError: If `param_regularizers` is not None and `fused` is True. ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ if fused: if batch_weights is not None: raise ValueError('Weighted mean and variance is not currently ' 'supported for fused batch norm.') if param_regularizers is not None: raise ValueError('Regularizers are not currently ' 'supported for fused batch norm.') if renorm: raise ValueError('Renorm is not supported for fused batch norm.') return _fused_batch_norm( inputs, decay=decay, center=center, scale=scale, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, variables_collections=variables_collections, outputs_collections=outputs_collections, trainable=trainable, data_format=data_format, zero_debias_moving_mean=zero_debias_moving_mean, scope=scope) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) # Determine whether we can use the core layer class. if (batch_weights is None and updates_collections is ops.GraphKeys.UPDATE_OPS and not zero_debias_moving_mean): # Use the core layer class. axis = 1 if data_format == DATA_FORMAT_NCHW else -1 if not param_initializers: param_initializers = {} beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) if not param_regularizers: param_regularizers = {} beta_regularizer = param_regularizers.get('beta') gamma_regularizer = param_regularizers.get('gamma') layer = normalization_layers.BatchNormalization( axis=axis, momentum=decay, epsilon=epsilon, center=center, scale=scale, beta_initializer=beta_initializer, gamma_initializer=gamma_initializer, moving_mean_initializer=moving_mean_initializer, moving_variance_initializer=moving_variance_initializer, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, trainable=trainable, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_decay, name=sc.name, _scope=sc, _reuse=reuse) outputs = layer.apply(inputs, training=is_training) # Add variables to collections. _add_variable_to_collections( layer.moving_mean, variables_collections, 'moving_mean') _add_variable_to_collections( layer.moving_variance, variables_collections, 'moving_variance') if layer.beta: _add_variable_to_collections(layer.beta, variables_collections, 'beta') if layer.gamma: _add_variable_to_collections( layer.gamma, variables_collections, 'gamma') if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs) # Not supported by layer class: batch_weights argument, # and custom updates_collections. In that case, use the legacy BN # implementation. # Custom updates collections are not supported because the update logic # is different in this case, in particular w.r.t. "forced updates" and # update op reuse. if renorm: raise ValueError('renorm is not supported with batch_weights, ' 'updates_collections or zero_debias_moving_mean') inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if batch_weights is not None: batch_weights = ops.convert_to_tensor(batch_weights) inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape()) # Reshape batch weight values so they broadcast across inputs. nshape = [-1] + [1 for _ in range(inputs_rank - 1)] batch_weights = array_ops.reshape(batch_weights, nshape) if data_format == DATA_FORMAT_NCHW: moments_axes = [0] + list(range(2, inputs_rank)) params_shape = inputs_shape[1:2] # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % ( inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if not param_initializers: param_initializers = {} if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables. partitioner = variable_scope.get_variable_scope().partitioner try: variable_scope.get_variable_scope().set_partitioner(None) moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) finally: variable_scope.get_variable_scope().set_partitioner(partitioner) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. if batch_weights is None: if data_format == DATA_FORMAT_NCHW: mean, _ = nn.moments(inputs, moments_axes, keep_dims=True) variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, _ = nn.moments(inputs, moments_axes) variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes) else: if data_format == DATA_FORMAT_NCHW: mean, _ = nn.weighted_moments(inputs, moments_axes, batch_weights, keep_dims=True) variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes, batch_weights, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, _ = nn.weighted_moments(inputs, moments_axes, batch_weights) variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes, batch_weights) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) with ops.control_dependencies([update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance if data_format == DATA_FORMAT_NCHW: mean = array_ops.reshape(mean, params_shape_broadcast) variance = array_ops.reshape(variance, params_shape_broadcast) beta = array_ops.reshape(beta, params_shape_broadcast) if gamma is not None: gamma = array_ops.reshape(gamma, params_shape_broadcast) # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def call(self, inputs, training=False): if self.num_virtual_batches > 1: # Virtual batches (aka ghost batches) can be simulated by using some # reshape/transpose tricks on top of base batch normalization. original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [-1, self.num_virtual_batches ] + original_shape[1:] # Will cause errors if num_virtual_batches does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) ndims = len(expanded_shape) if self.axis < 0: axis = ndims + self.axis else: axis = self.axis + 1 # Account for the added dimension # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis # TODO(b/66257056): when multi-axis batch normalization is implemented, # this permutation trick and the combined_dim reshape are no longer # necessary and can be reworked to simply use broadcasting. permutation = ([0] + list(range(2, axis)) + [1, axis] + list(range(axis + 1, ndims))) inverse_permutation = [ x[1] for x in sorted(zip(permutation, range(ndims))) ] inputs = array_ops.transpose(inputs, perm=permutation) # Combine the axis and num_virtual_batch dimension in order to take # advantage of fused batch normalization combined_dim = expanded_shape[1] * expanded_shape[axis] perm_shape = [-1] + inputs.shape.as_list()[1:] combined_shape = (perm_shape[:axis - 1] + [combined_dim] + perm_shape[axis + 1:]) inputs = array_ops.reshape(inputs, combined_shape) # After the above reshape, the batch norm axis is the original self.axis # Undoes the reshaping and transposing tricks done above def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, perm_shape) outputs = array_ops.transpose(outputs, perm=inverse_permutation) outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.num_virtual_batches > 1: return undo_virtual_batching(outputs) return outputs # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) scale, offset = self.gamma, self.beta # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. mean, variance = nn.moments(inputs, reduction_axes) mean = _smart_select(training, lambda: mean, lambda: self.moving_mean) variance = _smart_select(training, lambda: variance, lambda: self.moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. scale = array_ops.stop_gradient(r, name='renorm_r') offset = array_ops.stop_gradient(d, name='renorm_d') if self.gamma is not None: scale *= self.gamma offset *= self.gamma if self.beta is not None: offset += self.beta else: new_mean, new_variance = mean, variance # Update moving averages when training, and prevent updates otherwise. decay = _smart_select(training, lambda: self.momentum, lambda: 1.) mean_update = moving_averages.assign_moving_average( self.moving_mean, new_mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, new_variance, decay, zero_debias=False) if context.in_graph_mode(): self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) else: mean, variance = self.moving_mean, self.moving_variance def _broadcast(v): if needs_broadcasting and v is not None: # In this case we must explicitly broadcast all parameters. return array_ops.reshape(v, broadcast_shape) return v outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), _broadcast(offset), _broadcast(scale), self.epsilon) if self.num_virtual_batches > 1: return undo_virtual_batching(outputs) return outputs
def call(self, inputs, training=False): if self.fused: return self._fused_batch_norm(inputs, training=training) # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) scale, offset = self.gamma, self.beta # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. mean, variance = nn.moments(inputs, reduction_axes) mean = _smart_select(training, lambda: mean, lambda: self.moving_mean) variance = _smart_select(training, lambda: variance, lambda: self.moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. scale = array_ops.stop_gradient(r, name='renorm_r') offset = array_ops.stop_gradient(d, name='renorm_d') if self.gamma is not None: scale *= self.gamma offset *= self.gamma if self.beta is not None: offset += self.beta else: new_mean, new_variance = mean, variance # Update moving averages when training, and prevent updates otherwise. decay = _smart_select(training, lambda: self.momentum, lambda: 1.) mean_update = moving_averages.assign_moving_average( self.moving_mean, new_mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, new_variance, decay, zero_debias=False) self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) else: mean, variance = self.moving_mean, self.moving_variance def _broadcast(v): if needs_broadcasting and v is not None: # In this case we must explicitly broadcast all parameters. return array_ops.reshape(v, broadcast_shape) return v return nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), _broadcast(offset), _broadcast(scale), self.epsilon)
def call(self, inputs, training=False): if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching return undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = utils.smart_cond(training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = utils.smart_cond(training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) else: new_mean, new_variance = mean, variance if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(new_mean, axis=1, keep_dims=True) new_variance = math_ops.reduce_mean(new_variance, axis=1, keep_dims=True) def _do_update(var, value): return moving_averages.assign_moving_average( var, value, self.momentum, zero_debias=False) mean_update = utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if context.in_graph_mode(): self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) else: mean, variance = self.moving_mean, self.moving_variance outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: return undo_virtual_batching(outputs) return outputs
def call(self, inputs, training=None): if training is None: training = K.learning_phase() in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: strategy.unwrap(self.moving_mean)) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Args: -inputs: a tensor of size `[batch_size, height, width, channels]` or `[batch_size, channels]`. -decay: decay for the moving average. -center: If True, subtract `beta`. If False, `beta` is ignored. -scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. -epsilon: small float added to variance to avoid dividing by zero. -activation_fn: Optional activation function. -updates_collections: collections to collect the update ops for computation. If None, a control dependency would be added to make sure the updates are computed. -is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. -reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. -variables_collections: optional collections for the variables. -outputs_collections: collections to add the outputs. -trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). -scope: Optional scope for `variable_op_scope`. Returns: a tensor representing the output of the operation. """ with variable_scope.variable_op_scope([inputs],scope, 'BatchNorm', reuse=reuse) as sc: inputs_shape = inputs.get_shape() dtype = inputs.dtype.base_dtype axis = list(range(len(inputs_shape) - 1)) params_shape = inputs_shape[-1:] # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: beta_collections = utils.get_variable_collections(variables_collections,'beta') beta = variables.model_variable('beta',shape=params_shape,dtype=dtype,initializer=init_ops.zeros_initializer,collections=beta_collections,trainable=trainable) if scale: gamma_collections = utils.get_variable_collections(variables_collections,'gamma') gamma = variables.model_variable('gamma',shape=params_shape,dtype=dtype,initializer=init_ops.ones_initializer,collections=gamma_collections,trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections(variables_collections, 'moving_mean') moving_mean = variables.model_variable('moving_mean',shape=params_shape,dtype=dtype,initializer=init_ops.zeros_initializer,trainable=False,collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections(variables_collections, 'moving_variance') moving_variance = variables.model_variable('moving_variance',shape=params_shape,dtype=dtype,initializer=init_ops.ones_initializer,trainable=False,collections=moving_variance_collections) if is_training: # Calculate the moments based on the individual batch. mean, variance = nn.moments(inputs, axis, shift=moving_mean) # Update the moving_mean and moving_variance moments. update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay) if updates_collections is None: # Make sure the updates are computed here. with ops.control_dependencies([update_moving_mean,update_moving_variance]): outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) else: # Collect the updates to be computed later. ops.add_to_collections(updates_collections, update_moving_mean) ops.add_to_collections(updates_collections, update_moving_variance) outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) else: outputs = nn.batch_normalization( inputs, moving_mean, moving_variance, beta, gamma, epsilon) outputs.set_shape(inputs.get_shape()) if activation_fn: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, scope=None): """Code modification of tensorflow/contrib/layers/python/layers/layers.py """ with variable_scope.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined last dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer, trainable=False, collections=moving_variance_collections) # Calculate the moments based on the individual batch. mean, variance = nn.moments(inputs, axis, shift=moving_mean) # Update the moving_mean and moving_variance moments. update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) if updates_collections is None: # Make sure the updates are computed here. with ops.control_dependencies( [update_moving_mean, update_moving_variance]): outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) else: # Collect the updates to be computed later. ops.add_to_collections(updates_collections, update_moving_mean) ops.add_to_collections(updates_collections, update_moving_variance) outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) test_outputs = nn.batch_normalization(inputs, moving_mean, moving_variance, beta, gamma, epsilon) outputs = tf.cond(is_training, lambda: outputs, lambda: test_outputs) outputs.set_shape(inputs_shape) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def call(self, inputs, training=None): training = self._get_training_value(training) if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison mean, variance = self.moving_mean, self.moving_variance else: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(math_ops.cast( inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor(moving_variance)) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training, inputs_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" def true_branch_renorm(): # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = _do_update( self.moving_stddev, math_ops.sqrt(new_variance + self.epsilon)) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) if self.renorm: true_branch = true_branch_renorm else: true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
input_channels = 8 fixed_size = 8 fixed_prec = 4 testdata_scale = 10 inputs_vals = np.random.normal(size=(batch_size, input_width, input_height, input_channels)) * testdata_scale // 1 inputs = tf.constant(inputs_vals, dtype=tf.float64) means, variances = nn.moments(inputs, [0, 1, 2, 3]) quantizer = Quantizers.NoQuantizer() output = QBN.qbatch_normalization(inputs, means, variances, None, None, 0.0001, quantizer) gold_output = nn.batch_normalization(inputs, means, variances, None, None, 0.0001) with tf.Session() as sess: gold_result = gold_output.eval().flatten() result = output.eval().flatten() #print(sess.run(output)) #print('------------') #print(sess.run(gold_output)) print('mean: %f' % (sess.run(means))) print('variance: %f' % (sess.run(variances))) pass failed = False for i in range(len(result)): if result[i] != gold_result[i]: failed = True
def call(self, inputs, training=None): if training is None: training = K.learning_phase() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond(training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond(training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = self._moments( math_ops.cast(inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: strategy.unwrap(self.moving_mean) return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): return tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def _RedoRestBatchnorms(graph, is_training): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. is_training: Bool, true if training. Raises: ValueError: When batch norm folding fails. """ matches = _FindRestBatchNorms(graph) print("Replacing", len(matches), "BatchNorms (without a preceding Conv2D)") for match in matches: scope, sep, _ = match.bn_op.name.rpartition('/') # Make sure new ops are added to `graph` and put on the same device as # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope # named `scope`. Otherwise, TF creates a unique scope whose name starts with # `scope`. with graph.as_default(), graph.name_scope(scope + sep): with graph.name_scope(scope + sep + '_psb' + sep): mean = match.mean_tensor variance = match.variance_tensor beta = match.beta_tensor gamma = match.gamma_tensor eps = match.batch_epsilon # new gamma = gamma / sqrt(variance + epsilon) # new biases = -mean * gamma / sqrt(variance + epsilon) + beta multfac = gamma / math_ops.sqrt(variance + eps) gamma = multfac beta = -multfac * mean + beta mean = array_ops.zeros_like(mean) variance = array_ops.ones_like(variance) eps = array_ops.zeros_like(eps) gamma = variableFromSettings([], hiddenVar=gamma)[0] # gamma = fixed_point(gamma,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min")) # gamma = next_base2(gamma,strict_positive=False) # gamma = 1/variableFromSettings([],hiddenVar=1/gamma)[0] # variance = variableFromSettings([],hiddenVar=math_ops.sqrt(variance+eps))[0]**2 # beta = variableFromSettings([],hiddenVar=beta)[0] if S("util.variable.fixed_point.use"): beta = fixed_point(beta, S("util.variable.fixed_point.bits"), max=S("util.variable.fixed_point.max"), min=S("util.variable.fixed_point.min")) # gamma = fixed_point(gamma,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min")) # mean = fixed_point(mean,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min")) # variance = fixed_point(variance,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min")) # fixed_point division could be ok # silly silly_idiv(silly x, silly y) { # uint64_t sign_bit = 1UL<<63; # // unsetting the sign bit to ignore it # silly res = ((x & ~sign_bit) / (y & sign_bit)) << 32; # // setting the sign bit iff only one of sign bits is set # res |= (x & sign_bit) ^ (y & sign_bit); # return res; # } new_layer_tensor = nn.batch_normalization( match.input_tensor, mean, variance, beta, gamma, eps, name=match.bn_op.name.split("/")[-1] + "_psb") if S("util.variable.fixed_point.use"): new_layer_tensor = fixed_point( new_layer_tensor, S("util.variable.fixed_point.bits"), max=S("util.variable.fixed_point.max"), min=S("util.variable.fixed_point.min")) nodes_modified_count = common.RerouteTensor( new_layer_tensor, match.output_tensor) if nodes_modified_count == 0: raise ValueError( 'Folding batch norms failed, %s had no outputs.' % match['output_tensor'].name)
def call(self, inputs, training=False): if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching return undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim scale, offset = self.gamma, self.beta # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. scale = array_ops.stop_gradient(r, name='renorm_r') offset = array_ops.stop_gradient(d, name='renorm_d') if self.gamma is not None: scale *= self.gamma offset *= self.gamma if self.beta is not None: offset += self.beta else: new_mean, new_variance = mean, variance # Update moving averages when training, and prevent updates otherwise. decay = utils.smart_cond(training, lambda: self.momentum, lambda: 1.) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(new_mean, axis=1, keep_dims=True) new_variance = math_ops.reduce_mean(new_variance, axis=1, keep_dims=True) mean_update = moving_averages.assign_moving_average( self.moving_mean, new_mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, new_variance, decay, zero_debias=False) if context.in_graph_mode(): self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) else: mean, variance = self.moving_mean, self.moving_variance # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value rank = len(inputs.get_shape()) def _broadcast(v): if (v is not None and len(v.get_shape()) != rank and reduction_axes != list(range(ndims))[:-1]): return array_ops.reshape(v, broadcast_shape) return v outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), _broadcast(offset), _broadcast(scale), self.epsilon) if self.virtual_batch_size is not None: return undo_virtual_batching(outputs) return outputs
def call(self, inputs, params=None, training=None): if params[self.name + '/gamma:0'] is None: return super(layers.BatchNormalization, self).call(inputs) else: gamma = params.get(self.name + '/gamma:0') beta = params.get(self.name + '/beta:0') original_training_value = training if training is None: training = backend.learning_phase() # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(gamma), _broadcast(beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) # mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) if original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs
def _subdiv_batch_norm(self, inputs, training=None): # tf.print('bn', self.local_count) training = self._get_training_value(training) inputs_dtype = inputs.dtype.base_dtype if inputs_dtype in (dtypes.float16, dtypes.bfloat16): # Do all math in float32 if given 16-bit inputs for numeric stability. # In particular, it's very easy for variance to overflow in float16 and # for safety we also choose to cast bfloat16 to float32. inputs = math_ops.cast(inputs, dtypes.float32) params_dtype = self._param_dtype # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # what does this do... def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # is training value true false or None training_value = control_flow_util.constant_value(training) update_value = (self.local_count + 1) % self.subdivisions == 0 if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison mean, variance = self.moving_mean, self.moving_variance else: # training_value could be True or None -> None means determine at runtime if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = control_flow_util.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = control_flow_util.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 # normalization stats for the current batch important = mean and squared_mean mean, net_sum, variance, squared_mean, input_batch_size = self.subdiv_moments( math_ops.cast(inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) # aggregate the things def _update_aggragate_sum(): return self._assign_subdiv_rotating_sum( self.aggregated_sum_batch, net_sum, self.subdivisions, self.local_count, input_batch_size) def _update_aggragate_squared_sum(): return self._assign_subdiv_rotating_sum( self.aggregated_square_sum_batch, squared_mean, self.subdivisions, self.local_count, input_batch_size) def _update_aggragate_batch_size(): return self._assign_subdiv_rotating_sum( self.aggregated_batch_size, input_batch_size, self.subdivisions, self.local_count, input_batch_size) self.add_update(_update_aggragate_sum) self.add_update(_update_aggragate_squared_sum) self.add_update(_update_aggragate_batch_size) aggregated_mean = self.aggregated_sum_batch / math_ops.cast( self.aggregated_batch_size, params_dtype) aggregated_squared_mean = self.aggregated_square_sum_batch / math_ops.cast( self.aggregated_batch_size, params_dtype) aggregated_variance = aggregated_squared_mean - math_ops.square( aggregated_mean) moving_mean = self.moving_mean moving_variance = self.moving_variance # if we are training use the stats for this batch for normalizing this # value other wise use the moving average # should only happen when we update the moving values mean = control_flow_util.smart_cond( training, true_fn=lambda: mean, false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( moving_mean)) variance = control_flow_util.smart_cond( training, true_fn=lambda: variance, false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( moving_variance)) # circular update of the mean and variance new_mean = control_flow_util.smart_cond( update_value, true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( aggregated_mean), false_fn=lambda: moving_mean) new_variance = control_flow_util.smart_cond( update_value, true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch( aggregated_variance), false_fn=lambda: moving_variance) # # should only be done when the moving mean is updated # tf.print(new_variance, self.local_count, update_value, self.aggregated_batch_size, self.aggregated_sum_batch) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training, input_batch_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, self.aggregated_batch_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return control_flow_util.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" def true_branch_renorm(): # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = _do_update( self.moving_stddev, math_ops.sqrt(new_variance + self.epsilon)) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) if self.renorm: true_branch = true_branch_renorm else: true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return control_flow_util.smart_cond(training, true_branch, false_branch) def update_count(): with K.name_scope('update_count') as scope: # update the local count return state_ops.assign_add(self.local_count, tf.cast( 1, self.local_count.dtype), name=scope) self.add_update(mean_update) self.add_update(variance_update) self.add_update(update_count) mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) if inputs_dtype in (dtypes.float16, dtypes.bfloat16): outputs = math_ops.cast(outputs, inputs_dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def fused_layer_norm(inputs, center=True, scale=True, activation_fn=None, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, begin_norm_axis=1, begin_params_axis=-1, scope=None, use_fused_batch_norm=False): with tf.compat.v1.variable_scope(scope, 'LayerNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if begin_norm_axis < 0: begin_norm_axis = inputs_rank + begin_norm_axis if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank: raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) ' 'must be < rank(inputs) (%d)' % (begin_params_axis, begin_norm_axis, inputs_rank)) params_shape = inputs_shape[begin_params_axis:] if not params_shape.is_fully_defined(): raise ValueError( 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (inputs.name, begin_params_axis, inputs_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer(), collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) if use_fused_batch_norm: # get static TensorShape if fully defined, # otherwise retrieve shape tensor norm_shape = inputs.shape[begin_norm_axis:] if norm_shape.is_fully_defined(): bn_shape = [1, -1, 1, numpy.prod(norm_shape.as_list())] else: norm_shape = tf.shape(input=inputs)[begin_norm_axis:] bn_shape = [1, -1, 1, tf.reduce_prod(input_tensor=norm_shape)] if inputs.get_shape().is_fully_defined(): outputs_shape = inputs.get_shape() else: outputs_shape = tf.shape(input=inputs) inputs = array_ops.reshape(inputs, bn_shape) if inputs.get_shape().is_fully_defined(): # static inputs TensorShape fully defined after reshape. ones = array_ops.ones(inputs.get_shape()[1], dtype=dtypes.float32) zeros = array_ops.zeros(inputs.get_shape()[1], dtype=dtypes.float32) else: # static inputs TensorShape NOT fully defined after reshape. # must use dynamic shape, which means these input tensors # have to be created at runtime, which causes a slowdown. scale_shape = tf.shape(input=inputs)[1] ones = array_ops.ones(scale_shape, dtype=dtypes.float32) zeros = array_ops.zeros(scale_shape, dtype=dtypes.float32) outputs, mean, variance = nn.fused_batch_norm(inputs, ones, zeros, epsilon=1e-4, data_format="NCHW") outputs = array_ops.reshape(outputs, outputs_shape) if center and scale: outputs = outputs * gamma + beta elif center: outputs = outputs + beta elif scale: outputs = outputs * gamma else: # Calculate the moments on the last axis (layer activations). norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. variance_epsilon = 1e-4 outputs = nn.batch_normalization(inputs, mean, variance, offset=beta, scale=gamma, variance_epsilon=variance_epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def call(self, inputs, training=False): # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) scale, offset = self.gamma, self.beta # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. mean, variance = nn.moments(inputs, reduction_axes) mean = _smart_select(training, lambda: mean, lambda: self.moving_mean) variance = _smart_select(training, lambda: variance, lambda: self.moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. scale = array_ops.stop_gradient(r, name='renorm_r') offset = array_ops.stop_gradient(d, name='renorm_d') if self.gamma is not None: scale *= self.gamma offset *= self.gamma if self.beta is not None: offset += self.beta else: new_mean, new_variance = mean, variance # Update moving averages when training, and prevent updates otherwise. decay = _smart_select(training, lambda: self.momentum, lambda: 1.) mean_update = moving_averages.assign_moving_average( self.moving_mean, new_mean, decay, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, new_variance, decay, zero_debias=False) if not self.updates: self.add_update(mean_update) self.add_update(variance_update) else: mean, variance = self.moving_mean, self.moving_variance def _broadcast(v): if needs_broadcasting and v is not None: # In this case we must explictly broadcast all parameters. return array_ops.reshape(v, broadcast_shape) return v return nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), _broadcast(offset), _broadcast(scale), self.epsilon)
def layer_norm(inputs, center=True, scale=True, activation_fn=None, reuse=None, trainable=True, begin_norm_axis=1, begin_params_axis=-1, scope=None): """Adds a Layer Normalization layer. Based on the paper: "Layer Normalization" Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton https://arxiv.org/abs/1607.06450. Can be used as a normalizer function for conv2d and fully_connected. Given a tensor `inputs` of rank `R`, moments are calculated and normalization is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering, if requested, is performed over axes `begin_params_axis .. R - 1`. By default, `begin_norm_axis = 1` and `begin_params_axis = -1`, meaning that normalization is performed over all but the first axis (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable parameters are calculated for the rightmost axis (the `C` if `inputs` is `NHWC`). Scaling and recentering is performed via broadcast of the `beta` and `gamma` parameters with the normalized tensor. The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`, and this part of the inputs' shape must be fully defined. Args: inputs: A tensor having rank `R`. The normalization is performed over axes `begin_norm_axis ... R - 1` and centering and scaling parameters are calculated over `begin_params_axis ... R - 1`. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). begin_norm_axis: The first normalization dimension: normalization will be performed along dimensions `begin_norm_axis : rank(inputs)` begin_params_axis: The first parameter (beta, gamma) dimension: scale and centering parameters will have dimensions `begin_params_axis : rank(inputs)` and will be broadcast with the normalized inputs accordingly. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation, having the same shape and dtype as `inputs`. Raises: ValueError: If the rank of `inputs` is not known at graph build time, or if `inputs.shape[begin_params_axis:]` is not fully defined at graph build time. """ with variable_scope.variable_scope(scope, 'LayerNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if begin_norm_axis < 0: begin_norm_axis = inputs_rank + begin_norm_axis if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank: raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) ' 'must be < rank(inputs) (%d)' % (begin_params_axis, begin_norm_axis, inputs_rank)) params_shape = inputs_shape[begin_params_axis:] if not params_shape.is_fully_defined(): raise ValueError( 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (inputs.name, begin_params_axis, inputs_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: beta = tf.get_variable(name='beta', shape=params_shape, dtype=dtype, initializer=tf.zeros_initializer(), trainable=trainable) if scale: gamma = tf.get_variable(name='gamma', shape=params_shape, dtype=dtype, initializer=tf.zeros_initializer(), trainable=trainable) # Calculate the moments on the last axis (layer activations). norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. variance_epsilon = 1e-12 outputs = nn.batch_normalization(inputs, mean, variance, offset=beta, scale=gamma, variance_epsilon=variance_epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return collect_named_outputs(None, sc.name, outputs)
def call(self, inputs, training=None): if self.scale and self.gamma_quantizer: quantized_gamma = self.gamma_quantizer_internal(self.gamma) else: quantized_gamma = self.gamma if self.center and self.beta_quantizer: quantized_beta = self.beta_quantizer_internal(self.beta) else: quantized_beta = self.beta if self.mean_quantizer: quantized_moving_mean = self.mean_quantizer_internal( self.moving_mean) else: quantized_moving_mean = self.moving_mean if self.variance_quantizer: quantized_moving_variance = self.variance_quantizer_internal( self.moving_variance) else: quantized_moving_variance = self.moving_variance training = self._get_training_value(training) # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.smart_constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison quantized_mean, quantized_variance = (quantized_moving_mean, quantized_moving_variance) else: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = len(self.axis) > 1 mean, variance = self._moments(math_ops.cast( inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor(moving_variance)) new_mean, new_variance = mean, variance if self.mean_quantizer: quantized_mean = self.mean_quantizer_internal(mean) else: quantized_mean = mean if self.variance_quantizer: quantized_variance = self.variance_quantizer_internal(variance) else: quantized_variance = variance if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) quantized_mean = math_ops.cast(quantized_mean, inputs.dtype) quantized_variance = math_ops.cast(quantized_variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(quantized_mean), _broadcast(quantized_variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def call(self, inputs, training=False): # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) # Determine boolean training boolean value. May be False, True, None. # If None, it is assumed that `training` is a variable to be used in `cond`. if isinstance(training, bool): training_bool = training else: try: training_bool = tensor_util.constant_value(training) except TypeError: training_bool = None # Obtain current current batch mean, variance, if necessary. if training_bool is not False: # Use a copy of moving_mean as a shift to compute more reliable moments. shift = math_ops.add(self.moving_mean, 0) if needs_broadcasting: shift = array_ops.reshape(shift, broadcast_shape) broadcast_mean, broadcast_variance = nn.moments( inputs, reduction_axes, shift=shift, keep_dims=True) mean = array_ops.reshape(broadcast_mean, [-1]) variance = array_ops.reshape(broadcast_variance, [-1]) else: mean, variance = nn.moments(inputs, reduction_axes, shift=shift) # Prepare updates if necessary. if training_bool is not False and not self.updates: mean_update = moving_averages.assign_moving_average( self.moving_mean, mean, self.momentum, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, variance, self.momentum, zero_debias=False) # In the future this should be refactored into a self.add_update # methods in order to allow for instance-based BN layer sharing # across unrelated input streams (e.g. like in Keras). self.updates.append(mean_update) self.updates.append(variance_update) # Normalize batch. if needs_broadcasting: # In this case we must explictly broadcast all parameters. broadcast_moving_mean = array_ops.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None if training_bool is not False: normed_inputs_training = nn.batch_normalization(inputs, broadcast_mean, broadcast_variance, broadcast_beta, broadcast_gamma, self.epsilon) normed_inputs = nn.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, self.epsilon) else: # No need for broadcasting. if training_bool is not False: normed_inputs_training = nn.batch_normalization( inputs, mean, variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon) normed_inputs = nn.batch_normalization(inputs, self.moving_mean, self.moving_variance, self.beta if self.center else None, self.gamma if self.scale else None, self.epsilon) # Return the proper output depending on the boolean training phase. if training_bool is True: return normed_inputs_training if training_bool is False: return normed_inputs return control_flow_ops.cond(training, lambda: normed_inputs_training, lambda: normed_inputs)