def call_replica_local_fn(fn, *args, **kwargs):
  """Call a function that uses replica-local variables.

  This function correctly handles calling `fn` in a cross-replica
  context.

  Arguments:
    fn: The function to call.
    *args: Positional arguments to the `fn`.
    **kwargs: Keyword argument to `fn`.

  Returns:
    The result of calling `fn`.
  """
  # TODO(b/120571621): We want to avoid reductions here since
  # since TPUStrategy does not implement replica local variables.
  # Remove this hack once we support TPUReplicaLocalVariables.
  strategy = None
  if 'strategy' in kwargs:
    strategy = kwargs.pop('strategy')
  else:
    if ds_context.get_strategy():
      strategy = ds_context.get_strategy()

  is_tpu = is_tpu_strategy(strategy)
  if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
    with strategy.scope():
      return strategy.extended.call_for_each_replica(fn, args, kwargs)
  return fn(*args, **kwargs)
def create_slot(primary, val, name, colocate_with_primary=True):
  """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
  # Scope the slot name in the namespace of the primary variable.
  # Set "primary.op.name + '/' + name" as default name, so the scope name of
  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
  # and the same name has been previously used, the scope name will add '_N'
  # as suffix for unique identifications.
  validate_shape = val.get_shape().is_fully_defined()
  if context.executing_eagerly():
    prefix = primary._shared_name  # pylint: disable=protected-access
  else:
    prefix = primary.op.name
  with variable_scope.variable_scope(None, prefix + "/" + name):
    if colocate_with_primary:
      distribution_strategy = distribution_strategy_context.get_strategy()
      with distribution_strategy.colocate_vars_with(primary):
        return _create_slot_var(primary, val, "", validate_shape, None, None)
    else:
      return _create_slot_var(primary, val, "", validate_shape, None, None)
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()
        mean_update = strategy.extended.update(
            self.moving_mean, self._assign_moving_average,
            (mean, self.momentum))
        variance_update = strategy.extended.update(
            self.moving_variance, self._assign_moving_average,
            (variance, self.momentum))
      else:
        mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                  momentum)
        variance_update = self._assign_moving_average(self.moving_variance,
                                                      variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
def scale_loss_for_distribution(loss_value):
  """Scales and returns the given loss value by the number of replicas."""
  num_replicas = (
      distribution_strategy_context.get_strategy().num_replicas_in_sync)
  if num_replicas > 1:
    loss_value *= (1. / num_replicas)
  return loss_value
def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
 def add_slot(self, var, slot_name, initializer="zeros"):
   """Add a new slot variable for `var`."""
   if slot_name not in self._slot_names:
     self._slot_names.append(slot_name)
   var_key = _var_key(var)
   slot_dict = self._slots.setdefault(var_key, {})
   weight = slot_dict.get(slot_name, None)
   if weight is None:
     if isinstance(initializer, six.string_types) or callable(initializer):
       initializer = initializers.get(initializer)
       initial_value = functools.partial(
           initializer, shape=var.shape, dtype=var.dtype)
     else:
       initial_value = initializer
     strategy = distribute_ctx.get_strategy()
     with strategy.extended.colocate_vars_with(var):
       weight = tf_variables.Variable(
           name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
           dtype=var.dtype,
           trainable=False,
           initial_value=initial_value)
     backend.track_variable(weight)
     slot_dict[slot_name] = weight
     self._restore_slot_variable(
         slot_name=slot_name, variable=var,
         slot_variable=weight)
     self._weights.append(weight)
   return weight
Exemple #7
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_trackable()
      distribution_strategy = distribute_ctx.get_strategy()
      with distribution_strategy.extended.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(
            initial_value, name=name, trainable=False,
            use_resource=resource_variable_ops.is_resource_variable(
                colocate_with))
      # Restore this variable by name if necessary, but don't add a
      # Trackable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, trackable=v)
      self._non_slot_dict[key] = v

    return v
Exemple #8
0
 def _scale_loss(loss_value):
   ops.get_default_graph()._is_loss_scaled_by_optimizer = False  # pylint: disable=protected-access
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
       ops.get_default_graph()._is_loss_scaled_by_optimizer = True  # pylint: disable=protected-access
   return loss_value
 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
 def testSetStrategy(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   dist2 = _TestStrategy()
   ds_context.experimental_set_strategy(dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   expected_value = _get_test_variable(
       "baz", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="baz"))
   ds_context.experimental_set_strategy(dist2)
   self.assertIs(dist2, ds_context.get_strategy())
   ds_context.experimental_set_strategy(None)
   _assert_in_default_state(self)
Exemple #11
0
def _reduce_weighted_loss(
    weighted_losses, reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE):
  """Reduces the individual weighted loss measurements."""
  if reduction == losses_impl.ReductionV2.NONE:
    loss = weighted_losses
  else:
    loss = math_ops.reduce_sum(weighted_losses)
    if reduction == losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE:
      num_replicas = (  # Used to convert from local to global batch size.
          distribution_strategy_context.get_strategy().num_replicas_in_sync)
      loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses))
  return loss
  def _get_tensor(self, is_finite):
    tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))

    if not distribution_strategy_context.has_strategy():
      return tensor
    def get():
      rep_id = (distribution_strategy_context.get_replica_context()
                .replica_id_in_sync_group)
      return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor,
                                   lambda: 1.)
    distribution = distribution_strategy_context.get_strategy()
    return distribution.extended.call_for_each_replica(get)
 def testScopeDeviceNestingError(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   # Open a device scope with dist.scope().
   dist.extended._default_device = "/device:GPU:0"
   scope = dist.scope()
   scope.__enter__()
   self.assertIs(dist, ds_context.get_strategy())
   with ops.device("/device:CPU:0"):
     with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
       scope.__exit__(None, None, None)
   scope.__exit__(None, None, None)
   _assert_in_default_state(self)
 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
 def testScopeVarScopeNestingError(self):
   # We create a new graph here to simplify clean-up, since the error
   # we are triggering happens in the middle of scope.__exit__() and
   # leaves us in a weird state.
   with ops.Graph().as_default():
     _assert_in_default_state(self)
     dist = _TestStrategy()
     scope = dist.scope()
     scope.__enter__()
     self.assertIs(dist, ds_context.get_strategy())
     with variable_scope.variable_scope("AA"):
       with self.assertRaisesRegexp(RuntimeError,
                                    "Variable scope nesting error"):
         scope.__exit__(None, None, None)
   _assert_in_default_state(self)
  def testScopeVarCreatorNestingError(self):

    def creator(next_creator, **kwargs):
      return next_creator(**kwargs)

    _assert_in_default_state(self)
    dist = _TestStrategy()
    scope = dist.scope()
    scope.__enter__()
    self.assertIs(dist, ds_context.get_strategy())
    with variable_scope.variable_creator_scope(creator):
      with self.assertRaisesRegexp(RuntimeError,
                                   "Variable creator scope nesting error"):
        scope.__exit__(None, None, None)
    scope.__exit__(None, None, None)
    _assert_in_default_state(self)
    def begin(self):
        self._global_step_tensor = training_util.get_global_step()
        self._stop_var = self._get_or_create_stop_var_with_aggregation()
        assert distribution_strategy_context.in_cross_replica_context()

        strategy = distribution_strategy_context.get_strategy()
        self._stop_placeholder = None

        def stop_op_fn(var):
            placeholder = array_ops.placeholder_with_default(0,
                                                             tuple(),
                                                             name='stop_value')
            if self._stop_placeholder is None:
                self._stop_placeholder = placeholder
            return var.assign_add(placeholder)

        self._stop_op = strategy.run(stop_op_fn, args=(self._stop_var, ))
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_replica_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(distribution_strategy_context.in_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Exemple #19
0
    def __init__(self, trackable):
        if not isinstance(trackable, tracking.Trackable):
            raise ValueError('%s is not a Trackable object.' % (trackable, ))
        self._trackable = trackable
        self._distribute_strategy = distribution_strategy_context.get_strategy(
        )

        # TODO(b/141682913): Figure out why this is private and fix it.
        saveables = trackable._gather_saveables_for_checkpoint().values()  # pylint: disable=protected-access
        if len(saveables) != 1:
            raise ValueError(
                'Only Trackables with one Saveable are supported.')
        saveable = list(saveables)[0]

        if ops.executing_eagerly_outside_functions():
            # If we're in eager mode, we need to defer calling the Trackable's
            # saveable() callable until data export time.
            # However, it is safe to call the saveable as many times as we want, so
            # we will call it now to figure out how many tensors this Trackable will
            # produce.
            self._saveable = saveable
            self._num_tensors = len(self._saveable().specs)
            self._setter = lambda weights: self._saveable().restore(
                weights, None)
            self._getter = lambda: [
                spec.tensor for spec in self._saveable().specs
            ]
        else:
            # If we're in Graph mode, we need to evaluate the Saveable only once and
            # cache the resulting restore graph. Failing to do this will result in
            # new assignment ops being added to the graph each time set_weights() is
            # called.
            self._placeholder_tensors = []
            self._saveable = saveable()
            self._num_tensors = len(self._saveable.specs)
            for spec in self._saveable.specs:
                tensor = spec.tensor
                self._placeholder_tensors.append(
                    array_ops.placeholder(tensor.dtype, tensor.shape))
            self._assign_op = self._saveable.restore(self._placeholder_tensors,
                                                     None)
            self._setter = self._set_weights_v1
            self._getter = lambda: [
                spec.tensor for spec in self._saveable.specs
            ]
 def _raise_if_strategy_unsupported(self):
     if not strategy_supports_loss_scaling():
         strategy = distribution_strategy_context.get_strategy()
         if isinstance(
                 strategy,
             (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
             raise ValueError(
                 'Loss scaling is not supported with TPUStrategy. Loss scaling is '
                 'unnecessary with TPUs, since they support bfloat16 instead of '
                 'float16 and bfloat16 does not require loss scaling. You should '
                 'remove the use of the LossScaleOptimizer when TPUs are used.'
             )
         else:
             raise ValueError(
                 'Loss scaling is not supported with the '
                 'tf.distribute.Strategy: %s. Try using a different '
                 'Strategy, e.g. a MirroredStrategy' %
                 strategy.__class__.__name__)
Exemple #21
0
 def run_fn():
     replica_context = distribution_strategy_context.get_replica_context(
     )
     self.assertTrue(replica_context is not None)
     self.assertIs(
         None,
         distribution_strategy_context.get_cross_replica_context())
     self.assertFalse(
         distribution_strategy_context.in_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_strategy())
     self.assertIs(dist, distribution_strategy_context.get_strategy())
     self.assertEqual("foo",
                      replica_context.merge_call(None, test_arg="foo"))
     expected_value = _get_test_variable(
         "bar", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="bar"))
Exemple #22
0
  def decorated(metric_obj, *args, **kwargs):
    """Decorated function with `add_update()`."""
    strategy = distribution_strategy_context.get_strategy()

    for weight in metric_obj.weights:
      if (backend.is_tpu_strategy(strategy) and
          not strategy.extended.variable_created_in_scope(weight)
          and not distribution_strategy_context.in_cross_replica_context()):
        raise ValueError(
            'Trying to run metric.update_state in replica context when '
            'the metric was not created in TPUStrategy scope. '
            'Make sure the keras Metric is created in TPUstrategy scope. ')

    with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
      update_op = update_state_fn(*args, **kwargs)
    if update_op is not None:  # update_op will be None in eager execution.
      metric_obj.add_update(update_op)
    return update_op
 def add_slot(var, slot_name, initializer="zeros", shape=None):
     """Add a new slot variable for `var`."""
     if slot_name not in self._slot_names:
         self._slot_names.append(slot_name)
     var_key = optimizer_v2._var_key(var)
     slot_dict = self._slots.setdefault(var_key, {})
     weight = slot_dict.get(slot_name, None)
     if weight is None:
         if isinstance(initializer,
                       six.string_types) or callable(initializer):
             initializer = initializers.get(initializer)
             if isinstance(initializer,
                           trackable.CheckpointInitialValueCallable) or (
                               shape is not None):
                 slot_shape = shape
             else:
                 slot_shape = var.shape
             initial_value = functools.partial(initializer,
                                               shape=slot_shape,
                                               dtype=var.dtype)
         else:
             initial_value = initializer
         strategy = distribute_ctx.get_strategy()
         with strategy.extended.colocate_vars_with(var):
             if isinstance(var, de.TrainableWrapper):
                 weight = de.create_slots(var, initial_value, slot_name,
                                          var._shared_name, self._bp_v2)
             else:
                 weight = variables.Variable(
                     name="%s/%s" % (
                         var._shared_name,
                         slot_name,
                     ),  # pylint: disable=protected-access
                     dtype=var.dtype,
                     trainable=False,
                     initial_value=initial_value,
                 )
         backend.track_variable(weight)
         slot_dict[slot_name] = weight
         self._restore_slot_variable(slot_name=slot_name,
                                     variable=var,
                                     slot_variable=weight)
         self._weights.append(weight)
     return weight
Exemple #24
0
    def apply_gradients(self,
                        grads_vars_and_constraints,
                        name=None,
                        experimental_aggregate_gradients=True):
        grads_vars_and_constraints = _filter_grads(grads_vars_and_constraints)
        var_list = [v for (_, v, _) in grads_vars_and_constraints]
        constraint_list = [c for (_, _, c) in grads_vars_and_constraints]

        with backend.name_scope(self._name):
            with ops.init_scope():
                self._create_all_weights(var_list)

            if not grads_vars_and_constraints:
                return control_flow_ops.no_op()

            if distribute_ctx.in_cross_replica_context():
                raise RuntimeError(
                    "`apply_gradients() cannot be called in cross-replica context. "
                    "Use `tf.distribute.Strategy.run` to enter replica "
                    "context.")

            strategy = distribute_ctx.get_strategy()
            if (not experimental_aggregate_gradients and strategy
                    and isinstance(
                        strategy.extended, parameter_server_strategy.
                        ParameterServerStrategyExtended)):
                raise NotImplementedError(
                    "`experimental_aggregate_gradients=False is not supported for "
                    "ParameterServerStrategy and CentralStorageStrategy")

            apply_state = self._prepare(var_list)
            if experimental_aggregate_gradients:
                reduced_grads = self._aggregate_gradients(
                    grads_vars_and_constraints)
                var_list = [v for _, v, _ in grads_vars_and_constraints]
                grads_vars_and_constraints = list(
                    zip(reduced_grads, var_list, constraint_list))
            return distribute_ctx.get_replica_context().merge_call(
                functools.partial(self._distributed_apply,
                                  apply_state=apply_state),
                args=(grads_vars_and_constraints, ),
                kwargs={
                    "name": name,
                })
Exemple #25
0
    def set_last_step_output(self, name, output, reduce_op=None):
        """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
        if distribution_strategy_context.in_cross_replica_context():
            self._last_step_outputs_reduce_ops[name] = reduce_op
            if reduce_op is None:
                self._last_step_outputs[name] = output
            else:
                distribution = distribution_strategy_context.get_strategy()
                self._last_step_outputs[name] = distribution.reduce(reduce_op,
                                                                    output,
                                                                    axis=None)
        else:
            assert reduce_op is not None

            def merge_fn(distribution, value):
                self._last_step_outputs[name] = distribution.reduce(reduce_op,
                                                                    value,
                                                                    axis=None)
                # Setting this inside the `merge_fn` because all replicas share the same
                # context object, so it's more robust to set it only once (even if all
                # the replicas are trying to set the same value).
                self._last_step_outputs_reduce_ops[name] = reduce_op

            distribution_strategy_context.get_replica_context().merge_call(
                merge_fn, args=(output, ))
