def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: # TODO(b/120571621): We want to avoid colocating the variables here # since TPUStrategy does not implement replica local variables. # Remove this hack once we support TPULocalVariables. is_tpu_strategy = False if distribution_strategy_context.has_distribution_strategy(): distribute = distribution_strategy_context.get_distribution_strategy( ) if distribute.__class__.__name__ == 'TPUStrategy': is_tpu_strategy = True # TODO(apassos,srbs,skyewm): the colocation constraints here are disabled # because of a bug which leads cond_v2/while_v2 to skip rewriting them # creating conflicts. if (control_flow_util.EnableControlFlowV2(ops.get_default_graph()) or is_tpu_strategy): cm = contextlib.contextmanager(lambda: (yield))() else: cm = ops.colocate_with(variable) with cm: decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay return state_ops.assign_sub(variable, update_delta, name=scope)
def _scale_loss(loss_value): if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_distribution_strategy( ) mean_update = strategy.extended.update( self.moving_mean, self._assign_moving_average, (mean, self.momentum)) variance_update = strategy.extended.update( self.moving_variance, self._assign_moving_average, (variance, self.momentum)) else: mean_update = self._assign_moving_average( self.moving_mean, mean, momentum) variance_update = self._assign_moving_average( self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output
def create_slot(primary, val, name, colocate_with_primary=True): """Create a slot initialized to the given value. The type of the slot is determined by the given value. Args: primary: The primary `Variable` or `Tensor`. val: A `Tensor` specifying the initial value of the slot. name: Name to use for the slot variable. colocate_with_primary: Boolean. If True the slot is located on the same device as `primary`. Returns: A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. # Set "primary.op.name + '/' + name" as default name, so the scope name of # optimizer can be shared when reuse is True. Meanwhile when reuse is False # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() if context.executing_eagerly(): prefix = primary._shared_name # pylint: disable=protected-access else: prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: distribution_strategy = ( distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: return _create_slot_var(primary, val, "", validate_shape, None, None)
def train_step(): if distribution_strategy_context.get_distribution_strategy( ).cluster_resolver.task_id == raise_app_error_on_worker: raise errors_impl.ResourceExhaustedError( node_def=None, op=None, message='Running out of resources') model.v.assign_add(constant_op.constant(1.))
def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. eager = context.executing_eagerly() graph = None if eager else colocate_with.graph key = (name, graph) v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() distribution_strategy = distribute_ctx.get_distribution_strategy() with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( name=name, shape=None) if restored_initial_value is not None: initial_value = restored_initial_value v = variable_scope.variable(initial_value, name=name, trainable=False) # Restore this variable by name if necessary, but don't add a # Checkpointable dependency. Optimizers return the current graph's # non-slot variables from _checkpoint_dependencies explicitly rather # than unconditionally adding dependencies (since there may be multiple # non-slot variables with the same name in different graphs, trying to # save all of them would result in errors). self._handle_deferred_dependencies(name=name, checkpointable=v) self._non_slot_dict[key] = v return v
def _assert_in_default_state(t): t.assertIs(distribution_strategy_context._get_default_replica_context(), distribution_strategy_context.get_replica_context()) t.assertIs(None, distribution_strategy_context.get_cross_replica_context()) t.assertFalse(distribution_strategy_context.in_cross_replica_context()) t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), distribution_strategy_context.get_distribution_strategy()) t.assertFalse(distribution_strategy_context.has_distribution_strategy())
def merge_fn(dist, s): self.assertIs( distribution_strategy_context._get_default_distribution_strategy(), dist) self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs(dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue(distribution_strategy_context.in_cross_replica_context()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertFalse( distribution_strategy_context.has_distribution_strategy()) return "foo_" + s
def run_fn(): replica_context = distribution_strategy_context.get_replica_context() self.assertTrue(replica_context is not None) self.assertIs(None, distribution_strategy_context.get_cross_replica_context()) self.assertFalse(distribution_strategy_context.in_cross_replica_context()) self.assertTrue(distribution_strategy_context.has_distribution_strategy()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs(dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue(distribution_strategy_context.in_cross_replica_context()) self.assertTrue(distribution_strategy_context.has_distribution_strategy()) self.assertIs(dist, distribution_strategy_context.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: # TODO(b/120571621): We want to avoid colocating the variables here # since TPUStrategy does not implement replica local variables. # Remove this hack once we support TPULocalVariables. is_tpu_strategy = False if distribution_strategy_context.has_distribution_strategy(): distribute = distribution_strategy_context.get_distribution_strategy( ) if distribute.__class__.__name__ == 'TPUStrategy': is_tpu_strategy = True with ops.colocate_with(variable): decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay return state_ops.assign_sub(variable, update_delta, name=scope)
def set_last_step_output(self, name, output, reduce_op=None): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. reduce_op: Reduction method to use to reduce outputs from multiple replicas. Required if `set_last_step_output` is called in a replica context. Optional in cross_replica_context. When present, the outputs from all the replicas are reduced using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and reduction is set, output must be a `PerReplica` value. The reduce method is also recorded in a dictionary `_last_step_outputs_reduce_ops` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.in_cross_replica_context(): self._last_step_outputs_reduce_ops[name] = reduce_op if reduce_op is None: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce(reduce_op, output) else: assert reduce_op is not None def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce(reduce_op, value) # Setting this inside the `merge_fn` because all replicas share the same # context object, so it's more robust to set it only once (even if all # the replicas are trying to set the same value). self._last_step_outputs_reduce_ops[name] = reduce_op distribution_strategy_context.get_replica_context().merge_call( merge_fn, args=(output,))
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] var_list = [] # local_anchor op will be placed on this worker task by default. local_anchor = control_flow_ops.no_op() # Colocating local_step variable prevents it being placed on the PS. distribution_strategy = ( distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(local_anchor): self._local_step = variable_scope.variable( initial_value=0, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], dtype=global_step.dtype.base_dtype, name="sync_rep_local_step") self.local_step_init_op = state_ops.assign(self._local_step, global_step) chief_init_ops = [self.local_step_init_op] self.ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) with ops.name_scope(None, self._name): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): # Dense gradients. if grad is None: aggregated_grad.append(None) # pass-through. continue elif isinstance(grad, ops.Tensor): grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=var.get_shape(), shared_name=var.name + "/grad_accum") train_ops.append( grad_accum.apply_grad(grad, local_step=self._local_step)) aggregated_grad.append( grad_accum.take_grad(self._replicas_to_aggregate)) else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=(), shared_name=var.name + "/grad_accum") train_ops.append( grad_accum.apply_indexed_slices_grad( grad, local_step=self._local_step)) aggregated_grad.append( grad_accum.take_indexed_slices_grad( self._replicas_to_aggregate)) self._accumulator_list.append((grad_accum, var.device)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients( aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = (data_flow_ops.FIFOQueue( -1, global_step.dtype.base_dtype, shapes=(), name="sync_token_q", shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = (data_flow_ops.FIFOQueue( 1, types_pb2.DT_INT32, shapes=(), name="dummy_queue", shared_name="dummy_queue")) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies(train_ops): token = sync_token_queue.dequeue() train_op = state_ops.assign(self._local_step, token) with ops.control_dependencies([update_op]): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. tokens = array_ops.fill([self._tokens_per_step], global_step) sync_op = sync_token_queue.enqueue_many((tokens, )) if self._variable_averages is not None: with ops.control_dependencies([sync_op ]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner( dummy_queue, [sync_op]) for accum, dev in self._accumulator_list: with ops.device(dev): chief_init_ops.append( accum.set_global_step(global_step, name="SetGlobalStep")) self.chief_init_op = control_flow_ops.group(*(chief_init_ops)) self._gradients_applied = True return train_op
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError( 'virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError( 'When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError( 'When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused in (None, True): # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. if self._USE_V2_BEHAVIOR: if self.fused is None: self.fused = (ndims == 4) elif self.fused and ndims != 4: raise ValueError( 'Batch normalization layers with fused=True only ' 'support 4D input tensors.') else: assert self.fused is not None self.fused = (ndims == 4 and self._fused_can_be_used()) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError( 'Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') # Raise parameters of fp16 batch norm to fp32 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: param_dtype = dtypes.float32 else: param_dtype = self.dtype or dtypes.float32 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError( 'Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0], ) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [ axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) ] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight(name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.fused: self._gamma_const = array_ops.constant(1.0, dtype=param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight(name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None if self.fused: self._beta_const = array_ops.constant(0.0, dtype=param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), synchronization=tf_variables.VariableSynchronization. ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) return var with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable( 'renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable( 'renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribution_strategy_context.get_distribution_strategy( ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable( 'renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable( 'renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def call(self, inputs, training=None): if training is None: training = K.learning_phase() in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_distribution_strategy( ) def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: strategy.unwrap(self.moving_mean)) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs