Esempio n. 1
0
def batch_normalize(tensor_in, epsilon=1e-5, convnet=False, decay=0.9, scale_after_normalization=True):
    """Batch Normalization

  Args:
    tensor_in: input Tensor, 4D shape: [batch, in_height, in_width, in_depth].
    epsilon : A float number to avoid being divided by 0.
    decay: decay rate for exponential moving average.
    convnet: Whether this is for convolutional net use. If this is True,
      moments will sum across axis [0, 1, 2]. Otherwise, only [0].
    scale_after_normalization: Whether to scale after normalization.
  """
    shape = tensor_in.get_shape().as_list()

    with vs.variable_scope("batch_norm"):
        gamma = vs.get_variable("gamma", [shape[-1]], initializer=init_ops.random_normal_initializer(1.0, 0.02))
        beta = vs.get_variable("beta", [shape[-1]], initializer=init_ops.constant_initializer(0.0))
        ema = moving_averages.ExponentialMovingAverage(decay=decay)
        if convnet:
            assign_mean, assign_var = nn.moments(tensor_in, [0, 1, 2])
        else:
            assign_mean, assign_var = nn.moments(tensor_in, [0])
        ema_assign_op = ema.apply([assign_mean, assign_var])
        ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var)

        def update_mean_var():
            """Internal function that updates mean and variance during training"""
            with ops.control_dependencies([ema_assign_op]):
                return array_ops_.identity(assign_mean), array_ops_.identity(assign_var)

        is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
        mean, variance = control_flow_ops.cond(is_training, update_mean_var, lambda: (ema_mean, ema_var))
        return nn.batch_norm_with_global_normalization(
            tensor_in, mean, variance, beta, gamma, epsilon, scale_after_normalization=scale_after_normalization
        )
Esempio n. 2
0
def fused_switch_norm(
        x,
        scale,
        offset,  # pylint: disable=invalid-name
        mean_weight,
        var_weight,
        mean_bn=None,
        variance_bn=None,
        epsilon=0.001,
        data_format="NHWC",
        is_training=True,
        name=None):
    #x = ops.convert_to_tensor(x, name="input")
    scale = ops.convert_to_tensor(scale, name="scale")
    offset = ops.convert_to_tensor(offset, name="offset")

    # sn
    mean_in, variance_in = nn.moments(x, [1, 2], keep_dims=True)
    mean_ln, variance_ln = nn.moments(x, [1, 2, 3], keep_dims=True)

    if is_training:
        if (mean_bn is not None) or (variance_bn is not None):
            raise ValueError("Both 'mean' and 'variance' must be None "
                             "if is_training is True.")
        mean_bn, variance_bn = nn.moments(x, [0, 1, 2], keep_dims=True)
    mean_weight = tf.nn.softmax(mean_weight)
    var_weight = tf.nn.softmax(var_weight)
    mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[
        2] * mean_bn
    variance = var_weight[0] * variance_in + var_weight[
        1] * variance_ln + var_weight[2] * variance_bn

    outputs = scale * (x - mean) / (tf.sqrt(variance + epsilon)) + offset

    return outputs, tf.squeeze(mean_bn), tf.squeeze(variance_bn)
Esempio n. 3
0
def batch_normalize(tensor_in,
                    epsilon=1e-5,
                    convnet=False,
                    decay=0.9,
                    scale_after_normalization=True):
  """Batch normalization.

  Args:
    tensor_in: input `Tensor`, 4D shape: [batch, in_height, in_width, in_depth].
    epsilon : A float number to avoid being divided by 0.
    convnet: Whether this is for convolutional net use. If `True`, moments
        will sum across axis `[0, 1, 2]`. Otherwise, only `[0]`.
    decay: Decay rate for exponential moving average.
    scale_after_normalization: Whether to scale after normalization.

  Returns:
    A batch-normalized `Tensor`.
  """
  shape = tensor_in.get_shape().as_list()

  with vs.variable_scope("batch_norm"):
    gamma = vs.get_variable(
        "gamma", [shape[-1]],
        initializer=init_ops.random_normal_initializer(1., 0.02))
    beta = vs.get_variable("beta", [shape[-1]],
                           initializer=init_ops.constant_initializer(0.))
    ema = moving_averages.ExponentialMovingAverage(decay=decay)
    if convnet:
      assign_mean, assign_var = nn.moments(tensor_in, [0, 1, 2])
    else:
      assign_mean, assign_var = nn.moments(tensor_in, [0])
    ema_assign_op = ema.apply([assign_mean, assign_var])
    ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var)

    def _update_mean_var():
      """Internal function that updates mean and variance during training."""
      with ops.control_dependencies([ema_assign_op]):
        return array_ops_.identity(assign_mean), array_ops_.identity(assign_var)

    is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
    mean, variance = control_flow_ops.cond(is_training, _update_mean_var,
                                           lambda: (ema_mean, ema_var))
    return nn.batch_norm_with_global_normalization(
        tensor_in,
        mean,
        variance,
        beta,
        gamma,
        epsilon,
        scale_after_normalization=scale_after_normalization)
    def test_virtual_statistics(self):
        """Check that `_virtual_statistics` gives same result as `nn.moments`."""
        random_seed.set_random_seed(1234)

        batch_axis = 0
        partial_batch = random_ops.random_normal([4, 5, 7, 3])
        single_example = random_ops.random_normal([1, 5, 7, 3])
        full_batch = array_ops.concat([partial_batch, single_example], axis=0)

        for reduction_axis in range(1, 4):
            # Get `nn.moments` on the full batch.
            reduction_axes = list(range(4))
            del reduction_axes[reduction_axis]
            mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)

            # Get virtual batch statistics.
            vb_reduction_axes = list(range(4))
            del vb_reduction_axes[reduction_axis]
            del vb_reduction_axes[batch_axis]
            vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
            vb_mean, mean_sq = vbn._virtual_statistics(single_example,
                                                       vb_reduction_axes)
            vb_variance = mean_sq - math_ops.square(vb_mean)
            # Remove singleton batch dim for easy comparisons.
            vb_mean = array_ops.squeeze(vb_mean, batch_axis)
            vb_variance = array_ops.squeeze(vb_variance, batch_axis)

            with self.test_session(use_gpu=True) as sess:
                vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run(
                    [vb_mean, vb_variance, mom_mean, mom_variance])

            self.assertAllClose(mom_mean_np, vb_mean_np)
            self.assertAllClose(mom_var_np, vb_var_np)
Esempio n. 5
0
    def call(self, inputs):
        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)

        # Calculate the moments on the last axis (layer activations).
        mean, variance = nn.moments(inputs, self.norm_axis, keep_dims=True)

        # Broadcasting only necessary for norm where the params axes aren't just
        # the last dimension
        broadcast_shape = [1] * ndims
        for dim in self.params_axis:
            broadcast_shape[dim] = input_shape.dims[dim].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and self.params_axis != [ndims - 1]):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        # Compute layer normalization using the batch_normalization function.
        outputs = nn.batch_normalization(inputs,
                                         mean,
                                         variance,
                                         offset=offset,
                                         scale=scale,
                                         variance_epsilon=self.epsilon)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs
  def _inverse_log_det_jacobian(self, y, use_saved_statistics=False):
    if not y.shape.is_fully_defined():
      raise ValueError("Input must have shape known at graph construction.")
    input_shape = np.int32(y.shape.as_list())

    if not self.batchnorm.built:
      # Create variables.
      self.batchnorm.build(input_shape)

    event_dims = self.batchnorm.axis
    reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims]

    if use_saved_statistics or not self._training:
      log_variance = math_ops.log(
          self.batchnorm.moving_variance + self.batchnorm.epsilon)
    else:
      # At training-time, ildj is computed from the mean and log-variance across
      # the current minibatch.
      _, v = nn.moments(y, axes=reduction_axes, keep_dims=True)
      log_variance = math_ops.log(v + self.batchnorm.epsilon)

    # `gamma` and `log Var(y)` reductions over event_dims.
    # Log(total change in area from gamma term).
    log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma))

    # Log(total change in area from log-variance term).
    log_total_variance = math_ops.reduce_sum(log_variance)
    # The ildj is scalar, as it does not depend on the values of x and are
    # constant across minibatch elements.
    return log_total_gamma - 0.5 * log_total_variance
