Ejemplo n.º 1
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        # TODO(reedwm): Add support for fp16 inputs.
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = utils.smart_cond(training,
                                                  _fused_batch_norm_training,
                                                  _fused_batch_norm_inference)
        mean = array_ops.reshape(mean, shape=self.moving_mean.get_shape())
        variance = array_ops.reshape(variance,
                                     shape=self.moving_variance.get_shape())
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = utils.constant_value(training)
        if training_value is None:
            one_minus_decay = utils.smart_cond(training,
                                               lambda: self._one_minus_decay,
                                               lambda: 0.)
        else:
            one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
        if training_value or training_value is None:
            mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                      one_minus_decay)
            variance_update = self._assign_moving_average(
                self.moving_variance, variance, one_minus_decay)
            if context.in_graph_mode():
                # Note that in Eager mode, the updates are already executed when running
                # assign_moving_averages. So we do not need to put them into
                # collections.
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        return output
Ejemplo n.º 2
0
def _smart_select(pred, fn_then, fn_else):
  """Selects fn_then() or fn_else() based on the value of pred.

  The purpose of this function is the same as `utils.smart_cond`. However, at
  the moment there is a bug (b/36297356) that seems to kick in only when
  `smart_cond` delegates to `tf.cond`, which sometimes results in the training
  hanging when using parameter servers. This function will output the result
  of `fn_then` or `fn_else` if `pred` is known at graph construction time.
  Otherwise, it will use `tf.where` which will result in some redundant work
  (both branches will be computed but only one selected). However, the tensors
  involved will usually be small (means and variances in batchnorm), so the
  cost will be small and will not be incurred at all if `pred` is a constant.

  Args:
    pred: A boolean scalar `Tensor`.
    fn_then: A callable to use when pred==True.
    fn_else: A callable to use when pred==False.

  Returns:
    A `Tensor` whose value is fn_then() or fn_else() based on the value of pred.
  """
  pred_value = utils.constant_value(pred)
  if pred_value:
    return fn_then()
  elif pred_value is False:
    return fn_else()
  t_then = array_ops.expand_dims(fn_then(), 0)
  t_else = array_ops.expand_dims(fn_else(), 0)
  pred = array_ops.reshape(pred, [1])
  result = array_ops.where(pred, t_then, t_else)
  return array_ops.squeeze(result, [0])
Ejemplo n.º 3
0
def _smart_select(pred, fn_then, fn_else):
  """Selects fn_then() or fn_else() based on the value of pred.

  The purpose of this function is the same as `utils.smart_cond`. However, at
  the moment there is a bug (b/36297356) that seems to kick in only when
  `smart_cond` delegates to `tf.cond`, which sometimes results in the training
  hanging when using parameter servers. This function will output the result
  of `fn_then` or `fn_else` if `pred` is known at graph construction time.
  Otherwise, it will use `tf.where` which will result in some redundant work
  (both branches will be computed but only one selected). However, the tensors
  involved will usually be small (means and variances in batchnorm), so the
  cost will be small and will not be incurred at all if `pred` is a constant.

  Args:
    pred: A boolean scalar `Tensor`.
    fn_then: A callable to use when pred==True.
    fn_else: A callable to use when pred==False.

  Returns:
    A `Tensor` whose value is fn_then() or fn_else() based on the value of pred.
  """
  pred_value = utils.constant_value(pred)
  if pred_value:
    return fn_then()
  elif pred_value is False:
    return fn_else()
  t_then = array_ops.expand_dims(fn_then(), 0)
  t_else = array_ops.expand_dims(fn_else(), 0)
  pred = array_ops.reshape(pred, [1])
  result = array_ops.where(pred, t_then, t_else)
  return array_ops.squeeze(result, [0])