Exemple #26
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None,
                       distribution_strategy_context.get_replica_context())
         self.assertIs(
             dist,
             distribution_strategy_context.get_cross_replica_context())
         self.assertTrue(
             distribution_strategy_context.in_cross_replica_context())
         self.assertTrue(distribution_strategy_context.has_strategy())
         self.assertIs(dist, distribution_strategy_context.get_strategy())
         expected_value = _get_test_variable(
             "baz", variable_scope.VariableSynchronization.AUTO,
             variable_scope.VariableAggregation.NONE)
         self.assertDictEqual(expected_value,
                              variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_strategy():
                distribute = distribution_strategy_context.get_strategy()
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            with ops.colocate_with(variable):
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
def create_slot_with_initializer(primary,
                                 initializer,
                                 shape,
                                 dtype,
                                 name,
                                 colocate_with_primary=True):
    """Creates a slot initialized using an `Initializer`.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    initializer: An `Initializer`.  The initial value of the slot.
    shape: Shape of the initial value of the slot.
    dtype: Type of the value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
    # Scope the slot name in the namespace of the primary variable.
    # Set "primary.op.name + '/' + name" as default name, so the scope name of
    # optimizer can be shared when reuse is True. Meanwhile when reuse is False
    # and the same name has been previously used, the scope name will add '_N'
    # as suffix for unique identifications.
    validate_shape = shape.is_fully_defined()
    if context.executing_eagerly():
        prefix = primary._shared_name  # pylint: disable=protected-access
    else:
        prefix = primary.op.name
    with variable_scope.variable_scope(None, prefix + "/" + name):
        if colocate_with_primary:
            distribution_strategy = distribution_strategy_context.get_strategy(
            )
            with distribution_strategy.extended.colocate_vars_with(primary):
                return _create_slot_var(primary, initializer, "",
                                        validate_shape, shape, dtype)
        else:
            return _create_slot_var(primary, initializer, "", validate_shape,
                                    shape, dtype)
def _get_input_from_iterator(iterator):
    """Get elements from the iterator and verify the input shape and type."""
    next_element = next(iterator)

    if tensor_util.is_tensor(next_element) or isinstance(next_element, dict):
        next_element = [next_element]
    if len(next_element) == 1:
        x, = next_element
        y = None
        sample_weights = None
    elif len(next_element) == 2:
        x, y = next_element
        sample_weights = None
    else:
        x, y, sample_weights = next_element

    # Validate that all the elements in x and y are of the same type and shape.
    dist_utils.validate_distributed_dataset_inputs(
        distribution_strategy_context.get_strategy(), x, y, sample_weights)
    return x, y, sample_weights
def strategy_supports_loss_scaling():
  """Returns True if the current Strategy supports loss scaling."""
  if not distribution_strategy_context.has_strategy():
    return True
  strategy = distribution_strategy_context.get_strategy()
  # Strategies are supported if either there is only one replica or if variables
  # are replicated per device. Otherwise, the current model.fit() implementation
  # and most custom training loops incorrectly unscale the gradients. Currently,
  # gradients are unscaled once per compute replica, but they should be unscaled
  # once per variable replica. When there is one variable replica for each
  # compute replica, this works fine, but otherwise issues will occur.
  # TODO(reedwm): Support all strategies.
  return isinstance(strategy, (
      collective_all_reduce_strategy.CollectiveAllReduceStrategy,
      collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
      one_device_strategy.OneDeviceStrategy,
      one_device_strategy.OneDeviceStrategyV1,
      mirrored_strategy.MirroredStrategy,
      mirrored_strategy.MirroredStrategyV1,
  ))
Exemple #31
0
  def decorated(metric_obj, *args, **kwargs):
    """Decorated function with `add_update()`."""
    strategy = distribution_strategy_context.get_strategy()
    # TODO(b/142574744): Remove this check if a better solution is found for
    # declaring keras Metric outside of TPUStrategy and then updating it per
    # replica.

    for weight in metric_obj.weights:
      if (tpu.is_tpu_strategy(strategy) and
          not strategy.extended.variable_created_in_scope(weight)
          and not distribution_strategy_context.in_cross_replica_context()):
        raise ValueError(
            'Trying to run metric.update_state in replica context when '
            'the metric was not created in TPUStrategy scope. '
            'Make sure the keras Metric is created in TPUstrategy scope. ')

    with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
      update_op = update_state_fn(*args, **kwargs)
    if update_op is not None:  # update_op will be None in eager execution.
      metric_obj.add_update(update_op)
    return update_op
Exemple #32
0
def _get_distribution_strategy(model):
  """Get the model's distribution strategy."""
  if model._distribution_strategy:
    return model._distribution_strategy
  else:
    # Use the default strategy if no strategy was present at compile.
    # Validate there is no actual strategy scope active at execution
    # time.
    strategy = distribution_strategy_context.get_strategy()
    if distribution_strategy_context.has_strategy():
      raise ValueError(
          'Model was compiled without any active distribution strategy, '
          'but there is an execution-time distribution '
          'strategy scope of (%s). '
          'Try to make sure your code looks similar to the following.\n'
          'with strategy.scope():\n'
          '  model=_create_model()\n'
          '  model.compile(...)\n'
          '  model.fit(...)'% strategy)

    return strategy
Exemple #33
0
    def _skip(self, delta):
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                    # In cross-replica context we need to use strategy.extended.update.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
def remove_temp_dirpath(dirpath, strategy):
  """Removes the temp path after writing is finished.

  Args:
    dirpath: Original dirpath that would be used without distribution.
    strategy: The tf.distribute strategy object currently used.
  """
  if strategy is None:
    # Infer strategy from `distribution_strategy_context` if not given.
    strategy = distribution_strategy_context.get_strategy()
  if strategy is None:
    # If strategy is still not available, this is not in distributed training.
    # Fallback to no-op.
    return
  # TODO(anjalisridhar): Consider removing the check for multi worker mode since
  # it is redundant when used with the should_checkpoint property.
  if (strategy.extended._in_multi_worker_mode() and  # pylint: disable=protected-access
      not strategy.extended.should_checkpoint):
    # If this worker is not chief and hence should not save file, remove
    # the temporary directory.
    file_io.delete_recursively(_get_temp_dir(dirpath, strategy))
Exemple #35
0
  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribution_strategy_context.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
                                                            axis=None)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
                                                            axis=None)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribution_strategy_context.get_replica_context().merge_call(
          merge_fn, args=(output,))
  def test_optimizer_with_slot_creation_fn(self, use_tpu):
    def slot_creation_fn(table, slot_names, _):
      slots = {}
      for slot in slot_names:
        slots[slot] = tf_variables.Variable(
            name='{}_{}'.format(table.name, slot),
            initial_value=functools.partial(
                init_ops_v2.Zeros(), shape=table.shape, dtype=dtypes.float32),
            trainable=False)
      return slots
    optimizer = tpu_embedding_v2_utils.Adagrad(
        learning_rate=0.1,
        slot_variable_creation_fn=slot_creation_fn)
    if use_tpu:
      strategy = self._get_strategy()
    else:
      strategy = distribution_strategy_context.get_strategy()
    with strategy.scope():
      mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config=self.feature_config,
          optimizer=optimizer)
      # We aren't going to actually run anything, so the batch_size here does
      # not matter.
      mid_level.build(self.batch_size)
    video_accumulator = mid_level._variables['video']['accumulators']
    user_accumulator = mid_level._variables['user']['accumulators']
    if use_tpu:
      # To check the table contents (ensure that it is zero rather than the
      # normal initial accumulator value specified to in the optimizer config),
      # we need to select the underlying table variable on TPU.
      # We only have one shard on Forge.
      video_accumulator = video_accumulator.variables[0]
      user_accumulator = user_accumulator.variables[0]

    self.assertAllClose(video_accumulator.numpy(),
                        np.zeros((self.table_video.vocabulary_size,
                                  self.table_video.dim)))
    self.assertAllClose(user_accumulator.numpy(),
                        np.zeros((self.table_user.vocabulary_size,
                                  self.table_user.dim)))
