Example #1
0
 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribution_strategy_context.in_cross_replica_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as reduction doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribution_strategy_context.get_replica_context().merge_call(
         merge_fn, args=(output,))
Example #2
0
  def apply_gradients(self, grads_and_vars, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

    self._create_hypers()
    with ops.init_scope():
      self._create_slots(var_list)

    self._prepare(var_list)

    return distribute_ctx.get_replica_context().merge_call(
        self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
Example #4
0
  def decorated(_, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = array_ops.identity(result_fn(*args))
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerReplica` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        merged_result_fn = (
            distribution.experimental_local_results(merge_fn)[0](*args))

        # Wrapping result in identity so that control dependency between
        # update_op from `update_state` and result works in case result returns
        # a tensor.
        return array_ops.identity(merged_result_fn)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)
    return result_t
 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
Example #6
0
def merge_grads(grads_and_vars):
  """Merge gradients from different replicas."""

  def merge_grad_fn(strategy, grads_and_vars):
    reduced_grads = strategy.extended.batch_reduce_to(
        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
    return reduced_grads

  return distribute_ctx.get_replica_context().merge_call(
      merge_grad_fn, args=(grads_and_vars,))
Example #7
0
  def _test_step_fn(inputs):
    """A fn that returns output of single test step."""
    inputs, targets = inputs
    (distribution_strategy_context.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs, targets)))

    (_, outputs, updates, _) = (
        _per_device_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
    with ops.control_dependencies([updates]):
      return outputs
Example #8
0
  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribution_strategy_context.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
                                                            axis=None)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
                                                            axis=None)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribution_strategy_context.get_replica_context().merge_call(
          merge_fn, args=(output,))
 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
  def _test_run(self, strategy):
    out1 = strategy.experimental_run_v2(
        lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1)
    self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))

    out2 = strategy.experimental_run_v2(
        lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
    out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
    self.assertAllEqual([2, 4], out2_vals["a"])
    self.assertAllEqual([1, 4], out2_vals["b"])

    out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2)
    self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
Example #11
0
def merge_update_step(update_ops, local_step):
  """Merge local step counter update from different replicas."""

  def merge_update_step_fn(strategy, update_ops, local_step):
    merged_ops = []
    for update_op in update_ops:
      merged_ops.append(strategy.group(update_op))
    with ops.control_dependencies(merged_ops):
      incre_op = local_step.assign_add(1).op
    return incre_op

  return distribute_ctx.get_replica_context().merge_call(
      merge_update_step_fn, args=(update_ops, local_step))
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, ds_context.get_replica_context())
     self.assertIs(dist, ds_context.get_cross_replica_context())
     self.assertTrue(ds_context.in_cross_replica_context())
     self.assertTrue(ds_context.has_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
  def _test_step_fn(inputs):
    """A fn that returns output of single test step."""
    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
      inputs, targets = inputs
    else:
      targets = None

    (distribution_strategy_context.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs, targets)))

    (_, outputs, updates, _) = (
        _per_replica_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
    with ops.control_dependencies([updates]):
      return outputs
  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(ds_context._get_default_strategy(), dist)
      self.assertIs(None, ds_context.get_replica_context())
      self.assertIs(dist, ds_context.get_cross_replica_context())
      self.assertTrue(ds_context.in_cross_replica_context())
      self.assertIs(dist, ds_context.get_strategy())
      self.assertFalse(ds_context.has_strategy())
      return "foo_" + s

    replica_ctx = ds_context.get_replica_context()
    self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
    self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
    _assert_in_default_state(self)
def skip_summary():
  """Determines if summary should be skipped.

  If using multiple replicas in distributed strategy, skip summaries on all
  replicas except the first one (replica_id=0).

  Returns:
    True if the summary is skipped; False otherwise.
  """

  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last replica,
  # compute sum or mean across replicas).
  replica_context = distribution_strategy_context.get_replica_context()
  if not replica_context:
    return False
  # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
  # initialized, remember to change here as well.
  replica_id = replica_context.replica_id_in_sync_group
  if isinstance(replica_id, ops.Tensor):
    replica_id = tensor_util.constant_value(replica_id)
  return replica_id and replica_id > 0
Example #16
0
  def decorated(_, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)
    return result_t
Example #17
0
def _all_mean(value):
    ctx = ds_context.get_replica_context()
    return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
Example #18
0
  def apply_gradients(self,
                      grads_and_vars,
                      name=None,
                      experimental_aggregate_gradients=True):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    The method sums gradients from all replicas in the presence of
    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
    passing `experimental_aggregate_gradients=False`.

    Example:

    ```python
    grads = tape.gradient(loss, vars)
    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
    # Processing aggregated gradients.
    optimizer.apply_gradients(zip(grads, vars),
        experimental_aggregate_gradients=False)

    ```

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      name: Optional name for the returned operation. Default to the name passed
        to the `Optimizer` constructor.
      experimental_aggregate_gradients: Whether to sum gradients from different
        replicas in the presense of `tf.distribute.Strategy`. If False, it's
        user responsibility to aggregate the gradients. Default to True.

    Returns:
      An `Operation` that applies the specified gradients. The `iterations`
      will be automatically increased by 1.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

    with backend.name_scope(self._name):
      # Create iteration if necessary.
      with ops.init_scope():
        _ = self.iterations
        self._create_hypers()
        self._create_slots(var_list)

      if not grads_and_vars:
        # Distribution strategy does not support reducing an empty list of
        # gradients
        return control_flow_ops.no_op()

      if distribute_ctx.in_cross_replica_context():
        raise RuntimeError(
            "`apply_gradients() cannot be called in cross-replica context. "
            "Use `tf.distribute.Strategy.run` to enter replica "
            "context.")

      strategy = distribute_ctx.get_strategy()
      if (not experimental_aggregate_gradients and strategy and isinstance(
          strategy.extended,
          parameter_server_strategy.ParameterServerStrategyExtended)):
        raise NotImplementedError(
            "`experimental_aggregate_gradients=False is not supported for "
            "ParameterServerStrategy and CentralStorageStrategy")

      apply_state = self._prepare(var_list)
      if experimental_aggregate_gradients:
        reduced_grads = self._aggregate_gradients(grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = list(zip(reduced_grads, var_list))
      return distribute_ctx.get_replica_context().merge_call(
          functools.partial(self._distributed_apply, apply_state=apply_state),
          args=(grads_and_vars,),
          kwargs={
              "name": name,
          })
Example #19
0
 def model_fn():
   value = math_ops.cast(
       ds_context.get_replica_context().replica_id_in_sync_group,
       mirrored_var.dtype)
   return mirrored_var.assign_sub(value)
Example #20
0
 def model_fn():
   b = variable_scope.get_variable("b", [1])
   with variable_scope.variable_scope("foo"):
     c = ds_context.get_replica_context().merge_call(in_cross_replica)
   return b, c
Example #21
0
 def run_fn():
     """Function executed for each replica."""
     with summary_writer.as_default():
         replica_id = ds_context.get_replica_context(
         ).replica_id_in_sync_group
         return summary_ops.write("a", replica_id)
def _merge_call_merge_raises_fn():
  ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
Example #23
0
def f1_score(labels, predictions, weights=None, num_thresholds=200,
             metrics_collections=None, updates_collections=None, name=None):
  """Computes the approximately best F1-score across different thresholds.

  The f1_score function applies a range of thresholds to the predictions to
  convert them from [0, 1] to bool. Precision and recall are computed by
  comparing them to the labels. The F1-Score is then defined as
  2 * precision * recall / (precision + recall). The best one across the
  thresholds is returned.

  Disclaimer: In practice it may be desirable to choose the best threshold on
  the validation set and evaluate the F1 score with this threshold on a
  separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5).

  This function internally creates four local variables, `true_positives`,
  `true_negatives`, `false_positives` and `false_negatives` that are used to
  compute the pairs of recall and precision values for a linearly spaced set of
  thresholds from which the best f1-score is derived.

  This value is ultimately returned as `f1-score`, an idempotent operation that
  computes the F1-score (computed using the aforementioned variables). The
  `num_thresholds` variable controls the degree of discretization with larger
  numbers of thresholds more closely approximating the true best F1-score.

  For estimation of the metric over a stream of data, the function creates an
  `update_op` operation that updates these variables and returns the F1-score.

  Example usage with a custom estimator:
  def model_fn(features, labels, mode):
    predictions = make_predictions(features)
    loss = make_loss(predictions, labels)
    train_op = tf.contrib.training.create_train_op(
          total_loss=loss,
          optimizer='Adam')
    eval_metric_ops = {'f1': f1_score(labels, predictions)}
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs)
  estimator = tf.estimator.Estimator(model_fn=model_fn)

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
      `bool`.
    predictions: A floating point `Tensor` of arbitrary shape and whose values
      are in the range `[0, 1]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    num_thresholds: The number of thresholds to use when discretizing the roc
      curve.
    metrics_collections: An optional list of collections that `f1_score` should
      be added to.
    updates_collections: An optional list of collections that `update_op` should
      be added to.
    name: An optional variable_scope name.

  Returns:
    f1_score: A scalar `Tensor` representing the current best f1-score across
      different thresholds.
    update_op: An operation that increments the `true_positives`,
      `true_negatives`, `false_positives` and `false_negatives` variables
      appropriately and whose value matches the `f1_score`.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
  """
  with variable_scope.variable_scope(
      name, 'f1', (labels, predictions, weights)):
    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
        predictions=predictions, labels=labels, weights=weights)
    # To account for floating point imprecisions / avoid division by zero.
    epsilon = 1e-7
    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
                  for i in range(num_thresholds - 2)]
    thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon]

    # Confusion matrix.
    values, update_ops = metrics_impl._confusion_matrix_at_thresholds(  # pylint: disable=protected-access
        labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn'))

    # Compute precision and recall at various thresholds.
    def compute_best_f1_score(tp, fp, fn, name):
      precision_at_t = math_ops.div(tp, epsilon + tp + fp,
                                    name='precision_' + name)
      recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
      # Compute F1 score.
      f1_at_thresholds = (
          2.0 * precision_at_t * recall_at_t /
          (precision_at_t + recall_at_t + epsilon))
      return math_ops.reduce_max(f1_at_thresholds)

    def f1_across_replicas(_, values):
      best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'],
                                      fn=values['fn'], name='value')
      if metrics_collections:
        ops.add_to_collections(metrics_collections, best_f1)
      return best_f1

    best_f1 = distribution_strategy_context.get_replica_context().merge_call(
        f1_across_replicas, args=(values,))

    update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
                                      fn=update_ops['fn'], name='update')
    if updates_collections:
      ops.add_to_collections(updates_collections, update_op)

    return best_f1, update_op
Example #24
0
 def run(value):
     value_1 = array_ops.identity(value)
     value_3 = array_ops.identity(value_2)
     ctx = ds_context.get_replica_context()
     return ctx._all_gather([value_1, value_3], axis=axis)
Example #25
0
 def replica_fn(per_replica_value):
     ctx = ds_context.get_replica_context()
     local_value = array_ops.identity(per_replica_value)
     return ctx._all_gather(local_value, axis=axis)
def _merge_call_merge_raises_fn():
    ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
def _merge_raises_fn():
    ds_context.get_replica_context().merge_call(_raise_exception_fn)
 def mark_devices_fn():
     replica_id = self.evaluate(
         ds_context.get_replica_context().replica_id_in_sync_group)
     self.assertLess(replica_id, len(d.extended.worker_devices))
     self.assertFalse(expected_devices[replica_id])
     expected_devices[replica_id] = True
 def run():
     ctx = distribution_strategy_context.get_replica_context()
     val = np_array_ops.asarray(ctx.replica_id_in_sync_group)
     return val * multiplier
Example #30
0
 def run():
     value_identity = array_ops.identity(single_value)
     ctx = ds_context.get_replica_context()
     return ctx._all_gather([value_identity, value_identity], axis=axis)
 def run_fn():
   """Function executed for each replica."""
   with summary_writer.as_default():
     replica_id = ds_context.get_replica_context().replica_id_in_sync_group
     return summary_ops.write("a", replica_id)