Ejemplo n.º 4
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    # TODO(reedwm): Add support for fp16 inputs.
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    mean = array_ops.reshape(mean, shape=self.moving_mean.get_shape())
    variance = array_ops.reshape(variance,
                                 shape=self.moving_variance.get_shape())
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = utils.constant_value(training)
    if training_value is None:
      one_minus_decay = utils.smart_cond(training,
                                         lambda: self._one_minus_decay,
                                         lambda: 0.)
    else:
      one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
    if training_value or training_value is None:
      mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                one_minus_decay)
      variance_update = self._assign_moving_average(self.moving_variance,
                                                    variance, one_minus_decay)
      if context.in_graph_mode():
        # Note that in Eager mode, the updates are already executed when running
        # assign_moving_averages. So we do not need to put them into
        # collections.
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    return output
Ejemplo n.º 5
0
def test_weighted_add():
    a1 = Input(shape=(3, 3, 2))
    a2 = Input(shape=(3, 3, 2))
    layer = WeightedAdd()
    b = layer([a1, a2])
    model = Model(inputs=[a1, a2], outputs=b)
    data = np.ones((1, 3, 3, 2))
    model.compile(optimizer='Adam', loss=mean_squared_error)
    model.fit([data, data * 2], data * 2, epochs=10, verbose=False)
    results = model.predict_on_batch([data, data * 2])
    assert not np.array_equal(results, data)
    assert constant_value(layer.one) == 1
Ejemplo n.º 6
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = utils.smart_cond(training,
                                                  _fused_batch_norm_training,
                                                  _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = utils.constant_value(training)
        if training_value is not False:
            decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
            mean_update = moving_averages.assign_moving_average(
                self.moving_mean, mean, decay, zero_debias=False)
            variance_update = moving_averages.assign_moving_average(
                self.moving_variance, variance, decay, zero_debias=False)
            if context.in_graph_mode():
                # Note that in Eager mode, the updates are already executed when running
                # assign_moving_averages. So we do not need to put them into
                # collections.
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        return output
Ejemplo n.º 7
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = utils.smart_cond(training,
                                                  _fused_batch_norm_training,
                                                  _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = utils.constant_value(training)
        if training_value is None:
            momentum = utils.smart_cond(training, lambda: self.momentum,
                                        lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)
        if training_value or training_value is None:
            mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                      momentum)
            variance_update = self._assign_moving_average(
                self.moving_variance, variance, momentum)
            self.add_update(mean_update, inputs=inputs)
            self.add_update(variance_update, inputs=inputs)

        return output
Ejemplo n.º 8
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = utils.constant_value(training)
    if training_value is None:
      momentum = utils.smart_cond(training, lambda: self.momentum, lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                momentum)
      variance_update = self._assign_moving_average(self.moving_variance,
                                                    variance, momentum)
      self.add_update(mean_update, inputs=inputs)
      self.add_update(variance_update, inputs=inputs)

    return output
Ejemplo n.º 9
0
    def _build_update_ops(self, mean, variance, is_training):
        """Builds the moving average update ops when using moving variance.

        Args:
          mean: The mean value to update with.
          variance: The variance value to update with.
          is_training: Boolean Tensor to indicate if we're currently in
            training mode.

        Returns:
          Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or
          could be `True`. Returns `None` when `is_training=False`.
        """
        def build_update_ops():
            """Builds the exponential moving average update ops."""

            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                zero_debias=False,
                name="update_moving_mean",
            ).op

            update_variance_op = moving_averages.assign_moving_average(
                variable=self._moving_variance,
                value=variance,
                decay=self._decay_rate,
                zero_debias=False,
                name="update_moving_variance",
            ).op

            return update_mean_op, update_variance_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        # Only make the ops if we know that `is_training=True`, or the value of
        # `is_training` is unknown.
        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_variance_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )
            return (update_mean_op, update_variance_op)
        else:
            return None
