def _get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name): """Find or create a slot for a variable, using an Initializer. Args: var: A `Variable` object. initializer: An `Initializer`. The initial value of the slot. shape: Shape of the initial value of the slot. dtype: Type of the value of the slot. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if optimizer._var_key(var) not in named_slots: if isinstance(var, de.TrainableWrapper): new_slot_variable = de.create_slots(var, initializer, slot_name, op_name) else: new_slot_variable = slot_creator.create_slot_with_initializer( var, initializer, shape, dtype, op_name) self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)]
def _get_or_make_slot(var, val, slot_name, op_name): """Find or create a slot for a variable. Args: var: A `Variable` object. val: A `Tensor`. The initial value of the slot. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if optimizer._var_key(var) not in named_slots: if isinstance(var, de.TrainableWrapper): # 如果是de变量,则调用自定义的create_slots() new_slot_variable = de.create_slots(var, val, slot_name, op_name) else: # 如果是正常变量,则调用系统的create_slots() new_slot_variable = slot_creator.create_slot(var, val, op_name) self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)]
def _zeros_slot(var, slot_name, op_name): """Find or create a slot initialized with 0.0. Args: var: A `Variable` object. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: A `Variable` object. """ named_slots = self._slot_dict(slot_name) if optimizer._var_key(var) not in named_slots: if isinstance(var, de.TrainableWrapper): new_slot_variable = de.create_slots(var, 0.0, slot_name, op_name) else: new_slot_variable = slot_creator.create_zeros_slot( var, op_name) self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=new_slot_variable) named_slots[optimizer._var_key(var)] = new_slot_variable return named_slots[optimizer._var_key(var)]
def add_slot(var, slot_name, initializer="zeros"): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = optimizer_v2._var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) initial_value = functools.partial(initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer strategy = distribute_ctx.get_strategy() with strategy.extended.colocate_vars_with(var): if isinstance(var, de.TrainableWrapper): weight = de.create_slots(var, initial_value, slot_name, var._shared_name) else: weight = variables.Variable( name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight
def add_slot(var, slot_name, initializer="zeros", shape=None): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = optimizer_v2._var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) if isinstance(initializer, trackable.CheckpointInitialValueCallable) or ( shape is not None): slot_shape = shape else: slot_shape = var.shape initial_value = functools.partial(initializer, shape=slot_shape, dtype=var.dtype) else: initial_value = initializer with self._distribution_strategy_scope(): strategy = distribute_ctx.get_strategy() if not strategy.extended.variable_created_in_scope(var): raise ValueError( "Trying to create optimizer slot variable under the scope for " "tf.distribute.Strategy ({}), which is different from the scope " "used for the original variable ({}). Make sure the slot " "variables are created under the same strategy scope. This may " "happen if you're restoring from a checkpoint outside the scope" .format(strategy, var)) with strategy.extended.colocate_vars_with(var): if isinstance(var, de.TrainableWrapper): weight = de.create_slots(var, initial_value, slot_name, var._shared_name, self._bp_v2) else: weight = variables.Variable( name="%s/%s" % ( var._shared_name, slot_name, ), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value, ) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight