def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
Beispiel #2
0
  def get_updates(self, loss, params):
    if distribution_strategy_context.has_strategy():
      self.updates = []

      if not params:
        # After the model vars have been created, the second call to get_updates
        # is called with params as an empty list. This ensures that we call
        # compute_gradients with params=None.
        grads = self.optimizer.compute_gradients(loss)
      else:
        grads = self.optimizer.compute_gradients(loss, params)
      global_step = training_util.get_global_step()
      opt_update = self.optimizer.apply_gradients(grads, global_step)
    else:
      if not params:
        self.updates = [state_ops.assign_add(self.iterations, 1)]
        return self.updates

      # Updates list starts out empty because the iterations variable is
      # incremented in optimizer.apply_gradients()
      self.updates = []
      grads = self.optimizer.compute_gradients(loss, params)
      opt_update = self.optimizer.apply_gradients(
          grads, global_step=self.iterations)

    self.updates.append(opt_update)
    return self.updates
 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
Beispiel #4
0
 def _clip_gradients(self, grads):
     """Clip gradients according to the clipnorm and clipvalue attributes."""
     if self.clipnorm is not None:
         if distribute_ctx.has_strategy():
             raise ValueError(
                 "Gradient clipping in the optimizer "
                 "(by setting clipnorm or clipvalue) is currently "
                 "unsupported when using a distribution strategy.")
         grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
     if self.clipvalue is not None:
         if distribute_ctx.has_strategy():
             raise ValueError(
                 "Gradient clipping in the optimizer "
                 "(by setting clipnorm or clipvalue) is currently "
                 "unsupported when using a distribution strategy.")
         grads = [
             clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
             for g in grads
         ]
     return grads
 def _moments(self, inputs, reduction_axes, keep_dims):
   mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
   # TODO(b/129279393): Support zero batch input in non DistributionStrategy
   # code as well.
   # TODO(b/130185866): Support zero batch input in graph mode.
   if (ops.executing_eagerly_outside_functions() and
       distribution_strategy_context.has_strategy()):
     inputs_size = array_ops.size(inputs)
     mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
     variance = array_ops.where(inputs_size > 0, variance,
                                K.zeros_like(variance))
   return mean, variance
Beispiel #6
0
    def decorated(metric_obj, *args):
        """Decorated function with merge_call."""
        has_strategy = distribution_strategy_context.has_strategy()
        replica_context = distribution_strategy_context.get_replica_context()
        if not has_strategy or replica_context is None:
            raw_result = result_fn(*args)
            # Results need to be wrapped in a `tf.identity` op to ensure
            # correct execution order.
            if isinstance(raw_result,
                          (ops.Tensor, variables_module.Variable, float, int)):
                result_t = array_ops.identity(raw_result)
            elif isinstance(raw_result, dict):
                result_t = {
                    key: array_ops.identity(value)
                    for key, value in raw_result.items()
                }
            else:
                try:
                    result_t = array_ops.identity(raw_result)
                except (ValueError, TypeError):
                    raise RuntimeError(
                        'The output of `metric.result()` can only be a single '
                        'Tensor/Variable, or a dict of Tensors/Variables. '
                        'For metric %s, got result %s.' %
                        (metric_obj.name, raw_result))
        else:
            # TODO(psv): Test distribution of metrics using different distribution
            # strategies.

            # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
            # with distribution object as the first parameter. We create a wrapper
            # here so that the result function need not have that parameter.
            def merge_fn_wrapper(distribution, merge_fn, *args):
                # We will get `PerReplica` merge function. Taking the first one as all
                # are identical copies of the function that we had passed below.
                result = distribution.experimental_local_results(merge_fn)[0](
                    *args)

                # Wrapping result in identity so that control dependency between
                # update_op from `update_state` and result works in case result returns
                # a tensor.
                return array_ops.identity(result)

            # Wrapping result in merge_call. merge_call is used when we want to leave
            # replica mode and compute a value in cross replica mode.
            result_t = replica_context.merge_call(merge_fn_wrapper,
                                                  args=(result_fn, ) + args)

        # We are saving the result op here to be used in train/test execution
        # functions. This basically gives the result op that was generated with a
        # control dep to the updates for these workflows.
        metric_obj._call_result = result_t
        return result_t
  def _get_tensor(self, is_finite):
    tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))

    if not distribution_strategy_context.has_strategy():
      return tensor
    def get():
      rep_id = (distribution_strategy_context.get_replica_context()
                .replica_id_in_sync_group)
      return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor,
                                   lambda: 1.)
    distribution = distribution_strategy_context.get_strategy()
    return distribution.extended.call_for_each_replica(get)
Beispiel #8
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        assignments = []
        for (grad, param) in grads_and_vars:
            if grad is None or param is None:
                continue
            param_name = _get_variable_name(param.name)
            m = tf.get_variable(name=param_name + "/momentum",
                                shape=param.shape.as_list(),
                                dtype=param.dtype,
                                trainable=False,
                                initializer=tf.zeros_initializer())

            next_m = self.momentum * m + grad

            update = next_m

            # update is scaled by loss_scaling
            # so we need to restore it's scale
            update /= self.loss_scaling
            if _do_use_weight_decay(param_name, self.weight_decay_rate,
                                    self.exclude_from_weight_decay):
                update += self.weight_decay_rate * param

            update_with_lr = self.learning_rate * update

            next_param = param - update_with_lr

            if distribute_ctx.has_strategy():
                # Handle DistributionStrategy case.
                if distribute_ctx.in_cross_replica_context():
                    raise RuntimeError(
                        "Use `_distributed_apply()` instead of "
                        "`apply_gradients()` in a cross-replica context.")

                assign_params = distribute_ctx.get_replica_context(
                ).merge_call(assign_vars,
                             args=((param, m), (next_param, next_m)))
            else:
                assign_params = [param.assign(next_param), m.assign(next_m)]
            assignments.extend(assign_params)

            if _need_centering(param_name, self.darknet_gn, self.upsample_gn):
                with tf.control_dependencies(assign_params):
                    param_identity = tf.identity(param)
                centering_op = _centering_weights(param, param_identity)
                assignments.append(centering_op)

        if self.use_moving_avg:
            assignments.extend(
                _create_moving_avg(grads_and_vars, self.moving_avg_decay))

        return tf.group(*assignments, name=name)
 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Beispiel #10
0
 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Beispiel #11
0
 def _moments(self, inputs, reduction_axes, keep_dims):
     mean, variance = nn.moments(inputs,
                                 reduction_axes,
                                 keep_dims=keep_dims)
     # TODO(b/129279393): Support zero batch input in non DistributionStrategy
     # code as well.
     if distribution_strategy_context.has_strategy():
         inputs_size = array_ops.size(inputs)
         mean = tf_utils.smart_cond(inputs_size > 0, lambda: mean,
                                    lambda: K.zeros_like(mean))
         variance = tf_utils.smart_cond(inputs_size > 0, lambda: variance,
                                        lambda: K.zeros_like(variance))
     return mean, variance
Beispiel #12
0
 def merge_fn(dist, s):
     self.assertIs(
         distribution_strategy_context._get_default_strategy(), dist)
     self.assertIs(None,
                   distribution_strategy_context.get_replica_context())
     self.assertIs(
         dist,
         distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(
         distribution_strategy_context.in_cross_replica_context())
     self.assertIs(dist, distribution_strategy_context.get_strategy())
     self.assertFalse(distribution_strategy_context.has_strategy())
     return "foo_" + s
Beispiel #13
0
 def _moments(self, inputs, reduction_axes, keep_dims):
     mean, variance = nn.moments(inputs,
                                 reduction_axes,
                                 keep_dims=keep_dims)
     # TODO(b/129279393): Support zero batch input in non DistributionStrategy
     # code as well.
     if distribution_strategy_context.has_strategy(
     ) and not inputs.shape.is_fully_defined():
         inputs_size = array_ops.size(inputs)
         mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
         variance = array_ops.where(inputs_size > 0, variance,
                                    K.zeros_like(variance))
     return mean, variance
Beispiel #14
0
    def __init__(self, copy_from=None, state=None, alg=None):
        """Creates a generator.

    The new generator will be initialized by one of the following ways, with
    decreasing precedence:
    (1) If `copy_from` is not None, the new generator is initialized by copying
        information from another generator.
    (2) If `state` and `alg` are not None (they must be set together), the new
        generator is initialized by a state.

    Args:
      copy_from: a generator to be copied from.
      state: a vector of dtype STATE_TYPE representing the initial state of the
        RNG, whose length and semantics are algorithm-specific. If it's a
        variable, the generator will reuse it instead of creating a new
        variable.
      alg: the RNG algorithm. Possible values are
        `tf.random.Algorithm.PHILOX` for the Philox algorithm and
        `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
        (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
        [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
        The string names `"philox"` and `"threefry"` can also be used.
        Note `PHILOX` guarantees the same numbers are produced (given
        the same random state) across all architectures (CPU, GPU, XLA etc).
    """
        # TODO(b/175072242): Remove distribution-strategy dependencies in this file.
        if ds_context.has_strategy():
            self._distribution_strategy = ds_context.get_strategy()
        else:
            self._distribution_strategy = None
        if copy_from is not None:
            # All other arguments should be None
            assert (alg or state) is None
            self._state_var = self._create_variable(copy_from.state,
                                                    dtype=STATE_TYPE,
                                                    trainable=False)
            self._alg = copy_from.algorithm
        else:
            assert alg is not None and state is not None
            alg = stateless_random_ops.convert_alg_to_int(alg)
            if isinstance(state, variables.Variable):
                _check_state_shape(state.shape, alg)
                self._state_var = state
            else:
                state = _convert_to_state_tensor(state)
                _check_state_shape(state.shape, alg)
                self._state_var = self._create_variable(state,
                                                        dtype=STATE_TYPE,
                                                        trainable=False)
            self._alg = alg
    def update(self, grads):
        """Updates loss scale based on if gradients are finite in current step."""
        grads = nest.flatten(grads)
        if distribution_strategy_context.has_strategy():
            distribution = distribution_strategy_context.get_cross_replica_context(
            )

            def get_is_finite(grads):
                is_finite = _is_all_finite(grads)
                # We cast to float, because we cannot reduce booleans with
                # DistributionStrategy.
                return math_ops.cast(is_finite, dtypes.float32)

            is_finite_float = distribution.extended.call_for_each_replica(
                get_is_finite, args=(grads, ))
            reduced_is_finite_float = distribution.reduce(
                reduce_util.ReduceOp.SUM, is_finite_float, axis=None)
            is_finite = math_ops.equal(reduced_is_finite_float,
                                       distribution.num_replicas_in_sync)
        else:
            is_finite = _is_all_finite(grads)

        def update_if_finite_grads():
            """Update assuming the gradients are finite."""
            def incr_loss_scale():
                new_loss_scale = math_ops.minimum(
                    self._current_loss_scale * self._multiplier, 2**32)
                return control_flow_ops.group(
                    _assign_if_finite(self._current_loss_scale,
                                      new_loss_scale),
                    self._num_good_steps.assign(0))

            return control_flow_ops.cond(
                self._num_good_steps + 1 >= self._increment_period,
                incr_loss_scale,
                lambda: _op_in_graph_mode(self._num_good_steps.assign_add(1)))

        def update_if_not_finite_grads():
            """Update assuming the gradients are nonfinite."""

            new_loss_scale = math_ops.maximum(
                self._current_loss_scale / self._multiplier, 1)
            return control_flow_ops.group(
                self._num_good_steps.assign(0),
                self._current_loss_scale.assign(new_loss_scale))

        update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
                                          update_if_not_finite_grads)
        should_apply_gradients = is_finite
        return update_op, should_apply_gradients
