示例#1
0
    def _create_variable(self, next_creator, *args, **kwargs):
        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            device_map = values.ReplicaDeviceMap([self._variable_device])
            logical_device = 0
        elif isinstance(colocate_with, numpy_dataset.SingleDevice):
            with ops.device(colocate_with.device):
                return next_creator(*args, **kwargs)
        else:
            device_map = colocate_with.device_map
            logical_device = colocate_with.logical_device

        def _real_creator(devices, *args, **kwargs):
            assert len(devices) == 1
            assert devices[0] == self._variable_device

            # The chief worker will initialize and broadcast the value to
            # the other workers. Always done on the host.
            kwargs["initial_value"] = self._get_variable_creator_initial_value(
                replica_id=0,  # First (and only) replica on each worker.
                device=self._host_device,
                primary_var=None,
                **kwargs)

            # We always place sync-on-read variables on the IPU. They will
            # be transfered and reduced on the hosts only when read.
            synchronization = kwargs.get("synchronization")
            if (not self._variables_on_host or synchronization
                    == variable_scope.VariableSynchronization.ON_READ):
                with ops.device(self._ipu_device):
                    return [next_creator(*args, **kwargs)]

            # Cache a snapshot of the variable on the IPU device,
            # otherwise the XLA cluster containing the ops consuming the
            # variable might be moved to the host to be colocated with it.
            kwargs["caching_device"] = self._ipu_device

            # In case we are inside an ipu_jit_scope, we need to override it
            # to disable XLA for variable initialization on the host.
            disable_xla = {
                "_XlaCompile": attr_value_pb2.AttrValue(b=False),
                "_XlaScope": attr_value_pb2.AttrValue(s=b''),
            }

            graph = ops.get_default_graph()
            with ops.device(self._host_device), \
                graph._attr_scope(disable_xla):  # pylint: disable=protected-access
                return [next_creator(*args, **kwargs)]

        # For tf1: use distribute_lib.create_mirrored_variable
        return values.create_mirrored_variable(self._container_strategy(),
                                               device_map, logical_device,
                                               _real_creator,
                                               IPUMirroredVariable,
                                               IPUSyncOnReadVariable, *args,
                                               **kwargs)
示例#2
0
    def _create_variable(self, next_creator, *args, **kwargs):
        """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
        if kwargs.pop("tpu_embedding_variable_creator", False):
            return next_creator(*args, **kwargs)

        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            device_map = self._device_map
            logical_device = 0  # TODO(josh11b): Get logical device from scope here.
        elif isinstance(colocate_with, numpy_dataset.SingleDevice):
            with ops.device(colocate_with.device):
                return next_creator(*args, **kwargs)
        else:
            device_map = colocate_with.device_map
            logical_device = colocate_with.logical_device

        def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
            initial_value = None
            value_list = []
            for i, d in enumerate(devices):
                with ops.device(d):
                    if i == 0:
                        initial_value = kwargs["initial_value"]
                        # Note: some v1 code expects variable initializer creation to happen
                        # inside a init_scope.
                        with maybe_init_scope():
                            initial_value = initial_value() if callable(
                                initial_value) else initial_value

                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)
                    kwargs["initial_value"] = initial_value

                    with context.device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        v = next_creator(*args, **kwargs)

                    assert not isinstance(v, values.TPUMirroredVariable)
                    value_list.append(v)
            return value_list

        return values.create_mirrored_variable(self._container_strategy(),
                                               device_map, logical_device,
                                               _real_mirrored_creator,
                                               values.TPUMirroredVariable,
                                               values.TPUSyncOnReadVariable,
                                               *args, **kwargs)
示例#3
0
    def _create_variable(self, next_creator, *args, **kwargs):
        """Create a mirrored variable. See `DistributionStrategy.scope`."""
        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            device_map = self._device_map
            logical_device = 0  # TODO(josh11b): Get logical device from scope here.
        elif isinstance(colocate_with, numpy_dataset.SingleDevice):
            with ops.device(colocate_with.device):
                return next_creator(*args, **kwargs)
        else:
            device_map = colocate_with.device_map
            logical_device = colocate_with.logical_device

        def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
            value_list = []
            for i, d in enumerate(devices):
                with ops.device(d):
                    kwargs[
                        "initial_value"] = self._get_variable_creator_initial_value(
                            replica_id=i,
                            device=d,
                            primary_var=value_list[0] if value_list else None,
                            **kwargs)
                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)
                    with context.device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        # Don't record operations (e.g. other variable reads) during
                        # variable creation.
                        with tape.stop_recording():
                            v = next_creator(*args, **kwargs)
                    assert not isinstance(v, values.DistributedVariable)
                    value_list.append(v)
            return value_list

        return values.create_mirrored_variable(self._container_strategy(),
                                               device_map, logical_device,
                                               _real_mirrored_creator,
                                               values.MirroredVariable,
                                               values.SyncOnReadVariable,
                                               *args, **kwargs)