예제 #1
0
    def _create_variable(self, next_creator, **kwargs):
        """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
        if kwargs.pop("skip_mirrored_creator", False):
            return next_creator(**kwargs)

        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            devices = self._tpu_devices[:, self._logical_device_stack[-1]]
        elif isinstance(colocate_with, numpy_dataset.SingleDevice):
            with ops.device(colocate_with.device):
                return next_creator(**kwargs)
        else:
            devices = colocate_with._devices  # pylint: disable=protected-access

        def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
            initial_value = None
            value_list = []
            for i, d in enumerate(devices):
                with ops.device(d):
                    if i == 0:
                        initial_value = kwargs["initial_value"]
                        # Note: some v1 code expects variable initializer creation to happen
                        # inside a init_scope.
                        with maybe_init_scope():
                            initial_value = initial_value() if callable(
                                initial_value) else initial_value

                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)
                    kwargs["initial_value"] = initial_value

                    with context.device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        v = next_creator(**kwargs)

                    assert not isinstance(v, tpu_values.TPUMirroredVariable)
                    value_list.append(v)
            return value_list

        return distribute_utils.create_mirrored_variable(
            self._container_strategy(), _real_mirrored_creator,
            distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
            distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
예제 #2
0
    def _create_variable(self, next_creator, **kwargs):
        """Create a mirrored variable. See `DistributionStrategy.scope`."""
        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            devices = self._devices
        elif isinstance(colocate_with, numpy_dataset.SingleDevice):
            with ops.device(colocate_with.device):
                return next_creator(**kwargs)
        else:
            devices = colocate_with._devices  # pylint: disable=protected-access

        def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
            value_list = []
            for i, d in enumerate(devices):
                with ops.device(d):
                    kwargs[
                        "initial_value"] = self._get_variable_creator_initial_value(
                            replica_id=i,
                            device=d,
                            primary_var=value_list[0] if value_list else None,
                            **kwargs)
                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)
                    with context.device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        # Don't record operations (e.g. other variable reads) during
                        # variable creation.
                        with tape.stop_recording():
                            v = next_creator(**kwargs)
                    assert not isinstance(v, values.DistributedVariable)
                    value_list.append(v)
            return value_list

        return distribute_utils.create_mirrored_variable(
            self._container_strategy(), _real_mirrored_creator,
            distribute_utils.VARIABLE_CLASS_MAPPING,
            distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)