Beispiel #16
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        model = _clone_and_build_model(mode, keras_model, custom_objects,
                                       features, labels)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if distribution_strategy_context.has_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not ModeKeys.PREDICT:
            if mode is ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if not model._is_graph_network:
            # Reset model state to original state,
            # to avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)
        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)
            })
Beispiel #17
0
  def _get_tensor(self, is_finite):
    tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))

    if not distribution_strategy_context.has_strategy():
      return tensor

    def get():
      rep_id = (
          distribution_strategy_context.get_replica_context()
          .replica_id_in_sync_group)
      return control_flow_ops.cond(
          math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.)

    distribution = distribution_strategy_context.get_strategy()
    return distribution.extended.call_for_each_replica(get)
Beispiel #18
0
 def enumerate_epochs(self):
   """Yields `(epoch, tf.data.Iterator)`."""
   data_iterator = iter(self._dataset)
   for epoch in range(self._initial_epoch, self._epochs):
     if self._insufficient_data:  # Set by `catch_stop_iteration`.
       break
     if self._adapter.should_recreate_iterator():
       if ds_context.has_strategy():
         # TODO(b/138326910): remove this when MultiDeviceIterator is a
         # CompositeTensor (unless this is more efficient)
         data_iterator._initializer  # pylint: disable=pointless-statement, protected-access
       else:
         data_iterator = iter(self._dataset)
     yield epoch, data_iterator
     self._adapter.on_epoch_end()
Beispiel #19
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None, ds_context.get_replica_context())
         self.assertIs(dist, ds_context.get_cross_replica_context())
         self.assertTrue(ds_context.in_cross_replica_context())
         self.assertTrue(ds_context.has_strategy())
         self.assertIs(dist, ds_context.get_strategy())
         expected_value = _get_test_variable(
             "baz", variable_scope.VariableSynchronization.AUTO,
             variable_scope.VariableAggregation.NONE)
         self.assertDictEqual(expected_value,
                              variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, ds_context.get_replica_context())
     self.assertIs(dist, ds_context.get_cross_replica_context())
     self.assertTrue(ds_context.in_cross_replica_context())
     self.assertTrue(ds_context.has_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Beispiel #21
0
def default_model_compile(model, lr, loss='mean_absolute_error'):
    opt_kwargs = {}
    precision_policy = mixed_precision.global_policy()
    distributed = distribute_ctx.has_strategy()
    if precision_policy.loss_scale is None and not distributed:
        opt_kwargs['clipnorm'] = 1.
    if loss == 'compound_mssim':
        loss = compound_l1_mssim_loss
    elif loss == 'mssim':
        loss = partial(compound_l1_mssim_loss, alpha=0.9999)
        loss.__name__ = "mssim"
    model.compile(
        optimizer=tfa.optimizers.RectifiedAdam(lr=lr, **opt_kwargs),
        loss=loss,
        metrics=['mean_squared_error', keras_psnr, keras_ssim],
    )
Beispiel #22
0
 def _assign_moving_average(self, variable, value, momentum, inputs_size):
     with ops.name_scope(None, 'AssignMovingAvg',
                         [variable, value, momentum]) as scope:
         with ops.colocate_with(variable):
             decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
             if decay.dtype != variable.dtype.base_dtype:
                 decay = math_ops.cast(decay, variable.dtype.base_dtype)
             update_delta = (variable -
                             math_ops.cast(value, variable.dtype)) * decay
             # TODO(b/129279393): Support zero batch input in non
             # DistributionStrategy code as well.
             if distribution_strategy_context.has_strategy():
                 update_delta = tf_utils.smart_cond(
                     inputs_size > 0, lambda: update_delta,
                     lambda: K.zeros_like(update_delta))
             return state_ops.assign_sub(variable, update_delta, name=scope)
Beispiel #23
0
def compute_weighted_loss(losses, sample_weight=None, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE):
    if distribution_strategy_context.has_strategy() and \
            reduction in {tf.keras.losses.Reduction.AUTO, tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE}:
        raise ValueError(
            'Please use `tf.keras.losses.Reduction.SUM` or  `tf.keras.losses.Reduction.NONE` for loss reduction when '
            'losses are used with `tf.distribute.Strategy` outside of the built-in training loops. You can implement '
            '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch size like:\n'
            '```\n'
            'with strategy.scope():\n'
            '    loss_obj = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)\n'
            '....\n'
            '    loss = tf.reduce_sum(loss_obj(labels, predictions)) * (1. / global_batch_size)\n'
            '```\n'
            'Please see https://www.tensorflow.org/tutorials/distribute/custom_training for more details.')

    return losses_utils.compute_weighted_loss(losses, sample_weight=sample_weight, reduction=reduction)
Beispiel #24
0
  def update(self, grads):
    """Updates loss scale based on if gradients are finite in current step."""
    if distribution_strategy_context.has_strategy():
      distribution = distribution_strategy_context.get_cross_replica_context()

      def get_is_finite(grads):
        is_finite = _is_all_finite(grads)
        # We cast to float, because we cannot reduce booleans with
        # DistributionStrategy.
        return math_ops.cast(is_finite, dtypes.float32)

      is_finite_float = distribution.extended.call_for_each_replica(
          get_is_finite, args=(grads,))
      reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
                                                    is_finite_float, axis=None)
      is_finite = math_ops.equal(reduced_is_finite_float,
                                 distribution.num_replicas_in_sync)
    else:
      is_finite = _is_all_finite(grads)

    def update_if_finite_grads():
      """Update assuming the gradients are finite."""

      def incr_loss_scale():
        new_loss_scale = self._current_loss_scale * self._multiplier
        return control_flow_ops.group(
            _assign_if_finite(self._current_loss_scale, new_loss_scale),
            self._num_good_steps.assign(0))

      return control_flow_ops.cond(
          self._num_good_steps + 1 >= self._increment_period,
          incr_loss_scale, lambda: _op_in_graph_mode(
              self._num_good_steps.assign_add(1)))

    def update_if_not_finite_grads():
      """Update assuming the gradients are nonfinite."""

      new_loss_scale = math_ops.maximum(
          self._current_loss_scale / self._multiplier, 1)
      return control_flow_ops.group(
          self._num_good_steps.assign(0),
          self._current_loss_scale.assign(new_loss_scale))

    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
                                      update_if_not_finite_grads)
    should_apply_gradients = is_finite
    return update_op, should_apply_gradients
