def test_resource_variable_use_localhost(self): v1 = resource_variable_ops.ResourceVariable(2.) self.evaluate(v1.initializer) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v1.assign(1.)) self.evaluate(saver.restore(prefix, self.local_options)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) self.evaluate(v2.initializer) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) self.evaluate(second_saver.restore(prefix, self.local_options)) self.assertEqual(2., self.evaluate(v2)) # In graph mode, verify that the save and restore ops were set to run on # localhost. if not context.executing_eagerly(): for op in ops.get_default_graph().get_operations(): if op.type in ("SaveV2", "RestoreV2"): self.assertEqual(LOCALHOST, op.device)
def test_to_proto(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") proto_accumulator = [] wrapped = wrap_function.wrap_function( lambda: proto_accumulator.append(saver.to_proto()), signature=()) self.assertEqual(1, len(proto_accumulator)) proto = proto_accumulator[0] save = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) restore = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) save_path = save(constant_op.constant(prefix)) v1.assign(1.) restore(constant_op.constant(save_path)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2))
def test_checkpoint_multi_device_using_localhost(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate( saver.restore(constant_op.constant(prefix), self.local_options)) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2)) # In graph mode, verify that the save and restore ops were set to run on # localhost. if not context.executing_eagerly(): for op in ops.get_default_graph().get_operations(): if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): self.assertEqual(LOCALHOST, op.device)
def test_to_proto(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") proto_accumulator = [] wrapped = wrap_function.wrap_function( lambda: proto_accumulator.append(saver.to_proto()), signature=()) self.assertEqual(1, len(proto_accumulator)) proto = proto_accumulator[0] save = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) restore = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) save_path = save(constant_op.constant(prefix)) v1.assign(1.) restore(constant_op.constant(save_path)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2))
def test_checkpoint_is_sharded_by_task(self): servers = [server_lib.Server.create_local_server() for _ in range(3)] cluster_spec = server_lib.ClusterSpec({ "worker": [s.target[len("grpc://"):] for s in servers]}) remote.connect_to_cluster(cluster_spec) with ops.device("/job:worker/task:0/cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("/job:worker/task:1/cpu:0"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("/job:worker/task:2/cpu:0"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate(saver.restore(constant_op.constant(prefix))) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2))
def test_resource_variable(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.Saver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") save_path = saver.save(constant_op.constant(prefix)) v1.assign(1.) saver.restore(save_path) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) second_saver = functional_saver.Saver( saveable_object_util.saveable_objects_for_op(v2, "x")) second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2))
def test_resource_variable(self): v1 = resource_variable_ops.ResourceVariable(2.) self.evaluate(v1.initializer) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v1.assign(1.)) self.evaluate(saver.restore(prefix)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) self.evaluate(v2.initializer) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) self.evaluate(second_saver.restore(prefix)) self.assertEqual(2., self.evaluate(v2))
def test_resource_variable(self): v1 = resource_variable_ops.ResourceVariable(2.) self.evaluate(v1.initializer) saver = functional_saver._SingleDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(2, len(gfile.Glob(prefix + "*"))) self.evaluate(v1.assign(1.)) self.evaluate(saver.restore(prefix)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) self.evaluate(v2.initializer) second_saver = functional_saver._SingleDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) self.evaluate(second_saver.restore(prefix)) self.assertEqual(2., self.evaluate(v2))
def test_checkpoint_is_sharded_by_device(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate(saver.restore(constant_op.constant(prefix))) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2))
def test_checkpoint_is_sharded_by_device(self): with ops.device("cpu:0"): v0 = resource_variable_ops.ResourceVariable(0.) with ops.device("cpu:1"): v1 = resource_variable_ops.ResourceVariable(1.) with ops.device("cpu:2"): v2 = resource_variable_ops.ResourceVariable(2.) self.evaluate([v0.initializer, v1.initializer, v2.initializer]) saver = functional_saver.MultiDeviceSaver( list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(saver.save(constant_op.constant(prefix))) self.assertEqual(4, len(gfile.Glob(prefix + "*"))) self.evaluate(v0.assign(-1.)) self.evaluate(v1.assign(-1.)) self.evaluate(v2.assign(-1.)) self.evaluate(saver.restore(constant_op.constant(prefix))) self.assertEqual(0., self.evaluate(v0)) self.assertEqual(1., self.evaluate(v1)) self.assertEqual(2., self.evaluate(v2))
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] names_to_saveables = saveable_object_util.op_list_to_dict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): for s in saveable_object_util.saveable_objects_for_op(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] names_to_saveables = saveable_object_util.op_list_to_dict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): for s in saveable_object_util.saveable_objects_for_op(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def _add_attributes_to_object_graph(self, trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for checkpoint_id, (trackable, object_proto) in enumerate( zip(trackable_objects, object_graph_proto.nodes)): assert node_ids[trackable] == checkpoint_id object_name = object_names[trackable] if object_map is None: object_to_save = trackable else: object_to_save = object_map.get(trackable, trackable) if self._saveables_cache is not None: cached_attributes = self._saveables_cache.setdefault( object_to_save, {}) else: cached_attributes = None for name, saveable_factory in ( object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) if cached_attributes is None: saveables = None else: saveables = cached_attributes.get(name, None) if saveables is not None: for saveable in saveables: if attribute.checkpoint_key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, attribute.checkpoint_key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=attribute.checkpoint_key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if attribute.checkpoint_key not in saveable.name: raise AssertionError(( "The object %s produced a SaveableObject with name '%s' for " "attribute '%s'. Expected a name containing '%s'." ) % (trackable, name, saveable.name, attribute.checkpoint_key)) if cached_attributes is not None: cached_attributes[name] = saveables optional_restore = None for saveable in saveables: if optional_restore is None: optional_restore = saveable.optional_restore else: optional_restore = optional_restore and saveable.optional_restore if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError(( "The object %s tried to feed a value for the Tensor %s " "when saving, but another object is already feeding a " "value.") % (trackable, new_feed_key)) feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) if optional_restore is None: optional_restore = False attribute.optional_restore = optional_restore return named_saveable_objects, feed_additions
def _add_attributes_to_object_graph_for_saveable_objects( self, checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for trackable, factory_data_list in checkpoint_factory_map.items(): object_proto = object_graph_proto.nodes[node_ids[trackable]] if self._saveables_cache is not None: object_to_save = _get_mapped_trackable(trackable, object_map) cached_attributes = self._saveables_cache.setdefault( object_to_save, {}) else: cached_attributes = None for factory_data in factory_data_list: attribute = object_proto.attributes.add() attribute.name = name = factory_data.name attribute.checkpoint_key = key = factory_data.checkpoint_key saveable_factory = factory_data.factory # See if we can skip saving this checkpoint key. saveables = cached_attributes.get( name) if cached_attributes else None if saveables is not None: for saveable in saveables: if key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if key not in saveable.name: raise AssertionError( f"The object {trackable} produced a SaveableObject with name " f"'{saveable.name}' for attribute '{name}'. Expected a name" f" containing '{key}'.") if cached_attributes is not None: cached_attributes[name] = saveables for saveable in saveables: if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( f"The object {trackable} tried to feed a value for the " f"Tensor {new_feed_key} when saving, but another object " "is already feeding a value.") feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) return named_saveable_objects, feed_additions
def _add_attributes_to_object_graph_for_saveable_objects( checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for trackable, factory_data_list in checkpoint_factory_map.items(): object_proto = object_graph_proto.nodes[node_ids[trackable]] object_to_save = _get_mapped_trackable(trackable, object_map) if saveables_cache is not None: cached_attributes = saveables_cache.setdefault(object_to_save, {}) else: cached_attributes = None for factory_data in factory_data_list: name = factory_data.name key = factory_data.checkpoint_key saveable_factory = factory_data.factory # See if we can skip saving this checkpoint key. saveables = cached_attributes.get( name) if cached_attributes else None if saveables is not None: for saveable in saveables: if key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable, ) else: saveables = tuple( saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=key)) for saveable in saveables: if key not in saveable.name: raise AssertionError( f"The object {trackable} produced a SaveableObject with name " f"'{saveable.name}' for attribute '{name}'. Expected a name" f" containing '{key}'.") if cached_attributes is not None: cached_attributes[name] = saveables for saveable in saveables: if isinstance(saveable, python_state.PythonStateSaveable): if feed_additions is None: assert saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( f"The object {trackable} tried to feed a value for the " f"Tensor {new_feed_key} when saving, but another object " "is already feeding a value.") feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) # Update the object proto. # For updated Trackables that override serialize_to_tensors, add an # attribute for each tensor that is serialized. # For Trackables that have SaveableObjects or a legacy saveable name, # add a single attribute to the proto. if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and saveable_compat.get_saveable_name(object_to_save) is None): for local_name, local_key in ( saveables[0].get_proto_names_and_checkpoint_keys()): object_proto.attributes.add( name=local_name, checkpoint_key=local_key, full_name=_get_full_name(object_to_save)) else: object_proto.attributes.add( name=name, checkpoint_key=key, full_name=_get_full_name(object_to_save)) return named_saveable_objects, feed_additions