Exemplo n.º 1
0
    def _create_variable(self, next_creator, *args, **kwargs):
        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            device_map = self._device_map
            logical_device = 0  # TODO(josh11b): Get logical device from scope here.
        else:
            device_map = colocate_with.device_map
            logical_device = colocate_with.logical_device
        group_size = device_map.num_replicas_in_graph * self._num_workers
        group_key = self._collective_keys.get_group_key(self.worker_devices)

        def _real_mirrored_creator(devices, *args, **kwargs):
            """Creates one MirroredVariable on the current worker."""
            value_list = []
            unique_var_name = ops.get_default_graph().unique_name(
                kwargs["name"], mark_as_used=False).rstrip("/")
            collective_instance_key = self._collective_keys.get_instance_key(
                key_id=unique_var_name)
            if "initial_value" not in kwargs:
                raise ValueError("Initial value must be specified.")
            initial_value = kwargs["initial_value"]
            if callable(initial_value):
                initial_value_fn = initial_value
            else:
                initial_value_fn = lambda: initial_value

            for i, d in enumerate(devices):
                with ops.device(d):
                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)

                    # The initial value fn makes sure variables all initialized to
                    # same values. The first device of the chief worker will send their
                    # variable values to other devices and other workers.
                    def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
                        with ops.device(device):
                            initial_value = initial_value_fn()
                            assert not callable(initial_value)
                            initial_value = ops.convert_to_tensor(
                                initial_value)

                            if self._is_chief and index == 0:
                                bcast_send = collective_ops.broadcast_send(
                                    initial_value, initial_value.shape,
                                    initial_value.dtype, group_size, group_key,
                                    collective_instance_key)
                                with ops.control_dependencies([bcast_send]):
                                    return array_ops.identity(initial_value)
                            else:
                                return collective_ops.broadcast_recv(
                                    initial_value.shape, initial_value.dtype,
                                    group_size, group_key,
                                    collective_instance_key)

                    kwargs["initial_value"] = _overridden_initial_value_fn

                    with context.context().device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        v = next_creator(*args, **kwargs)

                    if i == 0:
                        actual_var_name = v.name.split(":")[0]
                        assert unique_var_name == actual_var_name, "%r vs %r" % (
                            unique_var_name, actual_var_name)
                    assert not isinstance(v, values.DistributedVariable)
                    value_list.append(v)
            return value_list

        # pylint: disable=protected-access
        return mirrored_strategy._create_mirrored_variable(
            device_map, logical_device, _real_mirrored_creator, *args,
            **kwargs)
  def _create_variable(self, next_creator, *args, **kwargs):
    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)
    group_size = len(devices) * self._num_workers
    group_key = self._collective_keys.get_group_key(self._devices)

    def _real_mirrored_creator(devices, *args, **kwargs):
      """Creates one MirroredVariable on the current worker."""
      index = {}
      collective_instance_key = self._collective_keys.get_instance_key(
          key_id=kwargs["name"])
      if "initial_value" not in kwargs:
        raise ValueError("Initial value must be specified.")
      initial_value = kwargs["initial_value"]
      if callable(initial_value):
        initial_value_fn = initial_value
      else:
        initial_value_fn = lambda: initial_value

      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            # We append a / to variable names created on towers with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)

          # The initial value fn makes sure variables all initialized to
          # same values. The first device of the chief worker will send their
          # variable values to other devices and other workers.
          def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
            with ops.device(device):
              initial_value = initial_value_fn()
              assert not callable(initial_value)
              initial_value = ops.convert_to_tensor(initial_value)

              if self._is_chief and index == 0:
                bcast_send = collective_ops.broadcast_send(
                    initial_value, initial_value.shape, initial_value.dtype,
                    group_size, group_key, collective_instance_key)
                with ops.control_dependencies([bcast_send]):
                  return array_ops.identity(initial_value)
              else:
                return collective_ops.broadcast_recv(
                    initial_value.shape, initial_value.dtype, group_size,
                    group_key, collective_instance_key)

          kwargs["initial_value"] = _overridden_initial_value_fn

          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)

          assert not isinstance(v, values.DistributedVariable)
          index[d] = v
      return index

    # pylint: disable=protected-access
    return mirrored_strategy._create_mirrored_variable(
        devices, _real_mirrored_creator, *args, **kwargs)
  def _create_variable(self, next_creator, *args, **kwargs):
    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)
    group_size = len(devices) * self._num_workers
    group_key = self._collective_keys.get_group_key(self._devices)

    def _real_mirrored_creator(devices, *args, **kwargs):
      """Creates one MirroredVariable on the current worker."""
      index = {}
      unique_var_name = ops.get_default_graph().unique_name(
          kwargs["name"], mark_as_used=False).rstrip("/")
      collective_instance_key = self._collective_keys.get_instance_key(
          key_id=unique_var_name)
      if "initial_value" not in kwargs:
        raise ValueError("Initial value must be specified.")
      initial_value = kwargs["initial_value"]
      if callable(initial_value):
        initial_value_fn = initial_value
      else:
        initial_value_fn = lambda: initial_value

      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            # We append a / to variable names created on towers with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)

          # The initial value fn makes sure variables all initialized to
          # same values. The first device of the chief worker will send their
          # variable values to other devices and other workers.
          def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
            with ops.device(device):
              initial_value = initial_value_fn()
              assert not callable(initial_value)
              initial_value = ops.convert_to_tensor(initial_value)

              if self._is_chief and index == 0:
                bcast_send = collective_ops.broadcast_send(
                    initial_value, initial_value.shape, initial_value.dtype,
                    group_size, group_key, collective_instance_key)
                with ops.control_dependencies([bcast_send]):
                  return array_ops.identity(initial_value)
              else:
                return collective_ops.broadcast_recv(
                    initial_value.shape, initial_value.dtype, group_size,
                    group_key, collective_instance_key)

          kwargs["initial_value"] = _overridden_initial_value_fn

          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)

          if i == 0:
            actual_var_name = v.name.split(":")[0]
            assert unique_var_name == actual_var_name, "%r vs %r" % (
                unique_var_name, actual_var_name)
          assert not isinstance(v, values.DistributedVariable)
          index[d] = v
      return index

    # pylint: disable=protected-access
    return mirrored_strategy._create_mirrored_variable(
        devices, _real_mirrored_creator, *args, **kwargs)
  def _create_variable(self, next_creator, *args, **kwargs):
    colocate_with = kwargs.pop("colocate_with", None)
    if colocate_with is None:
      device_map = self._device_map
      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
      with ops.device(colocate_with.device):
        return next_creator(*args, **kwargs)
    else:
      device_map = colocate_with.device_map
      logical_device = colocate_with.logical_device

    def _real_mirrored_creator(devices, *args, **kwargs):
      """Creates one MirroredVariable on the current worker."""
      unique_var_name = ops.get_default_graph().unique_name(
          kwargs["name"], mark_as_used=False).rstrip("/")
      # pylint: disable=protected-access
      collective_instance_key = self._collective_keys.get_instance_key(
          key_id=unique_var_name)
      # Only the first device participles in the broadcast of initial values.
      group_key = self._collective_keys.get_group_key([devices[0]])
      group_size = self._num_workers
      if "initial_value" not in kwargs:
        raise ValueError("Initial value must be specified.")
      initial_value = kwargs["initial_value"]
      if callable(initial_value):
        initial_value_fn = initial_value
      else:
        initial_value_fn = lambda: initial_value

      value_list = []
      for i, d in enumerate(devices):
        with ops.init_scope(), ops.device(d):
          if i == 0:
            # The initial value fn makes sure variables all initialized to
            # same values. The first device of the chief worker will send their
            # variable values to other workers.
            def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
              with ops.device(device):
                initial_value = initial_value_fn()
                assert not callable(initial_value)
                initial_value = ops.convert_to_tensor(initial_value)

                assert index == 0, index
                if self._num_workers > 1:
                  if self._is_chief:
                    bcast_send = collective_ops.broadcast_send(
                        initial_value, initial_value.shape, initial_value.dtype,
                        group_size, group_key, collective_instance_key)
                    with ops.control_dependencies([bcast_send]):
                      return array_ops.identity(initial_value)
                  else:
                    return collective_ops.broadcast_recv(
                        initial_value.shape, initial_value.dtype, group_size,
                        group_key, collective_instance_key)
                return initial_value
          else:
            # Give replicas meaningful distinct names:
            var0name = value_list[0].name.split(":")[0]
            # We append a / to variable names created on replicas with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)

            # Variables on non-first replica get initial values from the
            # variables created on the first device of each worker.
            def _overridden_initial_value_fn(device=d, index=i):
              assert index > 0
              with ops.device(device):
                if context.executing_eagerly():
                  return array_ops.identity(value_list[0].value())
                else:
                  return array_ops.identity(value_list[0].initial_value)

          kwargs["initial_value"] = _overridden_initial_value_fn
          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            # Don't record operations (e.g. other variable reads) during
            # variable creation.
            with tape.stop_recording():
              v = next_creator(*args, **kwargs)

          if i == 0:
            actual_var_name = v.name.split(":")[0]
            assert unique_var_name == actual_var_name, "%r vs %r" % (
                unique_var_name, actual_var_name)
          assert not isinstance(v, values.DistributedVariable)
          value_list.append(v)
      return value_list

    # pylint: disable=protected-access
    return mirrored_strategy._create_mirrored_variable(
        self._container_strategy(), device_map, logical_device,
        _real_mirrored_creator, *args, **kwargs)