def _create_variable(*args, **kwargs):
    """Creates a variable, and check that it's not MirroredVariable.

  Args:
    *args: positional arguments passed along to `variables.Variable.
    **kwargs: keyword arguments passed along to `variables.Variable.

  Returns:
    The created variable.
  """
    if ds_context.has_strategy():
        raise ValueError(
            "Creating a generator within a strategy scope is disallowed, because "
            "there is ambiguity on how to replicate a generator (e.g. should it be "
            "copied so that each replica gets the same random numbers, or 'split' "
            "so that each replica gets different random numbers).")
        # TODO(wangpeng): Link to the RNG guide for solutions in such cases.
    var = variables.Variable(*args, **kwargs)
    return var
Beispiel #26
0
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_strategy():
                distribute = distribution_strategy_context.get_strategy()
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            with ops.colocate_with(variable):
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
Beispiel #27
0
    def _update_weights(self, fast_weights, slow_weights, alpha):
        def _update_slow_weight(slow_weight, fast_weight, a):
            slow_weight.assign_add(a * (fast_weight - slow_weight))

        def _update_fast_weight(fast_weight, slow_weight):
            fast_weight.assign(slow_weight)

        if tf.equal(tf.cast(self._iterations, tf.float32) % self.k, 0):
            if distribution_strategy_context.has_strategy():
                distribution = distribution_strategy_context.get_replica_context()

                for fast, slow in zip(fast_weights, slow_weights):
                    distribution.extended.call_for_each_replica(_update_slow_weight,
                                                                args=(slow, fast.value(), alpha))
                    distribution.extended.call_for_each_replica(_update_fast_weight,
                                                                args=(fast, slow.value()))
            else:
                for fast, slow in zip(fast_weights, slow_weights):
                    _update_slow_weight(slow, fast.value(), alpha)
                    _update_fast_weight(fast, slow.value())
Beispiel #28
0
def _var_key(var):
  """Key for representing a primary variable, for looking up slots.

  In graph mode the name is derived from the var shared name.
  In eager mode the name is derived from the var unique id.
  If distribution strategy exists, get the primary variable first.

  Args:
    var: the variable.

  Returns:
    the unique name of the variable.
  """

  # pylint: disable=protected-access
  if distribute_ctx.has_strategy() and hasattr(var, "_primary_var"):
    var = var._primary_var
  if hasattr(var, "op"):
    return var._shared_name
  return var._unique_id
def strategy_supports_loss_scaling():
  """Returns True if the current Strategy supports loss scaling."""
  if not distribution_strategy_context.has_strategy():
    return True
  strategy = distribution_strategy_context.get_strategy()
  # Strategies are supported if either there is only one replica or if variables
  # are replicated per device. Otherwise, the current model.fit() implementation
  # and most custom training loops incorrectly unscale the gradients. Currently,
  # gradients are unscaled once per compute replica, but they should be unscaled
  # once per variable replica. When there is one variable replica for each
  # compute replica, this works fine, but otherwise issues will occur.
  # TODO(reedwm): Support all strategies.
  return isinstance(strategy, (
      collective_all_reduce_strategy.CollectiveAllReduceStrategy,
      collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
      one_device_strategy.OneDeviceStrategy,
      one_device_strategy.OneDeviceStrategyV1,
      mirrored_strategy.MirroredStrategy,
      mirrored_strategy.MirroredStrategyV1,
  ))
Beispiel #30
0
def _var_key(var):
  """Key for representing a primary variable, for looking up slots.

  In graph mode the name is derived from the var shared name.
  In eager mode the name is derived from the var unique id.
  If distribution strategy exists, get the primary variable first.

  Args:
    var: the variable.

  Returns:
    the unique name of the variable.
  """

  # pylint: disable=protected-access
  if distribute_ctx.has_strategy() and hasattr(var, "_primary_var"):
    var = var._primary_var
  if hasattr(var, "op"):
    return var._shared_name
  return var._unique_id
