Example #1
0
    def testCreationOfAssetsKeyCollectionIsDeterministic(self):
        tmp_asset_dir = os.path.join(self.get_temp_dir(), "assets")
        tf_v1.gfile.MakeDirs(tmp_asset_dir)
        filenames = [
            os.path.join(tmp_asset_dir, "file%d.txt" % n) for n in range(10)
        ]
        for filename in filenames:
            _write_string_to_file(filename, "I am file %s" % filename)

        with tf.Graph().as_default() as graph:
            assets = [
                tf.constant(f, name=os.path.basename(f)) for f in filenames
            ]
            for asset in assets:
                graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset)
            saved_model_lib.add_signature("default", {},
                                          {"default": assets[0]})

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        saved_model_proto = copy.deepcopy(handler._proto)
        export_dir = os.path.join(self.get_temp_dir(), "assets_key_test")
        saved_model_lib._make_assets_key_collection(saved_model_proto,
                                                    export_dir)

        meta_graph = list(saved_model_proto.meta_graphs)[0]
        asset_tensor_names = []
        for asset_any_proto in meta_graph.collection_def[
                tf_v1.saved_model.constants.ASSETS_KEY].any_list.value:
            asset_proto = meta_graph_pb2.AssetFileDef()
            asset_any_proto.Unpack(asset_proto)
            asset_tensor_names.append(asset_proto.tensor_info.name)
        self.assertEqual(asset_tensor_names, sorted(asset_tensor_names))
Example #2
0
  def testCreationOfAssetsKeyCollectionIsDeterministic(self):
    tmp_asset_dir = os.path.join(self.get_temp_dir(), "assets")
    tf.gfile.MakeDirs(tmp_asset_dir)
    filenames = [
        os.path.join(tmp_asset_dir, "file%d.txt" % n) for n in range(10)
    ]
    for filename in filenames:
      _write_string_to_file(filename, "I am file %s" % filename)

    with tf.Graph().as_default() as graph:
      assets = [tf.constant(f, name=os.path.basename(f)) for f in filenames]
      for asset in assets:
        graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset)
      saved_model_lib.add_signature("default", {}, {"default": assets[0]})

    handler = saved_model_lib.SavedModelHandler()
    handler.add_graph_copy(graph)
    saved_model_proto = copy.deepcopy(handler._proto)
    export_dir = os.path.join(self.get_temp_dir(), "assets_key_test")
    saved_model_lib._make_assets_key_collection(saved_model_proto, export_dir)

    meta_graph = list(saved_model_proto.meta_graphs)[0]
    asset_tensor_names = []
    for asset_any_proto in meta_graph.collection_def[
        tf.saved_model.constants.ASSETS_KEY].any_list.value:
      asset_proto = meta_graph_pb2.AssetFileDef()
      asset_any_proto.Unpack(asset_proto)
      asset_tensor_names.append(asset_proto.tensor_info.name)
    self.assertEqual(asset_tensor_names, sorted(asset_tensor_names))