def var_key_test(): self.assertFalse(a._in_graph_mode) self.assertFalse(b._in_graph_mode) var_key_a = optimizer_v2._var_key(a) self.assertStartsWith(var_key_a, 'var_') var_key_b = optimizer_v2._var_key(b) self.assertStartsWith(var_key_b, 'var_') self.assertNotEquals(var_key_a, var_key_b)
def testVarKey(self): with context.graph_mode(): a = variables.Variable([1., 2.], name='var') b = variables.Variable([1.], name='var') self.assertTrue(a._in_graph_mode) self.assertTrue(b._in_graph_mode) var_key = optimizer_v2._var_key(a) self.assertEqual('var', var_key) var_key = optimizer_v2._var_key(b) self.assertEqual('var_1', var_key)
def testVarKey(self): with ops.get_default_graph().as_default(): a = variables.Variable([1., 2.], name='var') b = variables.Variable([1.], name='var') self.assertTrue(a._in_graph_mode) self.assertTrue(b._in_graph_mode) var_key = optimizer_v2._var_key(a) self.assertEqual('var', var_key) var_key = optimizer_v2._var_key(b) self.assertEqual('var_1', var_key)
def _track_optimizer_slots(self, slots): if not all(isinstance(s, TrainableWrapper) for s in slots): raise TypeError( 'Can only track TrainableWrapper slots, but get {}'.format( [type(s) for s in slots])) identifiers = [optimizer_v2._var_key(s) for s in self._tracked_slots] for s in slots: if optimizer_v2._var_key(s) not in identifiers: self._tracked_slots.append(s) if self.params.restrict_policy is not None: self.params.restrict_policy._track_params_from_optimizer_slots(slots)
def add_slot(var, slot_name, initializer="zeros"): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = optimizer_v2._var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) initial_value = functools.partial(initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer strategy = distribute_ctx.get_strategy() with strategy.extended.colocate_vars_with(var): if isinstance(var, de.TrainableWrapper): weight = de.create_slots(var, initial_value, slot_name, var._shared_name) else: weight = variables.Variable( name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight
def add_slot(var, slot_name, initializer="zeros", shape=None): """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = optimizer_v2._var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) if isinstance(initializer, trackable.CheckpointInitialValueCallable) or ( shape is not None): slot_shape = shape else: slot_shape = var.shape initial_value = functools.partial(initializer, shape=slot_shape, dtype=var.dtype) else: initial_value = initializer with self._distribution_strategy_scope(): strategy = distribute_ctx.get_strategy() if not strategy.extended.variable_created_in_scope(var): raise ValueError( "Trying to create optimizer slot variable under the scope for " "tf.distribute.Strategy ({}), which is different from the scope " "used for the original variable ({}). Make sure the slot " "variables are created under the same strategy scope. This may " "happen if you're restoring from a checkpoint outside the scope" .format(strategy, var)) with strategy.extended.colocate_vars_with(var): if isinstance(var, de.TrainableWrapper): weight = de.create_slots(var, initial_value, slot_name, var._shared_name, self._bp_v2) else: weight = variables.Variable( name="%s/%s" % ( var._shared_name, slot_name, ), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value, ) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight
def create_slots(primary, init, slot_name, op_name, bp_v2): """Helper function for creating a slot variable for statefull optimizers.""" params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, kv_creator=params_var_.kv_creator, trainable=False, checkpoint=params_var_.checkpoint, bp_v2=bp_v2 if bp_v2 is not None else params_var_.bp_v2, ) scope_store._vars[full_name] = slot_variable_ # Record the optimizer Variable into trace. primary._optimizer_vars.append(slot_variable_) slot_trainable = None if context.executing_eagerly(): slot_tw_name = slot_name + '-' + str(optimizer_v2._var_key(primary)) else: # In graph mode of former version, It only uses slot_name as name to # trainable wrappers of slots. So here set it the name to slot_name # for forward compatibility. slot_tw_name = slot_name if isinstance(primary, de.shadow_ops.ShadowVariable): slot_trainable = de.shadow_ops.ShadowVariable( params=scope_store._vars[full_name], ids=primary.ids, exists=primary.exists, name=full_name, trainable=False, ) else: _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_tw_name, return_trainable=True, ) return slot_trainable
def add_slot(self, var, slot_name, initializer="zeros", manifold_wise=False): rank = self.manifold.rank # rank of tensot of a manifold """Add a new slot variable for `var`.""" if slot_name not in self._slot_names: self._slot_names.append(slot_name) var_key = opt._var_key(var) slot_dict = self._slots.setdefault(var_key, {}) weight = slot_dict.get(slot_name, None) if weight is None: if isinstance(initializer, six.string_types) or callable(initializer): initializer = initializers.get(initializer) if manifold_wise: initial_value = functools.partial( initializer, shape=var.shape[:-rank - 1] + rank * (1, ) + (2, ), dtype=var.dtype) else: initial_value = functools.partial(initializer, shape=var.shape, dtype=var.dtype) else: initial_value = initializer strategy = distribute_ctx.get_strategy() if not strategy.extended.variable_created_in_scope(var): raise ValueError( "Trying to create optimizer slot variable under the scope for " "tf.distribute.Strategy ({}), which is different from the scope " "used for the original variable ({}). Make sure the slot " "variables are created under the same strategy scope. This may " "happen if you're restoring from a checkpoint outside the scope" .format(strategy, var)) with strategy.extended.colocate_vars_with(var): weight = tf_variables.Variable( name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access dtype=var.dtype, trainable=False, initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable(slot_name=slot_name, variable=var, slot_variable=weight) self._weights.append(weight) return weight