Esempio n. 7
0
    def RunMomentTestWithDynamicShape(self, shape, global_norm):
        with self.test_session():
            # shape = [batch, width, height, depth]
            assert len(shape) == 4

            x_numpy = np.random.normal(size=shape).astype(np.float32)
            x = array_ops.placeholder(types.float32, shape=[None] * len(shape))

            axes = [0, 1, 2] if global_norm else [0]
            mean, var = nn.moments(x, axes)

            num_elements = np.prod([shape[i] for i in axes])

            ax = (0, 1, 2) if global_norm else (0)
            expected_mean = np.sum(x_numpy, axis=ax) / num_elements
            expected_mean_squared = np.multiply(expected_mean, expected_mean)
            expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
                                        axis=ax) / num_elements
            expected_variance = expected_x_squared - expected_mean_squared

            # Check that the moments are correct.
            self.assertAllClose(expected_mean,
                                mean.eval(feed_dict={x: x_numpy}))
            self.assertAllClose(expected_variance,
                                var.eval(feed_dict={x: x_numpy}))
  def test_virtual_statistics(self):
    """Check that `_virtual_statistics` gives same result as `nn.moments`."""
    random_seed.set_random_seed(1234)

    batch_axis = 0
    partial_batch = random_ops.random_normal([4, 5, 7, 3])
    single_example = random_ops.random_normal([1, 5, 7, 3])
    full_batch = array_ops.concat([partial_batch, single_example], axis=0)

    for reduction_axis in range(1, 4):
      # Get `nn.moments` on the full batch.
      reduction_axes = list(range(4))
      del reduction_axes[reduction_axis]
      mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)

      # Get virtual batch statistics.
      vb_reduction_axes = list(range(4))
      del vb_reduction_axes[reduction_axis]
      del vb_reduction_axes[batch_axis]
      vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
      vb_mean, mean_sq = vbn._virtual_statistics(
          single_example, vb_reduction_axes)
      vb_variance = mean_sq - math_ops.square(vb_mean)
      # Remove singleton batch dim for easy comparisons.
      vb_mean = array_ops.squeeze(vb_mean, batch_axis)
      vb_variance = array_ops.squeeze(vb_variance, batch_axis)

      with self.cached_session(use_gpu=True) as sess:
        vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
            vb_mean, vb_variance, mom_mean, mom_variance])

      self.assertAllClose(mom_mean_np, vb_mean_np)
      self.assertAllClose(mom_var_np, vb_var_np)
Esempio n. 9
0
            def my_graph(a):
                with ops.device("/device:IPU:0"):
                    with variable_scope.variable_scope("", use_resource=True):

                        beta = variable_scope.get_variable(
                            "x",
                            dtype=np.float32,
                            shape=[4],
                            initializer=init_ops.constant_initializer(0.0))
                        gamma = variable_scope.get_variable(
                            "y",
                            dtype=np.float32,
                            shape=[4],
                            initializer=init_ops.constant_initializer(1.0))

                        b_mean, b_var = nn.moments(a, [0, 1, 2],
                                                   name='moments')

                        normed = nn.fused_batch_norm(a,
                                                     gamma,
                                                     beta,
                                                     b_mean,
                                                     b_var,
                                                     is_training=False)
                        return normed
Esempio n. 10
0
  def call(self, inputs):
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)

    # Calculate the moments on the last axis (layer activations).
    mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

    # Broadcasting only necessary for norm where the axis is not just
    # the last dimension
    broadcast_shape = [1] * ndims
    for dim in self.axis:
      broadcast_shape[dim] = input_shape.dims[dim].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          self.axis != [ndims - 1]):
        return array_ops.reshape(v, broadcast_shape)
      return v
    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    # Compute layer normalization using the batch_normalization function.
    outputs = nn.batch_normalization(
        inputs,
        mean,
        variance,
        offset=offset,
        scale=scale,
        variance_epsilon=self.epsilon)

    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    return outputs
  def _inverse_log_det_jacobian(self, y, use_saved_statistics=False):
    if not y.shape.is_fully_defined():
      raise ValueError("Input must have shape known at graph construction.")
    input_shape = np.int32(y.shape.as_list())

    if not self.batchnorm.built:
      # Create variables.
      self.batchnorm.build(input_shape)

    event_dims = self.batchnorm.axis
    reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims]

    if use_saved_statistics or not self._training:
      log_variance = math_ops.log(
          self.batchnorm.moving_variance + self.batchnorm.epsilon)
    else:
      # At training-time, ildj is computed from the mean and log-variance across
      # the current minibatch.
      _, v = nn.moments(y, axes=reduction_axes, keepdims=True)
      log_variance = math_ops.log(v + self.batchnorm.epsilon)

    # `gamma` and `log Var(y)` reductions over event_dims.
    # Log(total change in area from gamma term).
    log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma))

    # Log(total change in area from log-variance term).
    log_total_variance = math_ops.reduce_sum(log_variance)
    # The ildj is scalar, as it does not depend on the values of x and are
    # constant across minibatch elements.
    return log_total_gamma - 0.5 * log_total_variance
Esempio n. 12
0
    def normalize(self, inputs):
        """Apply normalization to input.

    The shape must match the declared shape in the constructor.
    [This is copied from tf.contrib.rnn.LayerNormBasicLSTMCell.]

    Args:
      inputs: Input tensor

    Returns:
      Normalized version of input tensor.

    Raises:
      ValueError: if inputs has undefined rank.
    """
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        axis = range(1, inputs_rank)

        beta = self._component.get_variable('beta_%s' % self._name)
        gamma = self._component.get_variable('gamma_%s' % self._name)

        with tf.variable_scope('layer_norm_%s' % self._name):
            # Calculate the moments on the last axis (layer activations).
            mean, variance = nn.moments(inputs, axis, keep_dims=True)

            # Compute layer normalization using the batch_normalization function.
            variance_epsilon = 1E-12
            outputs = nn.batch_normalization(inputs, mean, variance, beta,
                                             gamma, variance_epsilon)
            outputs.set_shape(inputs_shape)
            return outputs
Esempio n. 13
0
def batch_norm(x, deterministic, alpha=0.9, shift=True, scope='bn'):
    with vs.variable_scope(scope):
        dtype = x.dtype
        input_shape = x.get_shape().as_list()
        feat_dim = input_shape[-1]
        axes = range(len(input_shape) - 1)

        if shift:
            beta = vs.get_variable(scope + "_beta",
                                   shape=[feat_dim],
                                   initializer=init_ops.zeros_initializer,
                                   dtype=dtype)
        else:
            beta = vs.get_variable(scope + "_beta",
                                   shape=[feat_dim],
                                   initializer=init_ops.zeros_initializer,
                                   dtype=dtype,
                                   trainable=False)

        gamma = vs.get_variable(scope + "_gamma",
                                shape=[feat_dim],
                                initializer=init_ops.constant_initializer(0.1),
                                dtype=dtype)

        mean = vs.get_variable(scope + "_mean",
                               shape=[feat_dim],
                               initializer=init_ops.zeros_initializer,
                               dtype=dtype,
                               trainable=False)

        var = vs.get_variable(scope + "_var",
                              shape=[feat_dim],
                              initializer=init_ops.ones_initializer,
                              dtype=dtype,
                              trainable=False)

        counter = vs.get_variable(scope + "_counter",
                                  shape=[],
                                  initializer=init_ops.constant_initializer(0),
                                  dtype=tf.int64,
                                  trainable=False)

        zero_cnt = vs.get_variable(
            scope + "_zero_cnt",
            shape=[],
            initializer=init_ops.constant_initializer(0),
            dtype=tf.int64,
            trainable=False)

        batch_mean, batch_var = moments(x, axes, name=scope + '_moments')

        mean, var = cond(math_ops.equal(counter, zero_cnt), lambda:
                         (batch_mean, batch_var), lambda: (mean, var))

        mean, var, counter = cond(
            deterministic, lambda: (mean, var, counter), lambda:
            ((1 - alpha) * batch_mean + alpha * mean,
             (1 - alpha) * batch_var + alpha * var, counter + 1))
        normed = batch_normalization(x, mean, var, beta, gamma, 1e-8)
    return normed
Esempio n. 14
0
 def _moments(self, inputs, reduction_axes, keep_dims):
   mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
   # TODO(b/129279393): Support zero batch input in non DistributionStrategy
   # code as well.
   if self._support_zero_size_input():
     inputs_size = array_ops.size(inputs)
     mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
     variance = array_ops.where(inputs_size > 0, variance,
                                K.zeros_like(variance))
   return mean, variance
Esempio n. 15
0
 def _update_mean_var():
   """Internal function that updates mean and variance during training."""
   axis = [0, 1, 2] if convnet else [0]
   mean, var = nn.moments(tensor_in, axis)
   update_moving_mean = moving_averages.assign_moving_average(
       moving_mean, mean, decay)
   update_moving_var = moving_averages.assign_moving_average(
       moving_var, var, decay)
   with ops.control_dependencies([update_moving_mean, update_moving_var]):
     return array_ops_.identity(mean), array_ops_.identity(var)
