def _create_variable(self, next_creator, *args, **kwargs):
        colocate_with = kwargs.pop("colocate_with", None)
        if colocate_with is None:
            device_map = self._device_map
            logical_device = 0  # TODO(josh11b): Get logical device from scope here.
        elif isinstance(colocate_with, numpy_dataset.SingleDevice):
            with ops.device(colocate_with.device):
                return next_creator(*args, **kwargs)
        else:
            device_map = colocate_with.device_map
            logical_device = colocate_with.logical_device

        def _real_mirrored_creator(devices, *args, **kwargs):
            """Creates one MirroredVariable on the current worker."""
            unique_var_name = ops.get_default_graph().unique_name(
                kwargs["name"], mark_as_used=False).rstrip("/")
            # pylint: disable=protected-access
            collective_instance_key = self._collective_keys.get_instance_key(
                key_id=unique_var_name)
            # Only the first device participles in the broadcast of initial values.
            group_key = self._collective_keys.get_group_key([devices[0]])
            group_size = self._num_workers
            if "initial_value" not in kwargs:
                raise ValueError("Initial value must be specified.")
            initial_value = kwargs["initial_value"]
            if callable(initial_value):
                initial_value_fn = initial_value
            else:
                initial_value_fn = lambda: initial_value

            value_list = []
            for i, d in enumerate(devices):
                with ops.init_scope(), ops.device(d):
                    if i == 0:
                        # The initial value fn makes sure variables all initialized to
                        # same values. The first device of the chief worker will send their
                        # variable values to other workers.
                        def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
                            with ops.device(device):
                                initial_value = initial_value_fn()
                                assert not callable(initial_value)
                                initial_value = ops.convert_to_tensor(
                                    initial_value)

                                assert index == 0, index
                                if self._num_workers > 1:
                                    if self._is_chief:
                                        bcast_send = collective_ops.broadcast_send(
                                            initial_value, initial_value.shape,
                                            initial_value.dtype, group_size,
                                            group_key, collective_instance_key)
                                        with ops.control_dependencies(
                                            [bcast_send]):
                                            return array_ops.identity(
                                                initial_value)
                                    else:
                                        return collective_ops.broadcast_recv(
                                            initial_value.shape,
                                            initial_value.dtype, group_size,
                                            group_key, collective_instance_key)
                                return initial_value
                    else:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)

                        # Variables on non-first replica get initial values from the
                        # variables created on the first device of each worker.
                        def _overridden_initial_value_fn(device=d, index=i):
                            assert index > 0
                            with ops.device(device):
                                if context.executing_eagerly():
                                    return array_ops.identity(
                                        value_list[0].value())
                                else:
                                    return array_ops.identity(
                                        value_list[0].initial_value)

                    kwargs["initial_value"] = _overridden_initial_value_fn
                    with context.device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        # Don't record operations (e.g. other variable reads) during
                        # variable creation.
                        with tape.stop_recording():
                            v = next_creator(*args, **kwargs)

                    if i == 0:
                        actual_var_name = v.name.split(":")[0]
                        assert unique_var_name == actual_var_name, "%r vs %r" % (
                            unique_var_name, actual_var_name)
                    assert not isinstance(v, values.DistributedVariable)
                    value_list.append(v)
            return value_list

        # pylint: disable=protected-access
        return mirrored_strategy._create_mirrored_variable(
            self._container_strategy(), device_map, logical_device,
            _real_mirrored_creator, *args, **kwargs)
  def _create_variable(self, next_creator, *args, **kwargs):
    colocate_with = kwargs.pop("colocate_with", None)
    if colocate_with is None:
      device_map = self._device_map
      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
      with ops.device(colocate_with.device):
        return next_creator(*args, **kwargs)
    else:
      device_map = colocate_with.device_map
      logical_device = colocate_with.logical_device

    def _real_mirrored_creator(devices, *args, **kwargs):
      """Creates one MirroredVariable on the current worker."""
      unique_var_name = ops.get_default_graph().unique_name(
          kwargs["name"], mark_as_used=False).rstrip("/")
      # pylint: disable=protected-access
      collective_instance_key = self._collective_keys.get_instance_key(
          key_id=unique_var_name)
      # Only the first device participles in the broadcast of initial values.
      group_key = self._collective_keys.get_group_key([devices[0]])
      group_size = self._num_workers
      if "initial_value" not in kwargs:
        raise ValueError("Initial value must be specified.")
      initial_value = kwargs["initial_value"]
      if callable(initial_value):
        initial_value_fn = initial_value
      else:
        initial_value_fn = lambda: initial_value

      value_list = []
      for i, d in enumerate(devices):
        with ops.init_scope(), ops.device(d):
          if i == 0:
            # The initial value fn makes sure variables all initialized to
            # same values. The first device of the chief worker will send their
            # variable values to other workers.
            def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
              with ops.device(device):
                initial_value = initial_value_fn()
                assert not callable(initial_value)
                initial_value = ops.convert_to_tensor(initial_value)

                assert index == 0, index
                if self._num_workers > 1:
                  if self._is_chief:
                    bcast_send = collective_ops.broadcast_send(
                        initial_value, initial_value.shape, initial_value.dtype,
                        group_size, group_key, collective_instance_key)
                    with ops.control_dependencies([bcast_send]):
                      return array_ops.identity(initial_value)
                  else:
                    return collective_ops.broadcast_recv(
                        initial_value.shape, initial_value.dtype, group_size,
                        group_key, collective_instance_key)
                return initial_value
          else:
            # Give replicas meaningful distinct names:
            var0name = value_list[0].name.split(":")[0]
            # We append a / to variable names created on replicas with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)

            # Variables on non-first replica get initial values from the
            # variables created on the first device of each worker.
            def _overridden_initial_value_fn(device=d, index=i):
              assert index > 0
              with ops.device(device):
                if context.executing_eagerly():
                  return array_ops.identity(value_list[0].value())
                else:
                  return array_ops.identity(value_list[0].initial_value)

          kwargs["initial_value"] = _overridden_initial_value_fn
          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
            # Don't record operations (e.g. other variable reads) during
            # variable creation.
            with tape.stop_recording():
              v = next_creator(*args, **kwargs)

          if i == 0:
            actual_var_name = v.name.split(":")[0]
            assert unique_var_name == actual_var_name, "%r vs %r" % (
                unique_var_name, actual_var_name)
          assert not isinstance(v, values.DistributedVariable)
          value_list.append(v)
      return value_list

    # pylint: disable=protected-access
    return mirrored_strategy._create_mirrored_variable(
        self._container_strategy(), device_map, logical_device,
        _real_mirrored_creator, *args, **kwargs)