def _create_variable(self, next_creator, *args, **kwargs): colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = values.ReplicaDeviceMap([self._variable_device]) logical_device = 0 elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_creator(devices, *args, **kwargs): assert len(devices) == 1 assert devices[0] == self._variable_device # The chief worker will initialize and broadcast the value to # the other workers. Always done on the host. kwargs["initial_value"] = self._get_variable_creator_initial_value( replica_id=0, # First (and only) replica on each worker. device=self._host_device, primary_var=None, **kwargs) # We always place sync-on-read variables on the IPU. They will # be transfered and reduced on the hosts only when read. synchronization = kwargs.get("synchronization") if (not self._variables_on_host or synchronization == variable_scope.VariableSynchronization.ON_READ): with ops.device(self._ipu_device): return [next_creator(*args, **kwargs)] # Cache a snapshot of the variable on the IPU device, # otherwise the XLA cluster containing the ops consuming the # variable might be moved to the host to be colocated with it. kwargs["caching_device"] = self._ipu_device # In case we are inside an ipu_jit_scope, we need to override it # to disable XLA for variable initialization on the host. disable_xla = { "_XlaCompile": attr_value_pb2.AttrValue(b=False), "_XlaScope": attr_value_pb2.AttrValue(s=b''), } graph = ops.get_default_graph() with ops.device(self._host_device), \ graph._attr_scope(disable_xla): # pylint: disable=protected-access return [next_creator(*args, **kwargs)] # For tf1: use distribute_lib.create_mirrored_variable return values.create_mirrored_variable(self._container_strategy(), device_map, logical_device, _real_creator, IPUMirroredVariable, IPUSyncOnReadVariable, *args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" if kwargs.pop("tpu_embedding_variable_creator", False): return next_creator(*args, **kwargs) colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring initial_value = None value_list = [] for i, d in enumerate(devices): with ops.device(d): if i == 0: initial_value = kwargs["initial_value"] # Note: some v1 code expects variable initializer creation to happen # inside a init_scope. with maybe_init_scope(): initial_value = initial_value() if callable( initial_value) else initial_value if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) kwargs["initial_value"] = initial_value with context.device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list return values.create_mirrored_variable(self._container_strategy(), device_map, logical_device, _real_mirrored_creator, values.TPUMirroredVariable, values.TPUSyncOnReadVariable, *args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: device_map = self._device_map logical_device = 0 # TODO(josh11b): Get logical device from scope here. elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): return next_creator(*args, **kwargs) else: device_map = colocate_with.device_map logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.device(d): kwargs[ "initial_value"] = self._get_variable_creator_initial_value( replica_id=i, device=d, primary_var=value_list[0] if value_list else None, **kwargs) if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) with context.device_policy( context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list return values.create_mirrored_variable(self._container_strategy(), device_map, logical_device, _real_mirrored_creator, values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)