Esempio n. 16
0
def _normalize_patches(patches):
    """Normalize patches by their mean and standard deviation.
  Args:
      patches: (tensor) The batch of patches (batch, size, size, channels).
  Returns:
      Tensor (batch, size, size, channels) of the normalized patches.
  """
    patches = array_ops.concat(patches, 0)
    mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True)
    patches = (patches - mean) / math_ops.sqrt(variance)
    return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1])
Esempio n. 17
0
 def var(weights, name=None):
   """Applies variance regularization to weights."""
   with ops.name_scope(scope, 'var_regularizer', [weights]) as name:
     my_scale = ops.convert_to_tensor(scale,
                                      dtype=weights.dtype.base_dtype,
                                      name='scale')
     _, var_axis0 = nn.moments(weights, axes=axes)
     return standard_ops.multiply(
         my_scale,
         var_axis0,
         name=name)
def _normalize_patches(patches):
  """Normalize patches by their mean and standard deviation.

  Args:
      patches: (tensor) The batch of patches (batch, size, size, channels).
  Returns:
      Tensor (batch, size, size, channels) of the normalized patches.
  """
  patches = array_ops.concat(patches, 0)
  mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True)
  patches = (patches - mean) / math_ops.sqrt(variance)
  return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1])
Esempio n. 19
0
 def _moments(self, inputs, reduction_axes, keep_dims):
   mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
   # TODO(b/129279393): Support zero batch input in non DistributionStrategy
   # code as well.
   # TODO(b/130185866): Support zero batch input in graph mode.
   if (ops.executing_eagerly_outside_functions() and
       distribution_strategy_context.has_strategy()):
     inputs_size = array_ops.size(inputs)
     mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
     variance = array_ops.where(inputs_size > 0, variance,
                                K.zeros_like(variance))
   return mean, variance
Esempio n. 20
0
 def _moments(self, inputs, reduction_axes, keep_dims):
     mean, variance = nn.moments(inputs,
                                 reduction_axes,
                                 keep_dims=keep_dims)
     # TODO(b/129279393): Support zero batch input in non DistributionStrategy
     # code as well.
     if distribution_strategy_context.has_strategy(
     ) and not inputs.shape.is_fully_defined():
         inputs_size = array_ops.size(inputs)
         mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
         variance = array_ops.where(inputs_size > 0, variance,
                                    K.zeros_like(variance))
     return mean, variance
Esempio n. 21
0
 def series_start_updates():
   # If this is the lowest-time chunk that we have seen so far, update
   # series start moments to reflect that. Note that these statistics are
   # "best effort", as there are race conditions in the update (however,
   # they should eventually converge if the start of the series is
   # presented enough times).
   mean, variance = nn.moments(
       values[min_time_batch, :self._starting_variance_window_size],
       axes=[0])
   return control_flow_ops.group(
       state_ops.assign(statistics.series_start_moments.mean, mean),
       state_ops.assign(statistics.series_start_moments.variance,
                        variance))
Esempio n. 22
0
 def series_start_updates():
   # If this is the lowest-time chunk that we have seen so far, update
   # series start moments to reflect that. Note that these statistics are
   # "best effort", as there are race conditions in the update (however,
   # they should eventually converge if the start of the series is
   # presented enough times).
   mean, variance = nn.moments(
       values[min_time_batch, :self._starting_variance_window_size],
       axes=[0])
   return control_flow_ops.group(
       state_ops.assign(statistics.series_start_moments.mean, mean),
       state_ops.assign(statistics.series_start_moments.variance,
                        variance))
Esempio n. 23
0
 def _moments(self, inputs, reduction_axes, keep_dims):
     mean, variance = nn.moments(inputs,
                                 reduction_axes,
                                 keep_dims=keep_dims)
     # TODO(b/129279393): Support zero batch input in non DistributionStrategy
     # code as well.
     if distribution_strategy_context.has_strategy():
         inputs_size = array_ops.size(inputs)
         mean = tf_utils.smart_cond(inputs_size > 0, lambda: mean,
                                    lambda: K.zeros_like(mean))
         variance = tf_utils.smart_cond(inputs_size > 0, lambda: variance,
                                        lambda: K.zeros_like(variance))
     return mean, variance
 def instance_norm(self, inputs, inputs_latent, name):
     inputs_rank = inputs.shape.ndims
     n_outputs = np.int(inputs.shape[-1])
     n_batch = np.int(inputs.shape[0])
     inputs_latent_flatten = tf.layers.flatten(inputs_latent)
     gamma = self.MLP(inputs_latent_flatten, n_outputs, name+"_gamma")
     beta = self.MLP(inputs_latent_flatten, n_outputs, name+"_beta")
     gamma = tf.reshape(gamma, [n_batch, 1, 1, n_outputs])
     beta = tf.reshape(beta, [n_batch, 1, 1, n_outputs])
     moments_axes = list(range(inputs_rank))
     mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
     outputs = nn.batch_normalization(
         inputs, mean, variance, beta, gamma, 1e-6, name=name)
     return outputs
Esempio n. 25
0
def batch_norm(x, deterministic, alpha=0.9, shift=True, scope='bn'):
    with vs.variable_scope(scope):
        dtype = x.dtype
        input_shape = x.get_shape().as_list()
        feat_dim = input_shape[-1]
        axes = range(len(input_shape)-1)
        
        if shift:
            beta = vs.get_variable(
                    scope+"_beta", shape=[feat_dim],
                    initializer=init_ops.zeros_initializer, dtype=dtype)
        else:
            beta = vs.get_variable(
                scope+"_beta", shape=[feat_dim],
                initializer=init_ops.zeros_initializer, 
                dtype=dtype, trainable=False)
        
        gamma = vs.get_variable(
                    scope+"_gamma", shape=[feat_dim],
                    initializer=init_ops.constant_initializer(0.1), dtype=dtype)
        
        mean = vs.get_variable(scope+"_mean", shape=[feat_dim],
                                       initializer=init_ops.zeros_initializer,
                                       dtype=dtype, trainable=False)
        
        var = vs.get_variable(scope+"_var", shape=[feat_dim],
                                          initializer=init_ops.ones_initializer,
                                          dtype=dtype, trainable=False)
        
        counter = vs.get_variable(scope+"_counter", shape=[],
                                          initializer=init_ops.constant_initializer(0),
                                          dtype=tf.int64, trainable=False)
        
        zero_cnt = vs.get_variable(scope+"_zero_cnt", shape=[],
                                          initializer=init_ops.constant_initializer(0),
                                          dtype=tf.int64, trainable=False)
        
        batch_mean, batch_var = moments(x, axes, name=scope+'_moments')
        
        mean, var = cond(math_ops.equal(counter, zero_cnt), lambda: (batch_mean, batch_var), 
                         lambda: (mean, var))
        
         
        mean, var, counter = cond(deterministic, lambda: (mean, var, counter), 
                                 lambda: ((1-alpha) * batch_mean + alpha * mean, 
                                         (1-alpha) * batch_var + alpha * var, 
                                         counter + 1))
        normed = batch_normalization(x, mean, var, beta, gamma, 1e-8)
    return normed
Esempio n. 26
0
 def after_forward(self, layer, outputs, inputs, **kwargs):
     message = ops.get_name_scope() + '/' + layer.name
     axes = list(range(len(outputs[0].shape) - 1))
     print_ops = []
     for i, x in enumerate(outputs):
         mean, var = nn.moments(x, axes=axes)
         print_ops.append(
             logging_ops.print_v2(array_ops.constant(message +
                                                     '/output:%d' % i),
                                  output_stream=sys.stdout))
         print_ops.append(
             logging_ops.print_v2(mean, var, output_stream=sys.stdout))
     with control_dependencies(print_ops):
         for i, x in enumerate(outputs):
             outputs[i] = array_ops.identity(x)
