Example #1
0
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_distribution_strategy():
                distribute = distribution_strategy_context.get_distribution_strategy(
                )
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            # TODO(apassos,srbs,skyewm): the colocation constraints here are disabled
            # because of a bug which leads cond_v2/while_v2 to skip rewriting them
            # creating conflicts.
            if (control_flow_util.EnableControlFlowV2(ops.get_default_graph())
                    or is_tpu_strategy):
                cm = contextlib.contextmanager(lambda: (yield))()
            else:
                cm = ops.colocate_with(variable)
            with cm:
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
Example #2
0
 def _scale_loss(loss_value):
     if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
         num_replicas = \
           distribute_ctx.get_distribution_strategy().num_replicas_in_sync
         if num_replicas > 1:
             loss_value *= (1. / num_replicas)
     return loss_value
Example #3
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)
        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)
        if training_value or training_value is None:
            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_distribution_strategy(
                )
                mean_update = strategy.extended.update(
                    self.moving_mean, self._assign_moving_average,
                    (mean, self.momentum))
                variance_update = strategy.extended.update(
                    self.moving_variance, self._assign_moving_average,
                    (variance, self.momentum))
            else:
                mean_update = self._assign_moving_average(
                    self.moving_mean, mean, momentum)
                variance_update = self._assign_moving_average(
                    self.moving_variance, variance, momentum)
            self.add_update(mean_update, inputs=True)
            self.add_update(variance_update, inputs=True)

        return output
