def test_constant(self): for v in [True, False, 1, 0, 1.0]: c = constant_op.constant(v) value = utils.constant_value(c) self.assertEqual(value, v) with self.test_session(): self.assertEqual(c.eval(), v)
def _build_update_ops_variance(self, mean, variance, is_training): def build_update_ops(): update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, name="update_moving_mean").op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=variance, decay=self._decay_rate, name="update_moving_variance").op return update_mean_op, update_variance_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the # value of `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) # Every new connection creates a new op which adds its contribution # to the running average when ran. tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_variance_op)
def test_constant(self): for v in [True, False, 1, 0, 1.0]: c = tf.constant(v) value = utils.constant_value(c) self.assertEqual(value, v) with self.test_session(): self.assertEqual(c.eval(), v)
def _build_update_ops_second_moment(self, mean, second_moment, is_training): def build_update_ops(): update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, name="update_moving_mean").op update_second_moment_op = moving_averages.assign_moving_average( variable=self._moving_second_moment, value=second_moment, decay=self._decay_rate, name="update_moving_second_moment").op return update_mean_op, update_second_moment_op def build_no_ops(): return (tf.no_op(), tf.no_op()) is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_second_moment_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_second_moment_op)
def test_placeholder(self): for v in [True, False, 1, 0, 1.0]: p = tf.placeholder(np.dtype(type(v)), []) x = tf.identity(p) value = utils.constant_value(p) self.assertEqual(value, None) with self.test_session(): self.assertEqual(x.eval(feed_dict={p: v}), v)
def test_placeholder(self): for v in [True, False, 1, 0, 1.0]: p = array_ops.placeholder(np.dtype(type(v)), []) x = array_ops.identity(p) value = utils.constant_value(p) self.assertEqual(value, None) with self.test_session(): self.assertEqual(x.eval(feed_dict={p: v}), v)
def test_variable(self): for v in [True, False, 1, 0, 1.0]: with tf.Graph().as_default() as g, self.test_session(g) as sess: x = tf.Variable(v) value = utils.constant_value(x) self.assertEqual(value, None) sess.run(tf.global_variables_initializer()) self.assertEqual(x.eval(), v)
def test_variable(self): for v in [True, False, 1, 0, 1.0]: with ops.Graph().as_default() as g, self.test_session(g) as sess: x = variables.Variable(v) value = utils.constant_value(x) self.assertEqual(value, None) sess.run(variables.global_variables_initializer()) self.assertEqual(x.eval(), v)
def _build_update_ops(self, mean, variance, is_training): """Builds the moving average update ops when using moving variance. Args: mean: The mean value to update with. variance: The variance value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. Returns: Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or could be `True`. Returns `None` when `is_training=False`. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, zero_debias=False, name="update_moving_mean").op update_variance_op = moving_averages.assign_moving_average( variable=self._moving_variance, value=variance, decay=self._decay_rate, zero_debias=False, name="update_moving_variance").op return update_mean_op, update_variance_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_variance_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) return (update_mean_op, update_variance_op) else: return None
def _build_update_ops_second_moment(self, mean, second_moment, is_training): """Builds the moving average update ops when using the moving second moment. Args: mean: The mean value to update with. second_moment: The second_moment value to update with. is_training: Boolean Tensor to indicate if we're currently in training mode. """ def build_update_ops(): """Builds the exponential moving average update ops.""" update_mean_op = moving_averages.assign_moving_average( variable=self._moving_mean, value=mean, decay=self._decay_rate, name="update_moving_mean").op update_second_moment_op = moving_averages.assign_moving_average( variable=self._moving_second_moment, value=second_moment, decay=self._decay_rate, name="update_moving_second_moment").op return update_mean_op, update_second_moment_op def build_no_ops(): return (tf.no_op(), tf.no_op()) # Only make the ops if we know that `is_training=True`, or the value of # `is_training` is unknown. is_training_const = utils.constant_value(is_training) if is_training_const is None or is_training_const: update_mean_op, update_second_moment_op = utils.smart_cond( is_training, build_update_ops, build_no_ops, ) # Every new connection creates a new op which adds its contribution # to the running average when ran. tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_second_moment_op)
def dropout(inputs, keep_prob=0.5, noise_shape=None, is_training=True, outputs_collections=None, scope=None): """Returns a dropout op applied to the input. With probability `keep_prob`, outputs the input element scaled up by `1 / keep_prob`, otherwise outputs `0`. The scaling is so that the expected sum is unchanged. Args: inputs: the tensor to pass to the nn.dropout op. keep_prob: A scalar `Tensor` with the same type as x. The probability that each element is kept. noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. is_training: A bool `Tensor` indicating whether or not the model is in training mode. If so, dropout is applied and values scaled. Otherwise, inputs is returned. outputs_collections: collection to add the outputs. scope: Optional scope for op_scope. Returns: a tensor representing the output of the operation. """ with ops.op_scope([inputs], scope, 'Dropout') as sc: inputs = ops.convert_to_tensor(inputs) is_training_value = utils.constant_value(is_training, dtypes.bool) if is_training_value is not None: if is_training_value: outputs = nn.dropout(inputs, keep_prob, noise_shape) else: outputs = inputs else: def _dropout(): return nn.dropout(inputs, keep_prob, noise_shape) outputs = control_flow_ops.cond(is_training, _dropout, lambda: inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs)
def test_value(self): for v in [True, False, 1, 0, 1.0]: value = utils.constant_value(v) self.assertEqual(value, v)
def nan_batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=False, scope=None): with variable_scope.variable_op_scope([inputs], scope, 'NanBatchNorm', reuse=reuse) as sc: inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] beta, gamma = None, None if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer, collections=beta_collections, trainable=False) if scale: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer, trainable=False, collections=moving_variance_collections) is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: mean = nanmean(inputs, axis=axis) variance = nanvar(inputs, axis=axis) moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) mean, variance = moving_mean, moving_variance outputs = tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, initializers={}, updates_collections=None, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=False, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Note: When is_training is True the moving_mean and moving_variance need to be updated, by default the update_ops are placed in tf.GraphKeys.UPDATE_OPS so they need to be added as a dependency to the train_op, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) total_loss = control_flow_ops.with_dependencies([updates], total_loss) One can set update_collections=None to force the updates in place, but that can have speed penalty, specially in distributed settings. Args: inputs: a tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension. decay: decay for the moving average. center: If True, subtract `beta`. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: small float added to variance to avoid dividing by zero. activation_fn: activation function, default set to None to skip it and maintain a linear activation. updates_collections: collections to collect the update ops for computation. The updates_ops need to be excuted with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: optional collections for the variables. outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: if rank or last dimension of `inputs` is undefined. """ with variable_scope.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined last dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_initializer = initializers.get('moving_mean', init_ops.zeros_initializer) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False) moving_variance_initializer = initializers.get( 'moving_variance', init_ops.ones_initializer) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. # Use a copy of moving_mean as a shift to compute more reliable moments. shift = math_ops.add(moving_mean, 0) mean, variance = nn.moments(inputs, axis, shift=shift) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity( variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond( is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def batch_norm( inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=False, #data_format=DATA_FORMAT_NHWC, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Note: When is_training is True the moving_mean and moving_variance need to be updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so they need to be added as a dependency to the `train_op`, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) total_loss = control_flow_ops.with_dependencies([updates], total_loss) One can set updates_collections=None to force the updates in place, but that can have speed penalty, specially in distributed settings. Args: inputs: a tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: decay for the moving average. center: If True, subtract `beta`. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: small float added to variance to avoid dividing by zero. activation_fn: activation function, default set to None to skip it and maintain a linear activation. param_initializers: optional initializers for beta, gamma, moving mean and moving variance. updates_collections: collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: optional collections for the variables. outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). batch_weights: An optional tensor of shape `[batch_size]`, containing a frequency weight for each batch item. If present, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. data_format: A string. `NHWC` (default) and `NCHW` are supported. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: if `batch_weights` is not None and `fused` is True. ValueError: if `data_format` is neither `NHWC` nor `NCHW`. ValueError: if `data_format` is `NCHW` while `fused` is False. ValueError: if the rank of `inputs` is undefined. ValueError: if rank or last dimension of `inputs` is undefined. """ if fused: if batch_weights is not None: raise ValueError('Weighted mean and variance is not currently ' 'supported for fused batch norm.') return _fused_batch_norm(inputs, decay=decay, center=center, scale=scale, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, variables_collections=variables_collections, outputs_collections=outputs_collections, trainable=trainable, data_format=data_format, scope=scope) #if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): #raise ValueError('data_format has to be either NCHW or NHWC.') #if data_format == DATA_FORMAT_NCHW: #raise ValueError('data_format must be NHWC if fused is False.') with variable_scope.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if batch_weights is not None: batch_weights = ops.convert_to_tensor(batch_weights) inputs_shape[0:1].assert_is_compatible_with( batch_weights.get_shape()) # Reshape batch weight values so they broadcast across inputs. nshape = [-1] + [1 for _ in range(inputs_rank - 1)] batch_weights = array_ops.reshape(batch_weights, nshape) axis = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined last dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if not param_initializers: param_initializers = {} if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta_initializer = param_initializers.get( 'beta', init_ops.zeros_initializer) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get( 'gamma', init_ops.ones_initializer) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropiate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables. partitioner = variable_scope.get_variable_scope().partitioner try: variable_scope.get_variable_scope().set_partitioner(None) moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) finally: variable_scope.get_variable_scope().set_partitioner(partitioner) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. if batch_weights is None: # Use a copy of moving_mean as a shift to compute more reliable moments. shift = math_ops.add(moving_mean, 0) mean, variance = nn.moments(inputs, axis, shift=shift) else: mean, variance = nn.weighted_moments(inputs, axis, batch_weights) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay) with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity( variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=False) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond( is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance # Compute batch_normalization. # Print out offset, scale, mean, variance import tensorflow as tf print_op_gamma = tf.Print(gamma, [gamma], message="scale factor is: ") print_op_beta = tf.Print(beta, [beta], message="offset factor is: ") print_op_mean = tf.Print(mean, [mean], message="mean is: ") print_op_var = tf.Print(variance, [variance], message="variance is: ") with ops.control_dependencies( [print_op_gamma, print_op_beta, print_op_mean, print_op_var]): outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def batch_norm_backbone(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=None, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99, adjustment=None, tower_config=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. The normalization is over all but the last dimension if `data_format` is `NHWC` and all but the second dimension if `data_format` is `NCHW`. In case of a 2D tensor this corresponds to the batch dimension, while in case of a 4D tensor this corresponds to the batch and space dimensions. Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they need to be added as a dependency to the `train_op`. For example: ```python update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) ``` One can set updates_collections=None to force the updates in place, but that can have a speed penalty, especially in distributed settings. Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). batch_weights: An optional tensor of shape `[batch_size]`, containing a frequency weight for each batch item. If present, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) fused: if `None` or `True`, use a faster, fused implementation if possible. If `False`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. scope: Optional scope for `variable_scope`. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to scalar `Tensors` used to clip the renorm correction. The correction `(r, d)` is used as `corrected_value = normalized_value * r + d`, with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_decay: Momentum used to update the moving means and standard deviations with renorm. Unlike `momentum`, this affects training and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. adjustment: A function taking the `Tensor` containing the (dynamic) shape of the input tensor and returning a pair (scale, bias) to apply to the normalized values (before gamma and beta), only during training. For example, `adjustment = lambda shape: ( tf.random_uniform(shape[-1:], 0.93, 1.07), tf.random_uniform(shape[-1:], -0.1, 0.1))` will scale the normalized value by up to 7% up or down, then shift the result by up to 0.1 (with independent scaling and bias for each feature but shared across all examples), and finally apply gamma and/or beta. If `None`, no adjustment is applied. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ # if fused is None: # fused = True # Only use _fused_batch_norm if all of the following three # conditions are true: # (1) fused is set True; # (2) it is possible to use (currently it doesn't support batch weights, # renorm, and the case when rank is neither 2 nor 4); # (3) it is used with zero_debias_moving_mean, or an input shape of rank 2, # or non-default updates_collections (not implemented in # normalization_layers.BatchNormalization yet); otherwise use the fused # implementation in normalization_layers.BatchNormalization. # inputs = ops.convert_to_tensor(inputs) # rank = inputs.get_shape().ndims # possible_to_fuse = ( # batch_weights is None and not renorm and rank in [2, 4] and # adjustment is None) # if fused and possible_to_fuse and ( # zero_debias_moving_mean or rank == 2 or # updates_collections is not ops.GraphKeys.UPDATE_OPS): # return _fused_batch_norm( # inputs, # decay=decay, # center=center, # scale=scale, # epsilon=epsilon, # activation_fn=activation_fn, # param_initializers=param_initializers, # param_regularizers=param_regularizers, # updates_collections=updates_collections, # is_training=is_training, # reuse=reuse, # variables_collections=variables_collections, # outputs_collections=outputs_collections, # trainable=trainable, # data_format=data_format, # zero_debias_moving_mean=zero_debias_moving_mean, # scope=scope) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) # # Determine whether we can use the core layer class. # if (batch_weights is None and # updates_collections is ops.GraphKeys.UPDATE_OPS and # not zero_debias_moving_mean): # print("F**K !!!!") # # Use the core layer class. # axis = 1 if data_format == DATA_FORMAT_NCHW else -1 # if not param_initializers: # param_initializers = {} # beta_initializer = param_initializers.get('beta', # init_ops.zeros_initializer()) # gamma_initializer = param_initializers.get('gamma', # init_ops.ones_initializer()) # moving_mean_initializer = param_initializers.get( # 'moving_mean', init_ops.zeros_initializer()) # moving_variance_initializer = param_initializers.get( # 'moving_variance', init_ops.ones_initializer()) # if not param_regularizers: # param_regularizers = {} # beta_regularizer = param_regularizers.get('beta') # gamma_regularizer = param_regularizers.get('gamma') # layer = normalization_layers.BatchNormalization( # axis=axis, # momentum=decay, # epsilon=epsilon, # center=center, # scale=scale, # beta_initializer=beta_initializer, # gamma_initializer=gamma_initializer, # moving_mean_initializer=moving_mean_initializer, # moving_variance_initializer=moving_variance_initializer, # beta_regularizer=beta_regularizer, # gamma_regularizer=gamma_regularizer, # trainable=trainable, # renorm=renorm, # renorm_clipping=renorm_clipping, # renorm_momentum=renorm_decay, # adjustment=adjustment, # name=sc.name, # _scope=sc, # _reuse=reuse, # fused=fused) # outputs = layer.apply(inputs, training=is_training) # # # Add variables to collections. # _add_variable_to_collections(layer.moving_mean, variables_collections, # 'moving_mean') # _add_variable_to_collections(layer.moving_variance, variables_collections, # 'moving_variance') # if layer.beta is not None: # _add_variable_to_collections(layer.beta, variables_collections, 'beta') # if layer.gamma is not None: # _add_variable_to_collections(layer.gamma, variables_collections, # 'gamma') # # if activation_fn is not None: # outputs = activation_fn(outputs) # return utils.collect_named_outputs(outputs_collections, sc.name, outputs) # Not supported by layer class: batch_weights argument, # and custom updates_collections. In that case, use the legacy BN # implementation. # Custom updates collections are not supported because the update logic # is different in this case, in particular w.r.t. "forced updates" and # update op reuse. if renorm: raise ValueError('renorm is not supported with batch_weights, ' 'updates_collections or zero_debias_moving_mean') inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if batch_weights is not None: batch_weights = ops.convert_to_tensor(batch_weights) inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape()) # Reshape batch weight values so they broadcast across inputs. nshape = [-1] + [1 for _ in range(inputs_rank - 1)] batch_weights = array_ops.reshape(batch_weights, nshape) if data_format == DATA_FORMAT_NCHW: moments_axes = [0] + list(range(2, inputs_rank)) params_shape = inputs_shape[1:2] # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if not param_initializers: param_initializers = {} if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables (this needs to be handled carefully, as it may break # the checkpoint backward compatibility). with variable_scope.variable_scope( variable_scope.get_variable_scope()) as local_scope: local_scope.set_partitioner(None) moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. if batch_weights is None: if data_format == DATA_FORMAT_NCHW: mean, variance = moments(inputs, moments_axes, tower_config=tower_config, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, variance = moments(inputs, moments_axes, tower_config=tower_config) else: if data_format == DATA_FORMAT_NCHW: mean, variance = weighted_moments( inputs, moments_axes, batch_weights, tower_config, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, variance = weighted_moments(inputs, moments_axes, batch_weights, tower_config=tower_config) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond( is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance if data_format == DATA_FORMAT_NCHW: mean = array_ops.reshape(mean, params_shape_broadcast) variance = array_ops.reshape(variance, params_shape_broadcast) if beta is not None: beta = array_ops.reshape(beta, params_shape_broadcast) if gamma is not None: gamma = array_ops.reshape(gamma, params_shape_broadcast) # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def batch_norm_mine_old(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=False, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99): """ This earlier version of my modification to batch norm uses current_mean and current_variance if is_training is True and moving_mean and moving_variance otherwise. This was leading a large divergence between the results depending upon whether the is_training set to True or not. I think ideally it should always use moving_mean and moving_variance. batch_norm_mine does this. Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. copy of tensorflow.contrib.layers Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). batch_weights: An optional tensor of shape `[batch_size]`, containing a frequency weight for each batch item. If present, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. scope: Optional scope for `variable_scope`. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to scalar `Tensors` used to clip the renorm correction. The correction `(r, d)` is used as `corrected_value = normalized_value * r + d`, with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_decay: Momentum used to update the moving means and standard deviations with renorm. Unlike `momentum`, this affects training and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `batch_weights` is not None and `fused` is True. ValueError: If `param_regularizers` is not None and `fused` is True. ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ if fused: if batch_weights is not None: raise ValueError('Weighted mean and variance is not currently ' 'supported for fused batch norm.') if param_regularizers is not None: raise ValueError('Regularizers are not currently ' 'supported for fused batch norm.') if renorm: raise ValueError('Renorm is not supported for fused batch norm.') return _fused_batch_norm( inputs, decay=decay, center=center, scale=scale, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, variables_collections=variables_collections, outputs_collections=outputs_collections, trainable=trainable, data_format=data_format, zero_debias_moving_mean=zero_debias_moving_mean, scope=scope) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) # Determine whether we can use the core layer class. if (batch_weights is None and updates_collections is ops.GraphKeys.UPDATE_OPS and not zero_debias_moving_mean): # Use the core layer class. axis = 1 if data_format == DATA_FORMAT_NCHW else -1 if not param_initializers: param_initializers = {} beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) if not param_regularizers: param_regularizers = {} beta_regularizer = param_regularizers.get('beta') gamma_regularizer = param_regularizers.get('gamma') layer = normalization_layers.BatchNormalization( axis=axis, momentum=decay, epsilon=epsilon, center=center, scale=scale, beta_initializer=beta_initializer, gamma_initializer=gamma_initializer, moving_mean_initializer=moving_mean_initializer, moving_variance_initializer=moving_variance_initializer, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, trainable=trainable, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_decay, name=sc.name, _scope=sc, _reuse=reuse) outputs = layer.apply(inputs, training=is_training) # Add variables to collections. _add_variable_to_collections( layer.moving_mean, variables_collections, 'moving_mean') _add_variable_to_collections( layer.moving_variance, variables_collections, 'moving_variance') if layer.beta: _add_variable_to_collections(layer.beta, variables_collections, 'beta') if layer.gamma: _add_variable_to_collections( layer.gamma, variables_collections, 'gamma') if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs) # Not supported by layer class: batch_weights argument, # and custom updates_collections. In that case, use the legacy BN # implementation. # Custom updates collections are not supported because the update logic # is different in this case, in particular w.r.t. "forced updates" and # update op reuse. if renorm: raise ValueError('renorm is not supported with batch_weights, ' 'updates_collections or zero_debias_moving_mean') inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if batch_weights is not None: batch_weights = ops.convert_to_tensor(batch_weights) inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape()) # Reshape batch weight values so they broadcast across inputs. nshape = [-1] + [1 for _ in range(inputs_rank - 1)] batch_weights = array_ops.reshape(batch_weights, nshape) if data_format == DATA_FORMAT_NCHW: moments_axes = [0] + list(range(2, inputs_rank)) params_shape = inputs_shape[1:2] # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % ( inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if not param_initializers: param_initializers = {} if center: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables. partitioner = variable_scope.get_variable_scope().partitioner try: variable_scope.get_variable_scope().set_partitioner(None) moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) finally: variable_scope.get_variable_scope().set_partitioner(partitioner) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `needs_moments` will be true. is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: # Calculate the moments based on the individual batch. if batch_weights is None: if data_format == DATA_FORMAT_NCHW: mean, _ = nn.moments(inputs, moments_axes, keep_dims=True) variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, _ = nn.moments(inputs, moments_axes) variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes) else: if data_format == DATA_FORMAT_NCHW: mean, _ = nn.weighted_moments(inputs, moments_axes, batch_weights, keep_dims=True) variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes, batch_weights, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, _ = nn.weighted_moments(inputs, moments_axes, batch_weights) variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes, batch_weights) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) with ops.control_dependencies([update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance if data_format == DATA_FORMAT_NCHW: mean = array_ops.reshape(mean, params_shape_broadcast) variance = array_ops.reshape(variance, params_shape_broadcast) beta = array_ops.reshape(beta, params_shape_broadcast) if gamma is not None: gamma = array_ops.reshape(gamma, params_shape_broadcast) # Compute batch_normalization. outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def fused_batch_norm( inputs, renorm=False, RMAX=None, DMAX=None, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Note: When is_training is True the moving_mean and moving_variance need to be updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so they need to be added as a dependency to the `train_op`, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) total_loss = control_flow_ops.with_dependencies([updates], total_loss) Args: inputs: a tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: small float added to variance to avoid dividing by zero. activation_fn: activation function, default set to None to skip it and maintain a linear activation. param_initializers: optional initializers for beta, gamma, moving mean and moving variance. updates_collections: collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: optional collections for the variables. outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: if `data_format` is neither `NHWC` nor `NCHW`. ValueError: if the rank of `inputs` is undefined. ValueError: if the rank of `inputs` is neither 2 or 4. ValueError: if rank or `C` dimension of `inputs` is undefined. """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with tf.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) original_shape = inputs.get_shape() original_rank = original_shape.ndims if original_rank is None: raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' ' Expected 2 or 4 but got %d' % ( inputs.name, original_rank)) if original_rank == 2: channels = inputs.get_shape()[-1].value if channels is None: raise ValueError('`C` dimension must be known but is None') new_shape = [-1, 1, 1, channels] if data_format == DATA_FORMAT_NCHW: new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: params_shape = inputs_shape[1:2] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined `C` dimension %s.' % (inputs.name, params_shape)) if not param_initializers: param_initializers = {} # Allocate parameters for the beta and gamma of the normalization. trainable_beta = trainable and center if trainable_beta: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) real_beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable_beta) beta = tf.zeros(params_shape, name='fakebeta') else: real_beta = tf.zeros(params_shape, name='beta') beta = tf.zeros(params_shape, name='fakebeta') trainable_gamma = trainable and scale if trainable_gamma: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable_gamma) else: gamma = tf.ones(params_shape, name='gamma') # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) def _fused_batch_norm_training(): outputs, mean, variance = nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=data_format) if renorm: moving_inv = math_ops.rsqrt(moving_variance + epsilon) r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv, 1/RMAX, RMAX)) d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv, -DMAX, DMAX)) outputs = outputs * r + d return outputs, mean, variance def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=moving_mean, variance=moving_variance, epsilon=epsilon, is_training=False, data_format=data_format) outputs, mean, variance = utils.smart_cond(is_training, _fused_batch_norm_training, _fused_batch_norm_inference) outputs = tf.nn.bias_add(outputs, real_beta) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `need_updates` will be true. is_training_value = utils.constant_value(is_training) need_updates = is_training_value is None or is_training_value if need_updates: moving_vars_fn = lambda: (moving_mean, moving_variance) def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_mean) ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_variance) outputs.set_shape(inputs_shape) if original_shape.ndims == 2: outputs = array_ops.reshape(outputs, original_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)
def call(self, inputs, training=False): # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. input_shape = inputs.get_shape() ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis].value # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) if needs_broadcasting: # In this case we must explictly broadcast all parameters. if self.center: broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None # Determines moments if training_value is not False: if needs_broadcasting: broadcast_mean, broadcast_variance = nn.moments(inputs, reduction_axes, keep_dims=True) mean = array_ops.reshape(broadcast_mean, [-1]) variance = array_ops.reshape(broadcast_variance, [-1]) else: mean, variance = nn.moments(inputs, reduction_axes) # Prepare updates if necessary. if not self.updates: mean_update = moving_averages.assign_moving_average( self.moving_mean, mean, self.momentum, zero_debias=False) variance_update = moving_averages.assign_moving_average( self.moving_variance, variance, self.momentum, zero_debias=False) # In the future this should be refactored into a self.add_update # methods in order to allow for instance-based BN layer sharing # across unrelated input streams (e.g. like in Keras). self.updates.append(mean_update) self.updates.append(variance_update) # Normalize batch. We do this inside separate functions for training # and inference so as to avoid evaluating both branches. def normalize_in_test(): if needs_broadcasting: broadcast_moving_mean = array_ops.reshape( self.moving_mean, broadcast_shape) broadcast_moving_variance = array_ops.reshape( self.moving_variance, broadcast_shape) arg_mean = broadcast_moving_mean if needs_broadcasting else self.moving_mean arg_variance = broadcast_moving_variance if needs_broadcasting else self.moving_variance arg_beta = broadcast_beta if needs_broadcasting else ( self.beta if self.center else None) arg_gamma = broadcast_gamma if needs_broadcasting else ( self.gamma if self.scale else None) if self.quantizer is None: return nn.batch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon) else: return qbatch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon, self.quantizer) def normalize_in_training(): arg_mean = broadcast_mean if needs_broadcasting else mean arg_variance = broadcast_variance if needs_broadcasting else variance arg_beta = broadcast_beta if needs_broadcasting else ( self.beta if self.center else None) arg_gamma = broadcast_gamma if needs_broadcasting else ( self.gamma if self.scale else None) if self.quantizer is None: return nn.batch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon) else: return qbatch_normalization(inputs, arg_mean, arg_variance, arg_beta, arg_gamma, self.epsilon, self.quantizer) return utils.smart_cond(training, normalize_in_training, normalize_in_test)
def conditional_batch_norm(inputs, conditional_layer, var_scope_postfix='', decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=tf.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, renorm=False, renorm_clipping=None, renorm_momentum=0.99, scope=None): """Custom implementation of batch norm to support the optional `conditional_layer` and `var_scope_postfix`. For comments on the other parameters, see tensorflow.contrib.layers.python.layers.batch_norm, where this is copied from (tf 1.5 version). Args: conditional_layer: A tensor with 2 dimensions [batch, channels]. If not None, the beta and gamma parameters will be conditioned on the `conditional_layer`. var_scope_postfix: A string. Append it to the var scopes of all variables other than the weight and bias. e.g. var scope of the `gamma` variable becomes `'gamma' + var_scope_postfix`. """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') if inputs.dtype != tf.float32: raise NotImplementedError( 'This implementation may not be compatible with mixed precision training.' ) with tf.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse) as sc: if conditional_layer is not None: conditional_layer = tf.convert_to_tensor(conditional_layer) # Normalizing the conditional layer seems to stabilize training a little. conditional_layer = tf.nn.l2_normalize( conditional_layer, dim=1, name='normalized_conditional_layer') conditional_layer_shape = conditional_layer.get_shape() conditional_layer_rank = conditional_layer_shape.ndims if conditional_layer_rank is None: raise ValueError('Conditional layer %s has undefined rank' % conditional_layer.name) elif conditional_layer_rank != 2: raise ValueError('Conditional layer %s is not rank 2.' % conditional_layer.name) inputs = tf.convert_to_tensor(inputs) original_shape = inputs.get_shape() original_inputs = inputs original_rank = original_shape.ndims if original_rank is None: raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' ' Expected 2 or 4 but got %d' % (inputs.name, original_rank)) if original_rank == 2: channels = inputs.get_shape()[-1].value if channels is None: raise ValueError('`C` dimension must be known but is None') new_shape = [-1, 1, 1, channels] if data_format == DATA_FORMAT_NCHW: new_shape = [-1, channels, 1, 1] inputs = tf.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: params_shape = inputs_shape[1:2] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined `C` dimension %s.' % (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta_collections = utils.get_variable_collections( variables_collections, 'beta') variable_dtype = inputs.dtype if not param_initializers: param_initializers = {} if not param_regularizers: param_regularizers = {} if center: beta_scope = 'beta' + var_scope_postfix if conditional_layer is not None: assert not param_initializers, 'param_initializers are not supported with conditional layer.' assert not param_regularizers, 'param_initializers are not supported with conditional layer.' beta = get_conditional_batch_norm_param(conditional_layer, int(params_shape[-1]), scope=beta_scope) else: # Behaves like normal batch norm. beta_collections = utils.get_variable_collections( variables_collections, beta_scope) beta_initializer = param_initializers.get( beta_scope, tf.zeros_initializer()) beta_regularizer = param_regularizers.get('beta') beta = variables.model_variable(beta_scope, shape=params_shape, dtype=variable_dtype, initializer=beta_initializer, regularizer=beta_regularizer, collections=beta_collections, trainable=trainable) else: beta = array_ops.constant(0.0, dtype=variable_dtype, shape=params_shape) if scale: gamma_scope = 'gamma' + var_scope_postfix if conditional_layer is not None: assert not param_initializers, 'param_initializers are not supported with conditional layer.' assert not param_regularizers, 'param_initializers are not supported with conditional layer.' delta_gamma = get_conditional_batch_norm_param( conditional_layer, int(params_shape[-1]), scope=gamma_scope) # Per https://arxiv.org/pdf/1707.03017.pdf. gamma = tf.constant( 1.0, dtype=variable_dtype, ) + delta_gamma else: gamma_collections = utils.get_variable_collections( variables_collections, gamma_scope) gamma_initializer = param_initializers.get( gamma_scope, tf.ones_initializer()) gamma_regularizer = param_regularizers.get('gamma') gamma = variables.model_variable(gamma_scope, shape=params_shape, dtype=variable_dtype, initializer=gamma_initializer, regularizer=gamma_regularizer, collections=gamma_collections, trainable=trainable) else: gamma = tf.constant(1.0, dtype=variable_dtype, shape=params_shape) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating # them, because assign_moving_average is not yet supported for partitioned # variables (this needs to be handled carefully, as it may break # the checkpoint backward compatibility). with tf.variable_scope(tf.get_variable_scope()) as local_scope: local_scope.set_partitioner(None) moving_mean_scope = 'moving_mean' + var_scope_postfix moving_mean_collections = utils.get_variable_collections( variables_collections, moving_mean_scope) moving_mean_initializer = param_initializers.get( moving_mean_scope, tf.zeros_initializer()) moving_mean = variables.model_variable( moving_mean_scope, shape=params_shape, dtype=tf.float32, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_scope = 'moving_variance' + var_scope_postfix moving_variance_collections = utils.get_variable_collections( variables_collections, moving_variance_scope) moving_variance_initializer = param_initializers.get( moving_variance_scope, tf.ones_initializer()) moving_variance = variables.model_variable( moving_variance_scope, shape=params_shape, dtype=tf.float32, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] if set(renorm_clipping) - set(keys): raise ValueError( 'renorm_clipping %s contains keys not in %s' % (renorm_clipping, keys)) # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = variables.model_variable( name= name, # renorm variable should be dependent on var_scope_postfix. shape=shape, dtype=tf.float32, initializer=param_initializers.get( name, tf.zeros_initializer()), trainable=False) return var with ops.device(None): device = ((lambda _: moving_mean.device) if context.executing_eagerly() else moving_mean.device) with ops.device(device): renorm_mean = _renorm_variable( 'renorm_mean' + var_scope_postfix, params_shape) renorm_mean_weight = _renorm_variable( 'renorm_mean_weight' + var_scope_postfix, ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. device = ((lambda _: moving_variance.device) if context.executing_eagerly() else moving_variance.device) with ops.device(device): renorm_stddev = _renorm_variable( 'renorm_stddev' + var_scope_postfix, params_shape) renorm_stddev_weight = _renorm_variable( 'renorm_stddev_weight' + var_scope_postfix, ()) class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ renorm_params = dotdict({ 'renorm_mean': renorm_mean, 'renorm_mean_weight': renorm_mean_weight, 'renorm_stddev': renorm_stddev, 'renorm_stddev_weight': renorm_stddev_weight, 'renorm_clipping': renorm_clipping, 'renorm_momentum': renorm_momentum, 'moving_mean': moving_mean, 'moving_variance': moving_variance, 'epsilon': epsilon }) else: renorm_params = None def _batch_norm_training(): # return tf.nn.fused_batch_norm( return _batch_norm_aux(inputs, gamma, beta, epsilon=epsilon, data_format=data_format, renorm=renorm, renorm_params=renorm_params) def _batch_norm_inference(): # return tf.nn.fused_batch_norm( return _batch_norm_aux(inputs, gamma, beta, mean=tf.cast(moving_mean, dtype=variable_dtype), variance=tf.cast(moving_variance, dtype=variable_dtype), epsilon=epsilon, is_training=False, data_format=data_format, renorm=renorm, renorm_params=renorm_params) outputs, mean, variance = utils.smart_cond(is_training, _batch_norm_training, _batch_norm_inference) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `need_updates` will be true. is_training_value = utils.constant_value(is_training) need_updates = is_training_value is None or is_training_value if need_updates: if updates_collections is None: no_updates = lambda: outputs def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) with tf.control_dependencies( [update_moving_mean, update_moving_variance]): return tf.identity(outputs) outputs = utils.smart_cond(is_training, _force_updates, no_updates) else: moving_vars_fn = lambda: (moving_mean, moving_variance) def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, tf.cast(mean, dtype=moving_mean.dtype), decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, tf.cast(variance, dtype=moving_variance.dtype), decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond( is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) outputs.set_shape(inputs_shape) if original_shape.ndims == 2: outputs = array_ops.reshape(outputs, array_ops.shape(original_inputs)) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)