Example #32
0
 def replica_fn(value):
     ctx = ds_context.get_replica_context()
     return ctx._all_gather(value, axis=0)
def _all_mean(value):
  ctx = ds_context.get_replica_context()
  return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
Example #34
0
 def run(value):
     value_identity = array_ops.identity(value)
     ctx = ds_context.get_replica_context()
     return ctx._all_gather(value_identity, axis=0)
Example #35
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

    # TODO(isaprykin): Get rid of `has_strategy()` check by
    # always calling _distributed_apply(), using the default distribution
    # as needed.
    if distribute_ctx.has_strategy():
      # Handle DistributionStrategy case.
      if distribute_ctx.in_cross_replica_context():
        raise RuntimeError("Use `_distributed_apply()` instead of "
                           "`apply_gradients()` in a cross-replica context.")

      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
      return distribute_ctx.get_replica_context().merge_call(
          self._distributed_apply, args=(grads_and_vars, global_step, name))

    # No DistributionStrategy case.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    converted_grads_and_vars = []
    for g, v in grads_and_vars:
      if g is not None:
        try:
          # Convert the grad to Tensor or IndexedSlices if necessary.
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)
      converted_grads_and_vars.append((g, v, p))

    converted_grads_and_vars = tuple(converted_grads_and_vars)
    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    if not var_list:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, v, _ in converted_grads_and_vars],))
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []
    with ops.name_scope(name, self._name) as name:
      self._prepare()
      for grad, var, processor in converted_grads_and_vars:
        if grad is None:
          continue
        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        if context.executing_eagerly() or isinstance(
            var,
            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
          scope_name = ""
        else:
          scope_name = var.op.name
        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
          update_ops.append(processor.update_op(self, grad))
      if global_step is None:
        apply_updates = self._finish(update_ops, name)
      else:
        with ops.control_dependencies([self._finish(update_ops, "update")]):
          with ops.colocate_with(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              # TODO(apassos): the implicit read in assign_add is slow; consider
              # making it less so.
              apply_updates = resource_variable_ops.assign_add_variable_op(
                  global_step.handle,
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  name=name)
            else:
              apply_updates = state_ops.assign_add(global_step, 1, name=name)

      if not context.executing_eagerly():
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates
Example #36
0
 def f():
     return ds_context.get_replica_context().replica_id_in_sync_group
Example #37
0
 def model_fn():
   b = variable_scope.variable(1.0, name="b")
   with ops.name_scope("foo"):
     c = ds_context.get_replica_context().merge_call(in_cross_replica)
   return b, c
Example #38
0
    def _calculate_mean_and_var(self, x, axes, keep_dims):

        with K.name_scope('moments'):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(
                x, dtypes.float32) if x.dtype == dtypes.float16 else x
            replica_ctx = ds.get_replica_context()
            if replica_ctx:
                # local to me
                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y),
                                                        axis=axes,
                                                        keepdims=True)
                batch_size = math_ops.cast(
                    array_ops.shape_v2(y)[0], dtypes.float32)
                # TODO(b/163099951): batch the all-reduces once we sort out the ordering
                # issue for NCCL. We don't have a mechanism to launch NCCL in the same
                # order in each replica nowadays, so we limit NCCL to batch all-reduces.

                # get the sum of all replicas (converge all devices)
                y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
                                               local_sum)
                # get the sum from all replicas (converge all devices)
                y_squared_sum = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, local_squared_sum)
                # get the net batch size from all devices (converge all devices)
                global_batch_size = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, batch_size)

                # get the number of total params you are averaging (local)
                axes_vals = [(array_ops.shape_v2(y))[i]
                             for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                           dtypes.float32)
                multiplier = multiplier * global_batch_size

                # conver mean var (locally)
                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
            else:
                # if you only have one replica dont worry about it
                # Compute true mean while keeping the dims for proper broadcasting.
                mean = math_ops.reduce_mean(y,
                                            axes,
                                            keepdims=True,
                                            name='mean')
                # sample variance, not unbiased variance
                # Note: stop_gradient does not change the gradient that gets
                #       backpropagated to the mean from the variance calculation,
                #       because that gradient is zero
                variance = math_ops.reduce_mean(math_ops.squared_difference(
                    y, mean),
                                                axes,
                                                keepdims=True,
                                                name='variance')
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                variance = array_ops.squeeze(variance, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16))
            else:
                return (mean, variance)