Exemple #37
0
    def add_slot(self, var, slot_name, initializer="zeros"):
        """Add a new slot variable for `var`."""
        if slot_name not in self._slot_names:
            self._slot_names.append(slot_name)
        var_key = _var_key(var)
        slot_dict = self._slots.setdefault(var_key, {})
        weight = slot_dict.get(slot_name, None)
        if weight is None:
            if isinstance(initializer,
                          six.string_types) or callable(initializer):
                initializer = initializers.get(initializer)
                initial_value = functools.partial(initializer,
                                                  shape=var.shape,
                                                  dtype=var.dtype)
            else:
                initial_value = initializer
            strategy = distribute_ctx.get_strategy()
            if not strategy.extended.variable_created_in_scope(var):
                raise ValueError(
                    "Trying to create optimizer slot variable under the scope for "
                    "tf.distribute.Strategy ({}), which is different from the scope "
                    "used for the original variable ({}). Make sure the slot "
                    "variables are created under the same strategy scope. This may "
                    "happen if you're restoring from a checkpoint outside the scope"
                    .format(strategy, var))

            with strategy.extended.colocate_vars_with(var):
                weight = tf_variables.Variable(
                    name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
                    dtype=var.dtype,
                    trainable=False,
                    initial_value=initial_value)
            backend.track_variable(weight)
            slot_dict[slot_name] = weight
            self._restore_slot_variable(slot_name=slot_name,
                                        variable=var,
                                        slot_variable=weight)
            self._weights.append(weight)
        return weight
