示例#1
0
  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

    Args:
      var: A `Variable` object.
      initializer: An `Initializer`.  The initial value of the slot.
      shape: Shape of the initial value of the slot.
      dtype: Type of the value of the slot.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    if _var_key(var) not in named_slots:
      new_slot_variable = slot_creator.create_slot_with_initializer(
          var, initializer, shape, dtype, op_name)
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=new_slot_variable)
      named_slots[_var_key(var)] = new_slot_variable
    return named_slots[_var_key(var)]
  def _get_or_make_slot_with_initializer(var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

        Args:
          var: A `Variable` object.
          initializer: An `Initializer`.  The initial value of the slot.
          shape: Shape of the initial value of the slot.
          dtype: Type of the value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        new_slot_variable = de.create_slots(var, initializer, slot_name,
                                            op_name)
      else:
        new_slot_variable = slot_creator.create_slot_with_initializer(
            var, initializer, shape, dtype, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]
示例#3
0
 def add_slot(self, var, slot_name, initializer="zeros"):
     """Add a new slot variable for `var`."""
     if slot_name not in self._slot_names:
         self._slot_names.append(slot_name)
     var_key = _var_key(var)
     slot_dict = self._slots.setdefault(var_key, {})
     weight = slot_dict.get(slot_name, None)
     if weight is None:
         if isinstance(initializer,
                       six.string_types) or callable(initializer):
             initializer = initializers.get(initializer)
             weight = slot_creator.create_slot_with_initializer(
                 var,
                 initializer,
                 shape=var.shape,
                 dtype=var.dtype,
                 name=slot_name)
         else:
             weight = slot_creator.create_slot(var, initializer, slot_name)
         backend.track_variable(weight)
         slot_dict[slot_name] = weight
         self._restore_slot_variable(slot_name=slot_name,
                                     variable=var,
                                     slot_variable=weight)
         self._weights.append(weight)
     return weight
示例#4
0
  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

    Args:
      var: A `Variable` object.
      initializer: An `Initializer`.  The initial value of the slot.
      shape: Shape of the initial value of the slot.
      dtype: Type of the value of the slot.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    if _var_key(var) not in named_slots:
      named_slots[_var_key(var)] = slot_creator.create_slot_with_initializer(
          var, initializer, shape, dtype, op_name)
    return named_slots[_var_key(var)]