Example #39
0
 def model_fn():
   v = variable_scope.variable(1.0, name="foo")
   ds_context.get_replica_context().merge_call(lambda _: _)
   return v
Example #40
0
def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
  """Compute the moving average of a variable.

  The moving average of 'variable' updated with 'value' is:
    variable * decay + value * (1 - decay)

  The returned Operation sets 'variable' to the newly computed moving average,
  by performing this subtraction:
     variable -= (1 - decay) * (variable - value)

  Since variables that are initialized to a `0` value will be `0` biased,
  `zero_debias` optionally enables scaling by the mathematically correct
  debiasing factor of
    1 - decay ** num_updates
  See `ADAM: A Method for Stochastic Optimization` Section 3 for more details
  (https://arxiv.org/abs/1412.6980).

  The names of the debias shadow variables, by default, include both the scope
  they were created in and the scope of the variables they debias. They are also
  given a uniquifying-suffix.

  E.g.:

  ```
    with tf.compat.v1.variable_scope('scope1'):
      with tf.compat.v1.variable_scope('scope2'):
        var = tf.compat.v1.get_variable('foo')
        update_1 = tf.assign_moving_average(var, 0.0, 1.0)
        update_2 = tf.assign_moving_average(var, 0.0, 0.9)

    # var.name: 'scope1/scope2/foo'
    # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
    #                   'scope1/scope2/scope1/scope2/foo/biased_1'
  ```

  Args:
    variable: A Variable.
    value: A tensor with the same shape as 'variable'.
    decay: A float Tensor or float value.  The moving average decay.
    zero_debias: A python bool. If true, assume the variable is 0-initialized
      and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
        `_zero_debias` for more details.
    name: Optional name of the returned operation.

  Returns:
    A tensor which if evaluated will compute and return the new moving average.
  """

  with ops.name_scope(name, "AssignMovingAvg",
                      [variable, value, decay]) as scope:
    decay = ops.convert_to_tensor(1.0 - decay, name="decay")
    if decay.dtype != variable.dtype.base_dtype:
      decay = math_ops.cast(decay, variable.dtype.base_dtype)

    def update_fn(v, value):
      return state_ops.assign_sub(v, (v - value) * decay, name=scope)

    def update(strategy, v, value):
      if zero_debias:
        return _zero_debias(strategy, v, value, decay)
      else:
        return strategy.extended.update(v, update_fn, args=(value,))

    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context:
      # In a replica context, we update variable using the mean of value across
      # replicas.
      def merge_fn(strategy, v, value):
        value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
                                            v)
        return update(strategy, v, value)

      return replica_context.merge_call(merge_fn, args=(variable, value))
    else:
      strategy = distribution_strategy_context.get_cross_replica_context()
      return update(strategy, variable, value)
