def _update_non_slot(self, colocate_with, fn, args, kwargs, group): assert isinstance(colocate_with, tuple) # TODO(josh11b): In eager mode, use one thread per device. updates = [] for i, d in enumerate(colocate_with): name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): updates.append(fn(*values.select_replica_mirrored(i, args), **values.select_replica_mirrored(i, kwargs))) return values.update_regroup(self, updates, group)
def _update(self, var, fn, args, kwargs, group): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) updates = [] for i, v in enumerate(var.values): name = "update_%d" % i with ops.device(v.device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append(fn(v, *values.select_replica_mirrored(i, args), **values.select_replica_mirrored(i, kwargs))) return values.update_regroup(self, updates, group)
def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if group: return fn(var, *args, **kwargs) else: return (fn(var, *args, **kwargs), ) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = [] for i, v in enumerate(var.values): name = "update_%d" % i with ops.device(v.device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(v, *values.select_replica_mirrored(i, args), **values.select_replica_mirrored(i, kwargs))) return values.update_regroup(self, updates, group)