Ejemplo n.º 10
0
  def _build_update_ops(self, mean, variance, is_training):
    """Builds the moving average update ops when using moving variance.

    Args:
      mean: The mean value to update with.
      variance: The variance value to update with.
      is_training: Boolean Tensor to indicate if we're currently in
        training mode.

    Returns:
      Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or
      could be `True`. Returns `None` when `is_training=False`.
    """

    def build_update_ops():
      """Builds the exponential moving average update ops."""

      update_mean_op = moving_averages.assign_moving_average(
          variable=self._moving_mean,
          value=mean,
          decay=self._decay_rate,
          zero_debias=False,
          name="update_moving_mean").op

      update_variance_op = moving_averages.assign_moving_average(
          variable=self._moving_variance,
          value=variance,
          decay=self._decay_rate,
          zero_debias=False,
          name="update_moving_variance").op

      return update_mean_op, update_variance_op

    def build_no_ops():
      return (tf.no_op(), tf.no_op())

    # Only make the ops if we know that `is_training=True`, or the value of
    # `is_training` is unknown.
    is_training_const = utils.constant_value(is_training)
    if is_training_const is None or is_training_const:
      update_mean_op, update_variance_op = utils.smart_cond(
          is_training,
          build_update_ops,
          build_no_ops,
      )
      return (update_mean_op, update_variance_op)
    else:
      return None
Ejemplo n.º 11
0
  def testConstantValue(self):
    f1 = lambda: constant_op.constant(5)
    f2 = lambda: constant_op.constant(32)

    # Boolean pred
    self.assertEqual(5, utils.constant_value(utils.smart_cond(True, f1, f2)))
    self.assertEqual(32, utils.constant_value(utils.smart_cond(False, f1, f2)))

    # Integer pred
    self.assertEqual(5, utils.constant_value(utils.smart_cond(1, f1, f2)))
    self.assertEqual(32, utils.constant_value(utils.smart_cond(0, f1, f2)))

    # Unknown pred
    pred = array_ops.placeholder_with_default(True, shape=())
    self.assertIsNone(utils.constant_value(utils.smart_cond(pred, f1, f2)))

    #Error case
    with self.assertRaises(TypeError):
      utils.constant_value(5)
Ejemplo n.º 12
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)

    training_value = utils.constant_value(training)
    if training_value is not False:
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, variance, decay, zero_debias=False)
      self.add_update(mean_update, inputs=inputs)
      self.add_update(variance_update, inputs=inputs)

    return output
Ejemplo n.º 13
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)

    training_value = utils.constant_value(training)
    if training_value is not False:
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, variance, decay, zero_debias=False)
      self.add_update(mean_update, inputs=inputs)
      self.add_update(variance_update, inputs=inputs)

    return output
Ejemplo n.º 14
0
    def testConstantValue(self):
        f1 = lambda: constant_op.constant(5)
        f2 = lambda: constant_op.constant(32)

        # Boolean pred
        self.assertEqual(5,
                         utils.constant_value(utils.smart_cond(True, f1, f2)))
        self.assertEqual(32,
                         utils.constant_value(utils.smart_cond(False, f1, f2)))

        # Integer pred
        self.assertEqual(5, utils.constant_value(utils.smart_cond(1, f1, f2)))
        self.assertEqual(32, utils.constant_value(utils.smart_cond(0, f1, f2)))

        # Unknown pred
        pred = array_ops.placeholder_with_default(True, shape=())
        self.assertIsNone(utils.constant_value(utils.smart_cond(pred, f1, f2)))

        #Error case
        with self.assertRaises(TypeError):
            utils.constant_value(5)
