예제 #1
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                momentum)
      variance_update = self._assign_moving_average(self.moving_variance,
                                                    variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
예제 #2
0
        def _update_renorm_variable(var, weight, value):
            """Updates a moving average and weight, returns the unbiased value."""
            value = array_ops.identity(value)

            def _do_update():
                """Updates the var and weight, returns their updated ratio."""
                # Update the variables without zero debiasing. The debiasing will be
                # accomplished by dividing the exponential moving average by the weight.
                # For example, after a single update, the moving average would be
                # (1-decay) * value. and the weight will be 1-decay, with their ratio
                # giving the value.
                # Make sure the weight is not updated until before r and d computation.
                with ops.control_dependencies([value]):
                    weight_value = array_ops.constant(1., dtype=weight.dtype)
                new_var = self._assign_moving_average(var, value,
                                                      self.renorm_momentum)
                new_weight = self._assign_moving_average(
                    weight, weight_value, self.renorm_momentum)
                # TODO(yuefengz): the updates to var and weighted can not be batched
                # together if we fetch their updated values here. Consider calculating
                # new values and delaying the updates.
                return new_var / new_weight

            def _fake_update():
                return array_ops.identity(var)

            return tf_utils.smart_cond(training, _do_update, _fake_update)
예제 #3
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        def dropped_inputs():
            return nn.dropout(inputs,
                              1 - self.rate,
                              noise_shape=self._get_noise_shape(inputs),
                              seed=self.seed)

        output = tf_utils.smart_cond(training, dropped_inputs,
                                     lambda: array_ops.identity(inputs))
        # EagerTensor object has no attribute _uses_learning_phase
        if not context.executing_eagerly() and training is K.learning_phase():
            output._uses_learning_phase = True  # pylint: disable=protected-access
        return output
예제 #4
0
  def call(self, inputs, training=None):
    original_training_value = training
    if training is None:
      training = K.learning_phase()

    def dropped_inputs():
      return nn.dropout(inputs, 1  - self.rate,
                        noise_shape=self._get_noise_shape(inputs),
                        seed=self.seed)
    output = tf_utils.smart_cond(training,
                                 dropped_inputs,
                                 lambda: array_ops.identity(inputs))
    # EagerTensor object has no attribute _uses_learning_phase
    if not context.executing_eagerly() and original_training_value is None:
      output._uses_learning_phase = True  # pylint: disable=protected-access
    return output
예제 #5
0
    def call(self, inputs, training=None):
        original_training_value = training
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            if not context.executing_eagerly(
            ) and original_training_value is None:
                outputs._uses_learning_phase = True  # pylint: disable=protected-access
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)
            else:
                new_mean, new_variance = mean, variance

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(new_mean,
                                                axis=1,
                                                keepdims=True)
                new_variance = math_ops.reduce_mean(new_variance,
                                                    axis=1,
                                                    keepdims=True)

            def _do_update(var, value):
                if in_eager_mode and not self.trainable:
                    return

                return self._assign_moving_average(var, value, self.momentum)

            mean_update = tf_utils.smart_cond(
                training, lambda: _do_update(self.moving_mean, new_mean),
                lambda: self.moving_mean)
            variance_update = tf_utils.smart_cond(
                training,
                lambda: _do_update(self.moving_variance, new_variance),
                lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        if not context.executing_eagerly() and original_training_value is None:
            outputs._uses_learning_phase = True  # pylint: disable=protected-access
        return outputs
예제 #6
0
    def _renorm_correction_and_moments(self, mean, variance, training):
        """Returns the correction and update values for renorm."""
        stddev = math_ops.sqrt(variance + self.epsilon)
        # Compute the average mean and standard deviation, as if they were
        # initialized with this batch's moments.
        mixed_renorm_mean = (self.renorm_mean +
                             (1. - self.renorm_mean_weight) * mean)
        mixed_renorm_stddev = (self.renorm_stddev +
                               (1. - self.renorm_stddev_weight) * stddev)
        # Compute the corrections for batch renorm.
        r = stddev / mixed_renorm_stddev
        d = (mean - mixed_renorm_mean) / mixed_renorm_stddev
        # Ensure the corrections use pre-update moving averages.
        with ops.control_dependencies([r, d]):
            mean = array_ops.identity(mean)
            stddev = array_ops.identity(stddev)
        rmin, rmax, dmax = [
            self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax']
        ]
        if rmin is not None:
            r = math_ops.maximum(r, rmin)
        if rmax is not None:
            r = math_ops.minimum(r, rmax)
        if dmax is not None:
            d = math_ops.maximum(d, -dmax)
            d = math_ops.minimum(d, dmax)
        # When not training, use r=1, d=0.
        r = tf_utils.smart_cond(training, lambda: r,
                                lambda: array_ops.ones_like(r))
        d = tf_utils.smart_cond(training, lambda: d,
                                lambda: array_ops.zeros_like(d))

        def _update_renorm_variable(var, weight, value):
            """Updates a moving average and weight, returns the unbiased value."""
            value = array_ops.identity(value)

            def _do_update():
                """Updates the var and weight, returns their updated ratio."""
                # Update the variables without zero debiasing. The debiasing will be
                # accomplished by dividing the exponential moving average by the weight.
                # For example, after a single update, the moving average would be
                # (1-decay) * value. and the weight will be 1-decay, with their ratio
                # giving the value.
                # Make sure the weight is not updated until before r and d computation.
                with ops.control_dependencies([value]):
                    weight_value = array_ops.constant(1., dtype=weight.dtype)
                new_var = self._assign_moving_average(var, value,
                                                      self.renorm_momentum)
                new_weight = self._assign_moving_average(
                    weight, weight_value, self.renorm_momentum)
                # TODO(yuefengz): the updates to var and weighted can not be batched
                # together if we fetch their updated values here. Consider calculating
                # new values and delaying the updates.
                return new_var / new_weight

            def _fake_update():
                return array_ops.identity(var)

            return tf_utils.smart_cond(training, _do_update, _fake_update)

        # TODO(yuefengz): colocate the operations
        new_mean = _update_renorm_variable(self.renorm_mean,
                                           self.renorm_mean_weight, mean)
        new_stddev = _update_renorm_variable(self.renorm_stddev,
                                             self.renorm_stddev_weight, stddev)
        # Make sqrt(moving_variance + epsilon) = new_stddev.
        new_variance = math_ops.square(new_stddev) - self.epsilon

        return (r, d, new_mean, new_variance)