def model_fn():
      replica_id_str = str(self.evaluate(_replica_id()))

      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + replica_id_str

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        ds_context.get_replica_context().merge_call(lambda _: _)
      return v
    def model_fn(device_id):
      assert isinstance(device_id, int)

      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribution_strategy_context.get_replica_context().merge_call(
            lambda _: _)
      return v
        def model_fn(device_id):
            assert isinstance(device_id, int)

            def thread_creator_fn(next_creator, *args, **kwargs):
                return next_creator(*args, **
                                    kwargs) + ":thread_" + str(device_id)

            with variable_scope.variable_creator_scope(thread_creator_fn):
                # Create a variable in this scope.
                v = variable_scope.variable(1.0)

                # This will pause the current thread, and execute the other thread.
                distribution_strategy_context.get_replica_context().merge_call(
                    lambda _: _)
            return v
Example #4
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_replica_context(),
             distribution_strategy_context.get_replica_context())
  t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
    def testMergeCall(self):
        _assert_in_default_state(self)

        def merge_fn(dist, s):
            self.assertIs(
                distribution_strategy_context.
                _get_default_distribution_strategy(), dist)
            self.assertIs(None,
                          distribution_strategy_context.get_replica_context())
            self.assertIs(
                dist,
                distribution_strategy_context.get_cross_replica_context())
            self.assertTrue(
                distribution_strategy_context.in_cross_replica_context())
            self.assertIs(
                dist,
                distribution_strategy_context.get_distribution_strategy())
            self.assertFalse(
                distribution_strategy_context.has_distribution_strategy())
            return "foo_" + s

        replica_ctx = distribution_strategy_context.get_replica_context()
        self.assertIs(
            distribution_strategy_context._get_default_replica_context(),
            replica_ctx)
        self.assertEqual("foo_bar",
                         replica_ctx.merge_call(merge_fn, args=("bar", )))
        _assert_in_default_state(self)
Example #6
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_replica_context(),
             distribution_strategy_context.get_replica_context())
  t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
Example #7
0
def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
    """Aggregate metric value across replicas."""
    def fn(distribution, *a):
        """Call `metric_value_fn` in the correct control flow context."""
        if hasattr(distribution.extended, '_outer_control_flow_context'):
            # If there was an outer context captured before this method was called,
            # then we enter that context to create the metric value op. If the
            # caputred context is `None`, ops.control_dependencies(None) gives the
            # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
            # captured context.
            # This special handling is needed because sometimes the metric is created
            # inside a while_loop (and perhaps a TPU rewrite context). But we don't
            # want the value op to be evaluated every step or on the TPU. So we
            # create it outside so that it can be evaluated at the end on the host,
            # once the update ops have been evaluted.

            # pylint: disable=protected-access
            if distribution.extended._outer_control_flow_context is None:
                with ops.control_dependencies(None):
                    metric_value = metric_value_fn(distribution, *a)
            else:
                distribution.extended._outer_control_flow_context.Enter()
                metric_value = metric_value_fn(distribution, *a)
                distribution.extended._outer_control_flow_context.Exit()
                # pylint: enable=protected-access
        else:
            metric_value = metric_value_fn(distribution, *a)
        if metrics_collections:
            ops.add_to_collections(metrics_collections, metric_value)
        return metric_value

    return distribution_strategy_context.get_replica_context().merge_call(
        fn, args=args)
def skip_summary():
  # If using multiple replicas in distributed strategy, skip summaries on all
  # replicas except the first one (replica_id=0).
  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last replica,
  # compute sum or mean across replicas).
  replica_context = distribution_strategy_context.get_replica_context()
  return replica_context and replica_context.replica_id > 0
Example #9
0
def skip_summary():
  # If using multiple replicas in distributed strategy, skip summaries on all
  # replicas except the first one (replica_id=0).
  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last replica,
  # compute sum or mean across replicas).
  replica_context = distribution_strategy_context.get_replica_context()
  # TODO(cjfj): Also check is sync group ID > 0?
  return replica_context and replica_context.replica_id_in_sync_group > 0
Example #10
0
def merge_grads(grads_and_vars):
    """Merge gradients from different replicas."""
    def merge_grad_fn(strategy, grads_and_vars):
        reduced_grads = strategy.batch_reduce(ds_reduce_util.ReduceOp.MEAN,
                                              grads_and_vars)
        return reduced_grads

    return distribution_strategy_context.get_replica_context().merge_call(
        merge_grad_fn, args=(grads_and_vars, ))
Example #11
0
def merge_grads(grads_and_vars):
    """Merge gradients from different replicas."""
    def merge_grad_fn(strategy, grads_and_vars):
        reduced_grads = strategy.batch_reduce(
            variable_scope.VariableAggregation.MEAN, grads_and_vars)
        return reduced_grads

    return distribution_strategy_context.get_replica_context().merge_call(
        merge_grad_fn, grads_and_vars)