Example #41
0
 def model_fn():
     return distribution.extended._get_local_replica_id(
         ds_context.get_replica_context().replica_id_in_sync_group)
Example #42
0
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
  """Replaces `tf.Variable` initializers so they load from a checkpoint file.

  Values are not loaded immediately, but when the initializer is run
  (typically by running a `tf.compat.v1.global_variables_initializer` op).

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:

  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    current `scope_name` from `checkpoint_scope_name` with matching tensor
    names.
  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    will initialize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.
  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    object with tensor 'scope_variable_name' from the checkpoint.
  * `'scope_variable_name': list(variable)` - will initialize list of
    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  `'<variable>/part_<part #>'`.

  Example:

  ```python

  # Say, '/tmp/model.ckpt' has the following tensors:
  #  -- name='old_scope_1/var1', shape=[20, 2]
  #  -- name='old_scope_1/var2', shape=[50, 4]
  #  -- name='old_scope_2/var3', shape=[100, 100]

  # Create new model's variables
  with tf.compat.v1.variable_scope('new_scope_1'):
    var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
                           initializer=tf.compat.v1.zeros_initializer())
  with tf.compat.v1.variable_scope('new_scope_2'):
    var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
                           initializer=tf.compat.v1.zeros_initializer())
    # Partition into 5 variables along the first axis.
    var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
                           initializer=tf.compat.v1.zeros_initializer(),
                           partitioner=lambda shape, dtype: [5, 1])

  # Initialize all variables in `new_scope_1` from `old_scope_1`.
  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})

  # Use names to specify which variables to initialize from checkpoint.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': 'new_scope_1/var1',
                        'old_scope_1/var2': 'new_scope_2/var2'})

  # Or use tf.Variable objects to identify what to initialize.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': var1,
                        'old_scope_1/var2': var2})

  # Initialize partitioned variables using variable's name
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': 'new_scope_2/var3'})

  # Or specify the list of tf.Variable objects.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': var3._get_variable_list()})

  ```

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    ValueError: If missing variables in current graph, or if missing
      checkpoints or tensors in checkpoints.
  """
  init_from_checkpoint_fn = lambda _: _init_from_checkpoint(
      ckpt_dir_or_file, assignment_map)
  if distribution_strategy_context.get_cross_replica_context():
    init_from_checkpoint_fn(None)
  else:
    distribution_strategy_context.get_replica_context().merge_call(
        init_from_checkpoint_fn)
