示例#1
0
def objects_ids_and_slot_variables_and_paths(graph_view):
    """Traverse the object graph and list all accessible objects.

  Looks for `Trackable` objects which are dependencies of
  `root_trackable`. Includes slot variables only if the variable they are
  slotting for and the optimizer are dependencies of `root_trackable`
  (i.e. if they would be saved with a checkpoint).

  Args:
    graph_view: A GraphView object.

  Returns:
    A tuple of (trackable objects, paths from root for each object,
                object -> node id, slot variables, object_names)
  """
    trackable_objects, node_paths = graph_view.breadth_first_traversal()
    object_names = object_identity.ObjectIdentityDictionary()
    for obj, path in node_paths.items():
        object_names[obj] = trackable_utils.object_path_to_string(path)
    node_ids = object_identity.ObjectIdentityDictionary()
    for node_id, node in enumerate(trackable_objects):
        node_ids[node] = node_id
    slot_variables = _serialize_slot_variables(
        trackable_objects=trackable_objects,
        node_ids=node_ids,
        object_names=object_names)
    return (trackable_objects, node_paths, node_ids, slot_variables,
            object_names)
示例#2
0
def serialize_gathered_objects(graph_view,
                               object_map=None,
                               call_with_mapped_captures=None,
                               saveables_cache=None):
    """Create SaveableObjects and protos for gathered objects."""
    trackable_objects, node_paths = graph_view.breadth_first_traversal()
    object_names = object_identity.ObjectIdentityDictionary()
    for obj, path in node_paths.items():
        object_names[obj] = trackable_utils.object_path_to_string(path)
    node_ids = object_identity.ObjectIdentityDictionary()
    for node_id, node in enumerate(trackable_objects):
        node_ids[node] = node_id
    slot_variables = _serialize_slot_variables(
        trackable_objects=trackable_objects,
        node_ids=node_ids,
        object_names=object_names)
    object_graph_proto = _fill_object_graph_proto(
        graph_view=graph_view,
        trackable_objects=trackable_objects,
        node_ids=node_ids,
        slot_variables=slot_variables)
    named_saveable_objects, feed_additions, registered_savers = (
        _add_attributes_to_object_graph(
            trackable_objects=trackable_objects,
            object_graph_proto=object_graph_proto,
            node_ids=node_ids,
            object_names=object_names,
            object_map=object_map,
            call_with_mapped_captures=call_with_mapped_captures,
            saveables_cache=saveables_cache))
    # Gather all trackables that have checkpoint values or descendants with
    # checkpoint values, and add that info to the proto.
    _add_checkpoint_values_check(trackable_objects, object_graph_proto)
    return (named_saveable_objects, object_graph_proto, feed_additions,
            registered_savers)