示例#1
0
 def prepare_config(config):
     if config is None:
         config = config_pb2.ConfigProto()
         config.allow_soft_placement = not force_gpu
         config.gpu_options.per_process_gpu_memory_fraction = 0.3
     elif force_gpu and config.allow_soft_placement:
         config = config_pb2.ConfigProto().CopyFrom(config)
         config.allow_soft_placement = False
     return config
 def testManyCPUs(self):
     # TODO(keveman): Implement ListDevices and test for the number of
     # devices returned by ListDevices.
     with session.Session(config=config_pb2.ConfigProto(
             device_count={'CPU': 2})):
         inp = constant_op.constant(10.0, name='W1')
         self.assertAllEqual(inp.eval(), 10.0)
示例#3
0
 def testPerSessionThreads(self):
     # TODO(keveman): Implement ListDevices and test for the number of
     # devices returned by ListDevices.
     with session.Session(config=config_pb2.ConfigProto(
             use_per_session_threads=True)):
         inp = constant_op.constant(10.0, name='W1')
         self.assertAllEqual(inp.eval(), 10.0)
示例#4
0
    def doBasicsOneExportPath(self,
                              export_path,
                              clear_devices=False,
                              global_step=GLOBAL_STEP):
        # Build a graph with 2 parameter nodes on different devices.
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            # v2 is an unsaved variable derived from v0 and v1.  It is used to
            # exercise the ability to run an init op when restoring a graph.
            with sess.graph.device("/cpu:0"):
                v0 = tf.Variable(10, name="v0")
            with sess.graph.device("/cpu:1"):
                v1 = tf.Variable(20, name="v1")
            v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
            assign_v2 = tf.assign(v2, tf.add(v0, v1))
            init_op = tf.group(assign_v2, name="init_op")

            tf.add_to_collection("v", v0)
            tf.add_to_collection("v", v1)
            tf.add_to_collection("v", v2)

            global_step_tensor = tf.Variable(global_step, name="global_step")
            named_tensor_bindings = {
                "logical_input_A": v0,
                "logical_input_B": v1
            }
            signatures = {
                "foo":
                exporter.regression_signature(input_tensor=v0,
                                              output_tensor=v1),
                "generic":
                exporter.generic_signature(named_tensor_bindings)
            }

            def write_asset(path):
                file_path = os.path.join(path, "file.txt")
                with gfile.FastGFile(file_path, "w") as f:
                    f.write("your data here")

            asset_file = tf.Variable("hello42.txt", name="filename42")
            assets = {("hello42.txt", asset_file)}

            tf.initialize_all_variables().run()

            # Run an export.
            save = tf.train.Saver({
                "v0": v0,
                "v1": v1
            },
                                  restore_sequentially=True,
                                  sharded=True)
            export = exporter.Exporter(save)
            export.init(
                sess.graph.as_graph_def(),
                init_op=init_op,
                clear_devices=clear_devices,
                default_graph_signature=exporter.classification_signature(
                    input_tensor=v0),
                named_graph_signatures=signatures,
                assets=assets,
                assets_callback=write_asset)
            export.export(export_path,
                          global_step_tensor,
                          sess,
                          exports_to_keep=gc.largest_export_versions(2))

        # Restore graph.
        compare_def = tf.get_default_graph().as_graph_def()
        tf.reset_default_graph()
        with tf.Session(target="",
                        config=config_pb2.ConfigProto(
                            device_count={"CPU": 2})) as sess:
            save = tf.train.import_meta_graph(
                os.path.join(export_path,
                             exporter.VERSION_FORMAT_SPECIFIER % global_step,
                             exporter.META_GRAPH_DEF_FILENAME))
            meta_graph_def = save.export_meta_graph()
            collection_def = meta_graph_def.collection_def

            # Validate custom graph_def.
            graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value
            self.assertEquals(len(graph_def_any), 1)
            graph_def = tf.GraphDef()
            graph_def_any[0].Unpack(graph_def)
            if clear_devices:
                for node in compare_def.node:
                    node.device = ""
            self.assertProtoEquals(compare_def, graph_def)

            # Validate init_op.
            init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value
            self.assertEquals(len(init_ops), 1)
            self.assertEquals(init_ops[0], "init_op")

            # Validate signatures.
            signatures_any = collection_def[
                exporter.SIGNATURES_KEY].any_list.value
            self.assertEquals(len(signatures_any), 1)
            signatures = manifest_pb2.Signatures()
            signatures_any[0].Unpack(signatures)
            default_signature = signatures.default_signature
            self.assertEqual(
                default_signature.classification_signature.input.tensor_name,
                "v0:0")
            bindings = signatures.named_signatures[
                "generic"].generic_signature.map
            self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
            self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
            read_foo_signature = (
                signatures.named_signatures["foo"].regression_signature)
            self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
            self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

            # Validate the assets.
            assets_any = collection_def[exporter.ASSETS_KEY].any_list.value
            self.assertEquals(len(assets_any), 1)
            asset = manifest_pb2.AssetFile()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                export_path, exporter.VERSION_FORMAT_SPECIFIER % global_step,
                exporter.ASSETS_DIRECTORY, "file.txt")
            asset_contents = gfile.GFile(assets_path).read()
            self.assertEqual(asset_contents, "your data here")
            self.assertEquals("hello42.txt", asset.filename)
            self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)

            # Validate graph restoration.
            save.restore(
                sess,
                os.path.join(export_path,
                             exporter.VERSION_FORMAT_SPECIFIER % global_step,
                             exporter.VARIABLES_DIRECTORY))
            self.assertEqual(10, tf.get_collection("v")[0].eval())
            self.assertEqual(20, tf.get_collection("v")[1].eval())
            tf.get_collection(exporter.INIT_OP_KEY)[0].run()
            self.assertEqual(30, tf.get_collection("v")[2].eval())