Ejemplo n.º 1
0
  def __init__(self, root):
    checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
    self.nodes = checkpointable_objects
    self.node_ids = node_ids
    self.slot_variables = slot_variables
    self.functions = util.ObjectIdentityDictionary()

    # Also add `Function`s as nodes.
    nodes_without_functions = list(self.nodes)
    for obj in nodes_without_functions:
      self.functions[obj] = self._list_functions(obj)
      for function in self.functions[obj].values():
        if function not in self.node_ids:
          self.node_ids[function] = len(self.nodes)
          self.nodes.append(function)
          # Avoids recursing into functions to see if other functions are
          # assigned to attributes. This is sometimes true for concrete
          # functions but not helpful.
          self.functions[function] = {}
        if isinstance(function, def_function.Function):
          # Force listing the concrete functions for the side effects:
          #  - populate the cache for functions that have an input_signature
          #  and have not been called.
          #  - force side effects of creation of concrete functions, e.g. create
          #  variables on first run.
          function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
Ejemplo n.º 2
0
def _write_object_graph(root, export_dir, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()

    checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
    util.fill_object_graph_proto(checkpointable_objects, node_ids,
                                 slot_variables, proto)

    node_ids = util.ObjectIdentityDictionary()
    for i in range(len(checkpointable_objects)):
        obj = checkpointable_objects[i]
        node_ids[obj] = i
        if resource_variable_ops.is_resource_variable(obj):
            node_ids[obj.handle] = i
        elif isinstance(obj, tracking.TrackableAsset):
            node_ids[obj.asset_path.handle] = i

    for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index)

    function_serialization.add_polymorphic_functions_to_object_graph_proto(
        checkpointable_objects, proto, node_ids)

    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 proto.SerializeToString())
Ejemplo n.º 3
0
    def __init__(self, root):
        checkpointable_objects, node_ids, slot_variables = util.find_objects(
            root)
        self.nodes = checkpointable_objects
        self.node_ids = node_ids
        self.slot_variables = slot_variables
        self.functions = util.ObjectIdentityDictionary()

        # Also add `Function`s as nodes.
        nodes_without_functions = list(self.nodes)
        for obj in nodes_without_functions:
            self.functions[obj] = self._list_functions(obj)
            for function in self.functions[obj].values():
                if function not in self.node_ids:
                    self.node_ids[function] = len(self.nodes)
                    self.nodes.append(function)
                    # Avoids recursing into functions to see if other functions are
                    # assigned to attributes. This is sometimes true for concrete
                    # functions but not helpful.
                    self.functions[function] = {}
                if isinstance(function, def_function.Function):
                    # Force listing the concrete functions for the side effects:
                    #  - populate the cache for functions that have an input_signature
                    #  and have not been called.
                    #  - force side effects of creation of concrete functions, e.g. create
                    #  variables on first run.
                    function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
Ejemplo n.º 4
0
  def __init__(self, root):
    checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
    self.nodes = checkpointable_objects
    self.node_ids = node_ids
    self.slot_variables = slot_variables
    self.polymorphic_functions = util.ObjectIdentityDictionary()

    # Also add polymorphic functions as nodes.
    for obj in self.nodes:
      self.polymorphic_functions[obj] = self._list_polymorphic_functions(obj)
      for function in self.polymorphic_functions[obj].values():
        if function not in self.node_ids:
          self.node_ids[function] = len(self.nodes)
          self.nodes.append(function)
        # Force listing the concrete functions for the side effects:
        #  - populate the cache for polymorphic functions that have an
        #  input_signature and have not been called.
        #  - force side effects of creation of concrete functions, e.g. create
        #  variables on first run.
        function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
Ejemplo n.º 5
0
def _write_object_graph(root, export_dir, asset_file_def_index):
  """Save a SavedObjectGraph proto for `root`."""
  # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
  # checkpoint. It will eventually go into the SavedModel.
  proto = saved_object_graph_pb2.SavedObjectGraph()

  checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
  util.fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables,
                               proto)

  for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
    _write_object_proto(obj, obj_proto, asset_file_def_index)

  extra_asset_dir = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
  file_io.recursive_create_dir(extra_asset_dir)
  object_graph_filename = os.path.join(
      extra_asset_dir, compat.as_bytes("object_graph.pb"))
  file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
Ejemplo n.º 6
0
def _write_object_graph(root, export_dir, asset_filename_map):
  """Save a SavedObjectGraph proto for `root`."""
  # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
  # checkpoint. It will eventually go into the SavedModel.
  proto = saved_object_graph_pb2.SavedObjectGraph()

  checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
  util.fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables,
                               proto)

  # Build map from original filename to relative asset filename.
  filename_map = {v: k for k, v in asset_filename_map.items()}

  for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
    _write_object_proto(obj, obj_proto, filename_map)

  extra_asset_dir = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
  file_io.recursive_create_dir(extra_asset_dir)
  object_graph_filename = os.path.join(
      extra_asset_dir, compat.as_bytes("object_graph.pb"))
  file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())