Exemple #1
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribution_strategy_context.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            if update_device is not None:
                # We are calling an assign function in an update context.
                return f(self._v, *args, **kwargs)

            # We are calling an assign function in cross tower context, wrap it in an
            # update call.
            return distribution_strategy_context.get_distribution_strategy(
            ).update(self, f, *args, **kwargs)
        else:
            assert distribution_strategy_context.get_tower_context()
            # We are calling an assign function in tower context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function with the reduced value.
            if self._aggregation == vs.VariableAggregation.NONE:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "a variable in Tower Context.")

            def merge_fn(strategy, value, *other_args, **other_kwargs):
                return strategy.update(
                    self, f,
                    strategy.reduce(aggregation=self._aggregation,
                                    value=value,
                                    destinations=self), *other_args,
                    **other_kwargs)

            return distribution_strategy_context.get_tower_context(
            ).merge_call(merge_fn, *args, **kwargs)
Exemple #2
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribute_lib.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      # We are calling update on the mirrored variable in cross tower context.
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in cross
        # tower context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      return distribute_lib.get_distribution_strategy().update(
          self, f, *args, **kwargs)
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
                                                           **kwargs)
Exemple #3
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribute_lib.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            # We are calling update on the mirrored variable in cross tower context.
            if update_device is not None:
                # We are calling an assign function on the mirrored variable in cross
                # tower context.
                v = self.get(device=update_device)
                return f(v, *args, **kwargs)

            return distribute_lib.get_distribution_strategy().update(
                self, f, *args, **kwargs)
        else:
            # We are calling an assign function on the mirrored variable in tower
            # context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function on each of the mirrored variables with the reduced
            # value.
            if not self._aggregation_method:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "MirroredVariable in Tower Context.")

            def merge_fn(strategy, value):
                return strategy.update(
                    self, f,
                    strategy.reduce(method_string=self._aggregation_method,
                                    value=value,
                                    destinations=self))

            return distribute_lib.get_tower_context().merge_call(
                merge_fn, *args, **kwargs)
Exemple #4
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribution_strategy_context.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            if update_device is not None:
                # We are calling an assign function on the mirrored variable in an
                # update context.
                v = self.get(device=update_device)
                return f(v, *args, **kwargs)

            # We are calling assign on the mirrored variable in cross tower context,
            # use update to update the variable.
            strategy = distribution_strategy_context.get_distribution_strategy(
            )
            updates = strategy.update(self, f, *args, **kwargs)
            grouped = strategy.group(updates)
            if isinstance(updates,
                          DistributedValues) and updates.is_tensor_like:
                # Make sure we run all updates. Without this, something like
                # session.run(mirrored_var.assign*(...)) may only update one tower.
                index = {}
                for d in updates.devices:
                    with ops.device(d), ops.control_dependencies([grouped]):
                        index[d] = array_ops.identity(updates.get(d))
                return Mirrored(index)
            else:
                return grouped
        else:
            _assert_tower_context()
            # We are calling an assign function on the mirrored variable in tower
            # context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function on each of the mirrored variables with the reduced
            # value.
            if self._aggregation == vs.VariableAggregation.NONE:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "MirroredVariable in Tower Context.")

            def merge_fn(strategy, value, *other_args, **other_kwargs):
                return strategy.update(
                    self, f,
                    strategy.reduce(aggregation=self._aggregation,
                                    value=value,
                                    destinations=self), *other_args,
                    **other_kwargs)

            return distribution_strategy_context.get_tower_context(
            ).merge_call(merge_fn, *args, **kwargs)
Exemple #5
0
 def get(self, device=None):
   """Returns the value for the current device or raises a ValueError."""
   if device is None:
     tower_context = distribute_lib.get_tower_context()
     if tower_context:
       device = tower_context.device
     else:
       device = distribute_lib.get_update_device()
       if device is None:
         device = device_util.current()
   device = device_util.canonicalize(device)
   try:
     return self._index[device]
   except KeyError:
     raise ValueError("Device %s not found in %s (current device %s)" %
                      (device, self._index.keys(), device_util.current()))
Exemple #6
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribution_strategy_context.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in an
        # update context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      # We are calling assign on the mirrored variable in cross tower context,
      # use update to update the variable.
      strategy = distribution_strategy_context.get_distribution_strategy()
      updates = strategy.update(self, f, *args, **kwargs)
      grouped = strategy.group(updates)
      if isinstance(updates, DistributedValues) and updates.is_tensor_like:
        # Make sure we run all updates. Without this, something like
        # session.run(mirrored_var.assign*(...)) may only update one tower.
        index = {}
        for d in updates.devices:
          with ops.device(d), ops.control_dependencies([grouped]):
            index[d] = array_ops.identity(updates.get(d))
        return Mirrored(index)
      else:
        return grouped
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, *args, **kwargs)
Exemple #7
0
def _get_update_device():
  """Validate we are in update/update_non_slot() and return current device.

  This is used in MirroredVariable.assign* members, to make sure they
  are only called via an update method, to make sure all components of the
  variable are being updated in a consistent way.

  Returns:
    A string device.

  Raises:
    RuntimeError: If not in distribution.update()/.update_non_slot().
  """
  device = distribute_lib.get_update_device()
  if device is None:
    raise RuntimeError(
        "Use DistributionStrategy.update() to modify a MirroredVariable.")
  return device