def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" if kwargs.pop("tpu_embedding_variable_creator", False): return next_creator(*args, **kwargs) colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring initial_value = None value_list = [] for i, d in enumerate(devices): with ops.device(d): if i == 0: initial_value = kwargs["initial_value"] # Note: some v1 code expects variable initializer creation to happen # inside a init_scope. with maybe_init_scope(): initial_value = initial_value() if callable( initial_value) else initial_value if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) kwargs["initial_value"] = initial_value with context.device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list return distribute_lib.create_mirrored_variable( self._container_strategy(), device_map, logical_device, _real_mirrored_creator, values.TPUMirroredVariable, values.TPUSyncOnReadVariable, *args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.device(d): kwargs[ "initial_value"] = self._get_variable_creator_initial_value( replica_id=i, device=d, primary_var=value_list[0] if value_list else None, **kwargs) if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) with context.device_policy( context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list return distribute_lib.create_mirrored_variable( self._container_strategy(), device_map, logical_device, _real_mirrored_creator, values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)