def _get_or_make_slot_with_initializer(var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

        Args:
          var: A `Variable` object.
          initializer: An `Initializer`.  The initial value of the slot.
          shape: Shape of the initial value of the slot.
          dtype: Type of the value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        new_slot_variable = de.create_slots(var, initializer, slot_name,
                                            op_name)
      else:
        new_slot_variable = slot_creator.create_slot_with_initializer(
            var, initializer, shape, dtype, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]
    def _zeros_slot(var, slot_name, op_name):
        """Find or create a slot initialized with 0.0.

        Args:
          var: A `Variable` object.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
        named_slots = self._slot_dict(slot_name)
        if optimizer._var_key(var) not in named_slots:
            if isinstance(var, de.TrainableWrapper):
                new_slot_variable = de.create_slots(var, 0.0, slot_name,
                                                    op_name)
            else:
                new_slot_variable = slot_creator.create_zeros_slot(
                    var, op_name)
            self._restore_slot_variable(slot_name=slot_name,
                                        variable=var,
                                        slot_variable=new_slot_variable)
            named_slots[optimizer._var_key(var)] = new_slot_variable
        return named_slots[optimizer._var_key(var)]
  def _get_or_make_slot(var, val, slot_name, op_name):
    """Find or create a slot for a variable.

        Args:
          var: A `Variable` object.
          val: A `Tensor`.  The initial value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        # 如果是de变量,则调用自定义的create_slots()
        new_slot_variable = de.create_slots(var, val, slot_name, op_name)
      else:
        # 如果是正常变量,则调用系统的create_slots()
        new_slot_variable = slot_creator.create_slot(var, val, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]
Exemplo n.º 4
0
 def _process_slot_restoration(self, slot_restoration, variable):
     """Restore a slot variable's value (creating it if necessary)."""
     # TODO(allenl): Move this to Optimizer
     assert isinstance(self, optimizer_lib.Optimizer)
     named_slots = self._slot_dict(slot_restoration.slot_name)
     variable_key = optimizer_lib._var_key(variable)  # pylint: disable=protected-access
     existing_slot_variable = named_slots.get(variable_key, None)
     if existing_slot_variable is None:
         base_dtype = slot_restoration.value_pointer.dtype.base_dtype
         initializer, = io_ops.restore_v2(
             prefix=slot_restoration.value_pointer.save_path,
             tensor_names=[slot_restoration.value_pointer.checkpoint_key],
             shape_and_slices=[""],
             dtypes=[base_dtype],
             name="checkpoint_initializer")
         new_slot_variable = slot_creator.create_slot(
             variable, initializer, slot_restoration.slot_name)
         if slot_restoration.value_pointer.session is not None:
             slot_restoration.value_pointer.session.run(
                 new_slot_variable.initializer)
         named_slots[variable_key] = new_slot_variable
     else:
         _assign_existing_variable(
             existing_slot_variable,
             value_pointer=slot_restoration.value_pointer)
Exemplo n.º 5
0
    def set_slot_shadow(self, var, val, slot_name, replace=False):
        named_slots = self._slot_dict(slot_name + '_shadow')
        key = var if isinstance(var, str) else _var_key(var)

        if replace:
            assert key in named_slots
        else:
            assert key not in named_slots
        named_slots[key] = val
Exemplo n.º 6
0
    def _get_or_make_slot_with_initializer(self, var, initializer, shape,
                                           dtype, slot_name, op_name):
        """Find or create a slot for a variable, using an Initializer.

        Args:
          var: A `Variable` object.
          initializer: An `Initializer`.  The initial value of the slot.
          shape: Shape of the initial value of the slot.
          dtype: Type of the value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
        named_slots = self._slot_dict(slot_name)
        if _var_key(var) not in named_slots:
            with tf.variable_scope('slots', reuse=tf.AUTO_REUSE):
                named_slots[_var_key(var)] = create_slot_with_initializer(
                    var, initializer, shape, dtype, op_name)
        return named_slots[_var_key(var)]
Exemplo n.º 7
0
        def _zeros_slot(self, var, slot_name, op_name):
            """Find or create a slot initialized with 0.0.
            This is effectively a copy of the original TF optimizer method
            excepts this one allows to pass a dtype to `create_zeros_slot`.
            Args:
              var: A `Variable` object.
              slot_name: Name for the slot.
              op_name: Name to use when scoping the Variable that
                needs to be created for the slot.
            Returns:
              A `Variable` object.
            """
            named_slots = self._slot_dict(slot_name)
            if _var_key(var) not in named_slots:
                new_slot_variable = slot_creator.create_zeros_slot(var, op_name,
                                                                   dtype=self.slots_dtype)
                self._restore_slot_variable(
                    slot_name=slot_name, variable=var,
                    slot_variable=new_slot_variable)
                named_slots[_var_key(var)] = new_slot_variable

            return tf.cast(named_slots[_var_key(var)], var.dtype)
Exemplo n.º 8
0
 def _process_slot_restoration(self, slot_restoration, variable):
   """Restore a slot variable's value (creating it if necessary)."""
   # TODO(allenl): Move this to Optimizer
   assert isinstance(self, optimizer_lib.Optimizer)
   named_slots = self._slot_dict(slot_restoration.slot_name)
   variable_key = optimizer_lib._var_key(variable)  # pylint: disable=protected-access
   existing_slot_variable = named_slots.get(variable_key, None)
   if existing_slot_variable is None:
     base_dtype = slot_restoration.value_pointer.dtype.base_dtype
     initializer, = io_ops.restore_v2(
         prefix=slot_restoration.value_pointer.save_path,
         tensor_names=[slot_restoration.value_pointer.checkpoint_key],
         shape_and_slices=[""],
         dtypes=[base_dtype],
         name="checkpoint_initializer")
     new_slot_variable = slot_creator.create_slot(variable, initializer,
                                                  slot_restoration.slot_name)
     if slot_restoration.value_pointer.session is not None:
       slot_restoration.value_pointer.session.run(
           new_slot_variable.initializer)
     named_slots[variable_key] = new_slot_variable
   else:
     _assign_existing_variable(
         existing_slot_variable, value_pointer=slot_restoration.value_pointer)