def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if tpu_values.enclosing_tpu_context() is not None:
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update the variable
        # on each replica directly.
        updates = []
        values_and_devices = []
        packed_var = var._packed_variable  # pylint: disable=protected-access
        if packed_var is not None:
            for device in packed_var.devices:
                values_and_devices.append((packed_var, device))
        else:
            for value in var.values:
                values_and_devices.append((value, value.device))

        for i, value_and_device in enumerate(values_and_devices):
            value = value_and_device[0]
            device = value_and_device[1]
            name = "update_%d" % i
            with ops.device(device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(value,
                       *distribute_utils.select_replica_mirrored(i, args),
                       **distribute_utils.select_replica_mirrored(i, kwargs)))
        return distribute_utils.update_regroup(self, updates, group)
 def gather_fn():
     gathered = cross_device_utils.build_collective_gather(inputs,
                                                           devices,
                                                           group_size,
                                                           collective_keys,
                                                           axis=0)
     return distribute_utils.update_regroup(strategy.extended,
                                            gathered,
                                            group=True)
 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
   assert isinstance(colocate_with, tuple)
   # TODO(josh11b): In eager mode, use one thread per device.
   updates = []
   for i, d in enumerate(colocate_with):
     name = "update_%d" % i
     with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
       updates.append(
           fn(*distribute_utils.select_replica_mirrored(i, args),
              **distribute_utils.select_replica_mirrored(i, kwargs)))
   return distribute_utils.update_regroup(self, updates, group)
 def _update(self, var, fn, args, kwargs, group):
   # TODO(josh11b): In eager mode, use one thread per device.
   assert isinstance(var, values.DistributedVariable)
   updates = []
   for i, v in enumerate(var.values):
     name = "update_%d" % i
     with ops.device(v.device), \
          distribute_lib.UpdateContext(i), \
          ops.name_scope(name):
       # If args and kwargs are not mirrored, the value is returned as is.
       updates.append(
           fn(v, *distribute_utils.select_replica_mirrored(i, args),
              **distribute_utils.select_replica_mirrored(i, kwargs)))
   return distribute_utils.update_regroup(self, updates, group)
Exemple #5
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if tpu_values.enclosing_tpu_context() is not None:
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = []
        for i, v in enumerate(var.values):
            name = "update_%d" % i
            with ops.device(v.device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(v, *distribute_utils.select_replica_mirrored(i, args),
                       **distribute_utils.select_replica_mirrored(i, kwargs)))
        return distribute_utils.update_regroup(self, updates, group)