Ejemplo n.º 15
0
  def call(self, inputs, training=False):
    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        return undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and
          len(v.get_shape()) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = utils.smart_cond(training,
                                     lambda: adj_scale,
                                     lambda: array_ops.ones_like(adj_scale))
        adj_bias = utils.smart_cond(training,
                                    lambda: adj_bias,
                                    lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = utils.smart_cond(training,
                              lambda: mean,
                              lambda: moving_mean)
      variance = utils.smart_cond(training,
                                  lambda: variance,
                                  lambda: moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)
      else:
        new_mean, new_variance = mean, variance

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(new_mean,
                                        axis=1, keep_dims=True)
        new_variance = math_ops.reduce_mean(new_variance,
                                            axis=1, keep_dims=True)

      def _do_update(var, value):
        return moving_averages.assign_moving_average(
            var, value, self.momentum, zero_debias=False)

      mean_update = utils.smart_cond(
          training,
          lambda: _do_update(self.moving_mean, new_mean),
          lambda: self.moving_mean)
      variance_update = utils.smart_cond(
          training,
          lambda: _do_update(self.moving_variance, new_variance),
          lambda: self.moving_variance)
      if context.in_graph_mode():
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      return undo_virtual_batching(outputs)

    return outputs
Ejemplo n.º 16
0
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)

      if not self.updates:
        self.add_update(mean_update)
        self.add_update(variance_update)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explictly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    return nn.batch_normalization(inputs,
                                  _broadcast(mean),
                                  _broadcast(variance),
                                  _broadcast(offset),
                                  _broadcast(scale),
                                  self.epsilon)
Ejemplo n.º 17
0
  def call(self, inputs, training=False):
    if self.fused:
      return self._fused_batch_norm(inputs, training=training)

    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)

      self.add_update(mean_update, inputs=inputs)
      self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explicitly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    return nn.batch_normalization(inputs,
                                  _broadcast(mean),
                                  _broadcast(variance),
                                  _broadcast(offset),
                                  _broadcast(scale),
                                  self.epsilon)
Ejemplo n.º 18
0
    def call(self, inputs, training=False):
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                return undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        scale, offset = self.gamma, self.beta

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)
        if training_value is not False:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = utils.smart_cond(training, lambda: mean,
                                    lambda: moving_mean)
            variance = utils.smart_cond(training, lambda: variance,
                                        lambda: moving_variance)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                scale = array_ops.stop_gradient(r, name='renorm_r')
                offset = array_ops.stop_gradient(d, name='renorm_d')
                if self.gamma is not None:
                    scale *= self.gamma
                    offset *= self.gamma
                if self.beta is not None:
                    offset += self.beta
            else:
                new_mean, new_variance = mean, variance

            # Update moving averages when training, and prevent updates otherwise.
            decay = utils.smart_cond(training, lambda: self.momentum,
                                     lambda: 1.)
            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(new_mean,
                                                axis=1,
                                                keep_dims=True)
                new_variance = math_ops.reduce_mean(new_variance,
                                                    axis=1,
                                                    keep_dims=True)

            mean_update = moving_averages.assign_moving_average(
                self.moving_mean, new_mean, decay, zero_debias=False)
            variance_update = moving_averages.assign_moving_average(
                self.moving_variance, new_variance, decay, zero_debias=False)
            if context.in_graph_mode():
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
        rank = len(inputs.get_shape())

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != rank
                    and reduction_axes != list(range(ndims))[:-1]):
                return array_ops.reshape(v, broadcast_shape)
            return v

        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance),
                                         _broadcast(offset), _broadcast(scale),
                                         self.epsilon)

        if self.virtual_batch_size is not None:
            return undo_virtual_batching(outputs)

        return outputs
    def _fused_batch_norm(self, inputs, training, use_moving_statistics):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        # use_moving_statistics==True use moving_mean and moving_variance, else mean and variance
        mean = tf_utils.smart_cond(use_moving_statistics,
                                   lambda: self.moving_mean, lambda: self.mean)
        variance = tf_utils.smart_cond(use_moving_statistics,
                                       lambda: self.moving_variance,
                                       lambda: self.variance)

        # these variables will be used in _fused_batch_norm_inference(), thanks to python closure

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=mean,
                                       variance=variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        # if training == True: mean and variance returned are mean and variance of the current batch
        # elif training == False: mean and variance return are (self.mean, self.variance) or
        #   (self.moving_mean, self.moving_variance) depending of the value of use_moving_statistics

        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)

        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)

        if training_value or training_value is None:
            # if training, first create operations which update self.mean and self.variance
            mean_update = self._update_statistics(self.mean, mean,
                                                  self.n_updates)
            variance_update = self._update_statistics(self.variance, variance,
                                                      self.n_updates)

            with ops.control_dependencies([mean_update, variance_update]):
                update_n_updates = state_ops.assign_add(
                    self.n_updates,
                    1.,
                )

            # add this combination of operations to a specific collection 'UPDATE_BN_OPS'
            ops.add_to_collection('UPDATE_BN_OPS', update_n_updates)

            # operations to reset bn statistics
            reset_mean = state_ops.assign(self.mean,
                                          array_ops.zeros_like(self.mean))
            reset_variance = state_ops.assign(
                self.variance, array_ops.zeros_like(self.variance))
            reset_n_updates = state_ops.assign(self.n_updates, 0.)
            with ops.control_dependencies(
                [reset_mean, reset_variance, reset_n_updates]):
                reset_bn = gen_control_flow_ops.no_op("ResetBatchNormStats")
            ops.add_to_collection('RESET_BN_OPS', reset_bn)

            # to keep the classical behavior of the Batch Norm !
            # update moving averages and add operations to tf.GraphKeys.UPDATE_OPS
            # these operation must be run when optimizing the network
            moving_mean_update = self._assign_moving_average(
                self.moving_mean, mean, momentum)
            moving_variance_update = self._assign_moving_average(
                self.moving_variance, variance, momentum)
            self.add_update(moving_mean_update, inputs=True)
            self.add_update(moving_variance_update, inputs=True)

        return output
