def _get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name): """Find or create a slot for a variable, using an Initializer. Args: var: A `Variable` object. initializer: An `Initializer`. The initial value of the slot. shape: Shape of the initial value of the slot. dtype: Type of the value of the slot. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if optimizer._var_key(var) not in named_slots: if isinstance(var, de.TrainableWrapper): new_slot_variable = de.create_slots(var, initializer, slot_name, op_name) else: new_slot_variable = slot_creator.create_slot_with_initializer( var, initializer, shape, dtype, op_name) self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)]
def _zeros_slot(var, slot_name, op_name): """Find or create a slot initialized with 0.0. Args: var: A `Variable` object. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if optimizer._var_key(var) not in named_slots: if isinstance(var, de.TrainableWrapper): new_slot_variable = de.create_slots(var, 0.0, slot_name, op_name) else: new_slot_variable = slot_creator.create_zeros_slot( var, op_name) self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)]
def _get_or_make_slot(var, val, slot_name, op_name): """Find or create a slot for a variable. Args: var: A `Variable` object. val: A `Tensor`. The initial value of the slot. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if optimizer._var_key(var) not in named_slots: if isinstance(var, de.TrainableWrapper): # 如果是de变量,则调用自定义的create_slots() new_slot_variable = de.create_slots(var, val, slot_name, op_name) else: # 如果是正常变量,则调用系统的create_slots() new_slot_variable = slot_creator.create_slot(var, val, op_name) self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)]
def _process_slot_restoration(self, slot_restoration, variable): """Restore a slot variable's value (creating it if necessary).""" # TODO(allenl): Move this to Optimizer assert isinstance(self, optimizer_lib.Optimizer) named_slots = self._slot_dict(slot_restoration.slot_name) variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access existing_slot_variable = named_slots.get(variable_key, None) if existing_slot_variable is None: base_dtype = slot_restoration.value_pointer.dtype.base_dtype initializer, = io_ops.restore_v2( prefix=slot_restoration.value_pointer.save_path, tensor_names=[slot_restoration.value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_dtype], name="checkpoint_initializer") new_slot_variable = slot_creator.create_slot( variable, initializer, slot_restoration.slot_name) if slot_restoration.value_pointer.session is not None: slot_restoration.value_pointer.session.run( new_slot_variable.initializer) named_slots[variable_key] = new_slot_variable else: _assign_existing_variable( existing_slot_variable, value_pointer=slot_restoration.value_pointer)
def set_slot_shadow(self, var, val, slot_name, replace=False): named_slots = self._slot_dict(slot_name + '_shadow') key = var if isinstance(var, str) else _var_key(var) if replace: assert key in named_slots else: assert key not in named_slots named_slots[key] = val
def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, slot_name, op_name): """Find or create a slot for a variable, using an Initializer. Args: var: A `Variable` object. initializer: An `Initializer`. The initial value of the slot. shape: Shape of the initial value of the slot. dtype: Type of the value of the slot. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: with tf.variable_scope('slots', reuse=tf.AUTO_REUSE): named_slots[_var_key(var)] = create_slot_with_initializer( var, initializer, shape, dtype, op_name) return named_slots[_var_key(var)]
def _zeros_slot(self, var, slot_name, op_name): """Find or create a slot initialized with 0.0. This is effectively a copy of the original TF optimizer method excepts this one allows to pass a dtype to `create_zeros_slot`. Args: var: A `Variable` object. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: new_slot_variable = slot_creator.create_zeros_slot(var, op_name, dtype=self.slots_dtype) self._restore_slot_variable( slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[_var_key(var)] = new_slot_variable return tf.cast(named_slots[_var_key(var)], var.dtype)
def _process_slot_restoration(self, slot_restoration, variable): """Restore a slot variable's value (creating it if necessary).""" # TODO(allenl): Move this to Optimizer assert isinstance(self, optimizer_lib.Optimizer) named_slots = self._slot_dict(slot_restoration.slot_name) variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access existing_slot_variable = named_slots.get(variable_key, None) if existing_slot_variable is None: base_dtype = slot_restoration.value_pointer.dtype.base_dtype initializer, = io_ops.restore_v2( prefix=slot_restoration.value_pointer.save_path, tensor_names=[slot_restoration.value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_dtype], name="checkpoint_initializer") new_slot_variable = slot_creator.create_slot(variable, initializer, slot_restoration.slot_name) if slot_restoration.value_pointer.session is not None: slot_restoration.value_pointer.session.run( new_slot_variable.initializer) named_slots[variable_key] = new_slot_variable else: _assign_existing_variable( existing_slot_variable, value_pointer=slot_restoration.value_pointer)