def _get_or_make_slot_with_initializer(var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

        Args:
          var: A `Variable` object.
          initializer: An `Initializer`.  The initial value of the slot.
          shape: Shape of the initial value of the slot.
          dtype: Type of the value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        new_slot_variable = de.create_slots(var, initializer, slot_name,
                                            op_name)
      else:
        new_slot_variable = slot_creator.create_slot_with_initializer(
            var, initializer, shape, dtype, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]
  def _get_or_make_slot(var, val, slot_name, op_name):
    """Find or create a slot for a variable.

        Args:
          var: A `Variable` object.
          val: A `Tensor`.  The initial value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        # 如果是de变量,则调用自定义的create_slots()
        new_slot_variable = de.create_slots(var, val, slot_name, op_name)
      else:
        # 如果是正常变量,则调用系统的create_slots()
        new_slot_variable = slot_creator.create_slot(var, val, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]
    def _zeros_slot(var, slot_name, op_name):
        """Find or create a slot initialized with 0.0.

        Args:
          var: A `Variable` object.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
        named_slots = self._slot_dict(slot_name)
        if optimizer._var_key(var) not in named_slots:
            if isinstance(var, de.TrainableWrapper):
                new_slot_variable = de.create_slots(var, 0.0, slot_name,
                                                    op_name)
            else:
                new_slot_variable = slot_creator.create_zeros_slot(
                    var, op_name)
            self._restore_slot_variable(slot_name=slot_name,
                                        variable=var,
                                        slot_variable=new_slot_variable)
            named_slots[optimizer._var_key(var)] = new_slot_variable
        return named_slots[optimizer._var_key(var)]
 def add_slot(var, slot_name, initializer="zeros"):
     """Add a new slot variable for `var`."""
     if slot_name not in self._slot_names:
         self._slot_names.append(slot_name)
     var_key = optimizer_v2._var_key(var)
     slot_dict = self._slots.setdefault(var_key, {})
     weight = slot_dict.get(slot_name, None)
     if weight is None:
         if isinstance(initializer,
                       six.string_types) or callable(initializer):
             initializer = initializers.get(initializer)
             initial_value = functools.partial(initializer,
                                               shape=var.shape,
                                               dtype=var.dtype)
         else:
             initial_value = initializer
         strategy = distribute_ctx.get_strategy()
         with strategy.extended.colocate_vars_with(var):
             if isinstance(var, de.TrainableWrapper):
                 weight = de.create_slots(var, initial_value, slot_name,
                                          var._shared_name)
             else:
                 weight = variables.Variable(
                     name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
                     dtype=var.dtype,
                     trainable=False,
                     initial_value=initial_value)
         backend.track_variable(weight)
         slot_dict[slot_name] = weight
         self._restore_slot_variable(slot_name=slot_name,
                                     variable=var,
                                     slot_variable=weight)
         self._weights.append(weight)
     return weight
    def add_slot(var, slot_name, initializer="zeros", shape=None):
        """Add a new slot variable for `var`."""
        if slot_name not in self._slot_names:
            self._slot_names.append(slot_name)
        var_key = optimizer_v2._var_key(var)
        slot_dict = self._slots.setdefault(var_key, {})
        weight = slot_dict.get(slot_name, None)
        if weight is None:
            if isinstance(initializer,
                          six.string_types) or callable(initializer):
                initializer = initializers.get(initializer)
                if isinstance(initializer,
                              trackable.CheckpointInitialValueCallable) or (
                                  shape is not None):
                    slot_shape = shape
                else:
                    slot_shape = var.shape
                initial_value = functools.partial(initializer,
                                                  shape=slot_shape,
                                                  dtype=var.dtype)
            else:
                initial_value = initializer
            with self._distribution_strategy_scope():
                strategy = distribute_ctx.get_strategy()
                if not strategy.extended.variable_created_in_scope(var):
                    raise ValueError(
                        "Trying to create optimizer slot variable under the scope for "
                        "tf.distribute.Strategy ({}), which is different from the scope "
                        "used for the original variable ({}). Make sure the slot "
                        "variables are created under the same strategy scope. This may "
                        "happen if you're restoring from a checkpoint outside the scope"
                        .format(strategy, var))

                with strategy.extended.colocate_vars_with(var):

                    if isinstance(var, de.TrainableWrapper):
                        weight = de.create_slots(var, initial_value, slot_name,
                                                 var._shared_name, self._bp_v2)
                    else:
                        weight = variables.Variable(
                            name="%s/%s" % (
                                var._shared_name,
                                slot_name,
                            ),  # pylint: disable=protected-access
                            dtype=var.dtype,
                            trainable=False,
                            initial_value=initial_value,
                        )
                backend.track_variable(weight)
                slot_dict[slot_name] = weight
                self._restore_slot_variable(slot_name=slot_name,
                                            variable=var,
                                            slot_variable=weight)
                self._weights.append(weight)
        return weight