Esempio n. 27
0
def ln(tensor, scope=None, epsilon=1e-5):
    """ Layer normalizes a 2D tensor along its second axis """
    assert len(tensor.get_shape()) == 2
    m, v = nn.moments(tensor, [1], keep_dims=True)
    if not isinstance(scope, str):
        scope = ''
    with vs.variable_scope(scope + 'layer_norm'):
        scale = vs.get_variable('scale',
                                shape=[tensor.get_shape()[1]],
                                initializer=init_ops.constant_initializer(1))
        shift = vs.get_variable('shift',
                                shape=[tensor.get_shape()[1]],
                                initializer=init_ops.constant_initializer(0))
    ln_initial = (tensor - m) / math_ops.sqrt(v + epsilon)

    return ln_initial * scale + shift
    def test_statistics(self):
        """Check that `_statistics` gives the same result as `nn.moments`."""
        random_seed.set_random_seed(1234)

        tensors = random_ops.random_normal([4, 5, 7, 3])
        for axes in [(3), (0, 2), (1, 2, 3)]:
            vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
            mom_mean, mom_var = nn.moments(tensors, axes)
            vb_var = mean_sq - math_ops.square(vb_mean)

            with self.test_session(use_gpu=True) as sess:
                vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run(
                    [vb_mean, vb_var, mom_mean, mom_var])

            self.assertAllClose(mom_mean_np, vb_mean_np)
            self.assertAllClose(mom_var_np, vb_var_np)
  def test_statistics(self):
    """Check that `_statistics` gives the same result as `nn.moments`."""
    random_seed.set_random_seed(1234)

    tensors = random_ops.random_normal([4, 5, 7, 3])
    for axes in [(3), (0, 2), (1, 2, 3)]:
      vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
      mom_mean, mom_var = nn.moments(tensors, axes)
      vb_var = mean_sq - math_ops.square(vb_mean)

      with self.cached_session(use_gpu=True) as sess:
        vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
            vb_mean, vb_var, mom_mean, mom_var])

      self.assertAllClose(mom_mean_np, vb_mean_np)
      self.assertAllClose(mom_var_np, vb_var_np)
Esempio n. 30
0
    def _testGlobalGradient(self, from_y="mean"):
        with self.test_session():
            x_shape = [3, 5, 4, 2]
            x_val = np.random.random_sample(x_shape).astype(np.float64)
            x = constant_op.constant(x_val)
            x.set_shape(x_shape)

            axes = [0, 1, 2]
            y_shape = [2]  # Depth of x
            out_mean, out_var = nn.moments(x, axes)
            if from_y == "mean":
                y = out_mean
            elif from_y == "var":
                y = out_var
            err = gc.ComputeGradientError(x, x_shape, y, y_shape)
            print("Moments %s gradient err = %g" % (from_y, err))
            self.assertLess(err, 1e-11)
Esempio n. 31
0
    def _data_dep_init(self, inputs):
        """Data dependent initialization"""

        with name_scope('data_dep_init'):
            # Generate data dependent init values
            activation = self.layer.activation
            self.layer.activation = None
            x_init = self.layer.call(inputs)
            data_norm_axes = list(range(x_init.shape.rank - 1))
            m_init, v_init = moments(x_init, data_norm_axes)
            scale_init = 1. / sqrt(v_init + 1e-10)

        # Assign data dependent init values
        self.layer.g = self.layer.g * scale_init
        self.layer.bias = (-m_init * scale_init)
        self.layer.activation = activation
        self.initialized = True
Esempio n. 32
0
  def _testGlobalGradient(self, from_y="mean"):
    with self.test_session():
      x_shape = [3, 5, 4, 2]
      x_val = np.random.random_sample(x_shape).astype(np.float64)
      x = constant_op.constant(x_val)
      x.set_shape(x_shape)

      axes = [0, 1, 2]
      y_shape = [2]  # Depth of x
      out_mean, out_var = nn.moments(x, axes)
      if from_y == "mean":
        y = out_mean
      elif from_y == "var":
        y = out_var
      err = gc.ComputeGradientError(x, x_shape, y, y_shape)
      print "Moments %s gradient err = %g" % (from_y, err)
      self.assertLess(err, 1e-11)
Esempio n. 33
0
    def _data_dep_init(self, inputs):
        """Data dependent initialization for eager execution"""
        from tensorflow.python.ops.nn import moments
        from tensorflow.python.ops.math_ops import sqrt

        with name_scope('data_dep_init'):
            # Generate data dependent init values
            activation = self.layer.activation
            self.layer.activation = None
            x_init = self.layer.call(inputs)
            m_init, v_init = moments(x_init, self.norm_axes)
            scale_init = 1. / sqrt(v_init + 1e-10)

        # Assign data dependent init values
        self.layer.g = self.layer.g * scale_init
        self.layer.bias = (-m_init * scale_init)
        self.layer.activation = activation
        self.initialized = True
Esempio n. 34
0
    def testBatchNormalizeFused(self):
        x = array_ops.placeholder(np.float32, [4, 64, 64, 4], name="a")

        with ops.device("/device:IPU:0"):
            with variable_scope.variable_scope("", use_resource=True):

                beta = variable_scope.get_variable(
                    "x",
                    dtype=np.float32,
                    shape=[4],
                    initializer=init_ops.constant_initializer(0.0))
                gamma = variable_scope.get_variable(
                    "y",
                    dtype=np.float32,
                    shape=[4],
                    initializer=init_ops.constant_initializer(1.0))

                b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments')

                normed = nn.fused_batch_norm(x,
                                             gamma,
                                             beta,
                                             b_mean,
                                             b_var,
                                             is_training=False)

        with ops.device('cpu'):
            report = gen_ipu_ops.ipu_event_trace()

        tu.configure_ipu_system()

        with tu.ipu_session() as sess:
            sess.run(report)

            sess.run(variables.global_variables_initializer())
            result, _, _ = sess.run(normed, {x: np.zeros([4, 64, 64, 4])})
            self.assertAllClose(result, np.zeros([4, 64, 64, 4]))

            rep = sess.run(report)
            s = tu.extract_all_strings_from_event_trace(rep)
            cs = tu.get_compute_sets_from_report(s)

            bl = ['*convert*/Cast*']
            self.assertTrue(tu.check_compute_sets_not_in_blacklist(cs, bl))
Esempio n. 35
0
    def call(self, inputs, training=True):
        layer_inputs = inputs[0]
        mix_weights = inputs[1]
        self.assign_mixture_value(name="template_beta",
                                  mixture_weights=mix_weights)
        self.assign_mixture_value(name="template_gamma",
                                  mixture_weights=mix_weights)
        norm_input = BatchNormalization.call(self, layer_inputs, training)

        input_shape = layer_inputs.shape
        reduction_axes = [
            i for i in range(len(input_shape)) if i not in self.axis
        ]
        mean, var = nn.moments(norm_input, reduction_axes, keep_dims=True)
        scale = self._broadcast(self.template_gamma, input_shape)
        offset = self._broadcast(self.template_beta, input_shape)
        output = nn.batch_normalization(norm_input, mean, var, offset, scale,
                                        self.epsilon)
        self.reset_all_values()
        return output
Esempio n. 36
0
                    def my_graph(a):
                        beta = variable_scope.get_variable(
                            "x",
                            dtype=np.float16,
                            shape=[4],
                            initializer=init_ops.constant_initializer(0.0))
                        gamma = variable_scope.get_variable(
                            "y",
                            dtype=np.float16,
                            shape=[4],
                            initializer=init_ops.constant_initializer(1.0))

                        b_mean, b_var = nn.moments(a, [0, 1, 2],
                                                   name='moments')

                        normed = nn.fused_batch_norm(a,
                                                     gamma,
                                                     beta,
                                                     b_mean,
                                                     b_var,
                                                     is_training=False)
                        return normed
Esempio n. 37
0
    def RunMomentTest(self, shape, global_norm):
        with self.test_session():
            # shape = [batch, width, height, depth]
            assert len(shape) == 4

            x_numpy = np.random.normal(size=shape).astype(np.float32)
            x = constant_op.constant(x_numpy)

            axes = [0, 1, 2] if global_norm else [0]
            mean, var = nn.moments(x, axes)

            num_elements = np.prod([shape[i] for i in axes])

            ax = (0, 1, 2) if global_norm else (0)
            expected_mean = np.sum(x_numpy, axis=ax) / num_elements
            expected_mean_squared = np.multiply(expected_mean, expected_mean)
            expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), axis=ax) / num_elements
            expected_variance = expected_x_squared - expected_mean_squared

            # Check that the moments are correct.
            self.assertAllClose(expected_mean, mean.eval())
            self.assertAllClose(expected_variance, var.eval())
Esempio n. 38
0
            def model(x, y, z):
                scale = gen_array_ops.broadcast_to(z, shape=[65536])
                offset = scale
                b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments')
                a = nn.fused_batch_norm(x,
                                        scale,
                                        offset,
                                        b_mean,
                                        b_var,
                                        1e-3,
                                        is_training=False,
                                        name="a")
                b = nn.fused_batch_norm(y,
                                        scale,
                                        offset,
                                        b_mean,
                                        b_var,
                                        1e-3,
                                        is_training=False,
                                        name="b")

                return a[0] + b[0]
