def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): assert isinstance(colocate_with, list) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, should_group)
def _update_non_slot(self, colocate_with, fn, *args, **kwargs): assert isinstance(colocate_with, list) # TODO (josh11b): In eager mode, use one thread per device. id:617 # https://github.com/imdone/tensorflow/issues/618 updates = {} for d in colocate_with: name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, fn, args, kwargs, group): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): result = fn(var, *self._select_single_value(args), **self._select_single_value(kwargs)) if group: return result else: return nest.map_structure(self._unwrap, result)
def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): Also support TowerLocalVariables here? If so, args and # kwargs don't need to be mirrored. assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored)
def _update(self, var, options, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, should_group)
def _update(self, var, options, fn, *args, **kwargs): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): result = fn(var, *self._select_single_value(args), **self._select_single_value(kwargs)) if should_group: return result else: return nest.map_structure(self._unwrap, result)
def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, group)
def _update_non_slot(self, colocate_with, fn, *args, **kwargs): with ops.device(colocate_with.device), distribute_lib.UpdateContext( colocate_with): return fn(*args, **kwargs)
def _update(self, var, fn, *args, **kwargs): with ops.device(self._device), distribute_lib.UpdateContext( self._device): return fn(var, *args, **kwargs)