Example #43
0
    def _subdiv_calculate_mean_and_var(self, x, axes, keep_dims):

        with K.name_scope('moments'):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(
                x, dtypes.float32) if x.dtype == dtypes.float16 else x
            replica_ctx = ds.get_replica_context()

            if replica_ctx:
                # local to me

                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y),
                                                        axis=axes,
                                                        keepdims=True)
                batch_size = math_ops.cast(
                    array_ops.shape_v2(y)[0], dtypes.float32)
                # TODO(b/163099951): batch the all-reduces once we sort out the ordering
                # issue for NCCL. We don't have a mechanism to launch NCCL in the same
                # order in each replica nowadays, so we limit NCCL to batch all-reduces.
                # get the sum of all replicas (converge all devices)
                y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
                                               local_sum)
                # get the sum from all replicas (converge all devices)
                y_squared_sum = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, local_squared_sum)
                # get the net batch size from all devices (converge all devices)
                input_batch_size = replica_ctx.all_reduce(
                    reduce_util.ReduceOp.SUM, batch_size)

                #tf.print(replica_ctx.replica_id_in_sync_group, replica_ctx.num_replicas_in_sync, batch_size, self.aggregated_square_sum_batch, axes)
                # get the number of total params you are averaging (local)
                axes_vals = [(array_ops.shape_v2(y))[i]
                             for i in range(1, len(axes))]
                multiplier_ = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                            dtypes.float32)
                multiplier = multiplier_ * input_batch_size

                # conver mean var (locally)
                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
                net_sum = y_sum / multiplier_
                squared_mean = y_squared_sum / multiplier_

            else:
                # mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
                # # sample variance, not unbiased variance
                # # Note: stop_gradient does not change the gradient that gets
                # #       backpropagated to the mean from the variance calculation,
                # #       because that gradient is zero
                # variance = math_ops.reduce_mean(
                #     math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
                #     axes,
                #     keepdims=True,
                #     name='variance')

                net_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                squared_mean = math_ops.reduce_sum(math_ops.square(y),
                                                   axis=axes,
                                                   keepdims=True)

                if self._support_zero_size_input():
                    # Keras assumes that batch dimension is the first dimension for Batch
                    # Normalization.
                    input_batch_size = array_ops.shape(y)[0]
                else:
                    input_batch_size = None

                # get the number of total params you are averaging including batchsize(local)
                axes_vals = [(array_ops.shape_v2(y))[i]
                             for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
                                           dtypes.float32)

                squared_mean = squared_mean / multiplier
                net_sum = net_sum / multiplier

                if input_batch_size is None:
                    mean, variance = nn.moments(y, axes, keep_dims=True)
                    input_batch_size = 0
                else:
                    batches_ = math_ops.cast(input_batch_size,
                                             self._param_dtype)
                    # # if you only have one replica dont worry about it
                    # # Compute true mean while keeping the dims for proper broadcasting.
                    mean = net_sum / batches_
                    variance = squared_mean / batches_ - math_ops.square(mean)

            input_batch_size = math_ops.cast(input_batch_size, dtypes.int32)
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                net_sum = array_ops.squeeze(net_sum, axes)
                variance = array_ops.squeeze(variance, axes)
                squared_mean = array_ops.squeeze(squared_mean, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(net_sum, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16),
                        math_ops.cast(squared_mean,
                                      dtypes.float16), input_batch_size)
            else:
                return (mean, net_sum, variance, squared_mean,
                        input_batch_size)