Esempio n. 39
0
    def RunMomentTest(self, shape, global_norm):
        with self.test_session():
            # shape = [batch, width, height, depth]
            assert len(shape) == 4

            x_numpy = np.random.normal(size=shape).astype(np.float32)
            x = constant_op.constant(x_numpy)
            x.set_shape(shape)
            axes = [0, 1, 2] if global_norm else [0]
            mean, var = nn.moments(x, axes)

            num_elements = np.prod([shape[i] for i in axes])

            ax = (0, 1, 2) if global_norm else (0)
            expected_mean = np.sum(x_numpy, axis=ax) / num_elements
            expected_mean_squared = np.multiply(expected_mean, expected_mean)
            expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
                                        axis=ax) / num_elements
            expected_variance = expected_x_squared - expected_mean_squared

            # Check that the moments are correct.
            self.assertAllClose(expected_mean, mean.eval())
            self.assertAllClose(expected_variance, var.eval())
Esempio n. 40
0
    def _subdiv_calculate_mean_and_var(self, inputs, reduction_axes,
                                       keep_dims):
        # calculate the
        net_sum = math_ops.reduce_sum(inputs,
                                      axis=reduction_axes,
                                      keepdims=keep_dims)
        squared_mean = math_ops.reduce_sum(math_ops.square(inputs),
                                           axis=reduction_axes,
                                           keepdims=keep_dims)

        if self._support_zero_size_input():
            # Keras assumes that batch dimension is the first dimension for Batch
            # Normalization.
            input_batch_size = array_ops.shape(inputs)[0]
        else:
            input_batch_size = None

        # get the number of total params you are averaging including batchsize(local)
        axes_vals = [(array_ops.shape_v2(inputs))[i]
                     for i in range(1, len(reduction_axes))]
        multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                   dtypes.float32)

        squared_mean = squared_mean / multiplier
        net_sum = net_sum / multiplier

        if input_batch_size is None:
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)
        else:
            batches_ = math_ops.cast(input_batch_size, self._param_dtype)
            mean = net_sum / batches_
            variance = squared_mean / batches_ - math_ops.square(
                array_ops.stop_gradient(mean))

        return mean, net_sum, variance, squared_mean, input_batch_size
Esempio n. 41
0
  def call(self, inputs, training=False):
    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        return undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = utils.smart_cond(training,
                              lambda: mean,
                              lambda: moving_mean)
      variance = utils.smart_cond(training,
                                  lambda: variance,
                                  lambda: moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = utils.smart_cond(training, lambda: self.momentum, lambda: 1.)
      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(new_mean,
                                        axis=1, keep_dims=True)
        new_variance = math_ops.reduce_mean(new_variance,
                                            axis=1, keep_dims=True)

      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)
      if context.in_graph_mode():
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
    rank = len(inputs.get_shape())
    def _broadcast(v):
      if (v is not None and
          len(v.get_shape()) != rank and
          reduction_axes != list(range(ndims))[:-1]):
        return array_ops.reshape(v, broadcast_shape)
      return v

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     _broadcast(offset),
                                     _broadcast(scale),
                                     self.epsilon)

    if self.virtual_batch_size is not None:
      return undo_virtual_batching(outputs)

    return outputs
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               scope=None):
  """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
    "Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift"
    Sergey Ioffe, Christian Szegedy
  Can be used as a normalizer function for conv2d and fully_connected.
  Args:
    -inputs: a tensor of size `[batch_size, height, width, channels]`
            or `[batch_size, channels]`.
    -decay: decay for the moving average.
    -center: If True, subtract `beta`. If False, `beta` is ignored.
    -scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    -epsilon: small float added to variance to avoid dividing by zero.
    -activation_fn: Optional activation function.
    -updates_collections: collections to collect the update ops for computation.
      If None, a control dependency would be added to make sure the updates are
      computed.
    -is_training: whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    -reuse: whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    -variables_collections: optional collections for the variables.
    -outputs_collections: collections to add the outputs.
    -trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    -scope: Optional scope for `variable_op_scope`.
  Returns:
    a tensor representing the output of the operation.
  """
  with variable_scope.variable_op_scope([inputs],scope, 'BatchNorm', reuse=reuse) as sc:
    inputs_shape = inputs.get_shape()
    dtype = inputs.dtype.base_dtype
    axis = list(range(len(inputs_shape) - 1))
    params_shape = inputs_shape[-1:]
    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,'beta')
      beta = variables.model_variable('beta',shape=params_shape,dtype=dtype,initializer=init_ops.zeros_initializer,collections=beta_collections,trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,'gamma')
      gamma = variables.model_variable('gamma',shape=params_shape,dtype=dtype,initializer=init_ops.ones_initializer,collections=gamma_collections,trainable=trainable)
    # Create moving_mean and moving_variance variables and add them to the
    # appropiate collections.
    moving_mean_collections = utils.get_variable_collections(variables_collections, 'moving_mean')
    moving_mean = variables.model_variable('moving_mean',shape=params_shape,dtype=dtype,initializer=init_ops.zeros_initializer,trainable=False,collections=moving_mean_collections)
    moving_variance_collections = utils.get_variable_collections(variables_collections, 'moving_variance')
    moving_variance = variables.model_variable('moving_variance',shape=params_shape,dtype=dtype,initializer=init_ops.ones_initializer,trainable=False,collections=moving_variance_collections)
    if is_training:
      # Calculate the moments based on the individual batch.
      mean, variance = nn.moments(inputs, axis, shift=moving_mean)
      # Update the moving_mean and moving_variance moments.
      update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay)
      update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay)
      if updates_collections is None:
        # Make sure the updates are computed here.
        with ops.control_dependencies([update_moving_mean,update_moving_variance]):
          outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
      else:
        # Collect the updates to be computed later.
        ops.add_to_collections(updates_collections, update_moving_mean)
        ops.add_to_collections(updates_collections, update_moving_variance)
        outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
    else:
      outputs = nn.batch_normalization(
          inputs, moving_mean, moving_variance, beta, gamma, epsilon)
    outputs.set_shape(inputs.get_shape())
    if activation_fn:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
Esempio n. 43
0
  def call(self, inputs, training=False):
    if self.num_virtual_batches > 1:
      # Virtual batches (aka ghost batches) can be simulated by using some
      # reshape/transpose tricks on top of base batch normalization.
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [-1, self.num_virtual_batches] + original_shape[1:]

      # Will cause errors if num_virtual_batches does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      ndims = len(expanded_shape)
      if self.axis < 0:
        axis = ndims + self.axis
      else:
        axis = self.axis + 1      # Account for the added dimension

      # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis
      # TODO(b/66257056): when multi-axis batch normalization is implemented,
      # this permutation trick and the combined_dim reshape are no longer
      # necessary and can be reworked to simply use broadcasting.
      permutation = ([0] + list(range(2, axis)) + [1, axis] +
                     list(range(axis + 1, ndims)))
      inverse_permutation = [x[1] for x in
                             sorted(zip(permutation, range(ndims)))]
      inputs = array_ops.transpose(inputs, perm=permutation)

      # Combine the axis and num_virtual_batch dimension in order to take
      # advantage of fused batch normalization
      combined_dim = expanded_shape[1] * expanded_shape[axis]
      perm_shape = [-1] + inputs.shape.as_list()[1:]
      combined_shape = (perm_shape[:axis - 1] +
                        [combined_dim] +
                        perm_shape[axis + 1:])
      inputs = array_ops.reshape(inputs, combined_shape)
      # After the above reshape, the batch norm axis is the original self.axis

      # Undoes the reshaping and transposing tricks done above
      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, perm_shape)
        outputs = array_ops.transpose(outputs, perm=inverse_permutation)
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.num_virtual_batches > 1:
        return undo_virtual_batching(outputs)
      return outputs

    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)
      if context.in_graph_mode():
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explicitly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     _broadcast(offset),
                                     _broadcast(scale),
                                     self.epsilon)

    if self.num_virtual_batches > 1:
      return undo_virtual_batching(outputs)

    return outputs
Esempio n. 44
0
  def call(self, inputs, training=False):
    if self.fused:
      return self._fused_batch_norm(inputs, training=training)

    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)

      self.add_update(mean_update, inputs=inputs)
      self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explicitly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    return nn.batch_normalization(inputs,
                                  _broadcast(mean),
                                  _broadcast(variance),
                                  _broadcast(offset),
                                  _broadcast(scale),
                                  self.epsilon)
