def model_fn():
   with ops.name_scope(None, "foo"):
     a = constant_op.constant(1.0, name="a")
     distribution_strategy_context.get_tower_context().merge_call(
         lambda _: _)
     b = constant_op.constant(2.0, name="b")
   return a, b
Esempio n. 2
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribution_strategy_context.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            if update_device is not None:
                # We are calling an assign function in an update context.
                return f(self._v, *args, **kwargs)

            # We are calling an assign function in cross tower context, wrap it in an
            # update call.
            return distribution_strategy_context.get_distribution_strategy(
            ).update(self, f, *args, **kwargs)
        else:
            assert distribution_strategy_context.get_tower_context()
            # We are calling an assign function in tower context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function with the reduced value.
            if self._aggregation == vs.VariableAggregation.NONE:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "a variable in Tower Context.")

            def merge_fn(strategy, value, *other_args, **other_kwargs):
                return strategy.update(
                    self, f,
                    strategy.reduce(aggregation=self._aggregation,
                                    value=value,
                                    destinations=self), *other_args,
                    **other_kwargs)

            return distribution_strategy_context.get_tower_context(
            ).merge_call(merge_fn, *args, **kwargs)
Esempio n. 3
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribution_strategy_context.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      if update_device is not None:
        # We are calling an assign function in an update context.
        return f(self._v, *args, **kwargs)

      # We are calling an assign function in cross tower context, wrap it in an
      # update call.
      return distribution_strategy_context.get_distribution_strategy().update(
          self, f, *args, **kwargs)
    else:
      assert distribution_strategy_context.get_tower_context()
      # We are calling an assign function in tower context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function with the reduced value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "a variable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, *args, **kwargs)
Esempio n. 4
0
 def model_fn():
   with ops.name_scope(None, "foo"):
     a = constant_op.constant(1.0, name="a")
     distribution_strategy_context.get_tower_context().merge_call(
         lambda _: _)
     b = constant_op.constant(2.0, name="b")
   return a, b
Esempio n. 5
0
 def model_fn():
   vs = []
   vs.append(variable_scope.variable(1.0, name="foo/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
   vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
   distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
   return vs
 def model_fn():
   vs = []
   vs.append(variable_scope.variable(1.0, name="foo/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
   vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
   distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
   return vs
Esempio n. 7
0
 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribution_strategy_context.get_cross_tower_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as aggregation doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribution_strategy_context.get_tower_context().merge_call(
         merge_fn, output)
Esempio n. 8
0
 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribution_strategy_context.get_cross_tower_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as aggregation doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribution_strategy_context.get_tower_context().merge_call(
         merge_fn, output)
    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribution_strategy_context.get_tower_context().merge_call(
            lambda _: _)
      return v
 def model_fn(features):
   with variable_scope.variable_scope("common"):
     layer1 = core.Dense(1)
     layer1(features)
     layer2 = core.Dense(1)
     layer2(features)
     # This will pause the current thread, and execute the other thread.
     distribution_strategy_context.get_tower_context().merge_call(
         lambda _: _)
     layer3 = core.Dense(1)
     layer3(features)
     return [(layer1.kernel, layer1.bias),
             (layer2.kernel, layer2.bias),
             (layer3.kernel, layer3.bias)]
Esempio n. 11
0
 def model_fn(features):
   with variable_scope.variable_scope("common"):
     layer1 = core.Dense(1)
     layer1(features)
     layer2 = core.Dense(1)
     layer2(features)
     # This will pause the current thread, and execute the other thread.
     distribution_strategy_context.get_tower_context().merge_call(
         lambda _: _)
     layer3 = core.Dense(1)
     layer3(features)
     return [(layer1.kernel, layer1.bias),
             (layer2.kernel, layer2.bias),
             (layer3.kernel, layer3.bias)]
Esempio n. 12
0
def _aggregate_across_towers(metrics_collections, metric_value_fn, *args):
    """Aggregate metric value across towers."""
    def fn(distribution, *a):
        """Call `metric_value_fn` in the correct control flow context."""
        if hasattr(distribution, '_outer_control_flow_context'):
            # If there was an outer context captured before this method was called,
            # then we enter that context to create the metric value op. If the
            # caputred context is `None`, ops.control_dependencies(None) gives the
            # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
            # captured context.
            # This special handling is needed because sometimes the metric is created
            # inside a while_loop (and perhaps a TPU rewrite context). But we don't
            # want the value op to be evaluated every step or on the TPU. So we
            # create it outside so that it can be evaluated at the end on the host,
            # once the update ops have been evaluted.

            # pylint: disable=protected-access
            if distribution._outer_control_flow_context is None:
                with ops.control_dependencies(None):
                    metric_value = metric_value_fn(distribution, *a)
            else:
                distribution._outer_control_flow_context.Enter()
                metric_value = metric_value_fn(distribution, *a)
                distribution._outer_control_flow_context.Exit()
                # pylint: enable=protected-access
        else:
            metric_value = metric_value_fn(distribution, *a)
        if metrics_collections:
            ops.add_to_collections(metrics_collections, metric_value)
        return metric_value

    return distribution_strategy_context.get_tower_context().merge_call(
        fn, *args)
Esempio n. 13
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_tower_context(),
             distribution_strategy_context.get_tower_context())
  t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
Esempio n. 14
0
    def set_last_step_output(
            self,
            name,
            output,
            aggregation=variables_lib.VariableAggregation.NONE):
        """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
        if distribution_strategy_context.get_cross_tower_context():
            self._last_step_outputs_aggregations[name] = aggregation
            if aggregation is variables_lib.VariableAggregation.NONE:
                self._last_step_outputs[name] = output
            else:
                distribution = distribution_strategy_context.get_distribution_strategy(
                )
                self._last_step_outputs[name] = distribution.reduce(
                    aggregation, output, destinations="/device:CPU:0")
        else:
            assert aggregation is not variables_lib.VariableAggregation.NONE

            def merge_fn(distribution, value):
                self._last_step_outputs[name] = distribution.reduce(
                    aggregation, value, destinations="/device:CPU:0")
                # Setting this inside the `merge_fn` because all towers share the same
                # context object, so it's more robust to set it only once (even if all
                # the towers are trying to set the same value).
                self._last_step_outputs_aggregations[name] = aggregation

            distribution_strategy_context.get_tower_context().merge_call(
                merge_fn, output)
Esempio n. 15
0
def skip_summary():
  # If using multiple towers in distributed strategy, skip summaries on all
  # towers except the first one (tower_id=0).
  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last tower,
  # compute sum or mean across towers).
  tower_context = distribution_strategy_context.get_tower_context()
  return tower_context and tower_context.tower_id > 0
Esempio n. 16
0
    def model_fn():
      v0 = variable_scope.get_variable("var0", [1])
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.get_variable("var1", [1])
        # This will pause the current thread, and execute the other thread.
        distribution_strategy_context.get_tower_context().merge_call(
            lambda _: _)
        v2 = variable_scope.get_variable(
            "var2", [1],
            synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=variable_scope.VariableAggregation.SUM)
        v3 = variable_scope.get_variable(
            "var3", [1],
            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
            aggregation=variable_scope.VariableAggregation.MEAN)

      return v0, v1, v2, v3
def skip_summary():
    # If using multiple towers in distributed strategy, skip summaries on all
    # towers except the first one (tower_id=0).
    # TODO(priyag): Add a new optional argument that will provide multiple
    # alternatives to override default behavior. (e.g. run on last tower,
    # compute sum or mean across towers).
    tower_context = distribution_strategy_context.get_tower_context()
    return tower_context and tower_context.tower_id > 0
    def model_fn():
      v0 = variable_scope.get_variable("var0", [1])
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.get_variable("var1", [1])
        # This will pause the current thread, and execute the other thread.
        distribution_strategy_context.get_tower_context().merge_call(
            lambda _: _)
        v2 = variable_scope.get_variable(
            "var2", [1],
            synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=variable_scope.VariableAggregation.SUM)
        v3 = variable_scope.get_variable(
            "var3", [1],
            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
            aggregation=variable_scope.VariableAggregation.MEAN)

      return v0, v1, v2, v3
Esempio n. 19
0
def vq_discrete_bottleneck(x, hparams):
    """Simple vector quantized discrete bottleneck."""
    bottleneck_size = 2**hparams.bottleneck_bits
    x_shape = commons.shape_list(x)
    x = tf.reshape(x, [-1, hparams.hidden_size])
    x_means_hot, e_loss = vq_nearest_neighbor(x, hparams)

    if hparams.bottleneck_kind == "mog":
        loss = hparams.beta * e_loss
    else:
        tf.logging.info("Using EMA with beta = {}".format(hparams.beta))
        means, ema_means, ema_count = (hparams.means, hparams.ema_means,
                                       hparams.ema_count)
        # Update the ema variables
        updated_ema_count = commons.assign_moving_average(ema_count,
                                                          tf.reduce_sum(
                                                              x_means_hot,
                                                              axis=0),
                                                          hparams.decay,
                                                          zero_debias=False)

        dw = tf.matmul(x_means_hot, x, transpose_a=True)
        updated_ema_means = commons.assign_moving_average(ema_means,
                                                          dw,
                                                          hparams.decay,
                                                          zero_debias=False)
        n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
        updated_ema_count = ((updated_ema_count + hparams.epsilon) /
                             (n + bottleneck_size * hparams.epsilon) * n)
        # pylint: disable=g-no-augmented-assignment
        updated_ema_means = updated_ema_means / tf.expand_dims(
            updated_ema_count, axis=-1)
        # pylint: enable=g-no-augmented-assignment
        with tf.control_dependencies([e_loss]):
            # distribution_strategy
            def update_fn(v, value):
                return tf.assign(v, value)

            tower_context = distribution_strategy_context.get_tower_context()
            if tower_context:

                def merge_fn(strategy, v, value):
                    value = strategy.reduce(tf.VariableAggregation.MEAN, value,
                                            v)
                    return strategy.update(v, update_fn, value)

                update_means = tower_context.merge_call(
                    merge_fn, means, updated_ema_means)
            else:
                strategy = distribution_strategy_context.get_cross_tower_context(
                )
                update_means = strategy.update(means, update_fn,
                                               updated_ema_means)
            with tf.control_dependencies([update_means]):
                loss = hparams.beta * e_loss

    discrete = tf.reshape(x_means_hot, x_shape[:-1] + [bottleneck_size])
    return discrete, loss
Esempio n. 20
0
def merge_grads(grads_and_vars):
    """Merge gradients from different replicas."""
    def merge_grad_fn(strategy, grads_and_vars):
        reduced_grads = strategy.batch_reduce(
            variable_scope.VariableAggregation.MEAN, grads_and_vars)
        return reduced_grads

    return distribution_strategy_context.get_tower_context().merge_call(
        merge_grad_fn, grads_and_vars)
Esempio n. 21
0
def increment_var(v, amount=1):
  """`v += amount`, distributed-aware version."""
  def update(vu):
    return vu.assign_add(amount, read_value=False)

  def merge_fn(dist, vm):
    return dist.update(vm, update)

  tower_context = distribution_strategy_context.get_tower_context()
  return tower_context.merge_call(merge_fn, v)
Esempio n. 22
0
  def set_last_step_output(self, name, output,
                           aggregation=variables_lib.VariableAggregation.NONE):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
    if distribution_strategy_context.get_cross_tower_context():
      self._last_step_outputs_aggregations[name] = aggregation
      if aggregation is variables_lib.VariableAggregation.NONE:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_distribution_strategy()
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, output, destinations="/device:CPU:0")
    else:
      assert aggregation is not variables_lib.VariableAggregation.NONE
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, value, destinations="/device:CPU:0")
        # Setting this inside the `merge_fn` because all towers share the same
        # context object, so it's more robust to set it only once (even if all
        # the towers are trying to set the same value).
        self._last_step_outputs_aggregations[name] = aggregation

      distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, output)
Esempio n. 23
0
def merge_update_step(update_ops, local_step):
    """Merge local step counter update from different replicas."""
    def merge_update_step_fn(strategy, update_ops, local_step):
        merged_ops = []
        for update_op in update_ops:
            merged_ops.append(strategy.group(update_op))
        with ops.control_dependencies(merged_ops):
            incre_op = local_step.assign_add(1).op
        return incre_op

    return distribution_strategy_context.get_tower_context().merge_call(
        merge_update_step_fn, update_ops, local_step)
Esempio n. 24
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_tower_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_tower_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
Esempio n. 25
0
def increment_var(v, amount=1):
  """`v += amount`, distributed-aware version."""
  def update(vu):
    if isinstance(vu, resource_variable_ops.ResourceVariable):
      return vu.assign_add(amount, read_value=False)
    else:
      return state_ops.assign_add(vu, amount)

  def merge_fn(dist, vm):
    return dist.group(dist.update(vm, update))

  tower_context = distribution_strategy_context.get_tower_context()
  return tower_context.merge_call(merge_fn, v)
Esempio n. 26
0
def increment_var(v, amount=1):
  """`v += amount`, distributed-aware version."""
  def update(vu):
    if isinstance(vu, resource_variable_ops.ResourceVariable):
      return vu.assign_add(amount, read_value=False)
    else:
      return state_ops.assign_add(vu, amount)

  def merge_fn(dist, vm):
    return dist.group(dist.update(vm, update))

  tower_context = distribution_strategy_context.get_tower_context()
  return tower_context.merge_call(merge_fn, v)
Esempio n. 27
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribution_strategy_context.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            if update_device is not None:
                # We are calling an assign function on the mirrored variable in an
                # update context.
                v = self.get(device=update_device)
                return f(v, *args, **kwargs)

            # We are calling assign on the mirrored variable in cross tower context,
            # use update to update the variable.
            strategy = distribution_strategy_context.get_distribution_strategy(
            )
            updates = strategy.update(self, f, *args, **kwargs)
            grouped = strategy.group(updates)
            if isinstance(updates,
                          DistributedValues) and updates.is_tensor_like:
                # Make sure we run all updates. Without this, something like
                # session.run(mirrored_var.assign*(...)) may only update one tower.
                index = {}
                for d in updates.devices:
                    with ops.device(d), ops.control_dependencies([grouped]):
                        index[d] = array_ops.identity(updates.get(d))
                return Mirrored(index)
            else:
                return grouped
        else:
            _assert_tower_context()
            # We are calling an assign function on the mirrored variable in tower
            # context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function on each of the mirrored variables with the reduced
            # value.
            if self._aggregation == vs.VariableAggregation.NONE:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "MirroredVariable in Tower Context.")

            def merge_fn(strategy, value, *other_args, **other_kwargs):
                return strategy.update(
                    self, f,
                    strategy.reduce(aggregation=self._aggregation,
                                    value=value,
                                    destinations=self), *other_args,
                    **other_kwargs)

            return distribution_strategy_context.get_tower_context(
            ).merge_call(merge_fn, *args, **kwargs)
Esempio n. 28
0
 def run_fn():
   tower_context = distribution_strategy_context.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None,
                 distribution_strategy_context.get_cross_tower_context())
   self.assertTrue(distribution_strategy_context.has_distribution_strategy())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Esempio n. 29
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_tower_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_tower_context())
     self.assertTrue(distribution_strategy_context.has_distribution_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Esempio n. 30
0
 def get(self, device=None):
   """Returns the value for the current device or raises a ValueError."""
   if device is None:
     tower_context = distribution_strategy_context.get_tower_context()
     if tower_context:
       device = tower_context.device
     else:
       device = distribute_lib.get_update_device()
       if device is None:
         return self._get_cross_tower()
   device = device_util.canonicalize(device)
   try:
     return self._index[device]
   except KeyError as e:
     six.raise_from(
         ValueError("Device %s not found in %s (current device %s)" %
                    (device, self._index.keys(), device_util.current())), e)
Esempio n. 31
0
 def get(self, device=None):
   """Returns the value for the current device or raises a ValueError."""
   if device is None:
     tower_context = distribution_strategy_context.get_tower_context()
     if tower_context:
       device = tower_context.device
     else:
       device = distribute_lib.get_update_device()
       if device is None:
         return self._get_cross_tower()
   device = device_util.canonicalize(device)
   try:
     return self._index[device]
   except KeyError as e:
     six.raise_from(
         ValueError("Device %s not found in %s (current device %s)" %
                    (device, self._index.keys(), device_util.current())), e)
Esempio n. 32
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribution_strategy_context.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in an
        # update context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      # We are calling assign on the mirrored variable in cross tower context,
      # use update to update the variable.
      strategy = distribution_strategy_context.get_distribution_strategy()
      updates = strategy.update(self, f, *args, **kwargs)
      grouped = strategy.group(updates)
      if isinstance(updates, DistributedValues) and updates.is_tensor_like:
        # Make sure we run all updates. Without this, something like
        # session.run(mirrored_var.assign*(...)) may only update one tower.
        index = {}
        for d in updates.devices:
          with ops.device(d), ops.control_dependencies([grouped]):
            index[d] = array_ops.identity(updates.get(d))
        return Mirrored(index)
      else:
        return grouped
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, *args, **kwargs)
Esempio n. 33
0
  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(
          distribution_strategy_context._get_default_distribution_strategy(),
          dist)
      self.assertIs(None, distribution_strategy_context.get_tower_context())
      self.assertIs(dist,
                    distribution_strategy_context.get_cross_tower_context())
      self.assertIs(dist,
                    distribution_strategy_context.get_distribution_strategy())
      self.assertFalse(
          distribution_strategy_context.has_distribution_strategy())
      return "foo_" + s

    tower_ctx = distribution_strategy_context.get_tower_context()
    self.assertIs(distribution_strategy_context._get_default_tower_context(),
                  tower_ctx)
    self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
    _assert_in_default_state(self)
Esempio n. 34
0
  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    tower_context = distribution_strategy_context.get_tower_context()
    if tower_context is None:  # if in cross tower context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # tower mode and compute a value in cross tower mode.
      result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
    check_is_tensor_or_operation(result_t,
                                 'Metric {0}\'s result'.format(metric_obj.name))
    return result_t
Esempio n. 35
0
  def decorated(metric_obj, *args):
    """Decorated function with merge_call."""
    tower_context = distribution_strategy_context.get_tower_context()
    if tower_context is None:  # if in cross tower context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # tower mode and compute a value in cross tower mode.
      result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
    check_is_tensor_or_operation(result_t,
                                 'Metric {0}\'s result'.format(metric_obj.name))
    return result_t
      def model_fn():
        if num_gpus == 0:
          last_part_device = 'device:CPU:0'
        else:
          last_part_device = (
              'device:GPU:%d' %
              distribution_strategy_context.get_tower_context().tower_id)

        a = constant_op.constant(1.0)
        b = constant_op.constant(2.0)
        c = a + b
        self.assertEqual(a.device, worker_device + '/' + last_part_device)
        self.assertEqual(b.device, worker_device + '/' + last_part_device)
        self.assertEqual(c.device, worker_device + '/' + last_part_device)

        # The device scope is ignored for variables but not for normal ops.
        with ops.device('/job:worker/task:0'):
          x = variable_scope.get_variable(
              'x', initializer=10.0,
              aggregation=variable_scope.VariableAggregation.SUM)
          x_add = x.assign_add(c)
          e = a + c
        # The variable x is on the task 1 since the device_function has been
        # called once before the model_fn.
        self.assertEqual(x.device, '/job:ps/task:1')
        self.assertEqual(x_add.device, x.device)
        self.assertEqual(e.device,
                         '/job:worker/replica:0/task:0/%s' % last_part_device)

        # The colocate_vars_with can override the distribution's device.
        with d.colocate_vars_with(x):
          y = variable_scope.get_variable(
              'y', initializer=20.0,
              aggregation=variable_scope.VariableAggregation.SUM)
        # We add an identity here to avoid complaints about summing
        # non-distributed values.
        y_add = y.assign_add(array_ops.identity(x_add))
        self.assertEqual(y.device, '/job:ps/task:1')
        self.assertEqual(y_add.device, y.device)
        self.assertEqual(y.device, x.device)

        z = variable_scope.get_variable(
            'z', initializer=10.0,
            aggregation=variable_scope.VariableAggregation.SUM)
        self.assertEqual(z.device, '/job:ps/task:0')
        self.assertNotEqual(z.device, x.device)

        with ops.control_dependencies([y_add]):
          # We add an identity here to avoid complaints about summing
          # non-distributed values.
          z_add = z.assign_add(array_ops.identity(y))
        with ops.control_dependencies([z_add]):
          f = z + c
        self.assertEqual(f.device, worker_device + '/' + last_part_device)

        # The device scope would merge with the default worker device.
        with ops.device('/CPU:1'):
          g = e + 1.0
        self.assertEqual(g.device, worker_device + '/device:CPU:1')

        # Ths ops.colocate_with will be ignored when defining a variale but not
        # for a normal tensor.
        with ops.colocate_with(x):
          u = variable_scope.get_variable('u', initializer=30.0)
          v = variable_scope.get_variable('v', initializer=30.0)
          h = f + 1.0
        self.assertIn('/job:ps/', u.device)
        self.assertIn('/job:ps/', v.device)
        # u and v are on different parameter servers.
        self.assertTrue(u.device != x.device or v.device != x.device)
        self.assertTrue(u.device == x.device or v.device == x.device)
        # Here h is not on one worker. Note h.device is canonical while x.device
        # is not but.
        self.assertIn('/job:ps/', h.device)
        return y_add, z_add, f
Esempio n. 37
0
def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
    """Compute the moving average of a variable.
  https://github.com/tensorflow/tensorflow/blob/c966b5eed60a570f2121cb84ddb4ece84c413719/tensorflow/python/training/moving_averages.py
  """
    def _zero_debias(unbiased_var, value, decay):
        """Compute the delta required for a debiased Variable.
    """
        with tf.variable_scope(unbiased_var.op.name,
                               values=[unbiased_var, value, decay]) as scope:
            with tf.init_scope():
                biased_initializer = tf.zeros_initializer(
                    dtype=unbiased_var.dtype)(unbiased_var.get_shape())
                local_step_initializer = tf.zeros_initializer()

            def _maybe_get_unique(name):
                """Get name for a unique variable, if not `reuse=True`."""
                if tf.get_variable_scope().reuse:
                    return name
                vs_vars = [
                    x.op.name
                    for x in tf.get_variable_scope().global_variables()
                ]
                full_name = tf.get_variable_scope().name + "/" + name
                if full_name not in vs_vars: return name
                idx = 1
                while full_name + ("_%d" % idx) in vs_vars:
                    idx += 1
                return name + ("_%d" % idx)

            biased_var = tf.get_variable(_maybe_get_unique("biased"),
                                         initializer=biased_initializer,
                                         trainable=False)
            local_step = tf.get_variable(_maybe_get_unique("local_step"),
                                         shape=[],
                                         dtype=unbiased_var.dtype,
                                         initializer=local_step_initializer,
                                         trainable=False)

            # Get an update ops for both shadow variables.
            update_biased = tf.assign_sub(biased_var,
                                          (biased_var - value) * decay,
                                          name=scope.name)
            update_local_step = local_step.assign_add(1)

            # Compute the value of the delta to update the unbiased EMA. Make sure to
            # use the new values of the biased variable and the local step.
            with tf.control_dependencies([update_biased, update_local_step]):
                # This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
                unbiased_ema_delta = (
                    unbiased_var - biased_var.read_value() /
                    (1 - tf.pow(1.0 - decay, local_step.read_value())))

            return unbiased_ema_delta

    def update_fn(v, value, decay=decay):
        decay = tf.convert_to_tensor(1.0 - decay, name="decay")
        if decay.dtype != v.dtype.base_dtype:
            decay = tf.cast(decay, v.dtype.base_dtype)
        if zero_debias:
            update_delta = _zero_debias(v, value, decay)
        else:
            update_delta = (v - value) * decay
        return tf.assign_sub(v, update_delta, name=scope)

    with tf.name_scope(name, "AssignMovingAvg",
                       [variable, value, decay]) as scope:
        tower_context = distribution_strategy_context.get_tower_context()
        if tower_context:
            # In a tower context, we update variable using the mean of value across
            # towers.
            def merge_fn(strategy, v, value):
                try:
                    value = strategy.reduce(tf.VariableAggregation.MEAN, value,
                                            v)
                except:
                    pass  # Mirrored variables are loaded
                return strategy.update(v, update_fn, value)

            return tower_context.merge_call(merge_fn, variable, value)
        else:
            strategy = distribution_strategy_context.get_cross_tower_context()
            return strategy.update(variable, update_fn, value)
Esempio n. 38
0
def f1_score(labels, predictions, weights=None, num_thresholds=200,
             metrics_collections=None, updates_collections=None, name=None):
  """Computes the approximately best F1-score across different thresholds.

  The f1_score function applies a range of thresholds to the predictions to
  convert them from [0, 1] to bool. Precision and recall are computed by
  comparing them to the labels. The F1-Score is then defined as
  2 * precision * recall / (precision + recall). The best one across the
  thresholds is returned.

  Disclaimer: In practice it may be desirable to choose the best threshold on
  the validation set and evaluate the F1 score with this threshold on a
  separate code set. Or it may be desirable to use a fixed threshold (e.g. 0.5).

  This function internally creates four local variables, `true_positives`,
  `true_negatives`, `false_positives` and `false_negatives` that are used to
  compute the pairs of recall and precision values for a linearly spaced set of
  thresholds from which the best f1-score is derived.

  This value is ultimately returned as `f1-score`, an idempotent operation that
  computes the F1-score (computed using the aforementioned variables). The
  `num_thresholds` variable controls the degree of discretization with larger
  numbers of thresholds more closely approximating the true best F1-score.

  For estimation of the metric over a stream of data, the function creates an
  `update_op` operation that updates these variables and returns the F1-score.

  Example usage with a custom estimator:
  def model_fn(features, labels, mode):
    predictions = make_predictions(features)
    loss = make_loss(predictions, labels)
    train_op = tf.contrib.training.create_train_op(
          total_loss=loss,
          optimizer='Adam')
    eval_metric_ops = {'f1': f1_score(labels, predictions)}
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs)
  estimator = tf.estimator.Estimator(model_fn=model_fn)

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
      `bool`.
    predictions: A floating point `Tensor` of arbitrary shape and whose values
      are in the range `[0, 1]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    num_thresholds: The number of thresholds to use when discretizing the roc
      curve.
    metrics_collections: An optional list of collections that `f1_score` should
      be added to.
    updates_collections: An optional list of collections that `update_op` should
      be added to.
    name: An optional variable_scope name.

  Returns:
    f1_score: A scalar `Tensor` representing the current best f1-score across
      different thresholds.
    update_op: An operation that increments the `true_positives`,
      `true_negatives`, `false_positives` and `false_negatives` variables
      appropriately and whose value matches the `f1_score`.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      either `metrics_collections` or `updates_collections` are not a list or
      tuple.
  """
  with variable_scope.variable_scope(
      name, 'f1', (labels, predictions, weights)):
    predictions, labels, weights = metrics_impl._remove_squeezable_dimensions(  # pylint: disable=protected-access
        predictions=predictions, labels=labels, weights=weights)
    # To account for floating point imprecisions / avoid division by zero.
    epsilon = 1e-7
    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
                  for i in range(num_thresholds - 2)]
    thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon]

    # Confusion matrix.
    values, update_ops = metrics_impl._confusion_matrix_at_thresholds(  # pylint: disable=protected-access
        labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn'))

    # Compute precision and recall at various thresholds.
    def compute_best_f1_score(tp, fp, fn, name):
      precision_at_t = math_ops.div(tp, epsilon + tp + fp,
                                    name='precision_' + name)
      recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
      # Compute F1 score.
      f1_at_thresholds = (
          2.0 * precision_at_t * recall_at_t /
          (precision_at_t + recall_at_t + epsilon))
      return math_ops.reduce_max(f1_at_thresholds)

    def f1_across_towers(_, values):
      best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'],
                                      fn=values['fn'], name='value')
      if metrics_collections:
        ops.add_to_collections(metrics_collections, best_f1)
      return best_f1

    best_f1 = distribution_strategy_context.get_tower_context().merge_call(
        f1_across_towers, values)

    update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
                                      fn=update_ops['fn'], name='update')
    if updates_collections:
      ops.add_to_collections(updates_collections, update_op)

    return best_f1, update_op
      def model_fn():
        if 'CPU' in compute_device:
          tower_compute_device = '/device:CPU:0'
        else:
          tower_compute_device = (
              '/device:GPU:%d' %
              distribution_strategy_context.get_tower_context().tower_id)
        tower_compute_device = device_util.canonicalize(tower_compute_device)

        if 'CPU' in variable_device:
          tower_variable_device = '/device:CPU:0'
        else:
          tower_variable_device = (
              '/device:GPU:%d' %
              distribution_strategy_context.get_tower_context().tower_id)
        tower_variable_device = device_util.canonicalize(tower_variable_device)

        a = constant_op.constant(1.0)
        b = constant_op.constant(2.0)
        c = a + b
        self.assertEqual(a.device, tower_compute_device)
        self.assertEqual(b.device, tower_compute_device)
        self.assertEqual(c.device, tower_compute_device)

        # The device scope is ignored for variables but not for normal ops.
        with ops.device('/device:GPU:2'):
          x = variable_scope.get_variable(
              'x', initializer=10.0,
              aggregation=variable_scope.VariableAggregation.SUM)
          x_add = x.assign_add(c)
          e = a + c
        self.assertEqual(
            device_util.canonicalize(x.device), tower_variable_device)
        self.assertEqual(x_add.device, x.device)
        self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2'))

        # The colocate_vars_with can override the distribution's device.
        with d.colocate_vars_with(x):
          y = variable_scope.get_variable(
              'y', initializer=20.0,
              aggregation=variable_scope.VariableAggregation.SUM)
        # We add an identity here to avoid complaints about summing
        # non-distributed values.
        y_add = y.assign_add(array_ops.identity(x_add))
        self.assertEqual(
            device_util.canonicalize(y.device), tower_variable_device)
        self.assertEqual(y_add.device, y.device)
        self.assertEqual(y.device, x.device)

        z = variable_scope.get_variable(
            'z', initializer=10.0,
            aggregation=variable_scope.VariableAggregation.SUM)
        self.assertEqual(
            device_util.canonicalize(z.device), tower_variable_device)

        with ops.control_dependencies([y_add]):
          # We add an identity here to avoid complaints about summing
          # non-distributed values.
          z_add = z.assign_add(array_ops.identity(y))
        with ops.control_dependencies([z_add]):
          f = z + c
        self.assertEqual(f.device, tower_compute_device)

        # The device scope would merge with the default worker device.
        with ops.device('/CPU:1'):
          g = e + 1.0
        self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))

        # Ths ops.colocate_with will be ignored when defining a variale but not
        # for a normal tensor.
        with ops.colocate_with(x):
          u = variable_scope.get_variable('u', initializer=30.0)
          h = f + 1.0
        self.assertEqual(
            device_util.canonicalize(u.device), tower_variable_device)
        self.assertEqual(device_util.canonicalize(x.device), h.device)
        return y_add, z_add, f
Esempio n. 40
0
 def mark_devices_fn():
   tower_id = distribution_strategy_context.get_tower_context().tower_id
   self.assertLess(tower_id, len(d.worker_devices))
   self.assertFalse(expected_devices[tower_id])
   expected_devices[tower_id] = True
Esempio n. 41
0
def _merge_call_merge_raises_fn():
  distribution_strategy_context.get_tower_context().merge_call(
      _call_merge_raises_fn)
 def model_fn(name):
   v = variable_scope.variable(1.0, name=name)
   distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
   return v
 def model_fn():
   b = variable_scope.get_variable("b", [1])
   with ops.name_scope("foo"):
     c = distribution_strategy_context.get_tower_context().merge_call(
         in_cross_tower)
   return b, c
Esempio n. 44
0
def _assert_tower_context():
    if not distribution_strategy_context.get_tower_context():
        raise RuntimeError(
            "Tower-local variables may only be assigned in a tower context.")
Esempio n. 45
0
    def apply_gradients(self, loss, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.
        This is the second part of `minimize()`. It returns an `Operation` that
        applies gradients.
        Args:
        grads_and_vars: List of (gradient, variable) pairs as returned by
            `compute_gradients()`.
        global_step: Optional `Variable` to increment by one after the
            variables have been updated.
        name: Optional name for the returned operation.  Default to the
            name passed to the `Optimizer` constructor.
        Returns:
        An `Operation` that applies the specified gradients. If `global_step`
        was not None, that operation also increments `global_step`.
        Raises:
        TypeError: If `grads_and_vars` is malformed.
        ValueError: If none of the variables have gradients.
        RuntimeError: If you should use `_distributed_apply()` instead.
        """
        # This is a default implementation of apply_gradients() that can be shared
        # by most optimizers.  It relies on the subclass implementing the following
        # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

        # Handle DistributionStrategy case.
        if distribution_strategy_context.get_cross_tower_context():
            raise RuntimeError("Use `_distributed_apply()` instead of `apply_gradients()` in a cross-tower context.")
        # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
        # always calling _distributed_apply(), using the default distribution
        # as needed.
        if distribution_strategy_context.has_distribution_strategy():
            grads_and_vars = optimizer.get_filtered_grad_fn(lambda: grads_and_vars)()
            return distribution_strategy_context.get_tower_context().merge_call(
                self._distributed_apply, grads_and_vars, global_step, name
            )

        # No DistributionStrategy case.
        grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
        if not grads_and_vars:
            raise ValueError("No variables provided.")
        converted_grads_and_vars = []
        for grad, var in grads_and_vars:
            if grad is not None:
                try:
                    # Convert the grad to Tensor or IndexedSlices if necessary.
                    grad = ops.convert_to_tensor_or_indexed_slices(grad)
                except TypeError:
                    raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None: %s" % grad)
                if not isinstance(grad, (ops.Tensor, ops.IndexedSlices)):
                    raise TypeError("Gradient must be a Tensor, IndexedSlices, or None: %s" % grad)
            processor = _get_processor(var)
            converted_grads_and_vars.append((grad, var, processor))

        converted_grads_and_vars = tuple(converted_grads_and_vars)
        var_list = [var for grad, var, _ in converted_grads_and_vars if grad is not None]
        if not var_list:
            raise ValueError("No gradients provided for any variable: %s." % ([str(var) for _, var, _ in converted_grads_and_vars],))
        with ops.init_scope():
            self._create_slots(var_list)
        update_ops = []
        with ops.name_scope(name, self._name) as name:
            self._prepare()
            for grad, var, processor in converted_grads_and_vars:
                if grad is None:
                    continue
                # We colocate all ops created in _apply_dense or _apply_sparse
                # on the same device as the variable.
                # TODO(apassos): figure out how to get the variable name here.
                if context.executing_eagerly() or isinstance(var, resource_variable_ops.ResourceVariable) and not var._in_graph_mode:
                    scope_name = ""
                else:
                    scope_name = var.op.name
                with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
                    update_ops.append(processor.update_op(self, loss, grad, global_step))
            if global_step is None:
                apply_updates = self._finish(update_ops, loss, name)
            else:
                with ops.control_dependencies([self._finish(update_ops, loss, "update")]):
                    with ops.colocate_with(global_step):
                        if isinstance(global_step, resource_variable_ops.ResourceVariable):
                            # TODO(apassos): the implicit read in assign_add is slow; consider
                            # making it less so.
                            apply_updates = resource_variable_ops.assign_add_variable_op(
                                resource=global_step.handle,
                                value=ops.convert_to_tensor(
                                    value=1,
                                    dtype=global_step.dtype
                                ),
                                name=name
                            )
                        else:
                            apply_updates = state_ops.assign_add(
                                ref=global_step,
                                value=1,
                                name=name
                            )

            if not context.executing_eagerly():
                if isinstance(apply_updates, ops.Tensor):
                    apply_updates = apply_updates.op
                train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
                if apply_updates not in train_op:
                    train_op.append(apply_updates)

            return apply_updates
Esempio n. 46
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

    # Handle DistributionStrategy case.
    if distribution_strategy_context.get_cross_tower_context():
      raise RuntimeError("Use `_distributed_apply()` instead of "
                         "`apply_gradients()` in a cross-tower context.")
    # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
    # always calling _distributed_apply(), using the default distribution
    # as needed.
    if distribution_strategy_context.has_distribution_strategy():
      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
      return distribution_strategy_context.get_tower_context().merge_call(
          self._distributed_apply, grads_and_vars, global_step, name)

    # No DistributionStrategy case.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    converted_grads_and_vars = []
    for g, v in grads_and_vars:
      if g is not None:
        try:
          # Convert the grad to Tensor or IndexedSlices if necessary.
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)
      converted_grads_and_vars.append((g, v, p))

    converted_grads_and_vars = tuple(converted_grads_and_vars)
    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    if not var_list:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, _, v in converted_grads_and_vars],))
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []
    with ops.name_scope(name, self._name) as name:
      self._prepare()
      for grad, var, processor in converted_grads_and_vars:
        if grad is None:
          continue
        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        if context.executing_eagerly() or isinstance(
            var,
            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
          scope_name = ""
        else:
          scope_name = var.op.name
        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
          update_ops.append(processor.update_op(self, grad))
      if global_step is None:
        apply_updates = self._finish(update_ops, name)
      else:
        with ops.control_dependencies([self._finish(update_ops, "update")]):
          with ops.colocate_with(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              # TODO(apassos): the implicit read in assign_add is slow; consider
              # making it less so.
              apply_updates = resource_variable_ops.assign_add_variable_op(
                  global_step.handle,
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  name=name)
            else:
              apply_updates = state_ops.assign_add(global_step, 1, name=name)

      if not context.executing_eagerly():
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates
 def model_fn(device_id):
   v = variable_scope.variable(1.0, name="foo_" + str(device_id))
   distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
   return v
 def model_fn():
   # This variable should be created only once across the threads because of
   # special variable_creator functions used by `dist.call_for_each_tower`.
   v = variable_scope.variable(1.0, name="foo")
   distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
   return v
 def model_fn():
   value = math_ops.cast(
       distribution_strategy_context.get_tower_context().tower_id,
       mirrored_var.dtype)
   return mirrored_var.assign_sub(value)
 def model_fn():
   vs = []
   for i in range(5):
     vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
   distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
   return vs