Ejemplo n.º 1
0
 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
   assert isinstance(colocate_with, tuple)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = []
   for i, d in enumerate(colocate_with):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates.append(fn(*values.select_device_mirrored(d, args),
                         **values.select_device_mirrored(d, kwargs)))
   return values.update_regroup(self, self._device_map, updates, group)
Ejemplo n.º 2
0
 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
   assert isinstance(colocate_with, list)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = {}
   for d in colocate_with:
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates[d] = fn(*values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.update_regroup(self, updates, group)
Ejemplo n.º 3
0
 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
   assert isinstance(colocate_with, tuple)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = []
   for i, d in enumerate(colocate_with):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       updates.append(fn(*values.select_device_mirrored(d, args),
                         **values.select_device_mirrored(d, kwargs)))
   return values.update_regroup(self, self._device_map, updates, group)
Ejemplo n.º 4
0
 def _update(self, var, fn, args, kwargs, group):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = []
   for i, (d, v) in enumerate(zip(var.devices, var.values)):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates.append(fn(v,
                         *values.select_device_mirrored(d, args),
                         **values.select_device_mirrored(d, kwargs)))
   return values.update_regroup(self, self._device_map, updates, group)
Ejemplo n.º 5
0
 def _update(self, var, fn, args, kwargs, group):
     # TODO(josh11b): In eager mode, use one thread per device.
     assert isinstance(var, values.DistributedVariable)
     updates = {}
     for d, v in var._index.items():  # pylint: disable=protected-access
         name = "update_%d" % self._device_index.get(d)
         with ops.device(d), distribute_lib.UpdateContext(
                 d), ops.name_scope(name):
             # If args and kwargs are not mirrored, the value is returned as is.
             updates[d] = fn(v, *values.select_device_mirrored(d, args),
                             **values.select_device_mirrored(d, kwargs))
     return values.update_regroup(self, updates, group)
Ejemplo n.º 6
0
 def _update(self, var, fn, args, kwargs, group):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = {}
   for d, v in var._index.items():  # pylint: disable=protected-access
     name = "update_%d" % self._device_index.get(d)
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates[d] = fn(v,
                       *values.select_device_mirrored(d, args),
                       **values.select_device_mirrored(d, kwargs))
   return values.update_regroup(self, updates, group)
Ejemplo n.º 7
0
 def _update(self, var, fn, args, kwargs, group):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = []
   for i, (d, v) in enumerate(zip(var.devices, var.values)):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates.append(fn(v,
                         *values.select_device_mirrored(d, args),
                         **values.select_device_mirrored(d, kwargs)))
   return values.update_regroup(self, self._device_map, updates, group)
Ejemplo n.º 8
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, values.TPUMirroredVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if group:
                return fn(var, *args, **kwargs)
            else:
                return [fn(var, *args, **kwargs)]

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = {}
        for d, v in var._index.items():  # pylint: disable=protected-access
            name = "update_%d" % self._device_index.get(d)
            with ops.device(d), distribute_lib.UpdateContext(
                    d), ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates[d] = fn(v, *values.select_device_mirrored(d, args),
                                **values.select_device_mirrored(d, kwargs))
        return values.update_regroup(self, updates, group)
Ejemplo n.º 9
0
  def _update(self, var, fn, args, kwargs, group):
    assert isinstance(var, values.TPUMirroredVariable)
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if group:
        return fn(var, *args, **kwargs)
      else:
        return (fn(var, *args, **kwargs),)

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = []
    for i, (d, v) in enumerate(zip(var.devices, var.values)):
      name = "update_%d" % i
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates.append(fn(v,
                          *values.select_device_mirrored(d, args),
                          **values.select_device_mirrored(d, kwargs)))
    return values.update_regroup(self, self._device_map, updates, group)
Ejemplo n.º 10
0
  def _update(self, var, fn, args, kwargs, group):
    assert isinstance(var, values.TPUMirroredVariable)
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if group:
        return fn(var, *args, **kwargs)
      else:
        return [fn(var, *args, **kwargs)]

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = {}
    for d, v in var._index.items():  # pylint: disable=protected-access
      name = "update_%d" % self._device_index.get(d)
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates[d] = fn(v,
                        *values.select_device_mirrored(d, args),
                        **values.select_device_mirrored(d, kwargs))
    return values.update_regroup(self, updates, group)
Ejemplo n.º 11
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = []
        for i, v in enumerate(var.values):
            name = "update_%d" % i
            with ops.device(v.device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(v, *values.select_replica_mirrored(i, args),
                       **values.select_replica_mirrored(i, kwargs)))
        return values.update_regroup(self, updates, group)