Example #44
0
def _all_sum(value):
    ctx = ds_context.get_replica_context()
    return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
Example #45
0
 def get():
   rep_id = (
       distribution_strategy_context.get_replica_context()
       .replica_id_in_sync_group)
   return control_flow_ops.cond(
       math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.)
 def apply_gradients(self, grads_and_vars, name=None):
   if distribution_strategy_context.in_cross_replica_context():
     raise ValueError('apply_gradients() must be called in a replica context.')
   return distribution_strategy_context.get_replica_context().merge_call(
       self._apply_gradients_cross_replica, args=(grads_and_vars, name))
def _get_replica_id_integer():
  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
  if isinstance(replica_id, ops.Tensor):
    replica_id = tensor_util.constant_value(replica_id)
  return replica_id
Example #48
0
 def all_reduce(x):
     ctx = distribution_strategy_context.get_replica_context()
     return ctx.all_reduce("SUM", w) + x
Example #49
0
 def get():
   rep_id = (
       distribution_strategy_context.get_replica_context()
       .replica_id_in_sync_group)
   return control_flow_ops.cond(
       math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.)
Example #50
0
def _replica_id():
  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
  if not isinstance(replica_id, ops.Tensor):
    replica_id = constant_op.constant(replica_id)
  return replica_id
 def mark_devices_fn():
   replica_id = self.evaluate(
       ds_context.get_replica_context().replica_id_in_sync_group)
   self.assertLess(replica_id, len(d.extended.worker_devices))
   self.assertFalse(expected_devices[replica_id])
   expected_devices[replica_id] = True
Example #52
0
def _replica_id_as_int():
  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
  if isinstance(replica_id, ops.Tensor):
    replica_id = tensor_util.constant_value(replica_id)
  return replica_id