Esempio n. 45
0
def batch_norm_mine_old(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               param_initializers=None,
               param_regularizers=None,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               batch_weights=None,
               fused=False,
               data_format=DATA_FORMAT_NHWC,
               zero_debias_moving_mean=False,
               scope=None,
               renorm=False,
               renorm_clipping=None,
               renorm_decay=0.99):
  """
  This earlier version of my modification to batch norm uses
current_mean and current_variance if is_training is True and
moving_mean and moving_variance otherwise. This was leading a large divergence between
the results depending upon whether the is_training set to True or not.

I think ideally it should always use moving_mean and moving_variance. batch_norm_mine
does this.

  Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
copy of tensorflow.contrib.layers
  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    decay: Decay for the moving average. Reasonable values for `decay` are close
      to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
      Lower `decay` value (recommend trying `decay`=0.9) if model experiences
      reasonably good training performance but poor validation and/or test
      performance. Try zero_debias_moving_mean=True for improved stability.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    param_regularizers: Optional regularizer for beta and gamma.
    updates_collections: Collections to collect the update ops for computation.
      The updates_ops need to be executed with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: Whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    batch_weights: An optional tensor of shape `[batch_size]`,
      containing a frequency weight for each batch item. If present,
      then the batch normalization uses weighted mean and
      variance. (This can be used to correct for bias in training
      example selection.)
    fused:  Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
      pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
    scope: Optional scope for `variable_scope`.
    renorm: Whether to use Batch Renormalization
      (https://arxiv.org/abs/1702.03275). This adds extra variables during
      training. The inference is the same for either value of this parameter.
    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
      scalar `Tensors` used to clip the renorm correction. The correction
      `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
      `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
      dmax are set to inf, 0, inf, respectively.
    renorm_decay: Momentum used to update the moving means and standard
      deviations with renorm. Unlike `momentum`, this affects training
      and should be neither too small (which would add noise) nor too large
      (which would give stale estimates). Note that `decay` is still applied
      to get the means and variances for inference.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `batch_weights` is not None and `fused` is True.
    ValueError: If `param_regularizers` is not None and `fused` is True.
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  if fused:
    if batch_weights is not None:
      raise ValueError('Weighted mean and variance is not currently '
                       'supported for fused batch norm.')
    if param_regularizers is not None:
      raise ValueError('Regularizers are not currently '
                       'supported for fused batch norm.')
    if renorm:
      raise ValueError('Renorm is not supported for fused batch norm.')
    return _fused_batch_norm(
        inputs,
        decay=decay,
        center=center,
        scale=scale,
        epsilon=epsilon,
        activation_fn=activation_fn,
        param_initializers=param_initializers,
        updates_collections=updates_collections,
        is_training=is_training,
        reuse=reuse,
        variables_collections=variables_collections,
        outputs_collections=outputs_collections,
        trainable=trainable,
        data_format=data_format,
        zero_debias_moving_mean=zero_debias_moving_mean,
        scope=scope)

  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  layer_variable_getter = _build_variable_getter()
  with variable_scope.variable_scope(
      scope, 'BatchNorm', [inputs], reuse=reuse,
      custom_getter=layer_variable_getter) as sc:
    inputs = ops.convert_to_tensor(inputs)

    # Determine whether we can use the core layer class.
    if (batch_weights is None and
        updates_collections is ops.GraphKeys.UPDATE_OPS and
        not zero_debias_moving_mean):
      # Use the core layer class.
      axis = 1 if data_format == DATA_FORMAT_NCHW else -1
      if not param_initializers:
        param_initializers = {}
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      if not param_regularizers:
        param_regularizers = {}
      beta_regularizer = param_regularizers.get('beta')
      gamma_regularizer = param_regularizers.get('gamma')
      layer = normalization_layers.BatchNormalization(
          axis=axis,
          momentum=decay,
          epsilon=epsilon,
          center=center,
          scale=scale,
          beta_initializer=beta_initializer,
          gamma_initializer=gamma_initializer,
          moving_mean_initializer=moving_mean_initializer,
          moving_variance_initializer=moving_variance_initializer,
          beta_regularizer=beta_regularizer,
          gamma_regularizer=gamma_regularizer,
          trainable=trainable,
          renorm=renorm,
          renorm_clipping=renorm_clipping,
          renorm_momentum=renorm_decay,
          name=sc.name,
          _scope=sc,
          _reuse=reuse)
      outputs = layer.apply(inputs, training=is_training)

      # Add variables to collections.
      _add_variable_to_collections(
          layer.moving_mean, variables_collections, 'moving_mean')
      _add_variable_to_collections(
          layer.moving_variance, variables_collections, 'moving_variance')
      if layer.beta:
        _add_variable_to_collections(layer.beta, variables_collections, 'beta')
      if layer.gamma:
        _add_variable_to_collections(
            layer.gamma, variables_collections, 'gamma')

      if activation_fn is not None:
        outputs = activation_fn(outputs)
      return utils.collect_named_outputs(outputs_collections,
                                         sc.original_name_scope, outputs)

    # Not supported by layer class: batch_weights argument,
    # and custom updates_collections. In that case, use the legacy BN
    # implementation.
    # Custom updates collections are not supported because the update logic
    # is different in this case, in particular w.r.t. "forced updates" and
    # update op reuse.
    if renorm:
      raise ValueError('renorm is not supported with batch_weights, '
                       'updates_collections or zero_debias_moving_mean')
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
      raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    dtype = inputs.dtype.base_dtype
    if batch_weights is not None:
      batch_weights = ops.convert_to_tensor(batch_weights)
      inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
      # Reshape batch weight values so they broadcast across inputs.
      nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
      batch_weights = array_ops.reshape(batch_weights, nshape)

    if data_format == DATA_FORMAT_NCHW:
      moments_axes = [0] + list(range(2, inputs_rank))
      params_shape = inputs_shape[1:2]
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      moments_axes = list(range(inputs_rank - 1))
      params_shape = inputs_shape[-1:]
      params_shape_broadcast = None
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if not param_initializers:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,
                                                        'beta')
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,
                                                         'gamma')
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)

    # Create moving_mean and moving_variance variables and add them to the
    # appropriate collections. We disable variable partitioning while creating
    # them, because assign_moving_average is not yet supported for partitioned
    # variables.
    partitioner = variable_scope.get_variable_scope().partitioner
    try:
      variable_scope.get_variable_scope().set_partitioner(None)
      moving_mean_collections = utils.get_variable_collections(
          variables_collections, 'moving_mean')
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_mean = variables.model_variable(
          'moving_mean',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_mean_initializer,
          trainable=False,
          collections=moving_mean_collections)
      moving_variance_collections = utils.get_variable_collections(
          variables_collections, 'moving_variance')
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      moving_variance = variables.model_variable(
          'moving_variance',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_variance_initializer,
          trainable=False,
          collections=moving_variance_collections)
    finally:
      variable_scope.get_variable_scope().set_partitioner(partitioner)

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `needs_moments` will be true.
    is_training_value = utils.constant_value(is_training)
    need_moments = is_training_value is None or is_training_value
    if need_moments:
      # Calculate the moments based on the individual batch.
      if batch_weights is None:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.moments(inputs, moments_axes, keep_dims=True)
          variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.moments(inputs, moments_axes)
          variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes)
      else:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights, keep_dims=True)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights)

      moving_vars_fn = lambda: (moving_mean, moving_variance)
      if updates_collections is None:
        def _force_updates():
          """Internal function forces updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          with ops.control_dependencies([update_moving_mean,
                                         update_moving_variance]):
            return array_ops.identity(mean), array_ops.identity(variance)
        mean, variance = utils.smart_cond(is_training,
                                          _force_updates,
                                          moving_vars_fn)
      else:
        def _delay_updates():
          """Internal function that delay updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          return update_moving_mean, update_moving_variance

        update_mean, update_variance = utils.smart_cond(is_training,
                                                        _delay_updates,
                                                        moving_vars_fn)
        ops.add_to_collections(updates_collections, update_mean)
        ops.add_to_collections(updates_collections, update_variance)
        # Use computed moments during training and moving_vars otherwise.
        vars_fn = lambda: (mean, variance)
        mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
    else:
      mean, variance = moving_mean, moving_variance
    if data_format == DATA_FORMAT_NCHW:
      mean = array_ops.reshape(mean, params_shape_broadcast)
      variance = array_ops.reshape(variance, params_shape_broadcast)
      beta = array_ops.reshape(beta, params_shape_broadcast)
      if gamma is not None:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Compute batch_normalization.
    outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                     epsilon)
    outputs.set_shape(inputs_shape)
    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections,
                                       sc.original_name_scope, outputs)
Esempio n. 46
0
 def _moments(self, inputs, reduction_axes, keep_dims):
   return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
Esempio n. 47
0
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

    # Determine boolean training boolean value. May be False, True, None.
    # If None, it is assumed that `training` is a variable to be used in `cond`.
    if isinstance(training, bool):
      training_bool = training
    else:
      try:
        training_bool = tensor_util.constant_value(training)
      except TypeError:
        training_bool = None

    # Obtain current current batch mean, variance, if necessary.
    if training_bool is not False:
      # Use a copy of moving_mean as a shift to compute more reliable moments.
      shift = math_ops.add(self.moving_mean, 0)
      if needs_broadcasting:
        shift = array_ops.reshape(shift, broadcast_shape)
        broadcast_mean, broadcast_variance = nn.moments(
            inputs, reduction_axes, shift=shift, keep_dims=True)
        mean = array_ops.reshape(broadcast_mean, [-1])
        variance = array_ops.reshape(broadcast_variance, [-1])
      else:
        mean, variance = nn.moments(inputs, reduction_axes, shift=shift)

    # Prepare updates if necessary.
    if training_bool is not False and not self.updates:
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, mean, self.momentum, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, variance, self.momentum, zero_debias=False)
      # In the future this should be refactored into a self.add_update
      # methods in order to allow for instance-based BN layer sharing
      # across unrelated input streams (e.g. like in Keras).
      self.updates.append(mean_update)
      self.updates.append(variance_update)

    # Normalize batch.
    if needs_broadcasting:
      # In this case we must explictly broadcast all parameters.
      broadcast_moving_mean = array_ops.reshape(self.moving_mean,
                                                broadcast_shape)
      broadcast_moving_variance = array_ops.reshape(self.moving_variance,
                                                    broadcast_shape)
      if self.center:
        broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
      else:
        broadcast_beta = None
      if self.scale:
        broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
      else:
        broadcast_gamma = None

      if training_bool is not False:
        normed_inputs_training = nn.batch_normalization(inputs,
                                                        broadcast_mean,
                                                        broadcast_variance,
                                                        broadcast_beta,
                                                        broadcast_gamma,
                                                        self.epsilon)
      normed_inputs = nn.batch_normalization(inputs,
                                             broadcast_moving_mean,
                                             broadcast_moving_variance,
                                             broadcast_beta,
                                             broadcast_gamma,
                                             self.epsilon)
    else:
      # No need for broadcasting.
      if training_bool is not False:
        normed_inputs_training = nn.batch_normalization(
            inputs,
            mean,
            variance,
            self.beta if self.center else None,
            self.gamma if self.scale else None,
            self.epsilon)
      normed_inputs = nn.batch_normalization(inputs,
                                             self.moving_mean,
                                             self.moving_variance,
                                             self.beta if self.center else None,
                                             self.gamma if self.scale else None,
                                             self.epsilon)

    # Return the proper output depending on the boolean training phase.
    if training_bool is True:
      return normed_inputs_training
    if training_bool is False:
      return normed_inputs
    return control_flow_ops.cond(training,
                                 lambda: normed_inputs_training,
                                 lambda: normed_inputs)
Esempio n. 48
0
def group_norm(inputs,
               groups=32,
               channels_axis=-1,
               reduction_axes=(-3, -2),
               center=True,
               scale=True,
               epsilon=1e-6,
               activation_fn=None,
               param_initializers=None,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               scope=None,
               mean_close_to_zero=False):
  """Functional interface for the group normalization layer.

  Reference: https://arxiv.org/abs/1803.08494.

    "Group Normalization", Yuxin Wu, Kaiming He

  Args:
    inputs: A Tensor with at least 2 dimensions one which is channels. All
     shape dimensions must be fully defined.
    groups: Integer. Divide the channels into this number of groups over which
      normalization statistics are computed. This number must be commensurate
      with the number of channels in `inputs`.
    channels_axis: An integer. Specifies index of channels axis which will be
      broken into `groups`, each of which whose statistics will be computed
      across. Must be mutually exclusive with `reduction_axes`. Preferred usage
      is to specify negative integers to be agnostic as to whether a batch
      dimension is included.
    reduction_axes: Tuple of integers. Specifies dimensions over which
       statistics will be accumulated. Must be mutually exclusive with
       `channels_axis`. Statistics will not be accumulated across axes not
       specified in `reduction_axes` nor `channel_axis`. Preferred usage is to
       specify negative integers to be agnostic to whether a batch dimension is
       included.

      Some sample usage cases:
        NHWC format: channels_axis=-1, reduction_axes=[-3, -2]
        NCHW format: channels_axis=-3, reduction_axes=[-2, -1]

    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.
    mean_close_to_zero: The mean of `input` before ReLU will be close to zero
      when batch size >= 4k for Resnet-50 on TPU. If `True`, use
      `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the
      variance. This is the same behavior as `fused` equals `True` in batch
      normalization. If `False`, use `nn.moments` to calculate the variance.
      When `mean` is close to zero, like 1e-4, use `mean` to calculate the
      variance may have poor result due to repeated roundoff error and
      denormalization in `mean`.  When `mean` is large, like 1e2,
      sum(`input`^2) is so large that only the high-order digits of the elements
      are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate
      the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
      when `mean` is large.


  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
    ValueError: If number of groups is not commensurate with number of channels.
    ValueError: If reduction_axes or channels_axis are out of bounds.
    ValueError: If reduction_axes are not mutually exclusive with channels_axis.
  """
  # TODO(shlens): Support partially defined shapes for the inputs.
  inputs = ops.convert_to_tensor(inputs)
  original_shape = inputs.shape

  if inputs.shape.ndims is None:
    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  if channels_axis > (inputs.shape.ndims - 1):
    raise ValueError('Axis is out of bounds.')

  # Standardize the channels_axis to be positive and identify # of channels.
  if channels_axis < 0:
    channels_axis = inputs.shape.ndims + channels_axis
  channels = inputs.shape[channels_axis].value

  if channels is None:
    raise ValueError('Inputs %s has undefined channel dimension: %d.' % (
        inputs.name, channels_axis))

  # Standardize the reduction_axes to be positive.
  reduction_axes = list(reduction_axes)
  for i in range(len(reduction_axes)):
    if reduction_axes[i] < 0:
      reduction_axes[i] += inputs.shape.ndims

  for a in reduction_axes:
    if a > inputs.shape.ndims:
      raise ValueError('Axis is out of bounds.')
    if inputs.shape[a].value is None:
      raise ValueError('Inputs %s has undefined dimensions %d.' % (
          inputs.name, a))
    if channels_axis == a:
      raise ValueError('reduction_axis must be mutually exclusive '
                       'with channels_axis')
  if groups > channels:
    raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
  if channels % groups != 0:
    raise ValueError('%d channels is not commensurate with %d groups.' %
                     (channels, groups))

  # Determine axes before channels. Some examples of common image formats:
  #  'NCHW': before = [N], after = [HW]
  #  'NHWC': before = [NHW], after = []
  axes_before_channels = inputs.shape.as_list()[:channels_axis]
  axes_after_channels = inputs.shape.as_list()[channels_axis+1:]

  # Manually broadcast the parameters to conform to the number of groups.
  params_shape_broadcast = ([1] * len(axes_before_channels) +
                            [groups, channels // groups] +
                            [1] * len(axes_after_channels))

  # Reshape the input by the group within the channel dimension.
  inputs_shape = (axes_before_channels + [groups, channels // groups] +
                  axes_after_channels)
  inputs = array_ops.reshape(inputs, inputs_shape)

  # Determine the dimensions across which moments are calculated.
  moments_axes = [channels_axis + 1]
  for a in reduction_axes:
    if a > channels_axis:
      moments_axes.append(a + 1)
    else:
      moments_axes.append(a)

  with variable_scope.variable_scope(
      scope, 'GroupNorm', [inputs], reuse=reuse) as sc:
    # Note that the params_shape is the number of channels always.
    params_shape = [channels]

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    dtype = inputs.dtype.base_dtype
    if param_initializers is None:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(
          variables_collections, 'beta')
      beta_initializer = param_initializers.get(
          'beta', init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
      beta = array_ops.reshape(beta, params_shape_broadcast)

    if scale:
      gamma_collections = utils.get_variable_collections(
          variables_collections, 'gamma')
      gamma_initializer = param_initializers.get(
          'gamma', init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)
      gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Calculate the moments.
    if mean_close_to_zero:
      # One pass algorithm returns better result when mean is close to zero.
      counts, means_ss, variance_ss, _ = nn.sufficient_statistics(
          inputs, moments_axes, keep_dims=True)
      mean, variance = nn.normalize_moments(
          counts, means_ss, variance_ss, shift=None)
    else:
      mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)

    # Compute normalization.
    # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
    # appropriately so that this operation may be faster.
    gain = math_ops.rsqrt(variance + epsilon)
    offset = -mean * gain
    if gamma is not None:
      gain *= gamma
      offset *= gamma
    if beta is not None:
      offset += beta
    outputs = inputs * gain + offset

    # Collapse the groups into the channel dimension.
    outputs = array_ops.reshape(outputs, original_shape)

    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
Esempio n. 49
0
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               scope=None):
  """Code modification of tensorflow/contrib/layers/python/layers/layers.py
  """
  with variable_scope.variable_op_scope([inputs],
                                        scope, 'BatchNorm', reuse=reuse) as sc:
    inputs = ops.convert_to_tensor(inputs)
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
      raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    dtype = inputs.dtype.base_dtype
    axis = list(range(inputs_rank - 1))
    params_shape = inputs_shape[-1:]
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined last dimension %s.' % (
          inputs.name, params_shape))
    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,
                                                        'beta')
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=init_ops.zeros_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,
                                                         'gamma')
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=init_ops.ones_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)
    # Create moving_mean and moving_variance variables and add them to the
    # appropiate collections.
    moving_mean_collections = utils.get_variable_collections(
        variables_collections, 'moving_mean')
    moving_mean = variables.model_variable(
        'moving_mean',
        shape=params_shape,
        dtype=dtype,
        initializer=init_ops.zeros_initializer,
        trainable=False,
        collections=moving_mean_collections)
    moving_variance_collections = utils.get_variable_collections(
        variables_collections, 'moving_variance')
    moving_variance = variables.model_variable(
        'moving_variance',
        shape=params_shape,
        dtype=dtype,
        initializer=init_ops.ones_initializer,
        trainable=False,
        collections=moving_variance_collections)

    # Calculate the moments based on the individual batch.
    mean, variance = nn.moments(inputs, axis, shift=moving_mean)
    # Update the moving_mean and moving_variance moments.
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, decay)
    if updates_collections is None:
      # Make sure the updates are computed here.
      with ops.control_dependencies([update_moving_mean,
                                      update_moving_variance]):
        outputs = nn.batch_normalization(
            inputs, mean, variance, beta, gamma, epsilon)
    else:
      # Collect the updates to be computed later.
      ops.add_to_collections(updates_collections, update_moving_mean)
      ops.add_to_collections(updates_collections, update_moving_variance)
      outputs = nn.batch_normalization(
          inputs, mean, variance, beta, gamma, epsilon)

    test_outputs = nn.batch_normalization(
        inputs, moving_mean, moving_variance, beta, gamma, epsilon)

    outputs = tf.cond(is_training, lambda: outputs, lambda: test_outputs)
    outputs.set_shape(inputs_shape)

    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
Esempio n. 50
0
def instance_norm(inputs,
                  center=True,
                  scale=True,
                  epsilon=1e-6,
                  activation_fn=None,
                  param_initializers=None,
                  reuse=None,
                  variables_collections=None,
                  outputs_collections=None,
                  trainable=True,
                  data_format=DATA_FORMAT_NHWC,
                  scope=None):
  """Functional interface for the instance normalization layer.

  Reference: https://arxiv.org/abs/1607.08022.

    "Instance Normalization: The Missing Ingredient for Fast Stylization"
    Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky

  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    scope: Optional scope for `variable_scope`.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  inputs = ops.convert_to_tensor(inputs)
  inputs_shape = inputs.shape
  inputs_rank = inputs.shape.ndims

  if inputs_rank is None:
    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  with variable_scope.variable_scope(
      scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
    if data_format == DATA_FORMAT_NCHW:
      reduction_axis = 1
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      reduction_axis = inputs_rank - 1
      params_shape_broadcast = None
    moments_axes = list(range(inputs_rank))
    del moments_axes[reduction_axis]
    del moments_axes[0]
    params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    dtype = inputs.dtype.base_dtype
    if param_initializers is None:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(
          variables_collections, 'beta')
      beta_initializer = param_initializers.get(
          'beta', init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
      if params_shape_broadcast:
        beta = array_ops.reshape(beta, params_shape_broadcast)
    if scale:
      gamma_collections = utils.get_variable_collections(
          variables_collections, 'gamma')
      gamma_initializer = param_initializers.get(
          'gamma', init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)
      if params_shape_broadcast:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Calculate the moments (instance activations).
    mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)

    # Compute instance normalization.
    outputs = nn.batch_normalization(
        inputs, mean, variance, beta, gamma, epsilon, name='instancenorm')
    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
Esempio n. 51
0
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)

    if needs_broadcasting:
      # In this case we must explictly broadcast all parameters.
      if self.center:
        broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
      else:
        broadcast_beta = None
      if self.scale:
        broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
      else:
        broadcast_gamma = None

    if training_value is not False:
      if needs_broadcasting:
        broadcast_mean, broadcast_variance = nn.moments(
            inputs, reduction_axes, keep_dims=True)
        mean = array_ops.reshape(broadcast_mean, [-1])
        variance = array_ops.reshape(broadcast_variance, [-1])
      else:
        mean, variance = nn.moments(inputs, reduction_axes)

      # Prepare updates if necessary.
      if not self.updates:
        mean_update = moving_averages.assign_moving_average(
            self.moving_mean, mean, self.momentum, zero_debias=False)
        variance_update = moving_averages.assign_moving_average(
            self.moving_variance, variance, self.momentum, zero_debias=False)
        # In the future this should be refactored into a self.add_update
        # methods in order to allow for instance-based BN layer sharing
        # across unrelated input streams (e.g. like in Keras).
        self.updates.append(mean_update)
        self.updates.append(variance_update)

    # Normalize batch. We do this inside separate functions for training
    # and inference so as to avoid evaluating both branches.
    def normalize_in_test():
      if needs_broadcasting:
        broadcast_moving_mean = array_ops.reshape(self.moving_mean,
                                                  broadcast_shape)
        broadcast_moving_variance = array_ops.reshape(self.moving_variance,
                                                      broadcast_shape)
        return nn.batch_normalization(inputs,
                                      broadcast_moving_mean,
                                      broadcast_moving_variance,
                                      broadcast_beta,
                                      broadcast_gamma,
                                      self.epsilon)
      else:
        return nn.batch_normalization(inputs,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta if self.center else None,
                                      self.gamma if self.scale else None,
                                      self.epsilon)

    def normalize_in_training():
      if needs_broadcasting:
        return nn.batch_normalization(inputs,
                                      broadcast_mean,
                                      broadcast_variance,
                                      broadcast_beta,
                                      broadcast_gamma,
                                      self.epsilon)
      else:
        return nn.batch_normalization(inputs,
                                      mean,
                                      variance,
                                      self.beta if self.center else None,
                                      self.gamma if self.scale else None,
                                      self.epsilon)

    return utils.smart_cond(training,
                            normalize_in_training,
                            normalize_in_test)
Esempio n. 52
0
  def call(self, inputs, training=None):
    original_training_value = training
    if training is None:
      training = K.learning_phase()

    in_eager_mode = context.executing_eagerly()
    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        outputs = undo_virtual_batching(outputs)
      if not context.executing_eagerly() and original_training_value is None:
        outputs._uses_learning_phase = True  # pylint: disable=protected-access
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and
          len(v.get_shape()) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = tf_utils.smart_cond(training,
                                        lambda: adj_scale,
                                        lambda: array_ops.ones_like(adj_scale))
        adj_bias = tf_utils.smart_cond(training,
                                       lambda: adj_bias,
                                       lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: moving_mean)
      variance = tf_utils.smart_cond(training,
                                     lambda: variance,
                                     lambda: moving_variance)

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
      else:
        new_mean, new_variance = mean, variance

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            new_mean, new_variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)

      def _do_update(var, value):
        if in_eager_mode and not self.trainable:
          return

        return self._assign_moving_average(var, value, self.momentum)

      mean_update = tf_utils.smart_cond(
          training,
          lambda: _do_update(self.moving_mean, new_mean),
          lambda: self.moving_mean)
      variance_update = tf_utils.smart_cond(
          training,
          lambda: _do_update(self.moving_variance, new_variance),
          lambda: self.moving_variance)
      if not context.executing_eagerly():
        self.add_update(mean_update, inputs=True)
        self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      outputs = undo_virtual_batching(outputs)
    if not context.executing_eagerly() and original_training_value is None:
      outputs._uses_learning_phase = True  # pylint: disable=protected-access
    return outputs