Ejemplo n.º 20
0
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)

    if needs_broadcasting:
      # In this case we must explictly broadcast all parameters.
      if self.center:
        broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
      else:
        broadcast_beta = None
      if self.scale:
        broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
      else:
        broadcast_gamma = None

    if training_value is not False:
      if needs_broadcasting:
        broadcast_mean, broadcast_variance = nn.moments(
            inputs, reduction_axes, keep_dims=True)
        mean = array_ops.reshape(broadcast_mean, [-1])
        variance = array_ops.reshape(broadcast_variance, [-1])
      else:
        mean, variance = nn.moments(inputs, reduction_axes)

      # Prepare updates if necessary.
      if not self.updates:
        mean_update = moving_averages.assign_moving_average(
            self.moving_mean, mean, self.momentum, zero_debias=False)
        variance_update = moving_averages.assign_moving_average(
            self.moving_variance, variance, self.momentum, zero_debias=False)
        # In the future this should be refactored into a self.add_update
        # methods in order to allow for instance-based BN layer sharing
        # across unrelated input streams (e.g. like in Keras).
        self.updates.append(mean_update)
        self.updates.append(variance_update)

    # Normalize batch. We do this inside separate functions for training
    # and inference so as to avoid evaluating both branches.
    def normalize_in_test():
      if needs_broadcasting:
        broadcast_moving_mean = array_ops.reshape(self.moving_mean,
                                                  broadcast_shape)
        broadcast_moving_variance = array_ops.reshape(self.moving_variance,
                                                      broadcast_shape)
        return nn.batch_normalization(inputs,
                                      broadcast_moving_mean,
                                      broadcast_moving_variance,
                                      broadcast_beta,
                                      broadcast_gamma,
                                      self.epsilon)
      else:
        return nn.batch_normalization(inputs,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta if self.center else None,
                                      self.gamma if self.scale else None,
                                      self.epsilon)

    def normalize_in_training():
      if needs_broadcasting:
        return nn.batch_normalization(inputs,
                                      broadcast_mean,
                                      broadcast_variance,
                                      broadcast_beta,
                                      broadcast_gamma,
                                      self.epsilon)
      else:
        return nn.batch_normalization(inputs,
                                      mean,
                                      variance,
                                      self.beta if self.center else None,
                                      self.gamma if self.scale else None,
                                      self.epsilon)

    return utils.smart_cond(training,
                            normalize_in_training,
                            normalize_in_test)