Example #12
0
def merge_grads(grads_and_vars):
  """Merge gradients from different replicas."""

  def merge_grad_fn(strategy, grads_and_vars):
    reduced_grads = strategy.batch_reduce(
        ds_reduce_util.ReduceOp.MEAN, grads_and_vars)
    return reduced_grads

  return distribution_strategy_context.get_replica_context().merge_call(
      merge_grad_fn, args=(grads_and_vars,))
Example #13
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
Example #14
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
Example #15
0
def merge_update_step(update_ops, local_step):
    """Merge local step counter update from different replicas."""
    def merge_update_step_fn(strategy, update_ops, local_step):
        merged_ops = []
        for update_op in update_ops:
            merged_ops.append(strategy.group(update_op))
        with ops.control_dependencies(merged_ops):
            incre_op = local_step.assign_add(1).op
        return incre_op

    return distribution_strategy_context.get_replica_context().merge_call(
        merge_update_step_fn, update_ops, local_step)
Example #16
0
def merge_update_step(update_ops, local_step):
  """Merge local step counter update from different replicas."""

  def merge_update_step_fn(strategy, update_ops, local_step):
    merged_ops = []
    for update_op in update_ops:
      merged_ops.append(strategy.group(update_op))
    with ops.control_dependencies(merged_ops):
      incre_op = local_step.assign_add(1).op
    return incre_op

  return distribution_strategy_context.get_replica_context().merge_call(
      merge_update_step_fn, args=(update_ops, local_step))
Example #17
0
 def run_fn():
   replica_context = distribution_strategy_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertTrue(distribution_strategy_context.has_distribution_strategy())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Example #18
0
 def run_fn():
   replica_context = distribution_strategy_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertTrue(distribution_strategy_context.has_distribution_strategy())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Example #19
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_replica_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_distribution_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Example #20
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_replica_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_distribution_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Example #21
0
  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(
          distribution_strategy_context._get_default_distribution_strategy(),
          dist)
      self.assertIs(None, distribution_strategy_context.get_replica_context())
      self.assertIs(dist,
                    distribution_strategy_context.get_cross_replica_context())
      self.assertIs(dist,
                    distribution_strategy_context.get_distribution_strategy())
      self.assertFalse(
          distribution_strategy_context.has_distribution_strategy())
      return "foo_" + s

    replica_ctx = distribution_strategy_context.get_replica_context()
    self.assertIs(distribution_strategy_context._get_default_replica_context(),
                  replica_ctx)
    self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, "bar"))
    _assert_in_default_state(self)
Example #22
0
def skip_summary():
  """Determines if summary should be skipped.

  If using multiple replicas in distributed strategy, skip summaries on all
  replicas except the first one (replica_id=0).

  Returns:
    True if the summary is skipped; False otherwise.
  """

  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last replica,
  # compute sum or mean across replicas).
  replica_context = distribution_strategy_context.get_replica_context()
  if not replica_context:
    return False
  # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
  # initialized, remember to change here as well.
  replica_id = replica_context.replica_id_in_sync_group
  if isinstance(replica_id, ops.Tensor):
    replica_id = tensor_util.constant_value(replica_id)
  return replica_id and replica_id > 0
Example #23
0
  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(merge_fn_wrapper, result_fn, *args)
    check_is_tensor_or_operation(result_t,
                                 'Metric {0}\'s result'.format(metric_obj.name))
    return result_t
Example #24
0
  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(merge_fn_wrapper, result_fn, *args)
    check_is_tensor_or_operation(result_t,
                                 'Metric {0}\'s result'.format(metric_obj.name))
    return result_t
Example #25
0
def f1_score(labels, predictions, weights=None, num_thresholds=200,
             metrics_collections=None, updates_collections=None, name=None):
  """Computes the approximately best F1-score across different thresholds.

  The f1_score function applies a range of thresholds to the predictions to
  convert them from [0, 1] to bool. Precision and recall are computed by
  comparing them to the labels. The F1-Score is then defined as
  2 * precision * recall / (precision + recall). The best one across the
  thresholds is returned.

  Disclaimer: In practice it may be desirable to choose the best threshold on
  the validation set and evaluate the F1 score with this threshold on a
  separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5).

  This function internally creates four local variables, `true_positives`,
  `true_negatives`, `false_positives` and `false_negatives` that are used to
  compute the pairs of recall and precision values for a linearly spaced set of
  thresholds from which the best f1-score is derived.

  This value is ultimately returned as `f1-score`, an idempotent operation that
  computes the F1-score (computed using the aforementioned variables). The
  `num_thresholds` variable controls the degree of discretization with larger
  numbers of thresholds more closely approximating the true best F1-score.

  For estimation of the metric over a stream of data, the function creates an
  `update_op` operation that updates these variables and returns the F1-score.

  Example usage with a custom estimator:
  def model_fn(features, labels, mode):
    predictions = make_predictions(features)
    loss = make_loss(predictions, labels)
    train_op = tf.contrib.training.create_train_op(
          total_loss=loss,
          optimizer='Adam')
    eval_metric_ops = {'f1': f1_score(labels, predictions)}
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs)
  estimator = tf.estimator.Estimator(model_fn=model_fn)

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
      `bool`.
    predictions: A floating point `Tensor` of arbitrary shape and whose values
      are in the range `[0, 1]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    num_thresholds: The number of thresholds to use when discretizing the roc
      curve.
    metrics_collections: An optional list of collections that `f1_score` should
      be added to.
    updates_collections: An optional list of collections that `update_op` should
      be added to.
    name: An optional variable_scope name.

  Returns:
    f1_score: A scalar `Tensor` representing the current best f1-score across
      different thresholds.
    update_op: An operation that increments the `true_positives`,
      `true_negatives`, `false_positives` and `false_negatives` variables
      appropriately and whose value matches the `f1_score`.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
  """
  with variable_scope.variable_scope(
      name, 'f1', (labels, predictions, weights)):
    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
        predictions=predictions, labels=labels, weights=weights)
    # To account for floating point imprecisions / avoid division by zero.
    epsilon = 1e-7
    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
                  for i in range(num_thresholds - 2)]
    thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon]

    # Confusion matrix.
    values, update_ops = metrics_impl._confusion_matrix_at_thresholds(  # pylint: disable=protected-access
        labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn'))

    # Compute precision and recall at various thresholds.
    def compute_best_f1_score(tp, fp, fn, name):
      precision_at_t = math_ops.div(tp, epsilon + tp + fp,
                                    name='precision_' + name)
      recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
      # Compute F1 score.
      f1_at_thresholds = (
          2.0 * precision_at_t * recall_at_t /
          (precision_at_t + recall_at_t + epsilon))
      return math_ops.reduce_max(f1_at_thresholds)

    def f1_across_replicas(_, values):
      best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'],
                                      fn=values['fn'], name='value')
      if metrics_collections:
        ops.add_to_collections(metrics_collections, best_f1)
      return best_f1

    best_f1 = distribution_strategy_context.get_replica_context().merge_call(
        f1_across_replicas, args=(values,))

    update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
                                      fn=update_ops['fn'], name='update')
    if updates_collections:
      ops.add_to_collections(updates_collections, update_op)

    return best_f1, update_op
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
    """Initializes current variables with tensors loaded from given checkpoint.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:

  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    current `scope_name` from `checkpoint_scope_name` with matching tensor
    names.
  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    will initialize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.
  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    object with tensor 'scope_variable_name' from the checkpoint.
  * `'scope_variable_name': list(variable)` - will initialize list of
    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  `'<variable>/part_<part #>'`.

  Example:

  ```python

  # Say, '/tmp/model.ckpt' has the following tensors:
  #  -- name='old_scope_1/var1', shape=[20, 2]
  #  -- name='old_scope_1/var2', shape=[50, 4]
  #  -- name='old_scope_2/var3', shape=[100, 100]

  # Create new model's variables
  with tf.variable_scope('new_scope_1'):
    var1 = tf.get_variable('var1', shape=[20, 2],
                           initializer=tf.zeros_initializer())
  with tf.variable_scope('new_scope_2'):
    var2 = tf.get_variable('var2', shape=[50, 4],
                           initializer=tf.zeros_initializer())
    # Partition into 5 variables along the first axis.
    var3 = tf.get_variable(name='var3', shape=[100, 100],
                           initializer=tf.zeros_initializer(),
                           partitioner=lambda shape, dtype: [5, 1])

  # Initialize all variables in `new_scope_1` from `old_scope_1`.
  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})

  # Use names to specify which variables to initialize from checkpoint.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': 'new_scope_1/var1',
                        'old_scope_1/var2': 'new_scope_2/var2'})

  # Or use tf.Variable objects to identify what to initialize.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': var1,
                        'old_scope_1/var2': var2})

  # Initialize partitioned variables using variable's name
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': 'new_scope_2/var3'})

  # Or specify the list of tf.Variable objects.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': var3._get_variable_list()})

  ```

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
    if distribution_strategy_context.get_cross_replica_context():
        _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
    else:
        distribution_strategy_context.get_replica_context().merge_call(
            _init_from_checkpoint, ckpt_dir_or_file, assignment_map)
      def model_fn():
        if 'CPU' in compute_device:
          replica_compute_device = '/device:CPU:0'
        else:
          replica_compute_device = (
              '/device:GPU:%d' %
              distribution_strategy_context.get_replica_context().replica_id)
        replica_compute_device = device_util.canonicalize(
            replica_compute_device)

        if 'CPU' in variable_device:
          replica_variable_device = '/device:CPU:0'
        else:
          replica_variable_device = (
              '/device:GPU:%d' %
              distribution_strategy_context.get_replica_context().replica_id)
        replica_variable_device = device_util.canonicalize(
            replica_variable_device)

        a = constant_op.constant(1.0)
        b = constant_op.constant(2.0)
        c = a + b
        self.assertEqual(a.device, replica_compute_device)
        self.assertEqual(b.device, replica_compute_device)
        self.assertEqual(c.device, replica_compute_device)

        # The device scope is ignored for variables but not for normal ops.
        with ops.device('/device:GPU:2'):
          x = variable_scope.get_variable(
              'x', initializer=10.0,
              aggregation=variable_scope.VariableAggregation.SUM)
          x_add = x.assign_add(c)
          e = a + c
        self.assertEqual(
            device_util.canonicalize(x.device), replica_variable_device)
        self.assertEqual(x_add.device, x.device)
        self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2'))

        # The colocate_vars_with can override the distribution's device.
        with d.colocate_vars_with(x):
          y = variable_scope.get_variable(
              'y', initializer=20.0,
              aggregation=variable_scope.VariableAggregation.SUM)
        # We add an identity here to avoid complaints about summing
        # non-distributed values.
        y_add = y.assign_add(array_ops.identity(x_add))
        self.assertEqual(
            device_util.canonicalize(y.device), replica_variable_device)
        self.assertEqual(y_add.device, y.device)
        self.assertEqual(y.device, x.device)

        z = variable_scope.get_variable(
            'z', initializer=10.0,
            aggregation=variable_scope.VariableAggregation.SUM)
        self.assertEqual(
            device_util.canonicalize(z.device), replica_variable_device)

        with ops.control_dependencies([y_add]):
          # We add an identity here to avoid complaints about summing
          # non-distributed values.
          z_add = z.assign_add(array_ops.identity(y))
        with ops.control_dependencies([z_add]):
          f = z + c
        self.assertEqual(f.device, replica_compute_device)

        # The device scope would merge with the default worker device.
        with ops.device('/CPU:1'):
          g = e + 1.0
        self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))

        # Ths ops.colocate_with will be ignored when defining a variale but not
        # for a normal tensor.
        with ops.colocate_with(x):
          u = variable_scope.get_variable('u', initializer=30.0)
          h = f + 1.0
        self.assertEqual(
            device_util.canonicalize(u.device), replica_variable_device)
        self.assertEqual(
            device_util.canonicalize(x.device),
            device_util.canonicalize(h.device))
        return y_add, z_add, f
def _replica_id():
    # TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
    return constant_op.constant(
        ds_context.get_replica_context().replica_id_in_sync_group)
            def model_fn():
                if num_gpus == 0:
                    last_part_device = 'device:CPU:0'
                else:
                    last_part_device = ('device:GPU:%d' %
                                        ds_context.get_replica_context(
                                        ).replica_id_in_sync_group)

                a = constant_op.constant(1.0)
                b = constant_op.constant(2.0)
                c = a + b
                self.assertEqual(a.device,
                                 worker_device + '/' + last_part_device)
                self.assertEqual(b.device,
                                 worker_device + '/' + last_part_device)
                self.assertEqual(c.device,
                                 worker_device + '/' + last_part_device)

                # The device scope is ignored for variables but not for normal ops.
                with ops.device('/job:worker/task:0'):
                    x = variable_scope.get_variable(
                        'x',
                        initializer=10.0,
                        aggregation=variable_scope.VariableAggregation.SUM)
                    x_add = x.assign_add(c)
                    e = a + c
                # The variable x is on the task 1 since the device_function has been
                # called once before the model_fn.
                self.assertEqual(x.device, '/job:ps/task:1')
                self.assertEqual(x_add.device, x.device)
                self.assertEqual(
                    e.device,
                    '/job:worker/replica:0/task:0/%s' % last_part_device)

                # The colocate_vars_with can override the distribution's device.
                with d.colocate_vars_with(x):
                    y = variable_scope.get_variable(
                        'y',
                        initializer=20.0,
                        aggregation=variable_scope.VariableAggregation.SUM)
                # We add an identity here to avoid complaints about summing
                # non-distributed values.
                y_add = y.assign_add(array_ops.identity(x_add))
                self.assertEqual(y.device, '/job:ps/task:1')
                self.assertEqual(y_add.device, y.device)
                self.assertEqual(y.device, x.device)

                z = variable_scope.get_variable(
                    'z',
                    initializer=10.0,
                    aggregation=variable_scope.VariableAggregation.SUM)
                self.assertEqual(z.device, '/job:ps/task:0')
                self.assertNotEqual(z.device, x.device)

                with ops.control_dependencies([y_add]):
                    # We add an identity here to avoid complaints about summing
                    # non-distributed values.
                    z_add = z.assign_add(array_ops.identity(y))
                with ops.control_dependencies([z_add]):
                    f = z + c
                self.assertEqual(f.device,
                                 worker_device + '/' + last_part_device)

                # The device scope would merge with the default worker device.
                with ops.device('/CPU:1'):
                    g = e + 1.0
                self.assertEqual(g.device, worker_device + '/device:CPU:1')

                # Ths ops.colocate_with will be ignored when defining a variale but not
                # for a normal tensor.
                with ops.colocate_with(x):
                    u = variable_scope.get_variable('u', initializer=30.0)
                    v = variable_scope.get_variable('v', initializer=30.0)
                    h = f + 1.0
                self.assertIn('/job:ps/', u.device)
                self.assertIn('/job:ps/', v.device)
                # u and v are on different parameter servers.
                self.assertTrue(u.device != x.device or v.device != x.device)
                self.assertTrue(u.device == x.device or v.device == x.device)
                # Here h is not on one worker. Note h.device is canonical while x.device
                # is not but.
                self.assertIn('/job:ps/', h.device)
                return y_add, z_add, f
def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
  """Compute the moving average of a variable.

  The moving average of 'variable' updated with 'value' is:
    variable * decay + value * (1 - decay)

  The returned Operation sets 'variable' to the newly computed moving average,
  by performing this subtraction:
     variable -= (1 - decay) * (variable - value)

  Since variables that are initialized to a `0` value will be `0` biased,
  `zero_debias` optionally enables scaling by the mathematically correct
  debiasing factor of
    1 - decay ** num_updates
  See `ADAM: A Method for Stochastic Optimization` Section 3 for more details
  (https://arxiv.org/abs/1412.6980).

  The names of the debias shadow variables, by default, include both the scope
  they were created in and the scope of the variables they debias. They are also
  given a uniquifying-suffix.

  E.g.:

  ```
    with tf.variable_scope('scope1'):
      with tf.variable_scope('scope2'):
        var = tf.get_variable('foo')
        update_1 = tf.assign_moving_average(var, 0.0, 1.0)
        update_2 = tf.assign_moving_average(var, 0.0, 0.9)

    # var.name: 'scope1/scope2/foo'
    # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
    #                   'scope1/scope2/scope1/scope2/foo/biased_1'
  ```

  Args:
    variable: A Variable.
    value: A tensor with the same shape as 'variable'.
    decay: A float Tensor or float value.  The moving average decay.
    zero_debias: A python bool. If true, assume the variable is 0-initialized
      and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
      `_zero_debias` for more details.
    name: Optional name of the returned operation.

  Returns:
    A tensor which if evaluated will compute and return the new moving average.
  """
  def update_fn(v, value, decay=decay):
    decay = ops.convert_to_tensor(1.0 - decay, name="decay")
    if decay.dtype != v.dtype.base_dtype:
      decay = math_ops.cast(decay, v.dtype.base_dtype)
    if zero_debias:
      update_delta = _zero_debias(v, value, decay)
    else:
      update_delta = (v - value) * decay
    return state_ops.assign_sub(v, update_delta, name=scope)

  with ops.name_scope(name, "AssignMovingAvg",
                      [variable, value, decay]) as scope:
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context:
      # In a replica context, we update variable using the mean of value across
      # replicas.
      def merge_fn(strategy, v, value):
        value = strategy.extended.reduce_to(
            ds_reduce_util.ReduceOp.MEAN, value, v)
        return strategy.update(v, update_fn, value)

      return replica_context.merge_call(merge_fn, args=(variable, value))
    else:
      strategy = distribution_strategy_context.get_cross_replica_context()
      return strategy.update(variable, update_fn, value)
            def model_fn():
                if 'CPU' in compute_device:
                    replica_compute_device = '/device:CPU:0'
                else:
                    replica_compute_device = ('/device:GPU:%d' %
                                              ds_context.get_replica_context(
                                              ).replica_id_in_sync_group)
                replica_compute_device = device_util.canonicalize(
                    replica_compute_device)

                if 'CPU' in variable_device:
                    replica_variable_device = '/device:CPU:0'
                else:
                    replica_variable_device = ('/device:GPU:%d' %
                                               ds_context.get_replica_context(
                                               ).replica_id_in_sync_group)
                replica_variable_device = device_util.canonicalize(
                    replica_variable_device)

                a = constant_op.constant(1.0)
                b = constant_op.constant(2.0)
                c = a + b
                self.assertEqual(a.device, replica_compute_device)
                self.assertEqual(b.device, replica_compute_device)
                self.assertEqual(c.device, replica_compute_device)

                # The device scope is ignored for variables but not for normal ops.
                with ops.device('/device:GPU:2'):
                    x = variable_scope.get_variable(
                        'x',
                        initializer=10.0,
                        aggregation=variable_scope.VariableAggregation.SUM)
                    x_add = x.assign_add(c)
                    e = a + c
                self.assertEqual(device_util.canonicalize(x.device),
                                 replica_variable_device)
                self.assertEqual(x_add.device, x.device)
                self.assertEqual(e.device,
                                 device_util.canonicalize('/device:GPU:2'))

                # The colocate_vars_with can override the distribution's device.
                with d.colocate_vars_with(x):
                    y = variable_scope.get_variable(
                        'y',
                        initializer=20.0,
                        aggregation=variable_scope.VariableAggregation.SUM)
                # We add an identity here to avoid complaints about summing
                # non-distributed values.
                y_add = y.assign_add(array_ops.identity(x_add))
                self.assertEqual(device_util.canonicalize(y.device),
                                 replica_variable_device)
                self.assertEqual(y_add.device, y.device)
                self.assertEqual(y.device, x.device)

                z = variable_scope.get_variable(
                    'z',
                    initializer=10.0,
                    aggregation=variable_scope.VariableAggregation.SUM)
                self.assertEqual(device_util.canonicalize(z.device),
                                 replica_variable_device)

                with ops.control_dependencies([y_add]):
                    # We add an identity here to avoid complaints about summing
                    # non-distributed values.
                    z_add = z.assign_add(array_ops.identity(y))
                with ops.control_dependencies([z_add]):
                    f = z + c
                self.assertEqual(f.device, replica_compute_device)

                # The device scope would merge with the default worker device.
                with ops.device('/CPU:1'):
                    g = e + 1.0
                self.assertEqual(g.device,
                                 device_util.canonicalize('/device:CPU:1'))

                # Ths ops.colocate_with will be ignored when defining a variale but not
                # for a normal tensor.
                with ops.colocate_with(x):
                    u = variable_scope.get_variable('u', initializer=30.0)
                    h = f + 1.0
                self.assertEqual(device_util.canonicalize(u.device),
                                 replica_variable_device)
                self.assertEqual(device_util.canonicalize(x.device),
                                 device_util.canonicalize(h.device))
                return y_add, z_add, f
def _replica_id():
  # TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
  return constant_op.constant(
      distribution_strategy_context.get_replica_context().replica_id)
def _get_replica_id_integer():
  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
  if isinstance(replica_id, ops.Tensor):
    replica_id = tensor_util.constant_value(replica_id)
  return replica_id
Example #34
0
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
  """Initializes current variables with tensors loaded from given checkpoint.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:

  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    current `scope_name` from `checkpoint_scope_name` with matching tensor
    names.
  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    will initialize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.
  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    object with tensor 'scope_variable_name' from the checkpoint.
  * `'scope_variable_name': list(variable)` - will initialize list of
    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  `'<variable>/part_<part #>'`.

  Example:

  ```python

  # Say, '/tmp/model.ckpt' has the following tensors:
  #  -- name='old_scope_1/var1', shape=[20, 2]
  #  -- name='old_scope_1/var2', shape=[50, 4]
  #  -- name='old_scope_2/var3', shape=[100, 100]

  # Create new model's variables
  with tf.variable_scope('new_scope_1'):
    var1 = tf.get_variable('var1', shape=[20, 2],
                           initializer=tf.zeros_initializer())
  with tf.variable_scope('new_scope_2'):
    var2 = tf.get_variable('var2', shape=[50, 4],
                           initializer=tf.zeros_initializer())
    # Partition into 5 variables along the first axis.
    var3 = tf.get_variable(name='var3', shape=[100, 100],
                           initializer=tf.zeros_initializer(),
                           partitioner=lambda shape, dtype: [5, 1])

  # Initialize all variables in `new_scope_1` from `old_scope_1`.
  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})

  # Use names to specify which variables to initialize from checkpoint.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': 'new_scope_1/var1',
                        'old_scope_1/var2': 'new_scope_2/var2'})

  # Or use tf.Variable objects to identify what to initialize.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': var1,
                        'old_scope_1/var2': var2})

  # Initialize partitioned variables using variable's name
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': 'new_scope_2/var3'})

  # Or specify the list of tf.Variable objects.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': var3._get_variable_list()})

  ```

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
  if distribution_strategy_context.get_cross_replica_context():
    _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
  else:
    distribution_strategy_context.get_replica_context().merge_call(
        _init_from_checkpoint, ckpt_dir_or_file, assignment_map)
def _merge_call_merge_raises_fn():
  ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
 def mark_devices_fn():
   replica_id = self.evaluate(
       ds_context.get_replica_context().replica_id_in_sync_group)
   self.assertLess(replica_id, len(d.extended.worker_devices))
   self.assertFalse(expected_devices[replica_id])
   expected_devices[replica_id] = True
Example #37
0
 def mark_devices_fn():
   replica_id = (
       distribution_strategy_context.get_replica_context().replica_id)
   self.assertLess(replica_id, len(d.worker_devices))
   self.assertFalse(expected_devices[replica_id])
   expected_devices[replica_id] = True
def f1_score(labels, predictions, weights=None, num_thresholds=200,
             metrics_collections=None, updates_collections=None, name=None):
  """Computes the approximately best F1-score across different thresholds.

  The f1_score function applies a range of thresholds to the predictions to
  convert them from [0, 1] to bool. Precision and recall are computed by
  comparing them to the labels. The F1-Score is then defined as
  2 * precision * recall / (precision + recall). The best one across the
  thresholds is returned.

  Disclaimer: In practice it may be desirable to choose the best threshold on
  the validation set and evaluate the F1 score with this threshold on a
  separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5).

  This function internally creates four local variables, `true_positives`,
  `true_negatives`, `false_positives` and `false_negatives` that are used to
  compute the pairs of recall and precision values for a linearly spaced set of
  thresholds from which the best f1-score is derived.

  This value is ultimately returned as `f1-score`, an idempotent operation that
  computes the F1-score (computed using the aforementioned variables). The
  `num_thresholds` variable controls the degree of discretization with larger
  numbers of thresholds more closely approximating the true best F1-score.

  For estimation of the metric over a stream of data, the function creates an
  `update_op` operation that updates these variables and returns the F1-score.

  Example usage with a custom estimator:
  def model_fn(features, labels, mode):
    predictions = make_predictions(features)
    loss = make_loss(predictions, labels)
    train_op = tf.contrib.training.create_train_op(
          total_loss=loss,
          optimizer='Adam')
    eval_metric_ops = {'f1': f1_score(labels, predictions)}
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs)
  estimator = tf.estimator.Estimator(model_fn=model_fn)

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
      `bool`.
    predictions: A floating point `Tensor` of arbitrary shape and whose values
      are in the range `[0, 1]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    num_thresholds: The number of thresholds to use when discretizing the roc
      curve.
    metrics_collections: An optional list of collections that `f1_score` should
      be added to.
    updates_collections: An optional list of collections that `update_op` should
      be added to.
    name: An optional variable_scope name.

  Returns:
    f1_score: A scalar `Tensor` representing the current best f1-score across
      different thresholds.
    update_op: An operation that increments the `true_positives`,
      `true_negatives`, `false_positives` and `false_negatives` variables
      appropriately and whose value matches the `f1_score`.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
  """
  with variable_scope.variable_scope(
      name, 'f1', (labels, predictions, weights)):
    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
        predictions=predictions, labels=labels, weights=weights)
    # To account for floating point imprecisions / avoid division by zero.
    epsilon = 1e-7
    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
                  for i in range(num_thresholds - 2)]
    thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon]

    # Confusion matrix.
    values, update_ops = metrics_impl._confusion_matrix_at_thresholds(  # pylint: disable=protected-access
        labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn'))

    # Compute precision and recall at various thresholds.
    def compute_best_f1_score(tp, fp, fn, name):
      precision_at_t = math_ops.div(tp, epsilon + tp + fp,
                                    name='precision_' + name)
      recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
      # Compute F1 score.
      f1_at_thresholds = (
          2.0 * precision_at_t * recall_at_t /
          (precision_at_t + recall_at_t + epsilon))
      return math_ops.reduce_max(f1_at_thresholds)

    def f1_across_replicas(_, values):
      best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'],
                                      fn=values['fn'], name='value')
      if metrics_collections:
        ops.add_to_collections(metrics_collections, best_f1)
      return best_f1

    best_f1 = distribution_strategy_context.get_replica_context().merge_call(
        f1_across_replicas, values)

    update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
                                      fn=update_ops['fn'], name='update')
    if updates_collections:
      ops.add_to_collections(updates_collections, update_op)

    return best_f1, update_op
def _merge_raises_fn():
  ds_context.get_replica_context().merge_call(_raise_exception_fn)
Example #40
0
def _get_replica_id_integer():
    replica_id = ds_context.get_replica_context().replica_id_in_sync_group
    if isinstance(replica_id, ops.Tensor):
        replica_id = tensor_util.constant_value(replica_id)
    return replica_id
def _merge_call_merge_raises_fn():
    distribution_strategy_context.get_replica_context().merge_call(
        _call_merge_raises_fn)
Example #42
0
 def mark_devices_fn():
   replica_id = (
       distribution_strategy_context.get_replica_context().replica_id)
   self.assertLess(replica_id, len(d.worker_devices))
   self.assertFalse(expected_devices[replica_id])
   expected_devices[replica_id] = True
def _replica_id():
  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
  if not isinstance(replica_id, ops.Tensor):
    replica_id = constant_op.constant(replica_id)
  return replica_id
Example #44
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

    # Handle DistributionStrategy case.
    if distribute_ctx.get_cross_replica_context():
      raise RuntimeError("Use `_distributed_apply()` instead of "
                         "`apply_gradients()` in a cross-replica context.")
    # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
    # always calling _distributed_apply(), using the default distribution
    # as needed.
    if distribute_ctx.has_distribution_strategy():
      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
      return distribute_ctx.get_replica_context().merge_call(
          self._distributed_apply, args=(grads_and_vars, global_step, name))

    # No DistributionStrategy case.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    converted_grads_and_vars = []
    for g, v in grads_and_vars:
      if g is not None:
        try:
          # Convert the grad to Tensor or IndexedSlices if necessary.
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)
      converted_grads_and_vars.append((g, v, p))

    converted_grads_and_vars = tuple(converted_grads_and_vars)
    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    if not var_list:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, v, _ in converted_grads_and_vars],))
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []
    with ops.name_scope(name, self._name) as name:
      self._prepare()
      for grad, var, processor in converted_grads_and_vars:
        if grad is None:
          continue
        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        if context.executing_eagerly() or isinstance(
            var,
            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
          scope_name = ""
        else:
          scope_name = var.op.name
        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
          update_ops.append(processor.update_op(self, grad))
      if global_step is None:
        apply_updates = self._finish(update_ops, name)
      else:
        with ops.control_dependencies([self._finish(update_ops, "update")]):
          with ops.colocate_with(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              # TODO(apassos): the implicit read in assign_add is slow; consider
              # making it less so.
              apply_updates = resource_variable_ops.assign_add_variable_op(
                  global_step.handle,
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  name=name)
            else:
              apply_updates = state_ops.assign_add(global_step, 1, name=name)

      if not context.executing_eagerly():
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates
 def mark_devices_fn():
     replica_id = ds_context.get_replica_context(
     ).replica_id_in_sync_group
     self.assertLess(replica_id, len(d.worker_devices))
     self.assertFalse(expected_devices[replica_id])
     expected_devices[replica_id] = True
def _merge_raises_fn():
    ds_context.get_replica_context().merge_call(_raise_exception_fn)
Example #47
0
def _merge_call_merge_raises_fn():
  distribution_strategy_context.get_replica_context().merge_call(
      _call_merge_raises_fn)
def _merge_call_merge_raises_fn():
    ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
def _replica_id():
    replica_id = ds_context.get_replica_context().replica_id_in_sync_group
    if not isinstance(replica_id, ops.Tensor):
        replica_id = constant_op.constant(replica_id)
    return replica_id
      def model_fn():
        if num_gpus == 0:
          last_part_device = 'device:CPU:0'
        else:
          last_part_device = (
              'device:GPU:%d' %
              distribution_strategy_context.get_replica_context().replica_id)

        a = constant_op.constant(1.0)
        b = constant_op.constant(2.0)
        c = a + b
        self.assertEqual(a.device, worker_device + '/' + last_part_device)
        self.assertEqual(b.device, worker_device + '/' + last_part_device)
        self.assertEqual(c.device, worker_device + '/' + last_part_device)

        # The device scope is ignored for variables but not for normal ops.
        with ops.device('/job:worker/task:0'):
          x = variable_scope.get_variable(
              'x', initializer=10.0,
              aggregation=variable_scope.VariableAggregation.SUM)
          x_add = x.assign_add(c)
          e = a + c
        # The variable x is on the task 1 since the device_function has been
        # called once before the model_fn.
        self.assertEqual(x.device, '/job:ps/task:1')
        self.assertEqual(x_add.device, x.device)
        self.assertEqual(e.device,
                         '/job:worker/replica:0/task:0/%s' % last_part_device)

        # The colocate_vars_with can override the distribution's device.
        with d.colocate_vars_with(x):
          y = variable_scope.get_variable(
              'y', initializer=20.0,
              aggregation=variable_scope.VariableAggregation.SUM)
        # We add an identity here to avoid complaints about summing
        # non-distributed values.
        y_add = y.assign_add(array_ops.identity(x_add))
        self.assertEqual(y.device, '/job:ps/task:1')
        self.assertEqual(y_add.device, y.device)
        self.assertEqual(y.device, x.device)

        z = variable_scope.get_variable(
            'z', initializer=10.0,
            aggregation=variable_scope.VariableAggregation.SUM)
        self.assertEqual(z.device, '/job:ps/task:0')
        self.assertNotEqual(z.device, x.device)

        with ops.control_dependencies([y_add]):
          # We add an identity here to avoid complaints about summing
          # non-distributed values.
          z_add = z.assign_add(array_ops.identity(y))
        with ops.control_dependencies([z_add]):
          f = z + c
        self.assertEqual(f.device, worker_device + '/' + last_part_device)

        # The device scope would merge with the default worker device.
        with ops.device('/CPU:1'):
          g = e + 1.0
        self.assertEqual(g.device, worker_device + '/device:CPU:1')

        # Ths ops.colocate_with will be ignored when defining a variale but not
        # for a normal tensor.
        with ops.colocate_with(x):
          u = variable_scope.get_variable('u', initializer=30.0)
          v = variable_scope.get_variable('v', initializer=30.0)
          h = f + 1.0
        self.assertIn('/job:ps/', u.device)
        self.assertIn('/job:ps/', v.device)
        # u and v are on different parameter servers.
        self.assertTrue(u.device != x.device or v.device != x.device)
        self.assertTrue(u.device == x.device or v.device == x.device)
        # Here h is not on one worker. Note h.device is canonical while x.device
        # is not but.
        self.assertIn('/job:ps/', h.device)
        return y_add, z_add, f