Example #4
0
def create_slot(primary, val, name, colocate_with_primary=True):
    """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
    # Scope the slot name in the namespace of the primary variable.
    # Set "primary.op.name + '/' + name" as default name, so the scope name of
    # optimizer can be shared when reuse is True. Meanwhile when reuse is False
    # and the same name has been previously used, the scope name will add '_N'
    # as suffix for unique identifications.
    validate_shape = val.get_shape().is_fully_defined()
    if context.executing_eagerly():
        prefix = primary._shared_name  # pylint: disable=protected-access
    else:
        prefix = primary.op.name
    with variable_scope.variable_scope(None, prefix + "/" + name):
        if colocate_with_primary:
            distribution_strategy = (
                distribution_strategy_context.get_distribution_strategy())
            with distribution_strategy.colocate_vars_with(primary):
                return _create_slot_var(primary, val, "", validate_shape, None,
                                        None)
        else:
            return _create_slot_var(primary, val, "", validate_shape, None,
                                    None)
Example #5
0
        def train_step():
          if distribution_strategy_context.get_distribution_strategy(
          ).cluster_resolver.task_id == raise_app_error_on_worker:
            raise errors_impl.ResourceExhaustedError(
                node_def=None, op=None, message='Running out of resources')

          model.v.assign_add(constant_op.constant(1.))
Example #6
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_checkpointable()
      distribution_strategy = distribute_ctx.get_distribution_strategy()
      with distribution_strategy.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(initial_value, name=name, trainable=False)
      # Restore this variable by name if necessary, but don't add a
      # Checkpointable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, checkpointable=v)
      self._non_slot_dict[key] = v

    return v
Example #7
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_replica_context(),
             distribution_strategy_context.get_replica_context())
  t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
  t.assertFalse(distribution_strategy_context.in_cross_replica_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
Example #8
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertTrue(distribution_strategy_context.in_cross_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
Example #9
0
 def run_fn():
   replica_context = distribution_strategy_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertFalse(distribution_strategy_context.in_cross_replica_context())
   self.assertTrue(distribution_strategy_context.has_distribution_strategy())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Example #10
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_replica_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(distribution_strategy_context.in_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_distribution_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Example #11
0
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_distribution_strategy():
                distribute = distribution_strategy_context.get_distribution_strategy(
                )
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            with ops.colocate_with(variable):
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
Example #12
0
  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribution_strategy_context.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_distribution_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribution_strategy_context.get_replica_context().merge_call(
          merge_fn, args=(output,))
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
        if not grads_and_vars:
            raise ValueError("Must supply at least one variable")

        if global_step is None:
            raise ValueError("Global step is required to check staleness")

        self._global_step = global_step
        train_ops = []
        aggregated_grad = []
        var_list = []

        # local_anchor op will be placed on this worker task by default.
        local_anchor = control_flow_ops.no_op()
        # Colocating local_step variable prevents it being placed on the PS.
        distribution_strategy = (
            distribution_strategy_context.get_distribution_strategy())
        with distribution_strategy.colocate_vars_with(local_anchor):
            self._local_step = variable_scope.variable(
                initial_value=0,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                dtype=global_step.dtype.base_dtype,
                name="sync_rep_local_step")

        self.local_step_init_op = state_ops.assign(self._local_step,
                                                   global_step)
        chief_init_ops = [self.local_step_init_op]
        self.ready_for_local_init_op = variables.report_uninitialized_variables(
            variables.global_variables())

        with ops.name_scope(None, self._name):
            for grad, var in grads_and_vars:
                var_list.append(var)
                with ops.device(var.device):
                    # Dense gradients.
                    if grad is None:
                        aggregated_grad.append(None)  # pass-through.
                        continue
                    elif isinstance(grad, ops.Tensor):
                        grad_accum = data_flow_ops.ConditionalAccumulator(
                            grad.dtype,
                            shape=var.get_shape(),
                            shared_name=var.name + "/grad_accum")
                        train_ops.append(
                            grad_accum.apply_grad(grad,
                                                  local_step=self._local_step))
                        aggregated_grad.append(
                            grad_accum.take_grad(self._replicas_to_aggregate))
                    else:
                        if not isinstance(grad, ops.IndexedSlices):
                            raise ValueError("Unknown grad type!")
                        grad_accum = data_flow_ops.SparseConditionalAccumulator(
                            grad.dtype,
                            shape=(),
                            shared_name=var.name + "/grad_accum")
                        train_ops.append(
                            grad_accum.apply_indexed_slices_grad(
                                grad, local_step=self._local_step))
                        aggregated_grad.append(
                            grad_accum.take_indexed_slices_grad(
                                self._replicas_to_aggregate))

                    self._accumulator_list.append((grad_accum, var.device))

            aggregated_grads_and_vars = zip(aggregated_grad, var_list)

            # sync_op will be assigned to the same device as the global step.
            with ops.device(global_step.device), ops.name_scope(""):
                update_op = self._opt.apply_gradients(
                    aggregated_grads_and_vars, global_step)

            # Create token queue.
            with ops.device(global_step.device), ops.name_scope(""):
                sync_token_queue = (data_flow_ops.FIFOQueue(
                    -1,
                    global_step.dtype.base_dtype,
                    shapes=(),
                    name="sync_token_q",
                    shared_name="sync_token_q"))
                self._sync_token_queue = sync_token_queue

                # dummy_queue is passed to the queue runner. Don't use the real queues
                # because the queue runner doesn't automatically reopen it once it
                # closed queues in PS devices.
                dummy_queue = (data_flow_ops.FIFOQueue(
                    1,
                    types_pb2.DT_INT32,
                    shapes=(),
                    name="dummy_queue",
                    shared_name="dummy_queue"))

            with ops.device(global_step.device), ops.name_scope(""):
                # Replicas have to wait until they can get a token from the token queue.
                with ops.control_dependencies(train_ops):
                    token = sync_token_queue.dequeue()
                train_op = state_ops.assign(self._local_step, token)

                with ops.control_dependencies([update_op]):
                    # Sync_op needs to insert tokens to the token queue at the end of the
                    # step so the replicas can fetch them to start the next step.
                    tokens = array_ops.fill([self._tokens_per_step],
                                            global_step)
                    sync_op = sync_token_queue.enqueue_many((tokens, ))

                if self._variable_averages is not None:
                    with ops.control_dependencies([sync_op
                                                   ]), ops.name_scope(""):
                        sync_op = self._variable_averages.apply(
                            self._variables_to_average)

                self._chief_queue_runner = queue_runner.QueueRunner(
                    dummy_queue, [sync_op])
            for accum, dev in self._accumulator_list:
                with ops.device(dev):
                    chief_init_ops.append(
                        accum.set_global_step(global_step,
                                              name="SetGlobalStep"))
            self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
            self._gradients_applied = True
            return train_op
Example #14
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if not input_shape.ndims:
            raise ValueError('Input has undefined rank:', input_shape)
        ndims = len(input_shape)

        # Convert axis to list and resolve negatives
        if isinstance(self.axis, int):
            self.axis = [self.axis]

        for idx, x in enumerate(self.axis):
            if x < 0:
                self.axis[idx] = ndims + x

        # Validate axes
        for x in self.axis:
            if x < 0 or x >= ndims:
                raise ValueError('Invalid axis: %d' % x)
        if len(self.axis) != len(set(self.axis)):
            raise ValueError('Duplicate axis: %s' % self.axis)

        if self.virtual_batch_size is not None:
            if self.virtual_batch_size <= 0:
                raise ValueError(
                    'virtual_batch_size must be a positive integer that '
                    'divides the true batch size of the input Tensor')
            # If using virtual batches, the first dimension must be the batch
            # dimension and cannot be the batch norm axis
            if 0 in self.axis:
                raise ValueError(
                    'When using virtual_batch_size, the batch dimension '
                    'must be 0 and thus axis cannot include 0')
            if self.adjustment is not None:
                raise ValueError(
                    'When using virtual_batch_size, adjustment cannot '
                    'be specified')

        if self.fused in (None, True):
            # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
            # output back to its original shape accordingly.
            if self._USE_V2_BEHAVIOR:
                if self.fused is None:
                    self.fused = (ndims == 4)
                elif self.fused and ndims != 4:
                    raise ValueError(
                        'Batch normalization layers with fused=True only '
                        'support 4D input tensors.')
            else:
                assert self.fused is not None
                self.fused = (ndims == 4 and self._fused_can_be_used())
            # TODO(chrisying): fused batch norm is currently not supported for
            # multi-axis batch norm and by extension virtual batches. In some cases,
            # it might be possible to use fused batch norm but would require reshaping
            # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
            # particularly tricky. A compromise might be to just support the most
            # common use case (turning 5D w/ virtual batch to NCHW)

        if self.fused:
            if self.axis == [1]:
                self._data_format = 'NCHW'
            elif self.axis == [3]:
                self._data_format = 'NHWC'
            else:
                raise ValueError(
                    'Unsupported axis, fused batch norm only supports '
                    'axis == [1] or axis == [3]')

        # Raise parameters of fp16 batch norm to fp32
        if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
            param_dtype = dtypes.float32
        else:
            param_dtype = self.dtype or dtypes.float32

        axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
        for x in axis_to_dim:
            if axis_to_dim[x] is None:
                raise ValueError(
                    'Input has undefined `axis` dimension. Input shape: ',
                    input_shape)
        self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

        if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
            # Single axis batch norm (most common/default use-case)
            param_shape = (list(axis_to_dim.values())[0], )
        else:
            # Parameter shape is the original shape but with 1 in all non-axis dims
            param_shape = [
                axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
            ]
            if self.virtual_batch_size is not None:
                # When using virtual batches, add an extra dim at index 1
                param_shape.insert(1, 1)
                for idx, x in enumerate(self.axis):
                    self.axis[idx] = x + 1  # Account for added dimension

        if self.scale:
            self.gamma = self.add_weight(name='gamma',
                                         shape=param_shape,
                                         dtype=param_dtype,
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint,
                                         trainable=True)
        else:
            self.gamma = None
            if self.fused:
                self._gamma_const = array_ops.constant(1.0,
                                                       dtype=param_dtype,
                                                       shape=param_shape)

        if self.center:
            self.beta = self.add_weight(name='beta',
                                        shape=param_shape,
                                        dtype=param_dtype,
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint,
                                        trainable=True)
        else:
            self.beta = None
            if self.fused:
                self._beta_const = array_ops.constant(0.0,
                                                      dtype=param_dtype,
                                                      shape=param_shape)

        try:
            # Disable variable partitioning when creating the moving mean and variance
            if hasattr(self, '_scope') and self._scope:
                partitioner = self._scope.partitioner
                self._scope.set_partitioner(None)
            else:
                partitioner = None
            self.moving_mean = self.add_weight(
                name='moving_mean',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_mean_initializer,
                synchronization=tf_variables.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf_variables.VariableAggregation.MEAN)

            self.moving_variance = self.add_weight(
                name='moving_variance',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_variance_initializer,
                synchronization=tf_variables.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf_variables.VariableAggregation.MEAN)

            if self.renorm:
                # Create variables to maintain the moving mean and standard deviation.
                # These are used in training and thus are different from the moving
                # averages above. The renorm variables are colocated with moving_mean
                # and moving_variance.
                # NOTE: below, the outer `with device` block causes the current device
                # stack to be cleared. The nested ones use a `lambda` to set the desired
                # device and ignore any devices that may be set by the custom getter.
                def _renorm_variable(name, shape):
                    var = self.add_weight(
                        name=name,
                        shape=shape,
                        dtype=param_dtype,
                        initializer=init_ops.zeros_initializer(),
                        synchronization=tf_variables.VariableSynchronization.
                        ON_READ,
                        trainable=False,
                        aggregation=tf_variables.VariableAggregation.MEAN)
                    return var

                with distribution_strategy_context.get_distribution_strategy(
                ).colocate_vars_with(self.moving_mean):
                    self.renorm_mean = _renorm_variable(
                        'renorm_mean', param_shape)
                    self.renorm_mean_weight = _renorm_variable(
                        'renorm_mean_weight', ())
                # We initialize renorm_stddev to 0, and maintain the (0-initialized)
                # renorm_stddev_weight. This allows us to (1) mix the average
                # stddev with the minibatch stddev early in training, and (2) compute
                # the unbiased average stddev by dividing renorm_stddev by the weight.
                with distribution_strategy_context.get_distribution_strategy(
                ).colocate_vars_with(self.moving_variance):
                    self.renorm_stddev = _renorm_variable(
                        'renorm_stddev', param_shape)
                    self.renorm_stddev_weight = _renorm_variable(
                        'renorm_stddev_weight', ())
        finally:
            if partitioner:
                self._scope.set_partitioner(partitioner)
        self.built = True
Example #15
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(inputs,
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_distribution_strategy(
                )

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return strategy.extended.update(
                        var,
                        self._assign_moving_average, (value, self.momentum),
                        group=False)

                # We need to unwrap the moving_mean or moving_variance in the case of
                # training being false to match the output of true_fn and false_fn
                # in the smart cond.
                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: strategy.unwrap(self.moving_mean))
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: strategy.unwrap(self.moving_variance))
            else:

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return self._assign_moving_average(var, value,
                                                       self.momentum)

                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: self.moving_mean)
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs