示例#1
0
 def _update(self, var, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.regroup(updates, values.Mirrored)
示例#2
0
 def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
   assert isinstance(colocate_with, list)
   should_group = options.pop("grouped")
   assert not options  # Validate that we are processing all of the options.
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.update_regroup(self, updates, should_group)
示例#3
0
 def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
     assert isinstance(colocate_with, list)
     # TODO (josh11b): In eager mode, use one thread per device. id:617
     # https://github.com/imdone/tensorflow/issues/618
     updates = {}
     for d in colocate_with:
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             updates[d] = fn(*values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.regroup(updates, values.Mirrored)
示例#4
0
 def _update(self, var, fn, args, kwargs, group):
     if isinstance(var, values.AggregatingVariable):
         var = var.get()
     if not isinstance(var, resource_variable_ops.ResourceVariable):
         raise ValueError(
             "You can not update `var` %r. It must be a Variable." % var)
     with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
         result = fn(var, *self._select_single_value(args),
                     **self._select_single_value(kwargs))
         if group:
             return result
         else:
             return nest.map_structure(self._unwrap, result)
示例#5
0
 def _update(self, var, fn, *args, **kwargs):
     # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
     # kwargs don't need to be mirrored.
     assert isinstance(var, values.MirroredVariable)
     # TODO(josh11b): In eager mode, use one thread per device.
     updates = {}
     for d, v in var._index.items():  # pylint: disable=protected-access
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             updates[d] = fn(v, *values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.regroup(updates, values.Mirrored)
示例#6
0
 def _update(self, var, options, fn, *args, **kwargs):
     # TODO(josh11b): In eager mode, use one thread per device.
     assert isinstance(var, values.DistributedVariable)
     should_group = options.pop("grouped")
     assert not options  # Validate that we are processing all of the options.
     updates = {}
     for d, v in var._index.items():  # pylint: disable=protected-access
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             # If args and kwargs are not mirrored, the value is returned as is.
             updates[d] = fn(v, *values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.update_regroup(self, updates, should_group)
示例#7
0
 def _update(self, var, options, fn, *args, **kwargs):
     if isinstance(var, values.AggregatingVariable):
         var = var.get()
     if not isinstance(var, resource_variable_ops.ResourceVariable):
         raise ValueError(
             "You can not update `var` %r. It must be a Variable." % var)
     should_group = options.pop("grouped")
     assert not options  # Validate that we are processing all of the options.
     with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
         result = fn(var, *self._select_single_value(args),
                     **self._select_single_value(kwargs))
         if should_group:
             return result
         else:
             return nest.map_structure(self._unwrap, result)
示例#8
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, values.TPUMirroredVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if group:
                return fn(var, *args, **kwargs)
            else:
                return [fn(var, *args, **kwargs)]

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = {}
        for d, v in var._index.items():  # pylint: disable=protected-access
            name = "update_%d" % self._device_index.get(d)
            with ops.device(d), distribute_lib.UpdateContext(
                    d), ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates[d] = fn(v, *values.select_device_mirrored(d, args),
                                **values.select_device_mirrored(d, kwargs))
        return values.update_regroup(self, updates, group)
 def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
     with ops.device(colocate_with.device), distribute_lib.UpdateContext(
             colocate_with):
         return fn(*args, **kwargs)
示例#10
0
 def _update(self, var, fn, *args, **kwargs):
     with ops.device(self._device), distribute_lib.UpdateContext(
             self._device):
         return fn(var, *args, **kwargs)