Ejemplo n.º 21
0
  def call(self, inputs, training=False):
    if self.num_virtual_batches > 1:
      # Virtual batches (aka ghost batches) can be simulated by using some
      # reshape/transpose tricks on top of base batch normalization.
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [-1, self.num_virtual_batches] + original_shape[1:]

      # Will cause errors if num_virtual_batches does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      ndims = len(expanded_shape)
      if self.axis < 0:
        axis = ndims + self.axis
      else:
        axis = self.axis + 1      # Account for the added dimension

      # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis
      # TODO(b/66257056): when multi-axis batch normalization is implemented,
      # this permutation trick and the combined_dim reshape are no longer
      # necessary and can be reworked to simply use broadcasting.
      permutation = ([0] + list(range(2, axis)) + [1, axis] +
                     list(range(axis + 1, ndims)))
      inverse_permutation = [x[1] for x in
                             sorted(zip(permutation, range(ndims)))]
      inputs = array_ops.transpose(inputs, perm=permutation)

      # Combine the axis and num_virtual_batch dimension in order to take
      # advantage of fused batch normalization
      combined_dim = expanded_shape[1] * expanded_shape[axis]
      perm_shape = [-1] + inputs.shape.as_list()[1:]
      combined_shape = (perm_shape[:axis - 1] +
                        [combined_dim] +
                        perm_shape[axis + 1:])
      inputs = array_ops.reshape(inputs, combined_shape)
      # After the above reshape, the batch norm axis is the original self.axis

      # Undoes the reshaping and transposing tricks done above
      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, perm_shape)
        outputs = array_ops.transpose(outputs, perm=inverse_permutation)
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.num_virtual_batches > 1:
        return undo_virtual_batching(outputs)
      return outputs

    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)
      if context.in_graph_mode():
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explicitly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     _broadcast(offset),
                                     _broadcast(scale),
                                     self.epsilon)

    if self.num_virtual_batches > 1:
      return undo_virtual_batching(outputs)

    return outputs
