def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average(self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output
def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): """Updates the var and weight, returns their updated ratio.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving the value. # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = self._assign_moving_average(var, value, self.renorm_momentum) new_weight = self._assign_moving_average( weight, weight_value, self.renorm_momentum) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. return new_var / new_weight def _fake_update(): return array_ops.identity(var) return tf_utils.smart_cond(training, _do_update, _fake_update)
def call(self, inputs, training=None): if training is None: training = K.learning_phase() def dropped_inputs(): return nn.dropout(inputs, 1 - self.rate, noise_shape=self._get_noise_shape(inputs), seed=self.seed) output = tf_utils.smart_cond(training, dropped_inputs, lambda: array_ops.identity(inputs)) # EagerTensor object has no attribute _uses_learning_phase if not context.executing_eagerly() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output
def call(self, inputs, training=None): original_training_value = training if training is None: training = K.learning_phase() def dropped_inputs(): return nn.dropout(inputs, 1 - self.rate, noise_shape=self._get_noise_shape(inputs), seed=self.seed) output = tf_utils.smart_cond(training, dropped_inputs, lambda: array_ops.identity(inputs)) # EagerTensor object has no attribute _uses_learning_phase if not context.executing_eagerly() and original_training_value is None: output._uses_learning_phase = True # pylint: disable=protected-access return output
def call(self, inputs, training=None): original_training_value = training if training is None: training = K.learning_phase() in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) if not context.executing_eagerly( ) and original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) else: new_mean, new_variance = mean, variance if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(new_mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(new_variance, axis=1, keepdims=True) def _do_update(var, value): if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) if not context.executing_eagerly() and original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs
def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) # Compute the average mean and standard deviation, as if they were # initialized with this batch's moments. mixed_renorm_mean = (self.renorm_mean + (1. - self.renorm_mean_weight) * mean) mixed_renorm_stddev = (self.renorm_stddev + (1. - self.renorm_stddev_weight) * stddev) # Compute the corrections for batch renorm. r = stddev / mixed_renorm_stddev d = (mean - mixed_renorm_mean) / mixed_renorm_stddev # Ensure the corrections use pre-update moving averages. with ops.control_dependencies([r, d]): mean = array_ops.identity(mean) stddev = array_ops.identity(stddev) rmin, rmax, dmax = [ self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax'] ] if rmin is not None: r = math_ops.maximum(r, rmin) if rmax is not None: r = math_ops.minimum(r, rmax) if dmax is not None: d = math_ops.maximum(d, -dmax) d = math_ops.minimum(d, dmax) # When not training, use r=1, d=0. r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r)) d = tf_utils.smart_cond(training, lambda: d, lambda: array_ops.zeros_like(d)) def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): """Updates the var and weight, returns their updated ratio.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be # (1-decay) * value. and the weight will be 1-decay, with their ratio # giving the value. # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) new_var = self._assign_moving_average(var, value, self.renorm_momentum) new_weight = self._assign_moving_average( weight, weight_value, self.renorm_momentum) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. return new_var / new_weight def _fake_update(): return array_ops.identity(var) return tf_utils.smart_cond(training, _do_update, _fake_update) # TODO(yuefengz): colocate the operations new_mean = _update_renorm_variable(self.renorm_mean, self.renorm_mean_weight, mean) new_stddev = _update_renorm_variable(self.renorm_stddev, self.renorm_stddev_weight, stddev) # Make sqrt(moving_variance + epsilon) = new_stddev. new_variance = math_ops.square(new_stddev) - self.epsilon return (r, d, new_mean, new_variance)