Beispiel #31
0
  def _get_reduction(self):
    if distribution_strategy_context.has_strategy() and (
        self.reduction == losses_utils.ReductionV2.AUTO or
        self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE):
      raise ValueError(
          'Please use `tf.keras.losses.Reduction.SUM` or '
          '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are '
          'used with `tf.distribute.Strategy` outside of the built-in training '
          'loops. You can implement '
          '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch '
          'size like:\n```\nwith strategy.scope():\n'
          '    loss_obj = tf.keras.losses.CategoricalCrossentropy('
          'reduction=tf.keras.losses.reduction.NONE)\n....\n'
          '    loss = tf.reduce_sum(loss_obj(labels, predictions)) * '
          '(1. / global_batch_size)\n```\nPlease see '
          'https://www.tensorflow.org/alpha/tutorials/distribute/training_loops'
          ' for more details.')

    if self.reduction == losses_utils.ReductionV2.AUTO:
      return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
    return self.reduction
Beispiel #32
0
def _get_distribution_strategy(model):
  """Get the model's distribution strategy."""
  if model._distribution_strategy:
    return model._distribution_strategy
  else:
    # Use the default strategy if no strategy was present at compile.
    # Validate there is no actual strategy scope active at execution
    # time.
    strategy = distribution_strategy_context.get_strategy()
    if distribution_strategy_context.has_strategy():
      raise ValueError(
          'Model was compiled without any active distribution strategy, '
          'but there is an execution-time distribution '
          'strategy scope of (%s). '
          'Try to make sure your code looks similar to the following.\n'
          'with strategy.scope():\n'
          '  model=_create_model()\n'
          '  model.compile(...)\n'
          '  model.fit(...)'% strategy)

    return strategy
Beispiel #33
0
  def _get_reduction(self):
    """Handles `AUTO` reduction cases and returns the reduction value."""
    if distribution_strategy_context.has_strategy() and (
        self.reduction == losses_utils.ReductionV2.AUTO or
        self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE):
      raise ValueError(
          'Please use `tf.keras.losses.Reduction.SUM` or '
          '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are '
          'used with `tf.distribute.Strategy` outside of the built-in training '
          'loops. You can implement '
          '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch '
          'size like:\n```\nwith strategy.scope():\n'
          '    loss_obj = tf.keras.losses.CategoricalCrossentropy('
          'reduction=tf.keras.losses.reduction.None)\n....\n'
          '    loss = tf.reduce_sum(loss_obj(labels, predictions)) * '
          '(1. / global_batch_size)\n```\nPlease see '
          'https://www.tensorflow.org/alpha/tutorials/distribute/training_loops'
          ' for more details.')

    if self.reduction == losses_utils.ReductionV2.AUTO:
      return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
    return self.reduction
Beispiel #34
0
def _centering_weights(weight, weight_identity):
    # when using group norm
    # normalize weights variable may get better result
    centered_weight = weight_identity - tf.reduce_mean(weight_identity,
                                                       axis=[0, 1, 2])
    weight_norm = linalg_ops.norm(tf.cast(tf.reshape(
        centered_weight, [-1, centered_weight.shape[-1]]),
                                          dtype=tf.float32),
                                  ord=2,
                                  axis=-2)
    normed_weight = centered_weight / tf.cast(weight_norm, dtype=weight.dtype)

    if distribute_ctx.has_strategy():
        # Handle DistributionStrategy case.
        if distribute_ctx.in_cross_replica_context():
            raise RuntimeError(
                "Use `_distributed_apply()` instead of "
                "`apply_gradients()` in a cross-replica context.")

        assign_op = distribute_ctx.get_replica_context().merge_call(
            assign_vars, args=(weight, normed_weight))[0]
    else:
        assign_op = weight.assign(normed_weight)
    return [assign_op]
Beispiel #35
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

    # TODO(isaprykin): Get rid of `has_strategy()` check by
    # always calling _distributed_apply(), using the default distribution
    # as needed.
    if distribute_ctx.has_strategy():
      # Handle DistributionStrategy case.
      if distribute_ctx.in_cross_replica_context():
        raise RuntimeError("Use `_distributed_apply()` instead of "
                           "`apply_gradients()` in a cross-replica context.")

      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
      return distribute_ctx.get_replica_context().merge_call(
          self._distributed_apply, args=(grads_and_vars, global_step, name))

    # No DistributionStrategy case.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    converted_grads_and_vars = []
    for g, v in grads_and_vars:
      if g is not None:
        try:
          # Convert the grad to Tensor or IndexedSlices if necessary.
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)
      converted_grads_and_vars.append((g, v, p))

    converted_grads_and_vars = tuple(converted_grads_and_vars)
    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    if not var_list:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, v, _ in converted_grads_and_vars],))
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []
    with ops.name_scope(name, self._name) as name:
      self._prepare()
      for grad, var, processor in converted_grads_and_vars:
        if grad is None:
          continue
        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        if context.executing_eagerly() or isinstance(
            var,
            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
          scope_name = ""
        else:
          scope_name = var.op.name
        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
          update_ops.append(processor.update_op(self, grad))
      if global_step is None:
        apply_updates = self._finish(update_ops, name)
      else:
        with ops.control_dependencies([self._finish(update_ops, "update")]):
          with ops.colocate_with(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              # TODO(apassos): the implicit read in assign_add is slow; consider
              # making it less so.
              apply_updates = resource_variable_ops.assign_add_variable_op(
                  global_step.handle,
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  name=name)
            else:
              apply_updates = state_ops.assign_add(global_step, 1, name=name)

      if not context.executing_eagerly():
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates
def _compute_gradients_until_finite(
    distribution, loss_scale_gradient_tapes, loss_scale, target, sources,
    output_gradients, unconnected_gradients):
  """Compute gradients and update the loss scale until the gradients are finite.

  This must be called in a cross-replica context.

  This is a function instead of a method of LossScaleGradientTape, as the `self`
  parameter would be meaningless. There is one LossScaleGradientTape per
  replica, but this function is called once total (not per replica), so there
  cannot be a singular `self` parameter.

  Args:
    distribution: The distribution strategy in effect.
    loss_scale_gradient_tapes: A PerReplica value of LossScaleGradientTapes.
      Contains the LossScaleGradientTape of each replica.
    loss_scale: The loss scale to use to scale the loss and unscale the
      gradient.
    target: a list or nested structure of Tensors or Variables to be
      differentiated.
    sources: a list or nested structure of Tensors or Variables. `target` will
      be differentiated against elements in `sources`.
    output_gradients: Passed to GradientTape.gradient
    unconnected_gradients: Pass to GradientTape.gradient.

  Returns:
    The gradients of `target` with respect to `sources`.
  """
  # Autograph cannot convert this function, so we must use an explicit
  # tf.while_loop.
  # TODO(b/143572314): Fix Autograph so that it can convert this function, then
  # replace the tf.while_loop with a Python while loop.

  # For convenience, we only deal with flattened sources
  flattened_sources = nest.flatten(sources)

  # Define the initial loop variables of the while loop.

  # Dummy value for initial_grads. The first iteration of the loop will
  # overwrite `grads` to the actual gradients.
  initial_grads = flattened_sources
  if distribution_strategy_context.has_strategy():
    # A while_loop requires the initial values to have the same types as the
    # return values from the body. However, 'initial_grads' may have type
    # 'DistributionVariable', while body returns a 'PerReplica'. While both
    # types subclass 'DistributedValues', while_loop will still throw an error.
    # So we convert 'initial_grads' to be PerReplica values.
    # TODO(b/146084534): Once the bug is fixed, remove this special case.
    initial_grads = _convert_to_per_replicas(distribution, initial_grads)
  initial_ready_to_update = False
  initial_is_first_iteration = True

  def cond(grads, ready_to_update, is_first_iteration):
    """The condition of the while loop."""
    del grads
    # Equivalent to:
    # `is_first_iteration or (not ready_to_update and loss_scale() > 1)`
    return math_ops.logical_or(
        is_first_iteration,
        math_ops.logical_and(
            math_ops.logical_not(ready_to_update),
            math_ops.greater(loss_scale(), 1)))

  # Boolean list specifying whether each gradient is None or not. Set by body().
  is_nones = []

  def body(grads, ready_to_update, is_first_iteration):
    """The body of the while loop."""
    del grads, ready_to_update, is_first_iteration
    def replica_fn(gradient_tape, target, flattened_sources, output_gradients,
                   initial_grads):
      """Scales the loss, computes the gradients, and unscales the gradients."""
      loss_scale_val = loss_scale()
      with gradient_tape:  # re-enter gradient tape so it sees the loss scaling
        scaled_target = nest.map_structure(
            lambda t: t * math_ops.cast(loss_scale_val, t.dtype), target)
      scaled_grads = super(LossScaleGradientTape, gradient_tape).gradient(
          scaled_target, flattened_sources, output_gradients,
          unconnected_gradients)

      is_nones[:] = [g is None for g in scaled_grads]
      inv_loss_scale = 1.0 / loss_scale_val
      grads = []  # The unscaled gradients
      for g, initial_grad in zip(scaled_grads, initial_grads):
        if g is not None:
          # We call ensure_shape as shape information can be lost for certain
          # ops, such as tf.transpose, if the op is called in a tf.function and
          # has inputs created outside the tf.function.
          # TODO(b/132092188): Remove ensure_shape call after this has been
          # fixed.
          g = array_ops.ensure_shape(g, initial_grad.shape)
          grads.append(g * math_ops.cast(inv_loss_scale, g.dtype))
        else:
          # We cannot return None from a tf.while_loop, so we pass a dummy
          # tensor instead. We use initial_grad as a dummy tensor as it has the
          # correct shape and dtype. We replace it with None outside the while
          # loop.
          grads.append(initial_grad)
      return grads

    # Switch to a replica-context to compute gradients once per replica.
    grads = distribution.experimental_run_v2(
        replica_fn, args=(loss_scale_gradient_tapes, target, flattened_sources,
                          output_gradients, initial_grads))
    # Check for non-finite gradients possibly resulting from scaling.
    _, ready_to_update = loss_scale.update(grads)
    is_first_iteration = False
    return grads, ready_to_update, is_first_iteration

  grads, _, _ = control_flow_ops.while_loop(
      cond, body, [initial_grads, initial_ready_to_update,
                   initial_is_first_iteration],
      )
  grads = [None if is_none else g for g, is_none in zip(grads, is_nones)]
  grads = nest.pack_sequence_as(sources, grads)
  return grads
Beispiel #37
0
  def fit(
      self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1,
      callbacks=None, validation_split=0., validation_data=None, shuffle=True,
      class_weight=None, sample_weight=None, initial_epoch=0,
      steps_per_epoch=None, validation_steps=None, validation_freq=1,
      max_queue_size=10, workers=1, use_multiprocessing=False, **kwargs):
    batch_size = model._validate_or_infer_batch_size(
        batch_size, steps_per_epoch, x)

    strategy = _get_distribution_strategy(model)
    batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
        strategy,
        x,
        batch_size,
        steps_per_epoch,
        ModeKeys.TRAIN,
        validation_split=validation_split)
    dist_utils.validate_callbacks(input_callbacks=callbacks,
                                  optimizer=model.optimizer)
    # Enter tf.distribute.Strategy scope.
    with strategy.scope():
      training_data_adapter, validation_adapter = _process_training_inputs(
          model,
          x,
          y,
          batch_size=batch_size,
          epochs=epochs,
          sample_weights=sample_weight,
          class_weights=class_weight,
          validation_split=validation_split,
          steps_per_epoch=steps_per_epoch,
          shuffle=shuffle,
          validation_data=validation_data,
          validation_steps=validation_steps,
          distribution_strategy=strategy,
          max_queue_size=max_queue_size,
          workers=workers,
          use_multiprocessing=use_multiprocessing)

      total_samples = _get_total_number_of_samples(training_data_adapter)
      use_sample = total_samples is not None
      do_validation = (validation_adapter is not None)

      recreate_training_iterator = (
          training_data_adapter.should_recreate_iterator(steps_per_epoch))
      if not steps_per_epoch:
        # TODO(b/139762795): Add step inference for when steps is None to
        # prevent end of sequence warning message.
        steps_per_epoch = training_data_adapter.get_size()

      # tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
      training_context = TrainingContext()

      training_dataset = training_data_adapter.get_dataset()
      # Raise an error if steps_per_epoch isn't specified but the dataset
      # is infinite.
      # TODO(scottzhu): This check should probably happen in the adapter
      inferred_steps = training_utils.infer_steps_for_dataset(
          model,
          training_dataset,
          steps_per_epoch,
          steps_name='steps_per_epoch',
          epochs=0)

      steps_per_epoch = (
          inferred_steps if steps_per_epoch is None else steps_per_epoch)

      training_dataset = strategy.experimental_distribute_dataset(
          training_dataset)

      training_function = training_v2_utils._get_or_make_execution_function(
          model, ModeKeys.TRAIN)

      training_data_iter = None
      if do_validation:
        validation_dataset = validation_adapter.get_dataset()
        if not validation_steps:
          # Raise an error if validation_steps isn't specified but the
          # validation dataset is infinite.
          validation_steps = (
              validation_adapter.get_size() or
              training_utils.infer_steps_for_dataset(
                  model,
                  validation_dataset,
                  validation_steps,
                  steps_name='validation_steps'))
        eval_function = training_v2_utils._get_or_make_execution_function(
            model, ModeKeys.TEST)
        eval_data_iter = None
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        val_total_samples = _get_total_number_of_samples(validation_adapter)
      else:
        val_total_samples = None

      if verbose and (total_samples or steps_per_epoch):
        _print_train_info(total_samples, steps_per_epoch, val_total_samples,
                          validation_steps)

      training_callbacks = cbks.configure_callbacks(
          callbacks,
          model,
          do_validation=do_validation,
          batch_size=batch_size,
          epochs=epochs,
          steps_per_epoch=steps_per_epoch,
          samples=total_samples or steps_per_epoch,
          count_mode='samples' if use_sample else 'steps',
          verbose=0,  # Handle ProgBarLogger separately in this loop.
          mode=ModeKeys.TRAIN)

      with training_context.on_start(model, training_callbacks, use_sample,
                                     verbose, ModeKeys.TRAIN):

        initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
            initial_epoch, ModeKeys.TRAIN)

        for epoch in range(initial_epoch, epochs):
          if training_context.callbacks.model.stop_training:
            break

          # Training
          with training_context.on_epoch(epoch, ModeKeys.TRAIN) as epoch_logs:
            model.reset_metrics()
            if training_data_iter is None or recreate_training_iterator:
              if (training_data_iter is not None and
                  distribution_strategy_context.has_strategy()):
                # TODO(kaftan): remove this when MultiDeviceIterator is a
                ## compositetensor (unless this is more efficient)
                training_data_iter._initializer  # pylint: disable=pointless-statement
              else:
                training_data_iter = iter(training_dataset)

            training_result = run_one_epoch(
                model,
                training_data_iter,
                training_function,
                dataset_size=training_data_adapter.get_size(),
                batch_size=training_data_adapter.batch_size(),
                strategy=strategy,
                steps_per_epoch=steps_per_epoch,
                num_samples=total_samples,
                mode=ModeKeys.TRAIN,
                training_context=training_context,
                total_epochs=epochs)
            cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)

            # In the case of steps_per_epoch = None, the final cardinality will
            # be determined when the inputs are fully consumed (eg dataset or
            # generator). Update the steps_per_epoch to the new value.
            if (steps_per_epoch is None
                and training_context.progbar.progbar.target is not None):
              steps_per_epoch = training_context.progbar.progbar.target

            # Evaluation
            if (do_validation and
                training_utils.should_run_validation(validation_freq, epoch) and
                not training_callbacks.model.stop_training):
              if (eval_data_iter is not None and
                  distribution_strategy_context.has_strategy()):
                # TODO(kaftan): remove this when MultiDeviceIterator is a
                ## compositetensor (unless this is more efficient)
                eval_data_iter._initializer  # pylint: disable=pointless-statement
              else:
                eval_data_iter = iter(validation_dataset)

              validation_callbacks = cbks.configure_callbacks(
                  training_callbacks,
                  model,
                  batch_size=batch_size,
                  epochs=1,
                  steps_per_epoch=validation_steps,
                  samples=val_total_samples or validation_steps,
                  count_mode='samples' if use_sample else 'steps',
                  verbose=0,  # Handle ProgBarLogger separately in this loop.
                  mode=ModeKeys.TEST)

              eval_context = TrainingContext()
              with eval_context.on_start(
                  model,
                  validation_callbacks,
                  use_sample,
                  verbose=0,
                  mode=ModeKeys.TEST):
                with eval_context.on_epoch(epoch, ModeKeys.TEST):
                  model.reset_metrics()
                  eval_result = run_one_epoch(
                      model,
                      eval_data_iter,
                      eval_function,
                      dataset_size=validation_adapter.get_size(),
                      batch_size=validation_adapter.batch_size(),
                      strategy=strategy,
                      steps_per_epoch=validation_steps,
                      num_samples=val_total_samples,
                      mode=ModeKeys.TEST,
                      training_context=eval_context,
                      total_epochs=1)
                  cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
                                 prefix='val_')

    return model.history
