def call_replica_local_fn(fn, *args, **kwargs): """Call a function that uses replica-local variables. This function correctly handles calling `fn` in a cross-replica context. Arguments: fn: The function to call. *args: Positional arguments to the `fn`. **kwargs: Keyword argument to `fn`. Returns: The result of calling `fn`. """ # TODO(b/120571621): We want to avoid reductions here since # since TPUStrategy does not implement replica local variables. # Remove this hack once we support TPUReplicaLocalVariables. strategy = None if 'strategy' in kwargs: strategy = kwargs.pop('strategy') else: if ds_context.get_strategy(): strategy = ds_context.get_strategy() is_tpu = is_tpu_strategy(strategy) if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()): with strategy.scope(): return strategy.extended.call_for_each_replica(fn, args, kwargs) return fn(*args, **kwargs)
def create_slot(primary, val, name, colocate_with_primary=True): """Create a slot initialized to the given value. The type of the slot is determined by the given value. Args: primary: The primary `Variable` or `Tensor`. val: A `Tensor` specifying the initial value of the slot. name: Name to use for the slot variable. colocate_with_primary: Boolean. If True the slot is located on the same device as `primary`. Returns: A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. # Set "primary.op.name + '/' + name" as default name, so the scope name of # optimizer can be shared when reuse is True. Meanwhile when reuse is False # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() if context.executing_eagerly(): prefix = primary._shared_name # pylint: disable=protected-access else: prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: distribution_strategy = distribution_strategy_context.get_strategy() with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: return _create_slot_var(primary, val, "", validate_shape, None, None)
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() mean_update = strategy.extended.update( self.moving_mean, self._assign_moving_average, (mean, self.momentum)) variance_update = strategy.extended.update( self.moving_variance, self._assign_moving_average, (variance, self.momentum)) else: mean_update = self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average(self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output
def scale_loss_for_distribution(loss_value): """Scales and returns the given loss value by the number of replicas.""" num_replicas = ( distribution_strategy_context.get_strategy().num_replicas_in_sync) if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), ds_context.get_replica_context()) t.assertIs(None, ds_context.get_cross_replica_context()) t.assertFalse(ds_context.in_cross_replica_context()) t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) t.assertFalse(ds_context.has_strategy())
def add_slot(self, var, slot_name, initializer="zeros"): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = _var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) initial_value = functools.partial( initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer strategy = distribute_ctx.get_strategy() with strategy.extended.colocate_vars_with(var): weight = tf_variables.Variable( name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable( slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight
def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. eager = context.executing_eagerly() graph = None if eager else colocate_with.graph key = (name, graph) v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_trackable() distribution_strategy = distribute_ctx.get_strategy() with distribution_strategy.extended.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( name=name, shape=None) if restored_initial_value is not None: initial_value = restored_initial_value v = variable_scope.variable( initial_value, name=name, trainable=False, use_resource=resource_variable_ops.is_resource_variable( colocate_with)) # Restore this variable by name if necessary, but don't add a # Trackable dependency. Optimizers return the current graph's # non-slot variables from _checkpoint_dependencies explicitly rather # than unconditionally adding dependencies (since there may be multiple # non-slot variables with the same name in different graphs, trying to # save all of them would result in errors). self._handle_deferred_dependencies(name=name, trackable=v) self._non_slot_dict[key] = v return v
def _scale_loss(loss_value): ops.get_default_graph()._is_loss_scaled_by_optimizer = False # pylint: disable=protected-access if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) ops.get_default_graph()._is_loss_scaled_by_optimizer = True # pylint: disable=protected-access return loss_value
def merge_fn(dist, s): self.assertIs(ds_context._get_default_strategy(), dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertIs(dist, ds_context.get_strategy()) self.assertFalse(ds_context.has_strategy()) return "foo_" + s
def testSetStrategy(self): _assert_in_default_state(self) dist = _TestStrategy() dist2 = _TestStrategy() ds_context.experimental_set_strategy(dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) ds_context.experimental_set_strategy(dist2) self.assertIs(dist2, ds_context.get_strategy()) ds_context.experimental_set_strategy(None) _assert_in_default_state(self)
def _reduce_weighted_loss( weighted_losses, reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE): """Reduces the individual weighted loss measurements.""" if reduction == losses_impl.ReductionV2.NONE: loss = weighted_losses else: loss = math_ops.reduce_sum(weighted_losses) if reduction == losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE: num_replicas = ( # Used to convert from local to global batch size. distribution_strategy_context.get_strategy().num_replicas_in_sync) loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses)) return loss
def _get_tensor(self, is_finite): tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN')) if not distribution_strategy_context.has_strategy(): return tensor def get(): rep_id = (distribution_strategy_context.get_replica_context() .replica_id_in_sync_group) return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.) distribution = distribution_strategy_context.get_strategy() return distribution.extended.call_for_each_replica(get)
def testScopeDeviceNestingError(self): _assert_in_default_state(self) dist = _TestStrategy() # Open a device scope with dist.scope(). dist.extended._default_device = "/device:GPU:0" scope = dist.scope() scope.__enter__() self.assertIs(dist, ds_context.get_strategy()) with ops.device("/device:CPU:0"): with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"): scope.__exit__(None, None, None) scope.__exit__(None, None, None) _assert_in_default_state(self)
def run_fn(): replica_context = ds_context.get_replica_context() self.assertTrue(replica_context is not None) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def testScopeVarScopeNestingError(self): # We create a new graph here to simplify clean-up, since the error # we are triggering happens in the middle of scope.__exit__() and # leaves us in a weird state. with ops.Graph().as_default(): _assert_in_default_state(self) dist = _TestStrategy() scope = dist.scope() scope.__enter__() self.assertIs(dist, ds_context.get_strategy()) with variable_scope.variable_scope("AA"): with self.assertRaisesRegexp(RuntimeError, "Variable scope nesting error"): scope.__exit__(None, None, None) _assert_in_default_state(self)
def testScopeVarCreatorNestingError(self): def creator(next_creator, **kwargs): return next_creator(**kwargs) _assert_in_default_state(self) dist = _TestStrategy() scope = dist.scope() scope.__enter__() self.assertIs(dist, ds_context.get_strategy()) with variable_scope.variable_creator_scope(creator): with self.assertRaisesRegexp(RuntimeError, "Variable creator scope nesting error"): scope.__exit__(None, None, None) scope.__exit__(None, None, None) _assert_in_default_state(self)
def begin(self): self._global_step_tensor = training_util.get_global_step() self._stop_var = self._get_or_create_stop_var_with_aggregation() assert distribution_strategy_context.in_cross_replica_context() strategy = distribution_strategy_context.get_strategy() self._stop_placeholder = None def stop_op_fn(var): placeholder = array_ops.placeholder_with_default(0, tuple(), name='stop_value') if self._stop_placeholder is None: self._stop_placeholder = placeholder return var.assign_add(placeholder) self._stop_op = strategy.run(stop_op_fn, args=(self._stop_var, ))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs(dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue(distribution_strategy_context.in_cross_replica_context()) self.assertTrue(distribution_strategy_context.has_strategy()) self.assertIs(dist, distribution_strategy_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def __init__(self, trackable): if not isinstance(trackable, tracking.Trackable): raise ValueError('%s is not a Trackable object.' % (trackable, )) self._trackable = trackable self._distribute_strategy = distribution_strategy_context.get_strategy( ) # TODO(b/141682913): Figure out why this is private and fix it. saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access if len(saveables) != 1: raise ValueError( 'Only Trackables with one Saveable are supported.') saveable = list(saveables)[0] if ops.executing_eagerly_outside_functions(): # If we're in eager mode, we need to defer calling the Trackable's # saveable() callable until data export time. # However, it is safe to call the saveable as many times as we want, so # we will call it now to figure out how many tensors this Trackable will # produce. self._saveable = saveable self._num_tensors = len(self._saveable().specs) self._setter = lambda weights: self._saveable().restore( weights, None) self._getter = lambda: [ spec.tensor for spec in self._saveable().specs ] else: # If we're in Graph mode, we need to evaluate the Saveable only once and # cache the resulting restore graph. Failing to do this will result in # new assignment ops being added to the graph each time set_weights() is # called. self._placeholder_tensors = [] self._saveable = saveable() self._num_tensors = len(self._saveable.specs) for spec in self._saveable.specs: tensor = spec.tensor self._placeholder_tensors.append( array_ops.placeholder(tensor.dtype, tensor.shape)) self._assign_op = self._saveable.restore(self._placeholder_tensors, None) self._setter = self._set_weights_v1 self._getter = lambda: [ spec.tensor for spec in self._saveable.specs ]
def _raise_if_strategy_unsupported(self): if not strategy_supports_loss_scaling(): strategy = distribution_strategy_context.get_strategy() if isinstance( strategy, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): raise ValueError( 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 'unnecessary with TPUs, since they support bfloat16 instead of ' 'float16 and bfloat16 does not require loss scaling. You should ' 'remove the use of the LossScaleOptimizer when TPUs are used.' ) else: raise ValueError( 'Loss scaling is not supported with the ' 'tf.distribute.Strategy: %s. Try using a different ' 'Strategy, e.g. a MirroredStrategy' % strategy.__class__.__name__)
def run_fn(): replica_context = distribution_strategy_context.get_replica_context( ) self.assertTrue(replica_context is not None) self.assertIs( None, distribution_strategy_context.get_cross_replica_context()) self.assertFalse( distribution_strategy_context.in_cross_replica_context()) self.assertTrue(distribution_strategy_context.has_strategy()) self.assertIs(dist, distribution_strategy_context.get_strategy()) self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="bar"))
def decorated(metric_obj, *args, **kwargs): """Decorated function with `add_update()`.""" strategy = distribution_strategy_context.get_strategy() for weight in metric_obj.weights: if (backend.is_tpu_strategy(strategy) and not strategy.extended.variable_created_in_scope(weight) and not distribution_strategy_context.in_cross_replica_context()): raise ValueError( 'Trying to run metric.update_state in replica context when ' 'the metric was not created in TPUStrategy scope. ' 'Make sure the keras Metric is created in TPUstrategy scope. ') with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): update_op = update_state_fn(*args, **kwargs) if update_op is not None: # update_op will be None in eager execution. metric_obj.add_update(update_op) return update_op
def add_slot(var, slot_name, initializer="zeros", shape=None): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = optimizer_v2._var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) if isinstance(initializer, trackable.CheckpointInitialValueCallable) or ( shape is not None): slot_shape = shape else: slot_shape = var.shape initial_value = functools.partial(initializer, shape=slot_shape, dtype=var.dtype) else: initial_value = initializer strategy = distribute_ctx.get_strategy() with strategy.extended.colocate_vars_with(var): if isinstance(var, de.TrainableWrapper): weight = de.create_slots(var, initial_value, slot_name, var._shared_name, self._bp_v2) else: weight = variables.Variable( name="%s/%s" % ( var._shared_name, slot_name, ), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value, ) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight
def apply_gradients(self, grads_vars_and_constraints, name=None, experimental_aggregate_gradients=True): grads_vars_and_constraints = _filter_grads(grads_vars_and_constraints) var_list = [v for (_, v, _) in grads_vars_and_constraints] constraint_list = [c for (_, _, c) in grads_vars_and_constraints] with backend.name_scope(self._name): with ops.init_scope(): self._create_all_weights(var_list) if not grads_vars_and_constraints: return control_flow_ops.no_op() if distribute_ctx.in_cross_replica_context(): raise RuntimeError( "`apply_gradients() cannot be called in cross-replica context. " "Use `tf.distribute.Strategy.run` to enter replica " "context.") strategy = distribute_ctx.get_strategy() if (not experimental_aggregate_gradients and strategy and isinstance( strategy.extended, parameter_server_strategy. ParameterServerStrategyExtended)): raise NotImplementedError( "`experimental_aggregate_gradients=False is not supported for " "ParameterServerStrategy and CentralStorageStrategy") apply_state = self._prepare(var_list) if experimental_aggregate_gradients: reduced_grads = self._aggregate_gradients( grads_vars_and_constraints) var_list = [v for _, v, _ in grads_vars_and_constraints] grads_vars_and_constraints = list( zip(reduced_grads, var_list, constraint_list)) return distribute_ctx.get_replica_context().merge_call( functools.partial(self._distributed_apply, apply_state=apply_state), args=(grads_vars_and_constraints, ), kwargs={ "name": name, })
def set_last_step_output(self, name, output, reduce_op=None): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. reduce_op: Reduction method to use to reduce outputs from multiple replicas. Required if `set_last_step_output` is called in a replica context. Optional in cross_replica_context. When present, the outputs from all the replicas are reduced using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and reduction is set, output must be a `PerReplica` value. The reduce method is also recorded in a dictionary `_last_step_outputs_reduce_ops` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.in_cross_replica_context(): self._last_step_outputs_reduce_ops[name] = reduce_op if reduce_op is None: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_strategy() self._last_step_outputs[name] = distribution.reduce(reduce_op, output, axis=None) else: assert reduce_op is not None def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce(reduce_op, value, axis=None) # Setting this inside the `merge_fn` because all replicas share the same # context object, so it's more robust to set it only once (even if all # the replicas are trying to set the same value). self._last_step_outputs_reduce_ops[name] = reduce_op distribution_strategy_context.get_replica_context().merge_call( merge_fn, args=(output, ))
def testScope(self): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): self.assertIs(None, distribution_strategy_context.get_replica_context()) self.assertIs( dist, distribution_strategy_context.get_cross_replica_context()) self.assertTrue( distribution_strategy_context.in_cross_replica_context()) self.assertTrue(distribution_strategy_context.has_strategy()) self.assertIs(dist, distribution_strategy_context.get_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) self.assertDictEqual(expected_value, variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self)
def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: # TODO(b/120571621): We want to avoid colocating the variables here # since TPUStrategy does not implement replica local variables. # Remove this hack once we support TPULocalVariables. is_tpu_strategy = False if distribution_strategy_context.has_strategy(): distribute = distribution_strategy_context.get_strategy() if distribute.__class__.__name__ == 'TPUStrategy': is_tpu_strategy = True with ops.colocate_with(variable): decay = ops.convert_to_tensor(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay return state_ops.assign_sub(variable, update_delta, name=scope)
def create_slot_with_initializer(primary, initializer, shape, dtype, name, colocate_with_primary=True): """Creates a slot initialized using an `Initializer`. The type of the slot is determined by the given value. Args: primary: The primary `Variable` or `Tensor`. initializer: An `Initializer`. The initial value of the slot. shape: Shape of the initial value of the slot. dtype: Type of the value of the slot. name: Name to use for the slot variable. colocate_with_primary: Boolean. If True the slot is located on the same device as `primary`. Returns: A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. # Set "primary.op.name + '/' + name" as default name, so the scope name of # optimizer can be shared when reuse is True. Meanwhile when reuse is False # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = shape.is_fully_defined() if context.executing_eagerly(): prefix = primary._shared_name # pylint: disable=protected-access else: prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: distribution_strategy = distribution_strategy_context.get_strategy( ) with distribution_strategy.extended.colocate_vars_with(primary): return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype) else: return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype)
def _get_input_from_iterator(iterator): """Get elements from the iterator and verify the input shape and type.""" next_element = next(iterator) if tensor_util.is_tensor(next_element) or isinstance(next_element, dict): next_element = [next_element] if len(next_element) == 1: x, = next_element y = None sample_weights = None elif len(next_element) == 2: x, y = next_element sample_weights = None else: x, y, sample_weights = next_element # Validate that all the elements in x and y are of the same type and shape. dist_utils.validate_distributed_dataset_inputs( distribution_strategy_context.get_strategy(), x, y, sample_weights) return x, y, sample_weights
def strategy_supports_loss_scaling(): """Returns True if the current Strategy supports loss scaling.""" if not distribution_strategy_context.has_strategy(): return True strategy = distribution_strategy_context.get_strategy() # Strategies are supported if either there is only one replica or if variables # are replicated per device. Otherwise, the current model.fit() implementation # and most custom training loops incorrectly unscale the gradients. Currently, # gradients are unscaled once per compute replica, but they should be unscaled # once per variable replica. When there is one variable replica for each # compute replica, this works fine, but otherwise issues will occur. # TODO(reedwm): Support all strategies. return isinstance(strategy, ( collective_all_reduce_strategy.CollectiveAllReduceStrategy, collective_all_reduce_strategy.CollectiveAllReduceStrategyV1, one_device_strategy.OneDeviceStrategy, one_device_strategy.OneDeviceStrategyV1, mirrored_strategy.MirroredStrategy, mirrored_strategy.MirroredStrategyV1, ))
def decorated(metric_obj, *args, **kwargs): """Decorated function with `add_update()`.""" strategy = distribution_strategy_context.get_strategy() # TODO(b/142574744): Remove this check if a better solution is found for # declaring keras Metric outside of TPUStrategy and then updating it per # replica. for weight in metric_obj.weights: if (tpu.is_tpu_strategy(strategy) and not strategy.extended.variable_created_in_scope(weight) and not distribution_strategy_context.in_cross_replica_context()): raise ValueError( 'Trying to run metric.update_state in replica context when ' 'the metric was not created in TPUStrategy scope. ' 'Make sure the keras Metric is created in TPUstrategy scope. ') with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): update_op = update_state_fn(*args, **kwargs) if update_op is not None: # update_op will be None in eager execution. metric_obj.add_update(update_op) return update_op
def _get_distribution_strategy(model): """Get the model's distribution strategy.""" if model._distribution_strategy: return model._distribution_strategy else: # Use the default strategy if no strategy was present at compile. # Validate there is no actual strategy scope active at execution # time. strategy = distribution_strategy_context.get_strategy() if distribution_strategy_context.has_strategy(): raise ValueError( 'Model was compiled without any active distribution strategy, ' 'but there is an execution-time distribution ' 'strategy scope of (%s). ' 'Try to make sure your code looks similar to the following.\n' 'with strategy.scope():\n' ' model=_create_model()\n' ' model.compile(...)\n' ' model.fit(...)'% strategy) return strategy
def _skip(self, delta): def update_fn(v): return self._skip_single_var(v, delta) # TODO(b/170515001): Always call strategy.extended.update after calling it # from both replica context and cross-replica context is supported. if values_util.is_saving_non_distributed(): # Assumes replica context with replica_id=0, since we only save the first # replica. return update_fn(self.state) if self._distribution_strategy is not None: with ds_context.enter_or_assert_strategy( self._distribution_strategy): if ds_context.in_cross_replica_context(): # Code that operates on all replicas of a variable cannot be saved # without retracing. values_util.mark_as_unsaveable() # In cross-replica context we need to use strategy.extended.update. return ds_context.get_strategy().extended.update( self.state, update_fn) return update_fn(self.state)
def remove_temp_dirpath(dirpath, strategy): """Removes the temp path after writing is finished. Args: dirpath: Original dirpath that would be used without distribution. strategy: The tf.distribute strategy object currently used. """ if strategy is None: # Infer strategy from `distribution_strategy_context` if not given. strategy = distribution_strategy_context.get_strategy() if strategy is None: # If strategy is still not available, this is not in distributed training. # Fallback to no-op. return # TODO(anjalisridhar): Consider removing the check for multi worker mode since # it is redundant when used with the should_checkpoint property. if (strategy.extended._in_multi_worker_mode() and # pylint: disable=protected-access not strategy.extended.should_checkpoint): # If this worker is not chief and hence should not save file, remove # the temporary directory. file_io.delete_recursively(_get_temp_dir(dirpath, strategy))
def set_last_step_output(self, name, output, reduce_op=None): """Set `output` with `name` to be outputted from the last step. Args: name: String, name to identify the output. Doesn't need to match tensor name. output: The tensors that should be outputted with `name`. See below for actual types supported. reduce_op: Reduction method to use to reduce outputs from multiple replicas. Required if `set_last_step_output` is called in a replica context. Optional in cross_replica_context. When present, the outputs from all the replicas are reduced using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and reduction is set, output must be a `PerReplica` value. The reduce method is also recorded in a dictionary `_last_step_outputs_reduce_ops` for later interpreting of the outputs as already reduced or not. """ if distribution_strategy_context.in_cross_replica_context(): self._last_step_outputs_reduce_ops[name] = reduce_op if reduce_op is None: self._last_step_outputs[name] = output else: distribution = distribution_strategy_context.get_strategy() self._last_step_outputs[name] = distribution.reduce(reduce_op, output, axis=None) else: assert reduce_op is not None def merge_fn(distribution, value): self._last_step_outputs[name] = distribution.reduce(reduce_op, value, axis=None) # Setting this inside the `merge_fn` because all replicas share the same # context object, so it's more robust to set it only once (even if all # the replicas are trying to set the same value). self._last_step_outputs_reduce_ops[name] = reduce_op distribution_strategy_context.get_replica_context().merge_call( merge_fn, args=(output,))
def test_optimizer_with_slot_creation_fn(self, use_tpu): def slot_creation_fn(table, slot_names, _): slots = {} for slot in slot_names: slots[slot] = tf_variables.Variable( name='{}_{}'.format(table.name, slot), initial_value=functools.partial( init_ops_v2.Zeros(), shape=table.shape, dtype=dtypes.float32), trainable=False) return slots optimizer = tpu_embedding_v2_utils.Adagrad( learning_rate=0.1, slot_variable_creation_fn=slot_creation_fn) if use_tpu: strategy = self._get_strategy() else: strategy = distribution_strategy_context.get_strategy() with strategy.scope(): mid_level = tpu_embedding_v2.TPUEmbedding( feature_config=self.feature_config, optimizer=optimizer) # We aren't going to actually run anything, so the batch_size here does # not matter. mid_level.build(self.batch_size) video_accumulator = mid_level._variables['video']['accumulators'] user_accumulator = mid_level._variables['user']['accumulators'] if use_tpu: # To check the table contents (ensure that it is zero rather than the # normal initial accumulator value specified to in the optimizer config), # we need to select the underlying table variable on TPU. # We only have one shard on Forge. video_accumulator = video_accumulator.variables[0] user_accumulator = user_accumulator.variables[0] self.assertAllClose(video_accumulator.numpy(), np.zeros((self.table_video.vocabulary_size, self.table_video.dim))) self.assertAllClose(user_accumulator.numpy(), np.zeros((self.table_user.vocabulary_size, self.table_user.dim)))
def add_slot(self, var, slot_name, initializer="zeros"): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = _var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) initial_value = functools.partial(initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer strategy = distribute_ctx.get_strategy() if not strategy.extended.variable_created_in_scope(var): raise ValueError( "Trying to create optimizer slot variable under the scope for " "tf.distribute.Strategy ({}), which is different from the scope " "used for the original variable ({}). Make sure the slot " "variables are created under the same strategy scope. This may " "happen if you're restoring from a checkpoint outside the scope" .format(strategy, var)) with strategy.extended.colocate_vars_with(var): weight = tf_variables.Variable( name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight
def _handle_xshards(self, dataset, steps, local_batch_size, shuffle): import tensorflow as tf data, label = ray_partition_get_data_label(ray.get(dataset), allow_tuple=True, allow_list=False) def dataset_fn(input_context): dataset = tf.data.Dataset.from_tensor_slices((data, label)) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = \ tf.data.experimental.AutoShardPolicy.OFF dataset = dataset.with_options(options) dataset = dataset.repeat() dataset = dataset.take(steps * local_batch_size) if shuffle: dataset = dataset.shuffle(local_batch_size * min(steps, 10)) dataset = dataset.batch(local_batch_size) return dataset from tensorflow.python.distribute import distribution_strategy_context as ds_context strategy = ds_context.get_strategy() dataset = strategy.experimental_distribute_datasets_from_function(dataset_fn) return dataset
def _infer_steps(self, steps, dataset): """Infers steps_per_epoch needed to loop through a dataset.""" if steps is not None: return steps adapter_steps = self._adapter.get_size() if adapter_steps is not None: return adapter_steps if (ds_context.get_strategy().extended._in_multi_worker_mode() and # pylint: disable=protected-access (dataset.options().experimental_distribute.auto_shard_policy != distribute_options.AutoShardPolicy.OFF)): # If the dataset would be auto-sharded, we should not infer a local # steps_per_epoch due to the possible inbalanced sharding between workers. return None size = cardinality.cardinality(dataset) if size == cardinality.INFINITE and steps is None: raise ValueError("When passing an infinitely repeating dataset, you " "must specify how many steps to draw.") if size >= 0: return size return None
def skip(self, delta): """Advance the counter of a counter-based RNG. Args: delta: the amount of advancement. The state of the RNG after `skip(n)` will be the same as that after `normal([n])` (or any other distribution). The actual increment added to the counter is an unspecified implementation detail. Returns: A `Tensor` of type `int64`. """ def update_fn(v): return self._skip_single_var(v, delta) # TODO(b/170515001): Always call strategy.extended.update after calling it # from both replica context and cross-replica context is supported. if values_util.is_saving_non_distributed(): # Assumes replica context with replica_id=0, since we only save the first # replica. return update_fn(self.state) if self._distribution_strategy is not None: with ds_context.enter_or_assert_strategy( self._distribution_strategy): if ds_context.in_cross_replica_context(): # Code that operates on all replicas of a variable cannot be saved # without retracing. values_util.mark_as_unsaveable() if (ds_context.in_cross_replica_context() or "CentralStorage" in type(self._distribution_strategy).__name__): # In cross-replica context we need to use strategy.extended.update. # In CentralStorageStrategy we also need to use # strategy.extended.update (even for replica context), # because variable updates here must be within merge_call. return ds_context.get_strategy().extended.update( self.state, update_fn) return update_fn(self.state)
def testSameScopeNesting(self): _assert_in_default_state(self) dist = _TestStrategy() scope_a = dist.scope() with scope_a: self.assertIs(dist, ds_context.get_strategy()) scope_b = dist.scope() with scope_b: self.assertIs(dist, ds_context.get_strategy()) with scope_a: self.assertIs(dist, ds_context.get_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertIs(dist, ds_context.get_strategy()) dist2 = _TestStrategy() scope2 = dist2.scope() with self.assertRaisesRegex( RuntimeError, "Mixing different tf.distribute.Strategy objects"): with scope2: pass _assert_in_default_state(self) with scope_b: self.assertIs(dist, ds_context.get_strategy()) _assert_in_default_state(self)
def testSameScopeNesting(self): _assert_in_default_state(self) dist = _TestStrategy() scope_a = dist.scope() with scope_a: self.assertIs(dist, ds_context.get_strategy()) scope_b = dist.scope() with scope_b: self.assertIs(dist, ds_context.get_strategy()) with scope_a: self.assertIs(dist, ds_context.get_strategy()) self.assertIs(dist, ds_context.get_strategy()) self.assertIs(dist, ds_context.get_strategy()) dist2 = _TestStrategy() scope2 = dist2.scope() with self.assertRaisesRegexp( RuntimeError, "Mixing different tf.distribute.Strategy objects"): with scope2: pass _assert_in_default_state(self) with scope_b: self.assertIs(dist, ds_context.get_strategy()) _assert_in_default_state(self)
def _scale_loss(loss_value): if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] var_list = [] # local_anchor op will be placed on this worker task by default. local_anchor = control_flow_ops.no_op() # Colocating local_step variable prevents it being placed on the PS. distribution_strategy = distribution_strategy_context.get_strategy() with distribution_strategy.extended.colocate_vars_with(local_anchor): self._local_step = variable_scope.variable( initial_value=0, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], dtype=global_step.dtype.base_dtype, name="sync_rep_local_step") self.local_step_init_op = state_ops.assign(self._local_step, global_step) chief_init_ops = [self.local_step_init_op] self.ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) with ops.name_scope(None, self._name): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): # Dense gradients. if grad is None: aggregated_grad.append(None) # pass-through. continue elif isinstance(grad, ops.Tensor): grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=var.get_shape(), shared_name=var.name + "/grad_accum") train_ops.append( grad_accum.apply_grad(grad, local_step=self._local_step)) aggregated_grad.append( grad_accum.take_grad(self._replicas_to_aggregate)) else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=(), shared_name=var.name + "/grad_accum") train_ops.append( grad_accum.apply_indexed_slices_grad( grad, local_step=self._local_step)) aggregated_grad.append( grad_accum.take_indexed_slices_grad( self._replicas_to_aggregate)) self._accumulator_list.append((grad_accum, var.device)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients( aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = (data_flow_ops.FIFOQueue( -1, global_step.dtype.base_dtype, shapes=(), name="sync_token_q", shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = (data_flow_ops.FIFOQueue( 1, types_pb2.DT_INT32, shapes=(), name="dummy_queue", shared_name="dummy_queue")) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies(train_ops): token = sync_token_queue.dequeue() train_op = state_ops.assign(self._local_step, token) with ops.control_dependencies([update_op]): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. tokens = array_ops.fill([self._tokens_per_step], global_step) sync_op = sync_token_queue.enqueue_many((tokens, )) if self._variable_averages is not None: with ops.control_dependencies([sync_op ]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner( dummy_queue, [sync_op]) for accum, dev in self._accumulator_list: with ops.device(dev): chief_init_ops.append( accum.set_global_step(global_step, name="SetGlobalStep")) self.chief_init_op = control_flow_ops.group(*(chief_init_ops)) self._gradients_applied = True return train_op
def _restore_checkpoint(self, master, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None): """Creates a `Session`, and tries to restore a checkpoint. Args: master: `String` representation of the TensorFlow master to use. saver: A `Saver` object used to restore a model. checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the dir will be used to restore. checkpoint_filename_with_path: Full file name path to the checkpoint file. wait_for_checkpoint: Whether to wait for checkpoint to become available. max_wait_secs: Maximum time to wait for checkpoints to become available. config: Optional `ConfigProto` proto used to configure the session. Returns: A pair (sess, is_restored) where 'is_restored' is `True` if the session could be restored, `False` otherwise. Raises: ValueError: If both checkpoint_dir and checkpoint_filename_with_path are set. """ self._target = master # This is required to so that we initialize the TPU device before # restoring from checkpoint since we'll be placing variables on the device # and TPUInitialize wipes out the memory of the device. strategy = distribution_strategy_context.get_strategy() if strategy and hasattr(strategy.extended, "_experimental_initialize_system"): strategy.extended._experimental_initialize_system() # pylint: disable=protected-access sess = session.Session(self._target, graph=self._graph, config=config) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " "checkpoint_filename_with_path.") # If either saver or checkpoint_* is not specified, cannot restore. Just # return. if not saver or not (checkpoint_dir or checkpoint_filename_with_path): return sess, False if checkpoint_filename_with_path: saver.restore(sess, checkpoint_filename_with_path) return sess, True # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) else: return sess, False # Loads the checkpoint. saver.restore(sess, ckpt.model_checkpoint_path) saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) return sess, True
def decorated(metric_obj, *args): """Decorated function with merge_call.""" has_strategy = distribution_strategy_context.has_strategy() replica_context = distribution_strategy_context.get_replica_context() # The purpose of using `merge_call` to call `result()` is to trigger cross # replica aggregation of metric state variables (SyncOnReadVariable). After # we introduced `variable_sync_on_read_context`, in principle there is no # need to use `merge_call` here. However the branch still exists because: # # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor # across replicas (achieved by `merge_call`). With # `variable_sync_on_read_context` each replica gets their own tensors # residing on replica's device, thus breaking the assumption. # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns # the metric values of the first replica. With # `variable_sync_on_read_context` since each replica gets their own # tensors, the metric result tensors on the non-first replicas are not in # the return value of train_function, making TF graph optimizer prune the # branch that computes and aggregates those metric results. As a result, # if NCCL is used to do the aggregation, the program will hang because # NCCL ops are only launched on the non-pruned first replica. # # We condition on strategy.extended._use_merge_call() since we know if it is # false, the program uses `jit_compile` to compile replica fn, meaning it is # not V1 training (hence #1 is okay), and no pruning will happen as # compiled functions are not inlined (hence #2 is okay). if (not has_strategy or replica_context is None or not distribution_strategy_context.get_strategy( ).extended._use_merge_call()): with distribution_strategy_context.variable_sync_on_read_context(): raw_result = result_fn(*args) # Results need to be wrapped in a `tf.identity` op to ensure # correct execution order. if isinstance(raw_result, (ops.Tensor, variables_module.Variable, float, int)): result_t = array_ops.identity(raw_result) elif isinstance(raw_result, dict): result_t = { key: array_ops.identity(value) for key, value in raw_result.items() } else: try: result_t = array_ops.identity(raw_result) except (ValueError, TypeError): raise RuntimeError( 'The output of `metric.result()` can only be a single ' 'Tensor/Variable, or a dict of Tensors/Variables. ' 'For metric %s, got result %s.' % (metric_obj.name, raw_result)) else: # TODO(psv): Test distribution of metrics using different distribution # strategies. # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn # with distribution object as the first parameter. We create a wrapper # here so that the result function need not have that parameter. def merge_fn_wrapper(distribution, merge_fn, *args): # We will get `PerReplica` merge function. Taking the first one as all # are identical copies of the function that we had passed below. result = distribution.experimental_local_results(merge_fn)[0](*args) # Wrapping result in identity so that control dependency between # update_op from `update_state` and result works in case result returns # a tensor. return array_ops.identity(result) # Wrapping result in merge_call. merge_call is used when we want to leave # replica mode and compute a value in cross replica mode. result_t = replica_context.merge_call( merge_fn_wrapper, args=(result_fn,) + args) # We are saving the result op here to be used in train/test execution # functions. This basically gives the result op that was generated with a # control dep to the updates for these workflows. metric_obj._call_result = result_t return result_t
def call(self, inputs, training=None): if training is None: training = K.learning_phase() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond(training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond(training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = self._moments( math_ops.cast(inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: strategy.unwrap(self.moving_mean) return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): return tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError('virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError('When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError('When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused in (None, True): # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. if self._USE_V2_BEHAVIOR: if self.fused is None: self.fused = (ndims == 4) elif self.fused and ndims != 4: raise ValueError('Batch normalization layers with fused=True only ' 'support 4D input tensors.') else: assert self.fused is not None self.fused = (ndims == 4 and self._fused_can_be_used()) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError('Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError('Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0],) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight( name='gamma', shape=param_shape, dtype=self._param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True, experimental_autocast=False) else: self.gamma = None if self.fused: self._gamma_const = K.constant( 1.0, dtype=self._param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight( name='beta', shape=param_shape, dtype=self._param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True, experimental_autocast=False) else: self.beta = None if self.fused: self._beta_const = K.constant( 0.0, dtype=self._param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=self._param_dtype, initializer=self.moving_mean_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN, experimental_autocast=False) self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=self._param_dtype, initializer=self.moving_variance_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN, experimental_autocast=False) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): """Create a renorm variable.""" var = self.add_weight( name=name, shape=shape, dtype=self._param_dtype, initializer=init_ops.zeros_initializer(), synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN, experimental_autocast=False) return var with distribution_strategy_context.get_strategy( ).extended.colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribution_strategy_context.get_strategy( ).extended.colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def __init__(self, input_shape=None, batch_size=None, dtype=None, input_tensor=None, sparse=False, name=None, **kwargs): strategy = distribution_strategy_context.get_strategy() if strategy and batch_size is not None and \ distributed_training_utils.global_batch_size_supported(strategy): if batch_size % strategy.num_replicas_in_sync != 0: raise ValueError('The `batch_size` argument value {} cannot be ' 'divisible by number of replicas {}'.format( batch_size, strategy.num_replicas_in_sync)) batch_size = batch_size // strategy.num_replicas_in_sync if 'batch_input_shape' in kwargs: batch_input_shape = kwargs.pop('batch_input_shape') if input_shape and batch_input_shape: raise ValueError('Only provide the input_shape OR ' 'batch_input_shape argument to ' 'InputLayer, not both at the same time.') batch_size = batch_input_shape[0] input_shape = batch_input_shape[1:] if kwargs: raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) if not name: prefix = 'input' name = prefix + '_' + str(backend.get_uid(prefix)) if not dtype: if input_tensor is None: dtype = backend.floatx() else: dtype = backend.dtype(input_tensor) elif input_tensor is not None and input_tensor.dtype != dtype: raise ValueError('`input_tensor.dtype` differs from `dtype`: %s vs. %s' % (input_tensor.dtype, dtype)) super(InputLayer, self).__init__(dtype=dtype, name=name) self.built = True self.sparse = sparse self.batch_size = batch_size self.supports_masking = True if isinstance(input_shape, tensor_shape.TensorShape): input_shape = tuple(input_shape.as_list()) elif isinstance(input_shape, int): input_shape = (input_shape,) if input_tensor is None: if input_shape is not None: batch_input_shape = (batch_size,) + tuple(input_shape) else: batch_input_shape = None graph = backend.get_graph() with graph.as_default(): # In graph mode, create a graph placeholder to call the layer on. if sparse: input_tensor = backend.placeholder( shape=batch_input_shape, dtype=dtype, name=self.name, sparse=True) else: input_tensor = backend.placeholder( shape=batch_input_shape, dtype=dtype, name=self.name) self.is_placeholder = True self._batch_input_shape = batch_input_shape else: if not tf_utils.is_symbolic_tensor(input_tensor): raise ValueError('You should not pass an EagerTensor to `Input`. ' 'For example, instead of creating an ' 'InputLayer, you should instantiate your model and ' 'directly call it on your input.') self.is_placeholder = False self._batch_input_shape = tuple(input_tensor.shape.as_list()) # Create an input node to add to self.outbound_node # and set output_tensors' _keras_history. input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access input_tensor._keras_mask = None base_layer.Node( self, inbound_layers=[], node_indices=[], tensor_indices=[], input_tensors=[input_tensor], output_tensors=[input_tensor])
def _restore_checkpoint(self, master, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None): """Creates a `Session`, and tries to restore a checkpoint. Args: master: `String` representation of the TensorFlow master to use. saver: A `Saver` object used to restore a model. checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the dir will be used to restore. checkpoint_filename_with_path: Full file name path to the checkpoint file. wait_for_checkpoint: Whether to wait for checkpoint to become available. max_wait_secs: Maximum time to wait for checkpoints to become available. config: Optional `ConfigProto` proto used to configure the session. Returns: A pair (sess, is_restored) where 'is_restored' is `True` if the session could be restored, `False` otherwise. Raises: ValueError: If both checkpoint_dir and checkpoint_filename_with_path are set. """ self._target = master # This is required to so that we initialize the TPU device before # restoring from checkpoint since we'll be placing variables on the device # and TPUInitialize wipes out the memory of the device. strategy = distribution_strategy_context.get_strategy() if strategy and hasattr(strategy.extended, "_experimental_initialize_system"): strategy.extended._experimental_initialize_system() # pylint: disable=protected-access sess = session.Session(self._target, graph=self._graph, config=config) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " "checkpoint_filename_with_path.") # If either saver or checkpoint_* is not specified, cannot restore. Just # return. if not saver or not (checkpoint_dir or checkpoint_filename_with_path): return sess, False if checkpoint_filename_with_path: saver.restore(sess, checkpoint_filename_with_path) return sess, True # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs ckpt = checkpoint_management.get_checkpoint_state( checkpoint_dir) else: return sess, False # Loads the checkpoint. saver.restore(sess, ckpt.model_checkpoint_path) saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) return sess, True
def _support_zero_size_input(self): return distribution_strategy_context.has_strategy() and getattr( distribution_strategy_context.get_strategy().extended, 'experimental_enable_get_next_as_optional', False)
def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: raise ValueError('Input has undefined rank:', input_shape) ndims = len(input_shape) # Convert axis to list and resolve negatives if isinstance(self.axis, int): self.axis = [self.axis] for idx, x in enumerate(self.axis): if x < 0: self.axis[idx] = ndims + x # Validate axes for x in self.axis: if x < 0 or x >= ndims: raise ValueError('Invalid axis: %d' % x) if len(self.axis) != len(set(self.axis)): raise ValueError('Duplicate axis: %s' % self.axis) if self.virtual_batch_size is not None: if self.virtual_batch_size <= 0: raise ValueError( 'virtual_batch_size must be a positive integer that ' 'divides the true batch size of the input Tensor') # If using virtual batches, the first dimension must be the batch # dimension and cannot be the batch norm axis if 0 in self.axis: raise ValueError( 'When using virtual_batch_size, the batch dimension ' 'must be 0 and thus axis cannot include 0') if self.adjustment is not None: raise ValueError( 'When using virtual_batch_size, adjustment cannot ' 'be specified') if self.fused in (None, True): # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. if self._USE_V2_BEHAVIOR: if self.fused is None: self.fused = (ndims == 4) elif self.fused and ndims != 4: raise ValueError( 'Batch normalization layers with fused=True only ' 'support 4D input tensors.') else: assert self.fused is not None self.fused = (ndims == 4 and self._fused_can_be_used()) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is # particularly tricky. A compromise might be to just support the most # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: if self.axis == [1]: self._data_format = 'NCHW' elif self.axis == [3]: self._data_format = 'NHWC' else: raise ValueError( 'Unsupported axis, fused batch norm only supports ' 'axis == [1] or axis == [3]') # Raise parameters of fp16 batch norm to fp32 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: param_dtype = dtypes.float32 else: param_dtype = self.dtype or dtypes.float32 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} for x in axis_to_dim: if axis_to_dim[x] is None: raise ValueError( 'Input has undefined `axis` dimension. Input shape: ', input_shape) self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) if len(axis_to_dim) == 1 and self.virtual_batch_size is None: # Single axis batch norm (most common/default use-case) param_shape = (list(axis_to_dim.values())[0], ) else: # Parameter shape is the original shape but with 1 in all non-axis dims param_shape = [ axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) ] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight(name='gamma', shape=param_shape, dtype=param_dtype, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, trainable=True) else: self.gamma = None if self.fused: self._gamma_const = K.constant(1.0, dtype=param_dtype, shape=param_shape) if self.center: self.beta = self.add_weight(name='beta', shape=param_shape, dtype=param_dtype, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, trainable=True) else: self.beta = None if self.fused: self._beta_const = K.constant(0.0, dtype=param_dtype, shape=param_shape) try: # Disable variable partitioning when creating the moving mean and variance if hasattr(self, '_scope') and self._scope: partitioner = self._scope.partitioner self._scope.set_partitioner(None) else: partitioner = None self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, synchronization=tf_variables.VariableSynchronization.ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving # averages above. The renorm variables are colocated with moving_mean # and moving_variance. # NOTE: below, the outer `with device` block causes the current device # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), synchronization=tf_variables.VariableSynchronization. ON_READ, trainable=False, aggregation=tf_variables.VariableAggregation.MEAN) return var with distribution_strategy_context.get_strategy( ).extended.colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable( 'renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable( 'renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. with distribution_strategy_context.get_strategy( ).extended.colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable( 'renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable( 'renorm_stddev_weight', ()) finally: if partitioner: self._scope.set_partitioner(partitioner) self.built = True
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] var_list = [] # local_anchor op will be placed on this worker task by default. local_anchor = control_flow_ops.no_op() # Colocating local_step variable prevents it being placed on the PS. distribution_strategy = distribution_strategy_context.get_strategy() with distribution_strategy.extended.colocate_vars_with(local_anchor): self._local_step = variable_scope.variable( initial_value=0, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], dtype=global_step.dtype.base_dtype, name="sync_rep_local_step") self.local_step_init_op = state_ops.assign(self._local_step, global_step) chief_init_ops = [self.local_step_init_op] self.ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) with ops.name_scope(None, self._name): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): # Dense gradients. if grad is None: aggregated_grad.append(None) # pass-through. continue elif isinstance(grad, ops.Tensor): grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=var.get_shape(), shared_name=var.name + "/grad_accum") train_ops.append(grad_accum.apply_grad( grad, local_step=self._local_step)) aggregated_grad.append(grad_accum.take_grad( self._replicas_to_aggregate)) else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=(), shared_name=var.name + "/grad_accum") train_ops.append(grad_accum.apply_indexed_slices_grad( grad, local_step=self._local_step)) aggregated_grad.append(grad_accum.take_indexed_slices_grad( self._replicas_to_aggregate)) self._accumulator_list.append((grad_accum, var.device)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients(aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = ( data_flow_ops.FIFOQueue(-1, global_step.dtype.base_dtype, shapes=(), name="sync_token_q", shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = ( data_flow_ops.FIFOQueue(1, types_pb2.DT_INT32, shapes=(), name="dummy_queue", shared_name="dummy_queue")) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies(train_ops): token = sync_token_queue.dequeue() train_op = state_ops.assign(self._local_step, token) with ops.control_dependencies([update_op]): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. tokens = array_ops.fill([self._tokens_per_step], global_step) sync_op = sync_token_queue.enqueue_many((tokens,)) if self._variable_averages is not None: with ops.control_dependencies([sync_op]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, [sync_op]) for accum, dev in self._accumulator_list: with ops.device(dev): chief_init_ops.append( accum.set_global_step( global_step, name="SetGlobalStep")) self.chief_init_op = control_flow_ops.group(*(chief_init_ops)) self._gradients_applied = True return train_op
def call(self, inputs, training=None): if training is None: training = K.learning_phase() in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: strategy.unwrap(self.moving_mean)) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Computes the weighted loss. Args: losses: `Tensor` of shape `[batch_size, d1, ... dN]`. weights: Optional `Tensor` whose rank is either 0, or the same rank as `losses`, and must be broadcastable to `losses` (i.e., all dimensions must be either `1`, or the same as the corresponding `losses` dimension). scope: the scope for the operations performed in computing the loss. loss_collection: the loss will be added to these collections. reduction: Type of reduction to apply to loss. Returns: Weighted loss `Tensor` of the same type as `losses`. If `reduction` is `NONE`, this has the same shape as `losses`; otherwise, it is scalar. Raises: ValueError: If `weights` is `None` or the shape is not compatible with `losses`, or if the number of dimensions (rank) of either `losses` or `weights` is missing. Note: When calculating the gradient of a weighted loss contributions from both `losses` and `weights` are considered. If your `weights` depend on some model parameters but you do not want this to affect the loss gradient, you need to apply `tf.stop_gradient` to `weights` before passing them to `compute_weighted_loss`. @compatibility(eager) The `loss_collection` argument is ignored when executing eagerly. Consider holding on to the return value or collecting losses via a `tf.keras.Model`. @end_compatibility """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): with ops.control_dependencies(( weights_broadcast_ops.assert_broadcastable(weights, losses),)): losses = ops.convert_to_tensor(losses) input_dtype = losses.dtype losses = math_ops.cast(losses, dtype=dtypes.float32) weights = math_ops.cast(weights, dtype=dtypes.float32) weighted_losses = math_ops.multiply(losses, weights) if reduction == Reduction.NONE: loss = weighted_losses else: loss = math_ops.reduce_sum(weighted_losses) num_replicas = ( # Used to convert from local to global batch size. distribution_strategy_context.get_strategy().num_replicas_in_sync) if reduction == Reduction.MEAN: denom = (num_replicas * math_ops.reduce_sum(array_ops.ones_like(losses) * weights)) loss = _safe_mean(loss, denom) elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS): loss = _safe_mean(loss, num_replicas * _num_present(losses, weights)) elif reduction == Reduction.SUM_OVER_BATCH_SIZE: loss = _safe_mean(loss, num_replicas * _num_elements(losses)) # Convert the result back to the input type. loss = math_ops.cast(loss, input_dtype) util.add_loss(loss, loss_collection) return loss