def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) # Only the first device participles in the broadcast of initial values. group_key = self._collective_keys.get_group_key([devices[0]]) group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] if callable(initial_value): initial_value_fn = initial_value else: initial_value_fn = lambda: initial_value value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i == 0: # The initial value fn makes sure variables all initialized to # same values. The first device of the chief worker will send their # variable values to other workers. def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring with ops.device(device): initial_value = initial_value_fn() assert not callable(initial_value) initial_value = ops.convert_to_tensor( initial_value) assert index == 0, index if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies( [bcast_send]): return array_ops.identity( initial_value) else: return collective_ops.broadcast_recv( initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Variables on non-first replica get initial values from the # variables created on the first device of each worker. def _overridden_initial_value_fn(device=d, index=i): assert index > 0 with ops.device(device): if context.executing_eagerly(): return array_ops.identity( value_list[0].value()) else: return array_ops.identity( value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn with context.device_policy( context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list # pylint: disable=protected-access return mirrored_strategy._create_mirrored_variable( self._container_strategy(), device_map, logical_device, _real_mirrored_creator, *args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) # Only the first device participles in the broadcast of initial values. group_key = self._collective_keys.get_group_key([devices[0]]) group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] if callable(initial_value): initial_value_fn = initial_value else: initial_value_fn = lambda: initial_value value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i == 0: # The initial value fn makes sure variables all initialized to # same values. The first device of the chief worker will send their # variable values to other workers. def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring with ops.device(device): initial_value = initial_value_fn() assert not callable(initial_value) initial_value = ops.convert_to_tensor(initial_value) assert index == 0, index if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies([bcast_send]): return array_ops.identity(initial_value) else: return collective_ops.broadcast_recv( initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Variables on non-first replica get initial values from the # variables created on the first device of each worker. def _overridden_initial_value_fn(device=d, index=i): assert index > 0 with ops.device(device): if context.executing_eagerly(): return array_ops.identity(value_list[0].value()) else: return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn with context.device_policy(context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list # pylint: disable=protected-access return mirrored_strategy._create_mirrored_variable( self._container_strategy(), device_map, logical_device, _real_mirrored_creator, *args, **kwargs)