Ejemplo n.º 22
0
    def call(self, inputs, training=False):
        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                return undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = utils.smart_cond(training, lambda: mean,
                                    lambda: moving_mean)
            variance = utils.smart_cond(training, lambda: variance,
                                        lambda: moving_variance)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)
            else:
                new_mean, new_variance = mean, variance

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(new_mean,
                                                axis=1,
                                                keep_dims=True)
                new_variance = math_ops.reduce_mean(new_variance,
                                                    axis=1,
                                                    keep_dims=True)

            def _do_update(var, value):
                if in_eager_mode and not self.trainable:
                    return

                return self._assign_moving_average(var, value, self.momentum)

            mean_update = utils.smart_cond(
                training, lambda: _do_update(self.moving_mean, new_mean),
                lambda: self.moving_mean)
            variance_update = utils.smart_cond(
                training,
                lambda: _do_update(self.moving_variance, new_variance),
                lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            return undo_virtual_batching(outputs)

        return outputs
    def call(self, inputs, training=None, use_moving_statistics=True):
        """
        :param inputs: input features
        :param training: boolean or boolean Tensor (with shape []) which determines the current training phase
        :param use_moving_statistics: boolean or boolean Tensor (with shape []) which selects statistics to use
               when training==True (or the Tensor value) statistics (mean and variance) are from the inputs !
               when training==False, if use_moving_statistics==True -> feed forward with moving statistics (updated
                                        with operations defined in GraphKeys.UPDATE_OPS)
                                     else (use_moving_statistics==False -> feed forward with raw statistics (updated
                                        with operations from collections 'UPDATE_BN_OPS'
                                        'RESET_BN_OPS' contains operations to reset these vaiables between inferences.
        """
        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(
                inputs,
                training=training,
                use_moving_statistics=use_moving_statistics)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)

        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1

            # mean and variance of the current batch
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: tf_utils.smart_cond(use_moving_statistics, lambda: self
                                            .moving_mean, lambda: self.mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance, lambda: tf_utils.smart_cond(
                    use_moving_statistics, lambda: self.moving_variance,
                    lambda: self.variance))

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)
            else:
                new_mean, new_variance = mean, variance

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)

            def _do_update(var, value):
                if in_eager_mode and not self.trainable:
                    return
                return self._assign_moving_average(var, value, self.momentum)

            moving_mean_update = tf_utils.smart_cond(
                training, lambda: _do_update(self.moving_mean, new_mean),
                lambda: self.moving_mean)
            moving_variance_update = tf_utils.smart_cond(
                training,
                lambda: _do_update(self.moving_variance, new_variance),
                lambda: self.moving_variance)

            if not context.executing_eagerly():
                self.add_update(moving_mean_update, inputs=True)
                self.add_update(moving_variance_update, inputs=True)

            mean_update = self._update_statistics(self.mean, mean,
                                                  self.n_updates)
            variance_update = self._update_statistics(self.variance, variance,
                                                      self.n_updates)

            with ops.control_dependencies([mean_update, variance_update]):
                # update n_updates only after updating self.mean and self.variance
                update_n_updates = state_ops.assign_add(self.n_updates, 1.)
                ops.add_to_collection('UPDATE_BN_OPS', update_n_updates)

            reset_mean = state_ops.assign(self.mean,
                                          array_ops.zeros_like(self.mean))
            reset_variance = state_ops.assign(
                self.variance, array_ops.zeros_like(self.variance))
            reset_n_updates = state_ops.assign(self.n_updates, 0.)
            with ops.control_dependencies(
                [reset_mean, reset_variance, reset_n_updates]):
                reset_bn = gen_control_flow_ops.no_op("ResetBatchNormStats")
            ops.add_to_collection('RESET_OPS', reset_bn)

        else:
            # training == False
            mean = tf_utils.smart_cond(use_moving_statistics,
                                       lambda: self.moving_mean,
                                       lambda: self.mean)
            variance = tf_utils.smart_cond(use_moving_statistics,
                                           lambda: self.moving_variance,
                                           lambda: self.variance)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)

        return outputs