Beispiel #38
0
def is_default_strategy(strategy):
    with strategy.scope():
        return not distribution_strategy_context.has_strategy()
Beispiel #39
0
  def apply_gradients(self, grads_and_vars, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]
    if distribute_ctx.has_strategy():
      reduced_grads = merge_grads(grads_and_vars)
      grads_and_vars = zip(reduced_grads, var_list)

    self._create_hypers()
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []

    self._prepare(var_list)

    def update_grad_to_var(grad, var):
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
      if isinstance(grad, ops.IndexedSlices):
        if var.constraint is not None:
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
        return self._resource_apply_sparse_duplicate_indices(
            grad.values, var, grad.indices)
      update_op = self._resource_apply_dense(grad, var)
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
      else:
        return update_op

    with ops.name_scope(name, self._name) as name:
      for grad, var in grads_and_vars:
        scope_name = ("" if ops.executing_eagerly_outside_functions() else
                      "_" + var.op.name)
        with ops.name_scope("update" + scope_name):
          update_ops.append(update_grad_to_var(grad, var))
      # control dependencies does not work in per replica mode, please change
      # this once b/118841692 is fixed.
      # with ops.control_dependencies(update_ops):
      #   apply_updates = self._iterations.assign_add(1).op
      apply_updates = merge_update_step(update_ops, self.iterations)
      return apply_updates
