def _infer_mht_saveable_name(ms): """Returns name of the `MutableHashTable._Saveable` Args: ms: A `MutableHashTable._Saveable` Returns: Name of the `MutableHashTable._Saveable` """ name_to_ms_dict = saveable_object_util.op_list_to_dict([ms]) if len(name_to_ms_dict) > 1: raise TypeError("`ms` = %s passed as arg violates the constraints. " "name_to_var_dict = %s" % (ms, name_to_ms_dict)) return list(name_to_ms_dict.keys())[0]
def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver copy global_center_variable to trainable variables Please call this function after all your variables created with ElasticAverageCustomGetter. For evaluations or inference, use this saver during training. It will save the global_center_variable of the trained parameters under the original parameter names. Args: var_list: List of variables to save, as per `Saver()`. If set to None, save all the trainable_variables that have been created before this call. name: The name of the saver. **kwargs: Keyword arguments of `Saver()`. Returns: A `tf.compat.v1.train.Saver` object. Raises: RuntimeError: global_center_variable is empty, please make sure this is called after model created and ElasticAverageCustomGetter is used when declaring you model """ if not self._global_map: raise RuntimeError( 'global_center_variable is empty, please make sure ' 'this is called after model created and ' 'ElasticAverageCustomGetter is used when declaring ' 'you model') if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): tensor = var if not isinstance(var, list): for tvar in variables.trainable_variables(): if tvar.op.name == var.op.name: tensor = self._global_map.get(tvar, var) break else: #partitioned variable tensor = [self._global_map.get(lvar, lvar) for lvar in var] swapped_var_list[key] = tensor return saver.Saver(swapped_var_list, name=name, **kwargs)
def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver copy global_center_variable to trainable variables Please call this function after all your variables created with ElasticAverageCustomGetter. For evaluations or inference, use this saver during training. It will save the global_center_variable of the trained parameters under the original parameter names. Args: var_list: List of variables to save, as per `Saver()`. If set to None, save all the trainable_variables that have been created before this call. name: The name of the saver. **kwargs: Keyword arguments of `Saver()`. Returns: A `tf.compat.v1.train.Saver` object. Raises: RuntimeError: global_center_variable is empty, please make sure this is called after model created and ElasticAverageCustomGetter is used when declaring you model """ if not self._global_map: raise RuntimeError('global_center_variable is empty, please make sure ' 'this is called after model created and ' 'ElasticAverageCustomGetter is used when declaring ' 'you model') if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): tensor = var if not isinstance(var, list): for tvar in variables.trainable_variables(): if tvar.op.name == var.op.name: tensor = self._global_map.get(tvar, var) break else: #partitioned variable tensor = [self._global_map.get(lvar, lvar) for lvar in var] swapped_var_list[key] = tensor return saver.Saver(swapped_var_list, name=name, **kwargs)
def _infer_var_name(var): """Returns name of the `var`. Args: var: A list. The list can contain either of the following: (i) A single `Variable` (ii) A single `ResourceVariable` (iii) Multiple `Variable` objects which must be slices of the same larger variable. (iv) A single `PartitionedVariable` Returns: Name of the `var` """ name_to_var_dict = saveable_object_util.op_list_to_dict(var) if len(name_to_var_dict) > 1: raise TypeError("`var` = %s passed as arg violates the constraints. " "name_to_var_dict = %s" % (var, name_to_var_dict)) return list(name_to_var_dict.keys())[0]
def _infer_var_name(var): """Returns name of the `var`. Args: var: A list. The list can contain either of the following: (i) A single `Variable` (ii) A single `ResourceVariable` (iii) Multiple `Variable` objects which must be slices of the same larger variable. (iv) A single `PartitionedVariable` Returns: Name of the `var` """ name_to_var_dict = saveable_object_util.op_list_to_dict(var) if len(name_to_var_dict) > 1: raise TypeError("`var` = %s passed as arg violates the constraints. " "name_to_var_dict = %s" % (var, name_to_var_dict)) return list(name_to_var_dict.keys())[0]
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] names_to_saveables = saveable_object_util.op_list_to_dict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): for s in saveable_object_util.saveable_objects_for_op(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] names_to_saveables = saveable_object_util.op_list_to_dict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): for s in saveable_object_util.saveable_objects_for_op(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def _add_attributes_to_object_graph(self, trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for checkpoint_id, (trackable, object_proto) in enumerate( zip(trackable_objects, object_graph_proto.nodes)): assert node_ids[trackable] == checkpoint_id object_name = object_names[trackable] if object_map is None: object_to_save = trackable else: object_to_save = object_map.get(trackable, trackable) if self._saveables_cache is not None: cached_attributes = self._saveables_cache.setdefault( object_to_save, {}) else: cached_attributes = None for name, saveable_factory in ( object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) if cached_attributes is None: saveables = None else: saveables = cached_attributes.get(name, None) if saveables is not None: for saveable in saveables: if attribute.checkpoint_key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, attribute.checkpoint_key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=attribute.checkpoint_key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if attribute.checkpoint_key not in saveable.name: raise AssertionError(( "The object %s produced a SaveableObject with name '%s' for " "attribute '%s'. Expected a name containing '%s'." ) % (trackable, name, saveable.name, attribute.checkpoint_key)) if cached_attributes is not None: cached_attributes[name] = saveables optional_restore = None for saveable in saveables: if optional_restore is None: optional_restore = saveable.optional_restore else: optional_restore = optional_restore and saveable.optional_restore if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError(( "The object %s tried to feed a value for the Tensor %s " "when saving, but another object is already feeding a " "value.") % (trackable, new_feed_key)) feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) if optional_restore is None: optional_restore = False attribute.optional_restore = optional_restore return named_saveable_objects, feed_additions
def _add_attributes_to_object_graph_for_saveable_objects( self, checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for trackable, factory_data_list in checkpoint_factory_map.items(): object_proto = object_graph_proto.nodes[node_ids[trackable]] if self._saveables_cache is not None: object_to_save = _get_mapped_trackable(trackable, object_map) cached_attributes = self._saveables_cache.setdefault( object_to_save, {}) else: cached_attributes = None for factory_data in factory_data_list: attribute = object_proto.attributes.add() attribute.name = name = factory_data.name attribute.checkpoint_key = key = factory_data.checkpoint_key saveable_factory = factory_data.factory # See if we can skip saving this checkpoint key. saveables = cached_attributes.get( name) if cached_attributes else None if saveables is not None: for saveable in saveables: if key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if key not in saveable.name: raise AssertionError( f"The object {trackable} produced a SaveableObject with name " f"'{saveable.name}' for attribute '{name}'. Expected a name" f" containing '{key}'.") if cached_attributes is not None: cached_attributes[name] = saveables for saveable in saveables: if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( f"The object {trackable} tried to feed a value for the " f"Tensor {new_feed_key} when saving, but another object " "is already feeding a value.") feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) return named_saveable_objects, feed_additions
def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. You should use this saver during training. It will save the moving averages of the trained parameters under the original parameter names. For evaluations or inference you should use a regular saver and it will automatically use the moving averages for the trained variable. You must call this function after all variables have been created and after you have called Optimizer.minimize(). Args: var_list: List of variables to save, as per `Saver()`. If set to None, will save all the variables that have been created before this call. name: The name of the saver. **kwargs: Keyword arguments of `Saver()`. Returns: A `tf.compat.v1.train.Saver` object. Raises: RuntimeError: If apply_gradients or minimize has not been called before. ValueError: If var_list is provided and contains some variables but not their moving average counterpart. """ if self._swapped_variable_name_map is None: raise RuntimeError('Must call apply_gradients or minimize before ' 'creating the swapping_saver') if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): # For each partitioned variable OpListToDict returns list of constituent # parts instead of single tensor. if (isinstance(tensor_or_list, list) or isinstance(tensor_or_list, variables.PartitionedVariable)): for tensor in tensor_or_list: v_name = tensor.op.name v_name_to_tensor[v_name] = tensor else: v_name_to_tensor[k] = tensor_or_list # Now swap variables and moving averages swapped_var_list = {} for k, tensor_or_list in six.iteritems(var_list): if isinstance(tensor_or_list, list): tensor_list_to_save = [] for tensor in tensor_or_list: v_name = tensor.op.name swapped_variable = self._find_swapped_variable(v_name_to_tensor, v_name, tensor) tensor_list_to_save.append(swapped_variable) swapped_var_list[k] = tensor_list_to_save else: swapped_var_list[k] = self._find_swapped_variable( v_name_to_tensor, k, tensor_or_list) # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs)
def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. You should use this saver during training. It will save the moving averages of the trained parameters under the original parameter names. For evaluations or inference you should use a regular saver and it will automatically use the moving averages for the trained variable. You must call this function after all variables have been created and after you have called Optimizer.minimize(). Args: var_list: List of variables to save, as per `Saver()`. If set to None, will save all the variables that have been created before this call. name: The name of the saver. **kwargs: Keyword arguments of `Saver()`. Returns: A `tf.compat.v1.train.Saver` object. Raises: RuntimeError: If apply_gradients or minimize has not been called before. ValueError: If var_list is provided and contains some variables but not their moving average counterpart. """ if self._swapped_variable_name_map is None: raise RuntimeError('Must call apply_gradients or minimize before ' 'creating the swapping_saver') if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): # For each partitioned variable OpListToDict returns list of constituent # parts instead of single tensor. if (isinstance(tensor_or_list, list) or isinstance( tensor_or_list, variables.PartitionedVariable)): for tensor in tensor_or_list: v_name = tensor.op.name v_name_to_tensor[v_name] = tensor else: v_name_to_tensor[k] = tensor_or_list # Now swap variables and moving averages swapped_var_list = {} for k, tensor_or_list in six.iteritems(var_list): if isinstance(tensor_or_list, list): tensor_list_to_save = [] for tensor in tensor_or_list: v_name = tensor.op.name swapped_variable = self._find_swapped_variable( v_name_to_tensor, v_name, tensor) tensor_list_to_save.append(swapped_variable) swapped_var_list[k] = tensor_list_to_save else: swapped_var_list[k] = self._find_swapped_variable( v_name_to_tensor, k, tensor_or_list) # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs)