Beispiel #1
0
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()
    saveable_view.fill_object_graph_proto(proto)

    node_ids = util.ObjectIdentityDictionary()
    for i, obj in enumerate(saveable_view.nodes):
        node_ids[obj] = i
        if resource_variable_ops.is_resource_variable(obj):
            node_ids[obj.handle] = i
        elif isinstance(obj, tracking.TrackableAsset):
            node_ids[obj.asset_path.handle] = i

    for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids)

    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 proto.SerializeToString())
Beispiel #2
0
def _write_object_graph(root, export_dir, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()

    checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
    util.fill_object_graph_proto(checkpointable_objects, node_ids,
                                 slot_variables, proto)

    node_ids = util.ObjectIdentityDictionary()
    for i in range(len(checkpointable_objects)):
        obj = checkpointable_objects[i]
        node_ids[obj] = i
        if resource_variable_ops.is_resource_variable(obj):
            node_ids[obj.handle] = i
        elif isinstance(obj, tracking.TrackableAsset):
            node_ids[obj.asset_path.handle] = i

    for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index)

    function_serialization.add_polymorphic_functions_to_object_graph_proto(
        checkpointable_objects, proto, node_ids)

    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 proto.SerializeToString())
Beispiel #3
0
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()
    saveable_view.fill_object_graph_proto(proto)

    coder = nested_structure_coder.StructureCoder()
    for concrete_function in saveable_view.concrete_functions:
        serialized = function_serialization.serialize_concrete_function(
            concrete_function, saveable_view.captured_tensor_node_ids, coder)
        if serialized is not None:
            proto.concrete_functions[concrete_function.name].CopyFrom(
                serialized)

    for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index)

    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 proto.SerializeToString())
Beispiel #4
0
def _write_object_graph(obj, export_dir):
    """Save a SavedObjectGraph proto for `obj`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    object_proto = util.make_object_graph_without_attributes(
        obj, proto=saved_object_graph_pb2.SavedObjectGraph())
    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 object_proto.SerializeToString())
Beispiel #5
0
def _write_object_graph(root, export_dir, asset_file_def_index):
  """Save a SavedObjectGraph proto for `root`."""
  # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
  # checkpoint. It will eventually go into the SavedModel.
  proto = saved_object_graph_pb2.SavedObjectGraph()

  checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
  util.fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables,
                               proto)

  for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
    _write_object_proto(obj, obj_proto, asset_file_def_index)

  extra_asset_dir = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
  file_io.recursive_create_dir(extra_asset_dir)
  object_graph_filename = os.path.join(
      extra_asset_dir, compat.as_bytes("object_graph.pb"))
  file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
Beispiel #6
0
def load(export_dir):
  """Load a SavedModel from `export_dir`."""
  object_graph_filename = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
      compat.as_bytes("object_graph.pb"))
  if file_io.file_exists(object_graph_filename):
    # If there is an object graph associated with the SavedModel, we'll create a
    # root object from that.
    object_graph_string = file_io.FileIO(object_graph_filename, "rb").read()
    object_graph_proto = (
        saved_object_graph_pb2.SavedObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    root = _recreate_object_graph(object_graph_proto)
  else:
    raise NotImplementedError(
        "Currently only SavedModels exported with `tf.saved_model.save` may be "
        "imported. Other SavedModels may eventually be supported via load().")
  # TODO(allenl): load functions from the SavedModel into the eager context
  return root