Ejemplo n.º 24
0
    def call(self, inputs, training=False):
        if self.num_virtual_batches > 1:
            # Virtual batches (aka ghost batches) can be simulated by using some
            # reshape/transpose tricks on top of base batch normalization.
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [-1, self.num_virtual_batches
                              ] + original_shape[1:]

            # Will cause errors if num_virtual_batches does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            ndims = len(expanded_shape)
            if self.axis < 0:
                axis = ndims + self.axis
            else:
                axis = self.axis + 1  # Account for the added dimension

            # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis
            # TODO(b/66257056): when multi-axis batch normalization is implemented,
            # this permutation trick and the combined_dim reshape are no longer
            # necessary and can be reworked to simply use broadcasting.
            permutation = ([0] + list(range(2, axis)) + [1, axis] +
                           list(range(axis + 1, ndims)))
            inverse_permutation = [
                x[1] for x in sorted(zip(permutation, range(ndims)))
            ]
            inputs = array_ops.transpose(inputs, perm=permutation)

            # Combine the axis and num_virtual_batch dimension in order to take
            # advantage of fused batch normalization
            combined_dim = expanded_shape[1] * expanded_shape[axis]
            perm_shape = [-1] + inputs.shape.as_list()[1:]
            combined_shape = (perm_shape[:axis - 1] + [combined_dim] +
                              perm_shape[axis + 1:])
            inputs = array_ops.reshape(inputs, combined_shape)

            # After the above reshape, the batch norm axis is the original self.axis

            # Undoes the reshaping and transposing tricks done above
            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, perm_shape)
                outputs = array_ops.transpose(outputs,
                                              perm=inverse_permutation)
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.num_virtual_batches > 1:
                return undo_virtual_batching(outputs)
            return outputs

        # First, compute the axes along which to reduce the mean / variance,
        # as well as the broadcast shape to be used for all parameters.
        input_shape = inputs.get_shape()
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis].value

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        scale, offset = self.gamma, self.beta

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)
        if training_value is not False:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            mean, variance = nn.moments(inputs, reduction_axes)
            mean = _smart_select(training, lambda: mean,
                                 lambda: self.moving_mean)
            variance = _smart_select(training, lambda: variance,
                                     lambda: self.moving_variance)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                scale = array_ops.stop_gradient(r, name='renorm_r')
                offset = array_ops.stop_gradient(d, name='renorm_d')
                if self.gamma is not None:
                    scale *= self.gamma
                    offset *= self.gamma
                if self.beta is not None:
                    offset += self.beta
            else:
                new_mean, new_variance = mean, variance

            # Update moving averages when training, and prevent updates otherwise.
            decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
            mean_update = moving_averages.assign_moving_average(
                self.moving_mean, new_mean, decay, zero_debias=False)
            variance_update = moving_averages.assign_moving_average(
                self.moving_variance, new_variance, decay, zero_debias=False)
            if context.in_graph_mode():
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        def _broadcast(v):
            if needs_broadcasting and v is not None:
                # In this case we must explicitly broadcast all parameters.
                return array_ops.reshape(v, broadcast_shape)
            return v

        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance),
                                         _broadcast(offset), _broadcast(scale),
                                         self.epsilon)

        if self.num_virtual_batches > 1:
            return undo_virtual_batching(outputs)

        return outputs
Ejemplo n.º 25
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        # TODO(reedwm): Add support for fp16 inputs.
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(
                inputs[:self.Nb_list[0]],
                gamma,
                beta,
                epsilon=self.epsilon,
                data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(
                inputs,
                gamma,
                beta,
                mean=self.moving_mean,
                variance=self.moving_variance,
                epsilon=self.epsilon,
                is_training=False,
                data_format=self._data_format)

        output, mean, variance = utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        add_to_collection('instant_means', mean)
        add_to_collection('instant_variances', variance)
        add_to_collection('moving', self.moving_mean)
        add_to_collection('moving', self.moving_variance)
        add_to_collection('bn', mean)
        add_to_collection('bn', variance)
        self.assign_instant_statistics(mean, variance)

        if training and len(self.Nb_list) > 1:
            # TODO (rob) use fused batch norm here. it is 6X faster
            # https://github.com/tensorflow/tensorflow/issues/7551#issuecomment-280421351
            inv = math_ops.rsqrt(stop_gradient(variance) + self.epsilon) * gamma
            second_output = inputs[self.Nb_list[0]:] * inv + beta - stop_gradient(mean) * inv
            log.debug('You are doing custom batchnorm by overwriting the BatchNormClass (1)')
            output = concat((output, second_output), axis=0)

        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
            factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = utils.constant_value(training)
        if training_value is None:
            one_minus_decay = utils.smart_cond(training,
                                               lambda: 1.0 - self.momentum,
                                               lambda: 0.)
        else:
            one_minus_decay = ops.convert_to_tensor(1.0 - self.momentum)
        if training_value or training_value is None:
            mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                      one_minus_decay)
            variance_update = self._assign_moving_average(self.moving_variance,
                                                          variance, one_minus_decay)
            if True:
                # Note that in Eager mode, the updates are already executed when running
                # assign_moving_averages. So we do not need to put them into
                # collections.
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        return output