Exemple #38
0
    def _handle_xshards(self, dataset, steps, local_batch_size, shuffle):
        import tensorflow as tf
        data, label = ray_partition_get_data_label(ray.get(dataset),
                                                   allow_tuple=True,
                                                   allow_list=False)

        def dataset_fn(input_context):
            dataset = tf.data.Dataset.from_tensor_slices((data, label))
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = \
                tf.data.experimental.AutoShardPolicy.OFF
            dataset = dataset.with_options(options)
            dataset = dataset.repeat()
            dataset = dataset.take(steps * local_batch_size)
            if shuffle:
                dataset = dataset.shuffle(local_batch_size * min(steps, 10))
            dataset = dataset.batch(local_batch_size)
            return dataset

        from tensorflow.python.distribute import distribution_strategy_context as ds_context
        strategy = ds_context.get_strategy()
        dataset = strategy.experimental_distribute_datasets_from_function(dataset_fn)
        return dataset
Exemple #39
0
  def _infer_steps(self, steps, dataset):
    """Infers steps_per_epoch needed to loop through a dataset."""
    if steps is not None:
      return steps

    adapter_steps = self._adapter.get_size()
    if adapter_steps is not None:
      return adapter_steps

    if (ds_context.get_strategy().extended._in_multi_worker_mode() and  # pylint: disable=protected-access
        (dataset.options().experimental_distribute.auto_shard_policy !=
         distribute_options.AutoShardPolicy.OFF)):
      # If the dataset would be auto-sharded, we should not infer a local
      # steps_per_epoch due to the possible inbalanced sharding between workers.
      return None

    size = cardinality.cardinality(dataset)
    if size == cardinality.INFINITE and steps is None:
      raise ValueError("When passing an infinitely repeating dataset, you "
                       "must specify how many steps to draw.")
    if size >= 0:
      return size
    return None