Beispiel #40
0
  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    has_strategy = distribution_strategy_context.has_strategy()
    replica_context = distribution_strategy_context.get_replica_context()

    # The purpose of using `merge_call` to call `result()` is to trigger cross
    # replica aggregation of metric state variables (SyncOnReadVariable). After
    # we introduced `variable_sync_on_read_context`, in principle there is no
    # need to use `merge_call` here. However the branch still exists because:
    #
    # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor
    #    across replicas (achieved by `merge_call`). With
    #    `variable_sync_on_read_context` each replica gets their own tensors
    #    residing on replica's device, thus breaking the assumption.
    # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns
    #    the metric values of the first replica. With
    #    `variable_sync_on_read_context` since each replica gets their own
    #    tensors, the metric result tensors on the non-first replicas are not in
    #    the return value of train_function, making TF graph optimizer prune the
    #    branch that computes and aggregates those metric results. As a result,
    #    if NCCL is used to do the aggregation, the program will hang because
    #    NCCL ops are only launched on the non-pruned first replica.
    #
    # We condition on strategy.extended._use_merge_call() since we know if it is
    # false, the program uses `jit_compile` to compile replica fn, meaning it is
    # not V1 training (hence #1 is okay), and no pruning will happen as
    # compiled functions are not inlined (hence #2 is okay).

    if (not has_strategy or replica_context is None or
        not distribution_strategy_context.get_strategy(
        ).extended._use_merge_call()):
      with distribution_strategy_context.variable_sync_on_read_context():
        raw_result = result_fn(*args)
        # Results need to be wrapped in a `tf.identity` op to ensure
        # correct execution order.
        if isinstance(raw_result,
                      (ops.Tensor, variables_module.Variable, float, int)):
          result_t = array_ops.identity(raw_result)
        elif isinstance(raw_result, dict):
          result_t = {
              key: array_ops.identity(value)
              for key, value in raw_result.items()
          }
        else:
          try:
            result_t = array_ops.identity(raw_result)
          except (ValueError, TypeError):
            raise RuntimeError(
                'The output of `metric.result()` can only be a single '
                'Tensor/Variable, or a dict of Tensors/Variables. '
                'For metric %s, got result %s.' % (metric_obj.name, raw_result))
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerReplica` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        result = distribution.experimental_local_results(merge_fn)[0](*args)

        # Wrapping result in identity so that control dependency between
        # update_op from `update_state` and result works in case result returns
        # a tensor.
        return array_ops.identity(result)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)

    # We are saving the result op here to be used in train/test execution
    # functions. This basically gives the result op that was generated with a
    # control dep to the updates for these workflows.
    metric_obj._call_result = result_t
    return result_t
Beispiel #41
0
 def _support_zero_size_input(self):
     return distribution_strategy_context.has_strategy() and getattr(
         distribution_strategy_context.get_strategy().extended,
         'experimental_enable_get_next_as_optional', False)
def is_default_strategy(strategy):
  with strategy.scope():
    return not distribution_strategy_context.has_strategy()