def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: def initial_value_fn(device=d): if context.executing_eagerly( ) or ops.inside_function(): return array_ops.identity( value_list[0].value()) else: with ops.device(device): return array_ops.identity( value_list[0].initial_value) kwargs["initial_value"] = initial_value_fn with context.device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list
def CreateInstances(cls, *args, **kwargs): if not has_strategy(): return EmbeddingVariable(local_replica_id=0, *args, **kwargs) strategy = get_strategy() strategy_extended = strategy.extended devices = strategy_extended._devices value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: name = value_list[0].name.split(":")[0] kwargs["name"] = "%s/replica_%d/" % (name, i) with context.device_policy(context.DEVICE_PLACEMENT_SILENT): with tape.stop_recording(): v = EmbeddingVariable(local_replica_id=i, *args, **kwargs) value_list.append(v) # TODO: check whether it will impact the performance due to the aggregation or synchronization setting. return DistributedVariable( strategy=strategy, values=value_list, aggregation=VariableAggregation.ONLY_FIRST_REPLICA, var_policy=VariableSynchronization.NONE)
def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring initial_value = None value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: def initial_value_fn(): return array_ops.identity(initial_value) kwargs["initial_value"] = initial_value_fn with context.device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) if i == 0: # To avoid incorrectly nested device scopes, we exit out of # existing control flow scopes and function building graphs. # TODO(b/132997073): Remove initialization scope once nested # device scope issue has been fixed. with ops.init_scope(): initial_value = ( v.value() if ops.executing_eagerly_outside_functions() else v.initial_value) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list
def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring initial_value = None value_list = [] for i, d in enumerate(devices): with ops.device(d): if i == 0: initial_value = kwargs["initial_value"] # Note: some v1 code expects variable initializer creation to happen # inside a init_scope. with maybe_init_scope(): initial_value = initial_value() if callable( initial_value) else initial_value if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) kwargs["initial_value"] = initial_value with context.device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(**kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list
def run(self): self.should_run.wait() self.should_run.clear() try: if self.coord.should_stop(): return self.restore_thread_local_summary_state() self.restore_thread_local_callable() self.restore_thread_local_eager_context_state() if (self.caching_scope_entered is not None and self.caching_scope_exited is not None): distribute_utils.caching_scope_local.new_cache_scope_count = self.caching_scope_entered distribute_utils.caching_scope_local.cache_scope_exited_count = self.caching_scope_exited # TODO(josh11b): Use current logical device instead of 0 here. with self.coord.stop_on_exception(), \ _enter_graph(self._init_graph, self._init_in_eager), \ _enter_graph(self.graph, self.in_eager, self._variable_creator_stack), \ context.device_policy(self.context_device_policy), \ _MirroredReplicaContext(self.distribution, self.replica_id_in_sync_group), \ ops.device(self.devices[self.replica_id]), \ ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._var_scope, reuse=self.replica_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) self.done = True finally: self.has_paused.set()
def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.device(d): kwargs[ "initial_value"] = self._get_variable_creator_initial_value( replica_id=i, device=d, primary_var=value_list[0] if value_list else None, **kwargs) if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) with context.device_policy( context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(**kwargs) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list
def run(self): self.should_run.wait() self.should_run.clear() try: if self.coord.should_stop(): return self.restore_thread_local_context_fields() # TODO(josh11b): Use current logical device instead of 0 here. with self.coord.stop_on_exception(), \ _enter_graph(self._init_graph, self._init_in_eager), \ _enter_graph(self.graph, self.in_eager, self._variable_creator_stack), \ context.device_policy(self.context_device_policy), \ MirroredReplicaContext(self.distribution, constant_op.constant( self.replica_id, dtypes.int32)), \ ops.device(self.device_map.logical_to_actual_devices(0)[ self.replica_id]), \ ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._var_scope, reuse=self.replica_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) self.done = True finally: self.has_paused.set()
def run(self): self.should_run.wait() self.should_run.clear() try: if self.coord.should_stop(): return self.restore_thread_local_context_fields() # TODO(josh11b): Use current logical device instead of 0 here. with self.coord.stop_on_exception(), \ _enter_graph(self._init_graph, self._init_in_eager), \ _enter_graph(self.graph, self.in_eager, self._variable_creator_stack), \ context.device_policy(self.context_device_policy), \ MirroredReplicaContext(self.distribution, constant_op.constant( self.replica_id, dtypes.int32)), \ ops.device(self.device_map.logical_to_actual_devices(0)[ self.replica_id]), \ ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._var_scope, reuse=self.replica_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) self.done = True finally: self.has_paused.set()
def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: def initial_value_fn(device=d): if context.executing_eagerly(): init_value = value_list[0].value() return array_ops.identity(init_value) else: with ops.device(device): init_value = value_list[0].initial_value return array_ops.identity(init_value) kwargs["initial_value"] = initial_value_fn with context.device_policy(context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list
def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring initial_value = None value_list = [] for i, d in enumerate(devices): with ops.device(d): if i == 0: initial_value = kwargs["initial_value"] # TODO(b/134779280): Remove initialization scope once the # "Tensor-typed variable initializers must either be wrapped in an " # "init_scope or callable" error is fixed. with ops.init_scope(): initial_value = initial_value() if callable( initial_value) else initial_value if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) kwargs["initial_value"] = initial_value with context.device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list
def testCopyScope(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') constant = constant_op.constant(1.0) with ops.device('gpu:0'): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): c = constant + 1.0 self.assertAllEqual(c, 2.0)
def testCopyScope(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') constant = constant_op.constant(1.0) with ops.device('gpu:0'): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): c = constant + 1.0 self.assertAllEqual(c, 2.0)
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, reduce_op): # pylint: disable=g-missing-docstring all_values = per_replica_value.values if not all_values: raise ValueError("`per_replica_value` must be non-empty") count = len(all_values) with ops.device(reduce_to_device): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) if reduce_op == reduce_util.ReduceOp.MEAN: reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) elif reduce_op != reduce_util.ReduceOp.SUM: raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") return reduced
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, reduce_op): # pylint: disable=g-missing-docstring all_values = per_replica_value.values if not all_values: raise ValueError("`per_replica_value` must be non-empty") count = len(all_values) with ops.device(reduce_to_device): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) if reduce_op == reduce_util.ReduceOp.MEAN: reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) elif reduce_op != reduce_util.ReduceOp.SUM: raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") return reduced
def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) # Only the first device participles in the broadcast of initial values. group_key = self._collective_keys.get_group_key([devices[0]]) group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] if callable(initial_value): initial_value_fn = initial_value else: initial_value_fn = lambda: initial_value value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i == 0: # The initial value fn makes sure variables all initialized to # same values. The first device of the chief worker will send their # variable values to other workers. def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring with ops.device(device): initial_value = initial_value_fn() assert not callable(initial_value) initial_value = ops.convert_to_tensor( initial_value) assert index == 0, index if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies( [bcast_send]): return array_ops.identity( initial_value) else: return collective_ops.broadcast_recv( initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Variables on non-first replica get initial values from the # variables created on the first device of each worker. def _overridden_initial_value_fn(device=d, index=i): assert index > 0 with ops.device(device): if context.executing_eagerly(): return array_ops.identity( value_list[0].value()) else: return array_ops.identity( value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn with context.device_policy( context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list
def testCopyScope(self): constant = constant_op.constant(1.0) with ops.device('gpu:0'): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): c = constant + 1.0 self.assertAllEqual(c, 2.0)
def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) # Only the first device participles in the broadcast of initial values. group_key = self._collective_keys.get_group_key([devices[0]]) group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] if callable(initial_value): initial_value_fn = initial_value else: initial_value_fn = lambda: initial_value value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i == 0: # The initial value fn makes sure variables all initialized to # same values. The first device of the chief worker will send their # variable values to other workers. def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring with ops.device(device): initial_value = initial_value_fn() assert not callable(initial_value) initial_value = ops.convert_to_tensor(initial_value) assert index == 0, index if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies([bcast_send]): return array_ops.identity(initial_value) else: return collective_ops.broadcast_recv( initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Variables on non-first replica get initial values from the # variables created on the first device of each worker. def _overridden_initial_value_fn(device=d, index=i): assert index > 0 with ops.device(device): if context.executing_eagerly(): return array_ops.identity(value_list[0].value()) else: return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn with context.device_policy(context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list
def testCopyScope(self): constant = constant_op.constant(1.0) with ops.device('gpu:0'): with context.device_policy(context.DEVICE_PLACEMENT_SILENT): c = constant + 1.0 self.assertAllEqual(c, 2.0)