Exemple #40
0
    def skip(self, delta):
        """Advance the counter of a counter-based RNG.

    Args:
      delta: the amount of advancement. The state of the RNG after
        `skip(n)` will be the same as that after `normal([n])`
        (or any other distribution). The actual increment added to the
        counter is an unspecified implementation detail.

    Returns:
      A `Tensor` of type `int64`.
    """
        def update_fn(v):
            return self._skip_single_var(v, delta)

        # TODO(b/170515001): Always call strategy.extended.update after calling it
        #   from both replica context and cross-replica context is supported.
        if values_util.is_saving_non_distributed():
            # Assumes replica context with replica_id=0, since we only save the first
            # replica.
            return update_fn(self.state)
        if self._distribution_strategy is not None:
            with ds_context.enter_or_assert_strategy(
                    self._distribution_strategy):
                if ds_context.in_cross_replica_context():
                    # Code that operates on all replicas of a variable cannot be saved
                    # without retracing.
                    values_util.mark_as_unsaveable()
                if (ds_context.in_cross_replica_context() or "CentralStorage"
                        in type(self._distribution_strategy).__name__):
                    # In cross-replica context we need to use strategy.extended.update.
                    # In CentralStorageStrategy we also need to use
                    # strategy.extended.update (even for replica context),
                    # because variable updates here must be within merge_call.
                    return ds_context.get_strategy().extended.update(
                        self.state, update_fn)
        return update_fn(self.state)
 def testSameScopeNesting(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   scope_a = dist.scope()
   with scope_a:
     self.assertIs(dist, ds_context.get_strategy())
     scope_b = dist.scope()
     with scope_b:
       self.assertIs(dist, ds_context.get_strategy())
       with scope_a:
         self.assertIs(dist, ds_context.get_strategy())
       self.assertIs(dist, ds_context.get_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     dist2 = _TestStrategy()
     scope2 = dist2.scope()
     with self.assertRaisesRegex(
         RuntimeError, "Mixing different tf.distribute.Strategy objects"):
       with scope2:
         pass
   _assert_in_default_state(self)
   with scope_b:
     self.assertIs(dist, ds_context.get_strategy())
   _assert_in_default_state(self)
 def testSameScopeNesting(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   scope_a = dist.scope()
   with scope_a:
     self.assertIs(dist, ds_context.get_strategy())
     scope_b = dist.scope()
     with scope_b:
       self.assertIs(dist, ds_context.get_strategy())
       with scope_a:
         self.assertIs(dist, ds_context.get_strategy())
       self.assertIs(dist, ds_context.get_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     dist2 = _TestStrategy()
     scope2 = dist2.scope()
     with self.assertRaisesRegexp(
         RuntimeError,
         "Mixing different tf.distribute.Strategy objects"):
       with scope2:
         pass
   _assert_in_default_state(self)
   with scope_b:
     self.assertIs(dist, ds_context.get_strategy())
   _assert_in_default_state(self)
Exemple #43
0
 def _scale_loss(loss_value):
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
        if not grads_and_vars:
            raise ValueError("Must supply at least one variable")

        if global_step is None:
            raise ValueError("Global step is required to check staleness")

        self._global_step = global_step
        train_ops = []
        aggregated_grad = []
        var_list = []

        # local_anchor op will be placed on this worker task by default.
        local_anchor = control_flow_ops.no_op()
        # Colocating local_step variable prevents it being placed on the PS.
        distribution_strategy = distribution_strategy_context.get_strategy()
        with distribution_strategy.extended.colocate_vars_with(local_anchor):
            self._local_step = variable_scope.variable(
                initial_value=0,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                dtype=global_step.dtype.base_dtype,
                name="sync_rep_local_step")

        self.local_step_init_op = state_ops.assign(self._local_step,
                                                   global_step)
        chief_init_ops = [self.local_step_init_op]
        self.ready_for_local_init_op = variables.report_uninitialized_variables(
            variables.global_variables())

        with ops.name_scope(None, self._name):
            for grad, var in grads_and_vars:
                var_list.append(var)
                with ops.device(var.device):
                    # Dense gradients.
                    if grad is None:
                        aggregated_grad.append(None)  # pass-through.
                        continue
                    elif isinstance(grad, ops.Tensor):
                        grad_accum = data_flow_ops.ConditionalAccumulator(
                            grad.dtype,
                            shape=var.get_shape(),
                            shared_name=var.name + "/grad_accum")
                        train_ops.append(
                            grad_accum.apply_grad(grad,
                                                  local_step=self._local_step))
                        aggregated_grad.append(
                            grad_accum.take_grad(self._replicas_to_aggregate))
                    else:
                        if not isinstance(grad, ops.IndexedSlices):
                            raise ValueError("Unknown grad type!")
                        grad_accum = data_flow_ops.SparseConditionalAccumulator(
                            grad.dtype,
                            shape=(),
                            shared_name=var.name + "/grad_accum")
                        train_ops.append(
                            grad_accum.apply_indexed_slices_grad(
                                grad, local_step=self._local_step))
                        aggregated_grad.append(
                            grad_accum.take_indexed_slices_grad(
                                self._replicas_to_aggregate))

                    self._accumulator_list.append((grad_accum, var.device))

            aggregated_grads_and_vars = zip(aggregated_grad, var_list)

            # sync_op will be assigned to the same device as the global step.
            with ops.device(global_step.device), ops.name_scope(""):
                update_op = self._opt.apply_gradients(
                    aggregated_grads_and_vars, global_step)

            # Create token queue.
            with ops.device(global_step.device), ops.name_scope(""):
                sync_token_queue = (data_flow_ops.FIFOQueue(
                    -1,
                    global_step.dtype.base_dtype,
                    shapes=(),
                    name="sync_token_q",
                    shared_name="sync_token_q"))
                self._sync_token_queue = sync_token_queue

                # dummy_queue is passed to the queue runner. Don't use the real queues
                # because the queue runner doesn't automatically reopen it once it
                # closed queues in PS devices.
                dummy_queue = (data_flow_ops.FIFOQueue(
                    1,
                    types_pb2.DT_INT32,
                    shapes=(),
                    name="dummy_queue",
                    shared_name="dummy_queue"))

            with ops.device(global_step.device), ops.name_scope(""):
                # Replicas have to wait until they can get a token from the token queue.
                with ops.control_dependencies(train_ops):
                    token = sync_token_queue.dequeue()
                train_op = state_ops.assign(self._local_step, token)

                with ops.control_dependencies([update_op]):
                    # Sync_op needs to insert tokens to the token queue at the end of the
                    # step so the replicas can fetch them to start the next step.
                    tokens = array_ops.fill([self._tokens_per_step],
                                            global_step)
                    sync_op = sync_token_queue.enqueue_many((tokens, ))

                if self._variable_averages is not None:
                    with ops.control_dependencies([sync_op
                                                   ]), ops.name_scope(""):
                        sync_op = self._variable_averages.apply(
                            self._variables_to_average)

                self._chief_queue_runner = queue_runner.QueueRunner(
                    dummy_queue, [sync_op])
            for accum, dev in self._accumulator_list:
                with ops.device(dev):
                    chief_init_ops.append(
                        accum.set_global_step(global_step,
                                              name="SetGlobalStep"))
            self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
            self._gradients_applied = True
            return train_op
Exemple #45
0
  def _restore_checkpoint(self,
                          master,
                          saver=None,
                          checkpoint_dir=None,
                          checkpoint_filename_with_path=None,
                          wait_for_checkpoint=False,
                          max_wait_secs=7200,
                          config=None):
    """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
    self._target = master

    # This is required to so that we initialize the TPU device before
    # restoring from checkpoint since we'll be placing variables on the device
    # and TPUInitialize wipes out the memory of the device.
    strategy = distribution_strategy_context.get_strategy()
    if strategy and hasattr(strategy.extended,
                            "_experimental_initialize_system"):
      strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access

    sess = session.Session(self._target, graph=self._graph, config=config)
    if checkpoint_dir and checkpoint_filename_with_path:
      raise ValueError("Can not provide both checkpoint_dir and "
                       "checkpoint_filename_with_path.")
    # If either saver or checkpoint_* is not specified, cannot restore. Just
    # return.
    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
      return sess, False

    if checkpoint_filename_with_path:
      saver.restore(sess, checkpoint_filename_with_path)
      return sess, True

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint.
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    return sess, True
Exemple #46
0
  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    has_strategy = distribution_strategy_context.has_strategy()
    replica_context = distribution_strategy_context.get_replica_context()

    # The purpose of using `merge_call` to call `result()` is to trigger cross
    # replica aggregation of metric state variables (SyncOnReadVariable). After
    # we introduced `variable_sync_on_read_context`, in principle there is no
    # need to use `merge_call` here. However the branch still exists because:
    #
    # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor
    #    across replicas (achieved by `merge_call`). With
    #    `variable_sync_on_read_context` each replica gets their own tensors
    #    residing on replica's device, thus breaking the assumption.
    # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns
    #    the metric values of the first replica. With
    #    `variable_sync_on_read_context` since each replica gets their own
    #    tensors, the metric result tensors on the non-first replicas are not in
    #    the return value of train_function, making TF graph optimizer prune the
    #    branch that computes and aggregates those metric results. As a result,
    #    if NCCL is used to do the aggregation, the program will hang because
    #    NCCL ops are only launched on the non-pruned first replica.
    #
    # We condition on strategy.extended._use_merge_call() since we know if it is
    # false, the program uses `jit_compile` to compile replica fn, meaning it is
    # not V1 training (hence #1 is okay), and no pruning will happen as
    # compiled functions are not inlined (hence #2 is okay).

    if (not has_strategy or replica_context is None or
        not distribution_strategy_context.get_strategy(
        ).extended._use_merge_call()):
      with distribution_strategy_context.variable_sync_on_read_context():
        raw_result = result_fn(*args)
        # Results need to be wrapped in a `tf.identity` op to ensure
        # correct execution order.
        if isinstance(raw_result,
                      (ops.Tensor, variables_module.Variable, float, int)):
          result_t = array_ops.identity(raw_result)
        elif isinstance(raw_result, dict):
          result_t = {
              key: array_ops.identity(value)
              for key, value in raw_result.items()
          }
        else:
          try:
            result_t = array_ops.identity(raw_result)
          except (ValueError, TypeError):
            raise RuntimeError(
                'The output of `metric.result()` can only be a single '
                'Tensor/Variable, or a dict of Tensors/Variables. '
                'For metric %s, got result %s.' % (metric_obj.name, raw_result))
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerReplica` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        result = distribution.experimental_local_results(merge_fn)[0](*args)

        # Wrapping result in identity so that control dependency between
        # update_op from `update_state` and result works in case result returns
        # a tensor.
        return array_ops.identity(result)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)

    # We are saving the result op here to be used in train/test execution
    # functions. This basically gives the result op that was generated with a
    # control dep to the updates for these workflows.
    metric_obj._call_result = result_t
    return result_t
  def call(self, inputs, training=None):
    if training is None:
      training = K.learning_phase()

    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        outputs = undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = tf_utils.smart_cond(training,
                                        lambda: adj_scale,
                                        lambda: array_ops.ones_like(adj_scale))
        adj_bias = tf_utils.smart_cond(training,
                                       lambda: adj_bias,
                                       lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, self._param_dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: moving_mean)
      variance = tf_utils.smart_cond(training,
                                     lambda: variance,
                                     lambda: moving_variance)

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
      else:
        new_mean, new_variance = mean, variance

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            new_mean, new_variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)

      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()

        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return strategy.extended.update(
              var, self._assign_moving_average, (value, self.momentum),
              group=False)
        # We need to unwrap the moving_mean or moving_variance in the case of
        # training being false to match the output of true_fn and false_fn
        # in the smart cond.
        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: strategy.unwrap(self.moving_mean)
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          return tf_utils.smart_cond(
              training, lambda: _do_update(self.moving_variance, new_variance),
              lambda: strategy.unwrap(self.moving_variance))
      else:
        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return self._assign_moving_average(var, value, self.momentum)

        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: self.moving_mean
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          true_branch = lambda: _do_update(self.moving_variance, new_variance)
          false_branch = lambda: self.moving_variance
          return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)
    # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
    # math in float16 hurts validation accuracy of popular models like resnet.
    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      outputs = undo_virtual_batching(outputs)
    return outputs
  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    if not input_shape.ndims:
      raise ValueError('Input has undefined rank:', input_shape)
    ndims = len(input_shape)

    # Convert axis to list and resolve negatives
    if isinstance(self.axis, int):
      self.axis = [self.axis]

    for idx, x in enumerate(self.axis):
      if x < 0:
        self.axis[idx] = ndims + x

    # Validate axes
    for x in self.axis:
      if x < 0 or x >= ndims:
        raise ValueError('Invalid axis: %d' % x)
    if len(self.axis) != len(set(self.axis)):
      raise ValueError('Duplicate axis: %s' % self.axis)

    if self.virtual_batch_size is not None:
      if self.virtual_batch_size <= 0:
        raise ValueError('virtual_batch_size must be a positive integer that '
                         'divides the true batch size of the input Tensor')
      # If using virtual batches, the first dimension must be the batch
      # dimension and cannot be the batch norm axis
      if 0 in self.axis:
        raise ValueError('When using virtual_batch_size, the batch dimension '
                         'must be 0 and thus axis cannot include 0')
      if self.adjustment is not None:
        raise ValueError('When using virtual_batch_size, adjustment cannot '
                         'be specified')

    if self.fused in (None, True):
      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
      # output back to its original shape accordingly.
      if self._USE_V2_BEHAVIOR:
        if self.fused is None:
          self.fused = (ndims == 4)
        elif self.fused and ndims != 4:
          raise ValueError('Batch normalization layers with fused=True only '
                           'support 4D input tensors.')
      else:
        assert self.fused is not None
        self.fused = (ndims == 4 and self._fused_can_be_used())
      # TODO(chrisying): fused batch norm is currently not supported for
      # multi-axis batch norm and by extension virtual batches. In some cases,
      # it might be possible to use fused batch norm but would require reshaping
      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
      # particularly tricky. A compromise might be to just support the most
      # common use case (turning 5D w/ virtual batch to NCHW)

    if self.fused:
      if self.axis == [1]:
        self._data_format = 'NCHW'
      elif self.axis == [3]:
        self._data_format = 'NHWC'
      else:
        raise ValueError('Unsupported axis, fused batch norm only supports '
                         'axis == [1] or axis == [3]')

    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
    for x in axis_to_dim:
      if axis_to_dim[x] is None:
        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
                         input_shape)
    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
      # Single axis batch norm (most common/default use-case)
      param_shape = (list(axis_to_dim.values())[0],)
    else:
      # Parameter shape is the original shape but with 1 in all non-axis dims
      param_shape = [axis_to_dim[i] if i in axis_to_dim
                     else 1 for i in range(ndims)]
      if self.virtual_batch_size is not None:
        # When using virtual batches, add an extra dim at index 1
        param_shape.insert(1, 1)
        for idx, x in enumerate(self.axis):
          self.axis[idx] = x + 1      # Account for added dimension

    if self.scale:
      self.gamma = self.add_weight(
          name='gamma',
          shape=param_shape,
          dtype=self._param_dtype,
          initializer=self.gamma_initializer,
          regularizer=self.gamma_regularizer,
          constraint=self.gamma_constraint,
          trainable=True,
          experimental_autocast=False)
    else:
      self.gamma = None
      if self.fused:
        self._gamma_const = K.constant(
            1.0, dtype=self._param_dtype, shape=param_shape)

    if self.center:
      self.beta = self.add_weight(
          name='beta',
          shape=param_shape,
          dtype=self._param_dtype,
          initializer=self.beta_initializer,
          regularizer=self.beta_regularizer,
          constraint=self.beta_constraint,
          trainable=True,
          experimental_autocast=False)
    else:
      self.beta = None
      if self.fused:
        self._beta_const = K.constant(
            0.0, dtype=self._param_dtype, shape=param_shape)

    try:
      # Disable variable partitioning when creating the moving mean and variance
      if hasattr(self, '_scope') and self._scope:
        partitioner = self._scope.partitioner
        self._scope.set_partitioner(None)
      else:
        partitioner = None
      self.moving_mean = self.add_weight(
          name='moving_mean',
          shape=param_shape,
          dtype=self._param_dtype,
          initializer=self.moving_mean_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN,
          experimental_autocast=False)

      self.moving_variance = self.add_weight(
          name='moving_variance',
          shape=param_shape,
          dtype=self._param_dtype,
          initializer=self.moving_variance_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN,
          experimental_autocast=False)

      if self.renorm:
        # Create variables to maintain the moving mean and standard deviation.
        # These are used in training and thus are different from the moving
        # averages above. The renorm variables are colocated with moving_mean
        # and moving_variance.
        # NOTE: below, the outer `with device` block causes the current device
        # stack to be cleared. The nested ones use a `lambda` to set the desired
        # device and ignore any devices that may be set by the custom getter.
        def _renorm_variable(name, shape):
          """Create a renorm variable."""
          var = self.add_weight(
              name=name,
              shape=shape,
              dtype=self._param_dtype,
              initializer=init_ops.zeros_initializer(),
              synchronization=tf_variables.VariableSynchronization.ON_READ,
              trainable=False,
              aggregation=tf_variables.VariableAggregation.MEAN,
              experimental_autocast=False)
          return var

        with distribution_strategy_context.get_strategy(
        ).extended.colocate_vars_with(self.moving_mean):
          self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
          self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
        # We initialize renorm_stddev to 0, and maintain the (0-initialized)
        # renorm_stddev_weight. This allows us to (1) mix the average
        # stddev with the minibatch stddev early in training, and (2) compute
        # the unbiased average stddev by dividing renorm_stddev by the weight.
        with distribution_strategy_context.get_strategy(
        ).extended.colocate_vars_with(self.moving_variance):
          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
          self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
                                                       ())
    finally:
      if partitioner:
        self._scope.set_partitioner(partitioner)
    self.built = True
Exemple #49
0
  def __init__(self,
               input_shape=None,
               batch_size=None,
               dtype=None,
               input_tensor=None,
               sparse=False,
               name=None,
               **kwargs):
    strategy = distribution_strategy_context.get_strategy()
    if strategy and batch_size is not None and \
        distributed_training_utils.global_batch_size_supported(strategy):
      if batch_size % strategy.num_replicas_in_sync != 0:
        raise ValueError('The `batch_size` argument value {} cannot be '
                         'divisible by number of replicas {}'.format(
                             batch_size, strategy.num_replicas_in_sync))
      batch_size = batch_size // strategy.num_replicas_in_sync

    if 'batch_input_shape' in kwargs:
      batch_input_shape = kwargs.pop('batch_input_shape')
      if input_shape and batch_input_shape:
        raise ValueError('Only provide the input_shape OR '
                         'batch_input_shape argument to '
                         'InputLayer, not both at the same time.')
      batch_size = batch_input_shape[0]
      input_shape = batch_input_shape[1:]
    if kwargs:
      raise ValueError('Unrecognized keyword arguments:', kwargs.keys())

    if not name:
      prefix = 'input'
      name = prefix + '_' + str(backend.get_uid(prefix))

    if not dtype:
      if input_tensor is None:
        dtype = backend.floatx()
      else:
        dtype = backend.dtype(input_tensor)
    elif input_tensor is not None and input_tensor.dtype != dtype:
      raise ValueError('`input_tensor.dtype` differs from `dtype`: %s vs. %s' %
                       (input_tensor.dtype, dtype))
    super(InputLayer, self).__init__(dtype=dtype, name=name)
    self.built = True
    self.sparse = sparse
    self.batch_size = batch_size
    self.supports_masking = True

    if isinstance(input_shape, tensor_shape.TensorShape):
      input_shape = tuple(input_shape.as_list())
    elif isinstance(input_shape, int):
      input_shape = (input_shape,)

    if input_tensor is None:
      if input_shape is not None:
        batch_input_shape = (batch_size,) + tuple(input_shape)
      else:
        batch_input_shape = None
      graph = backend.get_graph()
      with graph.as_default():
        # In graph mode, create a graph placeholder to call the layer on.
        if sparse:
          input_tensor = backend.placeholder(
              shape=batch_input_shape,
              dtype=dtype,
              name=self.name,
              sparse=True)
        else:
          input_tensor = backend.placeholder(
              shape=batch_input_shape,
              dtype=dtype,
              name=self.name)

      self.is_placeholder = True
      self._batch_input_shape = batch_input_shape
    else:
      if not tf_utils.is_symbolic_tensor(input_tensor):
        raise ValueError('You should not pass an EagerTensor to `Input`. '
                         'For example, instead of creating an '
                         'InputLayer, you should instantiate your model and '
                         'directly call it on your input.')
      self.is_placeholder = False
      self._batch_input_shape = tuple(input_tensor.shape.as_list())

    # Create an input node to add to self.outbound_node
    # and set output_tensors' _keras_history.
    input_tensor._keras_history = (self, 0, 0)  # pylint: disable=protected-access
    input_tensor._keras_mask = None
    base_layer.Node(
        self,
        inbound_layers=[],
        node_indices=[],
        tensor_indices=[],
        input_tensors=[input_tensor],
        output_tensors=[input_tensor])
    def _restore_checkpoint(self,
                            master,
                            saver=None,
                            checkpoint_dir=None,
                            checkpoint_filename_with_path=None,
                            wait_for_checkpoint=False,
                            max_wait_secs=7200,
                            config=None):
        """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
        self._target = master

        # This is required to so that we initialize the TPU device before
        # restoring from checkpoint since we'll be placing variables on the device
        # and TPUInitialize wipes out the memory of the device.
        strategy = distribution_strategy_context.get_strategy()
        if strategy and hasattr(strategy.extended,
                                "_experimental_initialize_system"):
            strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access

        sess = session.Session(self._target, graph=self._graph, config=config)
        if checkpoint_dir and checkpoint_filename_with_path:
            raise ValueError("Can not provide both checkpoint_dir and "
                             "checkpoint_filename_with_path.")
        # If either saver or checkpoint_* is not specified, cannot restore. Just
        # return.
        if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
            return sess, False

        if checkpoint_filename_with_path:
            saver.restore(sess, checkpoint_filename_with_path)
            return sess, True

        # Waits up until max_wait_secs for checkpoint to become available.
        wait_time = 0
        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
        while not ckpt or not ckpt.model_checkpoint_path:
            if wait_for_checkpoint and wait_time < max_wait_secs:
                logging.info("Waiting for checkpoint to be available.")
                time.sleep(self._recovery_wait_secs)
                wait_time += self._recovery_wait_secs
                ckpt = checkpoint_management.get_checkpoint_state(
                    checkpoint_dir)
            else:
                return sess, False

        # Loads the checkpoint.
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
        return sess, True
Exemple #51
0
 def _support_zero_size_input(self):
     return distribution_strategy_context.has_strategy() and getattr(
         distribution_strategy_context.get_strategy().extended,
         'experimental_enable_get_next_as_optional', False)
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if not input_shape.ndims:
            raise ValueError('Input has undefined rank:', input_shape)
        ndims = len(input_shape)

        # Convert axis to list and resolve negatives
        if isinstance(self.axis, int):
            self.axis = [self.axis]

        for idx, x in enumerate(self.axis):
            if x < 0:
                self.axis[idx] = ndims + x

        # Validate axes
        for x in self.axis:
            if x < 0 or x >= ndims:
                raise ValueError('Invalid axis: %d' % x)
        if len(self.axis) != len(set(self.axis)):
            raise ValueError('Duplicate axis: %s' % self.axis)

        if self.virtual_batch_size is not None:
            if self.virtual_batch_size <= 0:
                raise ValueError(
                    'virtual_batch_size must be a positive integer that '
                    'divides the true batch size of the input Tensor')
            # If using virtual batches, the first dimension must be the batch
            # dimension and cannot be the batch norm axis
            if 0 in self.axis:
                raise ValueError(
                    'When using virtual_batch_size, the batch dimension '
                    'must be 0 and thus axis cannot include 0')
            if self.adjustment is not None:
                raise ValueError(
                    'When using virtual_batch_size, adjustment cannot '
                    'be specified')

        if self.fused in (None, True):
            # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
            # output back to its original shape accordingly.
            if self._USE_V2_BEHAVIOR:
                if self.fused is None:
                    self.fused = (ndims == 4)
                elif self.fused and ndims != 4:
                    raise ValueError(
                        'Batch normalization layers with fused=True only '
                        'support 4D input tensors.')
            else:
                assert self.fused is not None
                self.fused = (ndims == 4 and self._fused_can_be_used())
            # TODO(chrisying): fused batch norm is currently not supported for
            # multi-axis batch norm and by extension virtual batches. In some cases,
            # it might be possible to use fused batch norm but would require reshaping
            # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
            # particularly tricky. A compromise might be to just support the most
            # common use case (turning 5D w/ virtual batch to NCHW)

        if self.fused:
            if self.axis == [1]:
                self._data_format = 'NCHW'
            elif self.axis == [3]:
                self._data_format = 'NHWC'
            else:
                raise ValueError(
                    'Unsupported axis, fused batch norm only supports '
                    'axis == [1] or axis == [3]')

        # Raise parameters of fp16 batch norm to fp32
        if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
            param_dtype = dtypes.float32
        else:
            param_dtype = self.dtype or dtypes.float32

        axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
        for x in axis_to_dim:
            if axis_to_dim[x] is None:
                raise ValueError(
                    'Input has undefined `axis` dimension. Input shape: ',
                    input_shape)
        self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

        if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
            # Single axis batch norm (most common/default use-case)
            param_shape = (list(axis_to_dim.values())[0], )
        else:
            # Parameter shape is the original shape but with 1 in all non-axis dims
            param_shape = [
                axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
            ]
            if self.virtual_batch_size is not None:
                # When using virtual batches, add an extra dim at index 1
                param_shape.insert(1, 1)
                for idx, x in enumerate(self.axis):
                    self.axis[idx] = x + 1  # Account for added dimension

        if self.scale:
            self.gamma = self.add_weight(name='gamma',
                                         shape=param_shape,
                                         dtype=param_dtype,
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint,
                                         trainable=True)
        else:
            self.gamma = None
            if self.fused:
                self._gamma_const = K.constant(1.0,
                                               dtype=param_dtype,
                                               shape=param_shape)

        if self.center:
            self.beta = self.add_weight(name='beta',
                                        shape=param_shape,
                                        dtype=param_dtype,
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint,
                                        trainable=True)
        else:
            self.beta = None
            if self.fused:
                self._beta_const = K.constant(0.0,
                                              dtype=param_dtype,
                                              shape=param_shape)

        try:
            # Disable variable partitioning when creating the moving mean and variance
            if hasattr(self, '_scope') and self._scope:
                partitioner = self._scope.partitioner
                self._scope.set_partitioner(None)
            else:
                partitioner = None
            self.moving_mean = self.add_weight(
                name='moving_mean',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_mean_initializer,
                synchronization=tf_variables.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf_variables.VariableAggregation.MEAN)

            self.moving_variance = self.add_weight(
                name='moving_variance',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_variance_initializer,
                synchronization=tf_variables.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf_variables.VariableAggregation.MEAN)

            if self.renorm:
                # Create variables to maintain the moving mean and standard deviation.
                # These are used in training and thus are different from the moving
                # averages above. The renorm variables are colocated with moving_mean
                # and moving_variance.
                # NOTE: below, the outer `with device` block causes the current device
                # stack to be cleared. The nested ones use a `lambda` to set the desired
                # device and ignore any devices that may be set by the custom getter.
                def _renorm_variable(name, shape):
                    var = self.add_weight(
                        name=name,
                        shape=shape,
                        dtype=param_dtype,
                        initializer=init_ops.zeros_initializer(),
                        synchronization=tf_variables.VariableSynchronization.
                        ON_READ,
                        trainable=False,
                        aggregation=tf_variables.VariableAggregation.MEAN)
                    return var

                with distribution_strategy_context.get_strategy(
                ).extended.colocate_vars_with(self.moving_mean):
                    self.renorm_mean = _renorm_variable(
                        'renorm_mean', param_shape)
                    self.renorm_mean_weight = _renorm_variable(
                        'renorm_mean_weight', ())
                # We initialize renorm_stddev to 0, and maintain the (0-initialized)
                # renorm_stddev_weight. This allows us to (1) mix the average
                # stddev with the minibatch stddev early in training, and (2) compute
                # the unbiased average stddev by dividing renorm_stddev by the weight.
                with distribution_strategy_context.get_strategy(
                ).extended.colocate_vars_with(self.moving_variance):
                    self.renorm_stddev = _renorm_variable(
                        'renorm_stddev', param_shape)
                    self.renorm_stddev_weight = _renorm_variable(
                        'renorm_stddev_weight', ())
        finally:
            if partitioner:
                self._scope.set_partitioner(partitioner)
        self.built = True
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
    if not grads_and_vars:
      raise ValueError("Must supply at least one variable")

    if global_step is None:
      raise ValueError("Global step is required to check staleness")

    self._global_step = global_step
    train_ops = []
    aggregated_grad = []
    var_list = []

    # local_anchor op will be placed on this worker task by default.
    local_anchor = control_flow_ops.no_op()
    # Colocating local_step variable prevents it being placed on the PS.
    distribution_strategy = distribution_strategy_context.get_strategy()
    with distribution_strategy.extended.colocate_vars_with(local_anchor):
      self._local_step = variable_scope.variable(
          initial_value=0,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          dtype=global_step.dtype.base_dtype,
          name="sync_rep_local_step")

    self.local_step_init_op = state_ops.assign(self._local_step, global_step)
    chief_init_ops = [self.local_step_init_op]
    self.ready_for_local_init_op = variables.report_uninitialized_variables(
        variables.global_variables())

    with ops.name_scope(None, self._name):
      for grad, var in grads_and_vars:
        var_list.append(var)
        with ops.device(var.device):
          # Dense gradients.
          if grad is None:
            aggregated_grad.append(None)  # pass-through.
            continue
          elif isinstance(grad, ops.Tensor):
            grad_accum = data_flow_ops.ConditionalAccumulator(
                grad.dtype,
                shape=var.get_shape(),
                shared_name=var.name + "/grad_accum")
            train_ops.append(grad_accum.apply_grad(
                grad, local_step=self._local_step))
            aggregated_grad.append(grad_accum.take_grad(
                self._replicas_to_aggregate))
          else:
            if not isinstance(grad, ops.IndexedSlices):
              raise ValueError("Unknown grad type!")
            grad_accum = data_flow_ops.SparseConditionalAccumulator(
                grad.dtype, shape=(), shared_name=var.name + "/grad_accum")
            train_ops.append(grad_accum.apply_indexed_slices_grad(
                grad, local_step=self._local_step))
            aggregated_grad.append(grad_accum.take_indexed_slices_grad(
                self._replicas_to_aggregate))

          self._accumulator_list.append((grad_accum, var.device))

      aggregated_grads_and_vars = zip(aggregated_grad, var_list)

      # sync_op will be assigned to the same device as the global step.
      with ops.device(global_step.device), ops.name_scope(""):
        update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
                                              global_step)

      # Create token queue.
      with ops.device(global_step.device), ops.name_scope(""):
        sync_token_queue = (
            data_flow_ops.FIFOQueue(-1,
                                    global_step.dtype.base_dtype,
                                    shapes=(),
                                    name="sync_token_q",
                                    shared_name="sync_token_q"))
        self._sync_token_queue = sync_token_queue

        # dummy_queue is passed to the queue runner. Don't use the real queues
        # because the queue runner doesn't automatically reopen it once it
        # closed queues in PS devices.
        dummy_queue = (
            data_flow_ops.FIFOQueue(1,
                                    types_pb2.DT_INT32,
                                    shapes=(),
                                    name="dummy_queue",
                                    shared_name="dummy_queue"))

      with ops.device(global_step.device), ops.name_scope(""):
        # Replicas have to wait until they can get a token from the token queue.
        with ops.control_dependencies(train_ops):
          token = sync_token_queue.dequeue()
        train_op = state_ops.assign(self._local_step, token)

        with ops.control_dependencies([update_op]):
          # Sync_op needs to insert tokens to the token queue at the end of the
          # step so the replicas can fetch them to start the next step.
          tokens = array_ops.fill([self._tokens_per_step], global_step)
          sync_op = sync_token_queue.enqueue_many((tokens,))

        if self._variable_averages is not None:
          with ops.control_dependencies([sync_op]), ops.name_scope(""):
            sync_op = self._variable_averages.apply(
                self._variables_to_average)

        self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue,
                                                            [sync_op])
      for accum, dev in self._accumulator_list:
        with ops.device(dev):
          chief_init_ops.append(
              accum.set_global_step(
                  global_step, name="SetGlobalStep"))
      self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
      self._gradients_applied = True
      return train_op
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(inputs,
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_strategy()

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return strategy.extended.update(
                        var,
                        self._assign_moving_average, (value, self.momentum),
                        group=False)

                # We need to unwrap the moving_mean or moving_variance in the case of
                # training being false to match the output of true_fn and false_fn
                # in the smart cond.
                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: strategy.unwrap(self.moving_mean))
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: strategy.unwrap(self.moving_variance))
            else:

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return self._assign_moving_average(var, value,
                                                       self.momentum)

                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: self.moving_mean)
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
def compute_weighted_loss(
    losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
  """Computes the weighted loss.

  Args:
    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `losses`, and must be broadcastable to `losses` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `losses` dimension).
    scope: the scope for the operations performed in computing the loss.
    loss_collection: the loss will be added to these collections.
    reduction: Type of reduction to apply to loss.

  Returns:
    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.

  Raises:
    ValueError: If `weights` is `None` or the shape is not compatible with
      `losses`, or if the number of dimensions (rank) of either `losses` or
      `weights` is missing.

  Note:
    When calculating the gradient of a weighted loss contributions from
    both `losses` and `weights` are considered. If your `weights` depend
    on some model parameters but you do not want this to affect the loss
    gradient, you need to apply `tf.stop_gradient` to `weights` before
    passing them to `compute_weighted_loss`.

  @compatibility(eager)
  The `loss_collection` argument is ignored when executing eagerly. Consider
  holding on to the return value or collecting losses via a `tf.keras.Model`.
  @end_compatibility
  """
  Reduction.validate(reduction)
  with ops.name_scope(scope, "weighted_loss", (losses, weights)):
    with ops.control_dependencies((
        weights_broadcast_ops.assert_broadcastable(weights, losses),)):
      losses = ops.convert_to_tensor(losses)
      input_dtype = losses.dtype
      losses = math_ops.cast(losses, dtype=dtypes.float32)
      weights = math_ops.cast(weights, dtype=dtypes.float32)
      weighted_losses = math_ops.multiply(losses, weights)
      if reduction == Reduction.NONE:
        loss = weighted_losses
      else:
        loss = math_ops.reduce_sum(weighted_losses)
        num_replicas = (  # Used to convert from local to global batch size.
            distribution_strategy_context.get_strategy().num_replicas_in_sync)
        if reduction == Reduction.MEAN:
          denom = (num_replicas *
                   math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
          loss = _safe_mean(loss, denom)
        elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or
              reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS):
          loss = _safe_mean(loss, num_replicas * _num_present(losses, weights))
        elif reduction == Reduction.SUM_OVER_BATCH_SIZE:
          loss = _safe_mean(loss, num_replicas * _num_elements(losses))

      # Convert the result back to the input type.
      loss = math_ops.cast(loss, input_dtype)
      util.add_loss(loss, loss_collection)
      return loss