Ejemplo n.º 1
0
def export_module_spec(spec, export_path):
  """Export module with random initialization."""
  with tf_v1.Graph().as_default():
    m = hub.Module(spec)
    with tf_v1.Session() as session:
      session.run(tf_v1.initializers.global_variables())
      m.export(export_path, session)
Ejemplo n.º 2
0
    def testMatchingTensorInfoProtoMaps(self):
        with tf_v1.Graph().as_default():
            sig1 = _make_signature({
                "x": tf_v1.placeholder(tf.int32, [2]),
            }, {
                "x": tf_v1.placeholder(tf.int32, [2]),
            })

            sig2 = _make_signature(
                {
                    "x": tf_v1.placeholder(tf.int32, [2]),
                }, {
                    "x": tf_v1.sparse_placeholder(tf.int64, [2]),
                })
            self.assertTrue(
                tensor_info.tensor_info_proto_maps_match(
                    sig1.inputs, sig2.inputs))
            self.assertFalse(
                tensor_info.tensor_info_proto_maps_match(
                    sig1.outputs, sig2.outputs))

            sig3 = _make_signature({
                "x": tf_v1.placeholder(tf.int32, [None]),
            }, {
                "x": tf_v1.placeholder(tf.int32, [2]),
            })
            self.assertFalse(
                tensor_info.tensor_info_proto_maps_match(
                    sig1.inputs, sig3.inputs))
            self.assertTrue(
                tensor_info.tensor_info_proto_maps_match(
                    sig1.outputs, sig3.outputs))
Ejemplo n.º 3
0
def prune_unused_nodes(meta_graph, signature_def):
  """Function to prune unused ops given a signature def.

  This function does a graph traversal through from all outputs as
  defined in the signature_def to collect all used nodes. Then, any
  nodes which are unused can be discarded. This is useful for graph which are
  executing eagerly or on TPUs.

  Args:
    meta_graph: The input/output MetaGraphDef for which we wish to prune.
   signature_def: A SignatureDef which specifies the outputs from which we wish
     to start graph traversal.
  """
  # Instantiate a temporary empty graph so that we have access to Graph API
  # and import the meta_graph.
  graph = tf_v1.Graph()
  with graph.as_default():
    tf_v1.train.import_meta_graph(meta_graph, input_map={}, import_scope="")
    # Traverse from all outputs and mark all nodes.
    used_node_names = set()
    for _, tensor_def in signature_def.outputs.items():
      output_tensor = graph.get_tensor_by_name(tensor_def.name)
      mark_backward(output_tensor, used_node_names)
    # Filter out all nodes in the meta_graph that are not used.
    node_filter_in_list = []
    for node in meta_graph.graph_def.node:
      # Make a special exception for VarHandleOp. Removing VarhandleOps
      # will make the graph not importable as they often leave nodes hanging.
      # These will be disconnected through the feedmap when importing the
      # metagraph.
      if node.name in used_node_names or node.op == "VarHandleOp":
        node_filter_in_list.append(node)
    del meta_graph.graph_def.node[:]
    meta_graph.graph_def.node.extend(node_filter_in_list)
  del graph
Ejemplo n.º 4
0
    def testParsingTensorInfoProtoMaps(self):
        with tf_v1.Graph().as_default():
            sig = _make_signature(
                {
                    "x": tf_v1.placeholder(tf.string, [2]),
                }, {
                    "y": tf_v1.placeholder(tf.int32, [2]),
                    "z": tf_v1.sparse_placeholder(tf.float32, [2, 10]),
                })

            inputs = tensor_info.parse_tensor_info_map(sig.inputs)
            self.assertEqual(set(inputs.keys()), set(["x"]))
            self.assertEqual(inputs["x"].get_shape(), [2])
            self.assertEqual(inputs["x"].dtype, tf.string)
            self.assertFalse(inputs["x"].is_sparse)

            outputs = tensor_info.parse_tensor_info_map(sig.outputs)
            self.assertEqual(set(outputs.keys()), set(["y", "z"]))
            self.assertEqual(outputs["y"].get_shape(), [2])
            self.assertEqual(outputs["y"].dtype, tf.int32)
            self.assertFalse(outputs["y"].is_sparse)

            self.assertEqual(outputs["z"].get_shape(), [2, 10])
            self.assertEqual(outputs["z"].dtype, tf.float32)
            self.assertTrue(outputs["z"].is_sparse)
Ejemplo n.º 5
0
    def testBuildOutputMap(self):
        with tf_v1.Graph().as_default():
            x = tf_v1.placeholder(tf.int32, [2])
            y = tf_v1.sparse_placeholder(tf.string, [None])
            sig = _make_signature({}, {"x": x, "y": y})

            def _get_tensor(name):
                return tf_v1.get_default_graph().get_tensor_by_name(name)

            output_map = tensor_info.build_output_map(sig.outputs, _get_tensor)
            self.assertEqual(len(output_map), 2)
            self.assertEqual(output_map["x"], x)
            self.assertEqual(output_map["y"].indices, y.indices)
            self.assertEqual(output_map["y"].values, y.values)
            self.assertEqual(output_map["y"].dense_shape, y.dense_shape)
Ejemplo n.º 6
0
    def testBuildInputMap(self):
        with tf_v1.Graph().as_default():
            x = tf_v1.placeholder(tf.int32, [2])
            y = tf_v1.sparse_placeholder(tf.string, [None])
            sig = _make_signature({"x": x, "y": y}, {})

            input_map = tensor_info.build_input_map(sig.inputs, {
                "x": x,
                "y": y
            })
            self.assertEqual(len(input_map), 4)
            self.assertEqual(input_map[x.name], x)
            self.assertEqual(input_map[y.indices.name], y.indices)
            self.assertEqual(input_map[y.values.name], y.values)
            self.assertEqual(input_map[y.dense_shape.name], y.dense_shape)
Ejemplo n.º 7
0
    def testConvertTensors(self):
        with tf_v1.Graph().as_default():
            a = tf_v1.placeholder(tf.int32, [None])
            protomap = _make_signature({"a": a}, {}).inputs
            targets = tensor_info.parse_tensor_info_map(protomap)

            # convert constant
            in0 = [1, 2, 3]
            output = tensor_info.convert_dict_to_compatible_tensor({"a": in0},
                                                                   targets)
            self.assertEqual(output["a"].dtype, a.dtype)

            # check sparsity
            in1 = tf_v1.sparse_placeholder(tf.int32, [])
            with self.assertRaisesRegexp(TypeError, "dense"):
                tensor_info.convert_dict_to_compatible_tensor({"a": in1},
                                                              targets)
Ejemplo n.º 8
0
    def testRepr(self):
        with tf_v1.Graph().as_default():
            sig = _make_signature(
                {
                    "x": tf_v1.placeholder(tf.string, [2]),
                }, {
                    "y": tf_v1.placeholder(tf.int32, [2]),
                    "z": tf_v1.sparse_placeholder(tf.float32, [2, 10]),
                })

            outputs = tensor_info.parse_tensor_info_map(sig.outputs)
            self.assertEqual(
                repr(outputs["y"]),
                "<hub.ParsedTensorInfo shape=(2,) dtype=int32 is_sparse=False>"
            )
            self.assertEqual(
                repr(outputs["z"]),
                "<hub.ParsedTensorInfo shape=(2, 10) dtype=float32 is_sparse=True>"
            )