def _all_sum(value):
  ctx = ds_context.get_replica_context()
  return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
Example #54
0
 def model_fn():
   traces.append(1)
   return ds_context.get_replica_context().replica_id_in_sync_group
def _merge_raises_fn():
  ds_context.get_replica_context().merge_call(_raise_exception_fn)
Example #56
0
 def model_fn():
   ds_context.get_replica_context().merge_call(merge_fn)
   return 0.
def _replica_id():
  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
  if not isinstance(replica_id, ops.Tensor):
    replica_id = constant_op.constant(replica_id)
  return replica_id
Example #58
0
 def model_fn():
   with ops.name_scope(None, "foo"):
     a = constant_op.constant(1.0, name="a")
     ds_context.get_replica_context().merge_call(lambda _: _)
     b = constant_op.constant(2.0, name="b")
   return a, b
Example #59
0
 def test_fn():
     replica_ctx = ds_context.get_replica_context()
     replica_ctx.merge_call(merge_fn, args=("bar", ))
Example #60
0
def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
    """Compute the moving average of a variable.

  The moving average of 'variable' updated with 'value' is:
    variable * decay + value * (1 - decay)

  The returned Operation sets 'variable' to the newly computed moving average,
  by performing this subtraction:
     variable -= (1 - decay) * (variable - value)

  Since variables that are initialized to a `0` value will be `0` biased,
  `zero_debias` optionally enables scaling by the mathematically correct
  debiasing factor of
    1 - decay ** num_updates
  See Section 3 of (Kingma et al., 2015) for more details.

  The names of the debias shadow variables, by default, include both the scope
  they were created in and the scope of the variables they debias. They are also
  given a uniquifying-suffix.

  E.g.:

  ```
    with tf.compat.v1.variable_scope('scope1'):
      with tf.compat.v1.variable_scope('scope2'):
        var = tf.compat.v1.get_variable('foo')
        update_1 = tf.assign_moving_average(var, 0.0, 1.0)
        update_2 = tf.assign_moving_average(var, 0.0, 0.9)

    # var.name: 'scope1/scope2/foo'
    # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
    #                   'scope1/scope2/scope1/scope2/foo/biased_1'
  ```

  Args:
    variable: A Variable.
    value: A tensor with the same shape as 'variable'.
    decay: A float Tensor or float value.  The moving average decay.
    zero_debias: A python bool. If true, assume the variable is 0-initialized
      and unbias it, as in (Kingma et al., 2015). See docstring in
        `_zero_debias` for more details.
    name: Optional name of the returned operation.

  Returns:
    A tensor which if evaluated will compute and return the new moving average.

  References:
    Adam - A Method for Stochastic Optimization:
      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
  """
    with ops.name_scope(name, "AssignMovingAvg",
                        [variable, value, decay]) as scope:
        decay = ops.convert_to_tensor(1.0 - decay, name="decay")
        if decay.dtype != variable.dtype.base_dtype:
            decay = math_ops.cast(decay, variable.dtype.base_dtype)

        def update_fn(v, value):
            return state_ops.assign_sub(v, (v - value) * decay, name=scope)

        def update(strategy, v, value):
            if zero_debias:
                return _zero_debias(strategy, v, value, decay)
            else:
                return _update(strategy, v, update_fn, args=(value, ))

        replica_context = distribution_strategy_context.get_replica_context()
        if replica_context:
            # In a replica context, we update variable using the mean of value across
            # replicas.
            def merge_fn(strategy, v, value):
                value = strategy.extended.reduce_to(
                    ds_reduce_util.ReduceOp.MEAN, value, v)
                return update(strategy, v, value)

            return replica_context.merge_call(merge_fn, args=(variable, value))
        else:
            strategy = distribution_strategy_context.get_cross_replica_context(
            )
            return update(strategy, variable, value)