Exemplo n.º 1
0
  def _testExportImportAcrossScopes(self, graph_fn):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn()
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn()

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
    self.assertProtoEquals(expected, result)
Exemplo n.º 2
0
  def testScopedImportWithSelectedCollections(self):
    meta_graph_filename = os.path.join(
        _TestDir("selected_collections_import"), "meta_graph.pb")

    graph = ops.Graph()
    # Add a variable to populate two collections. The functionality tested is
    # not specific to variables, but using variables in the test is convenient.
    with graph.as_default():
      variables.Variable(initial_value=1.0, trainable=True)
    self.assertTrue(
        all([
            graph.get_collection(key)
            for key in
            [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
        ]))
    meta_graph.export_scoped_meta_graph(
        filename=meta_graph_filename, graph=graph)

    def _test_import(include_collection_keys, omit_collection_keys):
      assert set(include_collection_keys).isdisjoint(omit_collection_keys)
      newgraph = ops.Graph()
      import_scope = "some_scope_name"

      def _restore_collections_predicate(collection_key):
        return (collection_key in include_collection_keys and
                collection_key not in omit_collection_keys)

      meta_graph.import_scoped_meta_graph(
          meta_graph_filename,
          graph=newgraph,
          import_scope=import_scope,
          restore_collections_predicate=_restore_collections_predicate)
      collection_values = [
          newgraph.get_collection(name=key, scope=import_scope)
          for key in include_collection_keys
      ]
      self.assertTrue(all(collection_values))
      collection_values = [
          newgraph.get_collection(name=key, scope=import_scope)
          for key in omit_collection_keys
      ]
      self.assertFalse(any(collection_values))

    _test_import(
        include_collection_keys=[
            ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
        ],
        omit_collection_keys=[])
    _test_import(
        include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
        omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
    _test_import(
        include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
        omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
    _test_import(
        include_collection_keys=[],
        omit_collection_keys=[
            ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
        ])
Exemplo n.º 3
0
  def testClearDevices(self):
    graph1 = ops.Graph()
    with graph1.as_default():
      with ops.device("/device:CPU:0"):
        a = variables.Variable(
            constant_op.constant(
                1.0, shape=[2, 2]), name="a")
      with ops.device("/job:ps/replica:0/task:0/gpu:0"):
        b = variables.Variable(
            constant_op.constant(
                2.0, shape=[2, 2]), name="b")
      with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
        math_ops.matmul(a, b, name="matmul")

    self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
    self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
                     str(graph1.as_graph_element("b").device))
    self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
                     str(graph1.as_graph_element("matmul").device))

    # Verifies that devices are cleared on export.
    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        graph=graph1, clear_devices=True)

    graph2 = ops.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))

    # Verifies that devices are cleared on export when passing in graph_def.
    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        graph_def=graph1.as_graph_def(), clear_devices=True)

    graph2 = ops.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))

    # Verifies that devices are cleared on import.
    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        graph=graph1, clear_devices=False)

    graph2 = ops.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))
Exemplo n.º 4
0
  def testMetricsCollection(self):

    def _enqueue_vector(sess, queue, values, shape=None):
      if not shape:
        shape = (1, len(values))
      dtype = queue.dtypes[0]
      sess.run(
          queue.enqueue(constant_op.constant(
              values, dtype=dtype, shape=shape)))

    meta_graph_filename = os.path.join(
        _TestDir("metrics_export"), "meta_graph.pb")

    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      _, update_op = metrics.mean(values)

      initializer = variables.local_variables_initializer()
      self.evaluate(initializer)
      self.evaluate(update_op)

    meta_graph.export_scoped_meta_graph(
        filename=meta_graph_filename, graph=graph)

    # Verifies that importing a meta_graph with LOCAL_VARIABLES collection
    # works correctly.
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      meta_graph.import_scoped_meta_graph(meta_graph_filename)
      initializer = variables.local_variables_initializer()
      self.evaluate(initializer)

    # Verifies that importing an old meta_graph where "local_variables"
    # collection is of node_list type works, but cannot build initializer
    # with the collection.
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      meta_graph.import_scoped_meta_graph(
          test.test_src_dir_path(
              "python/framework/testdata/metrics_export_meta_graph.pb"))
      self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
                       2)
      with self.assertRaisesRegexp(
          AttributeError, "'Tensor' object has no attribute 'initializer'"):
        initializer = variables.local_variables_initializer()
Exemplo n.º 5
0
  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = ops.Graph()
    with self.session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = array_ops.placeholder(
          dtypes.float32, shape=[], name="input")
      offset = constant_op.constant(42, dtype=dtypes.float32, name="offset")
      output_tensor = math_ops.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      ops.add_to_collection("input_tensor", input_tensor)
      ops.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
          filename=filename,
          graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None)
      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
                          "")
      self.assertEqual({}, var_list)

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = ops.Graph()
    with self.session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      meta_graph.import_scoped_meta_graph(filename)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
                                                                  "_new")
      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
                                               new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = ops.get_collection("input_tensor")[0]
      new_output_tensor = ops.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(new_output_tensor,
                                  {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
Exemplo n.º 6
0
  def testDefaultAttrStripping(self):
    """Verifies that default attributes are stripped from a graph def."""

    # Complex Op has 2 attributes with defaults:
    #   o "T"    : float32.
    #   o "Tout" : complex64.

    # When inputs to the Complex Op are float32 instances, "T" maps to float32
    # and "Tout" maps to complex64. Since these attr values map to their
    # defaults, they must be stripped unless stripping of default attrs is
    # disabled.
    with self.cached_session():
      real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")

      # strip_default_attrs is enabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertNotIn("T", node_def.attr)
      self.assertNotIn("Tout", node_def.attr)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)

      # strip_default_attrs is disabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=False)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertIn("T", node_def.attr)
      self.assertIn("Tout", node_def.attr)
      self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)

    # When inputs to the Complex Op are float64 instances, "T" maps to float64
    # and "Tout" maps to complex128. Since these attr values don't map to their
    # defaults, they must not be stripped.
    with self.session(graph=ops.Graph()):
      real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertEqual(node_def.attr["T"].type, dtypes.float64)
      self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
Exemplo n.º 7
0
  def testWhileLoopGradients(self):
    # Create a simple while loop.
    with ops.Graph().as_default():
      with ops.name_scope("export"):
        var = variables.Variable(0.)
        var_name = var.name
        _, output = control_flow_ops.while_loop(
            lambda i, x: i < 5,
            lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)),
            [0, var])
        output_name = output.name

      # Generate a MetaGraphDef containing the while loop with an export scope.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          export_scope="export")

      # Build and run the gradients of the while loop. We use this below to
      # verify that the gradients are correct with the imported MetaGraphDef.
      init_op = variables.global_variables_initializer()
      grad = gradients_impl.gradients([output], [var])
      with session.Session() as sess:
        self.evaluate(init_op)
        expected_grad_value = self.evaluate(grad)

    # Restore the MetaGraphDef into a new Graph with an import scope.
    with ops.Graph().as_default():
      meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import")

      # Re-export and make sure we get the same MetaGraphDef.
      new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          export_scope="import")
      test_util.assert_meta_graph_protos_equal(
          self, meta_graph_def, new_meta_graph_def)

      # Make sure we can still build gradients and get the same result.

      def new_name(tensor_name):
        base_tensor_name = tensor_name.replace("export/", "")
        return "import/" + base_tensor_name

      var = ops.get_default_graph().get_tensor_by_name(new_name(var_name))
      output = ops.get_default_graph().get_tensor_by_name(new_name(output_name))
      grad = gradients_impl.gradients([output], [var])

      init_op = variables.global_variables_initializer()

      with session.Session() as sess:
        self.evaluate(init_op)
        actual_grad_value = self.evaluate(grad)
        self.assertEqual(expected_grad_value, actual_grad_value)
Exemplo n.º 8
0
  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

      if use_resource:
        # Bringing in a collection that contains ResourceVariables adds ops
        # to the graph, so mimic the same behavior.
        for collection_key in sorted([
            ops.GraphKeys.GLOBAL_VARIABLES,
            ops.GraphKeys.TRAINABLE_VARIABLES,
        ]):
          for var in expected_graph.get_collection(collection_key):
            var._read_variable_op()

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # supposed to be orthogonal to scopes.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    self.assertProtoEquals(expected, result)
Exemplo n.º 9
0
    def testDefaultAttrStrippingNestedFunctions(self):
        """Verifies that default attributes are stripped from function node defs."""
        with self.test_session():

            @function.Defun(dtypes.float32, dtypes.float32)
            def f0(i, j):
                return math_ops.complex(i, j, name="double_nested_complex")

            @function.Defun(dtypes.float32, dtypes.float32)
            def f1(i, j):
                return f0(i, j)

            _ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
            meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                graph_def=ops.get_default_graph().as_graph_def(),
                strip_default_attrs=True)

            double_nested_complex_node_def = None
            for function_def in meta_graph_def.graph_def.library.function:
                for node_def in function_def.node_def:
                    if node_def.name.startswith("double_nested_complex"):
                        double_nested_complex_node_def = node_def
                        break
                if double_nested_complex_node_def:
                    break

            self.assertIsNotNone(double_nested_complex_node_def)
            self.assertNotIn("T", double_nested_complex_node_def.attr)
            self.assertNotIn("Tout", double_nested_complex_node_def.attr)
            self.assertTrue(
                meta_graph_def.meta_info_def.stripped_default_attrs)
Exemplo n.º 10
0
  def testSummaryWithFamilyMetaGraphExport(self):
    with ops.name_scope('outer'):
      i = constant_op.constant(11)
      summ = summary_lib.scalar('inner', i)
      self.assertEquals(summ.op.name, 'outer/inner')
      summ_f = summary_lib.scalar('inner', i, family='family')
      self.assertEquals(summ_f.op.name, 'outer/family/inner')

    metagraph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='outer')

    with ops.Graph().as_default() as g:
      meta_graph.import_scoped_meta_graph(metagraph_def, graph=g,
                                          import_scope='new_outer')
      # The summaries should exist, but with outer scope renamed.
      new_summ = g.get_tensor_by_name('new_outer/inner:0')
      new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')

      # However, the tags are unaffected.
      with self.cached_session() as s:
        new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
        new_summ_pb = summary_pb2.Summary()
        new_summ_pb.ParseFromString(new_summ_str)
        self.assertEquals('outer/inner', new_summ_pb.value[0].tag)
        new_summ_f_pb = summary_pb2.Summary()
        new_summ_f_pb.ParseFromString(new_summ_f_str)
        self.assertEquals('family/outer/family/inner',
                          new_summ_f_pb.value[0].tag)
Exemplo n.º 11
0
    def testImportWhileLoopInWhileLoop(self):
        # Create a simple while loop.
        with ops.Graph().as_default():
            var = variables.Variable(0.0)
            _, output = control_flow_ops.while_loop(
                lambda i, x: i < 5, lambda i, x: (i + 1, x * 2.0), [0, var])
            output_name = output.name

            # Generate a MetaGraphDef containing the while loop with an export scope.
            meta_graph_def, _ = meta_graph.export_scoped_meta_graph()

        # Restore the MetaGraphDef in a while loop in a new graph.
        with ops.Graph().as_default():

            def body(i, _):
                meta_graph.import_scoped_meta_graph(meta_graph_def)
                return i + 1, ops.get_default_graph().get_tensor_by_name(
                    output_name)

            _, x = control_flow_ops.while_loop(lambda i, x: i < 2,
                                               body, [0, 0.0],
                                               name="")
            with session.Session() as sess:
                sess.run(variables.global_variables_initializer())
                sess.run(x)
Exemplo n.º 12
0
  def testDefaultAttrStrippingNestedFunctions(self):
    """Verifies that default attributes are stripped from function node defs."""
    with self.cached_session():

      @function.Defun(dtypes.float32, dtypes.float32)
      def f0(i, j):
        return math_ops.complex(i, j, name="double_nested_complex")

      @function.Defun(dtypes.float32, dtypes.float32)
      def f1(i, j):
        return f0(i, j)

      _ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)

      double_nested_complex_node_def = None
      for function_def in meta_graph_def.graph_def.library.function:
        for node_def in function_def.node_def:
          if node_def.name.startswith("double_nested_complex"):
            double_nested_complex_node_def = node_def
            break
        if double_nested_complex_node_def:
          break

      self.assertIsNotNone(double_nested_complex_node_def)
      self.assertNotIn("T", double_nested_complex_node_def.attr)
      self.assertNotIn("Tout", double_nested_complex_node_def.attr)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
Exemplo n.º 13
0
    def testPotentialCycle(self):
        graph1 = ops.Graph()
        with graph1.as_default():
            a = constant_op.constant(1.0, shape=[2, 2])
            b = constant_op.constant(2.0, shape=[2, 2])
            matmul = math_ops.matmul(a, b)
            with ops.name_scope("hidden1"):
                c = nn_ops.relu(matmul)
                d = constant_op.constant(3.0, shape=[2, 2])
                matmul = math_ops.matmul(c, d)

        orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
            export_scope="hidden1", graph=graph1)

        graph2 = ops.Graph()
        with graph2.as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "Graph contains unbound inputs"):
                meta_graph.import_scoped_meta_graph(orig_meta_graph,
                                                    import_scope="new_hidden1")

            meta_graph.import_scoped_meta_graph(orig_meta_graph,
                                                import_scope="new_hidden1",
                                                input_map={
                                                    "$unbound_inputs_MatMul":
                                                    constant_op.constant(
                                                        4.0, shape=[2, 2])
                                                })
Exemplo n.º 14
0
  def testClearDevices(self):
    graph1 = tf.Graph()
    with graph1.as_default():
      with tf.device("/device:CPU:0"):
        a = tf.Variable(tf.constant(1.0, shape=[2, 2]), name="a")
      with tf.device("/job:ps/replica:0/task:0/gpu:0"):
        b = tf.Variable(tf.constant(2.0, shape=[2, 2]), name="b")
      with tf.device("/job:localhost/replica:0/task:0/cpu:0"):
        tf.matmul(a, b, name="matmul")

    self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
    self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
                     str(graph1.as_graph_element("b").device))
    self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
                     str(graph1.as_graph_element("matmul").device))

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)

    graph2 = tf.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))
Exemplo n.º 15
0
 def testBuild(self):
     graph = GraphPlacerTest._buildInception()
     mg = meta_graph.create_meta_graph_def(graph=graph)
     #gcluster = cluster.Cluster(devices=None) # Automatically generates local machine cluster
     gcluster = GraphPlacerTest._buildCluster()
     print(gcluster.ListDevices())  # Print clust info
     # Spend 15 seconds trying to optimize the placement of the model. This
     # should give us enough time to exercise the code, but not enough to find
     # a good placement, so we'll just check for legality.
     placed_mg = graph_placer.PlaceGraph(mg,
                                         allotted_time=108000,
                                         cluster=gcluster,
                                         verbose=True)
     placed_g = placed_mg.graph_def
     meta_graph.export_scoped_meta_graph(filename="./g/g.meta",
                                         graph_def=placed_g)
Exemplo n.º 16
0
  def testPotentialCycle(self):
    graph1 = ops.Graph()
    with graph1.as_default():
      a = constant_op.constant(1.0, shape=[2, 2])
      b = constant_op.constant(2.0, shape=[2, 2])
      matmul = math_ops.matmul(a, b)
      with ops.name_scope("hidden1"):
        c = nn_ops.relu(matmul)
        d = constant_op.constant(3.0, shape=[2, 2])
        matmul = math_ops.matmul(c, d)

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        export_scope="hidden1", graph=graph1)

    graph2 = ops.Graph()
    with graph2.as_default():
      with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
        meta_graph.import_scoped_meta_graph(
            orig_meta_graph, import_scope="new_hidden1")

      meta_graph.import_scoped_meta_graph(
          orig_meta_graph,
          import_scope="new_hidden1",
          input_map={
              "$unbound_inputs_MatMul": constant_op.constant(
                  4.0, shape=[2, 2])
          })
Exemplo n.º 17
0
    def testSummaryWithFamilyMetaGraphExport(self):
        with ops.name_scope('outer'):
            i = constant_op.constant(11)
            summ = summary_lib.scalar('inner', i)
            self.assertEqual(summ.op.name, 'outer/inner')
            summ_f = summary_lib.scalar('inner', i, family='family')
            self.assertEqual(summ_f.op.name, 'outer/family/inner')

        metagraph_def, _ = meta_graph.export_scoped_meta_graph(
            export_scope='outer')

        with ops.Graph().as_default() as g:
            meta_graph.import_scoped_meta_graph(metagraph_def,
                                                graph=g,
                                                import_scope='new_outer')
            # The summaries should exist, but with the outer scope renamed.
            new_summ = g.get_tensor_by_name('new_outer/inner:0')
            new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')

            # However, the tags are unaffected.
            with self.cached_session() as s:
                new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
                new_summ_pb = summary_pb2.Summary()
                new_summ_pb.ParseFromString(new_summ_str)
                self.assertEqual('outer/inner', new_summ_pb.value[0].tag)
                new_summ_f_pb = summary_pb2.Summary()
                new_summ_f_pb.ParseFromString(new_summ_f_str)
                self.assertEqual('family/outer/family/inner',
                                 new_summ_f_pb.value[0].tag)
Exemplo n.º 18
0
    def testNoVariables(self):
        test_dir = _TestDir("no_variables")
        filename = os.path.join(test_dir, "metafile")

        input_feed_value = -10  # Arbitrary input value for feed_dict.

        orig_graph = tf.Graph()
        with self.test_session(graph=orig_graph) as sess:
            # Create a minimal graph with zero variables.
            input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
            offset = tf.constant(42, dtype=tf.float32, name="offset")
            output_tensor = tf.add(input_tensor, offset, name="add_offset")

            # Add input and output tensors to graph collections.
            tf.add_to_collection("input_tensor", input_tensor)
            tf.add_to_collection("output_tensor", output_tensor)

            output_value = sess.run(output_tensor,
                                    {input_tensor: input_feed_value})
            self.assertEqual(output_value, 32)

            # Generates MetaGraphDef.
            meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
                filename=filename,
                graph_def=tf.get_default_graph().as_graph_def(add_shapes=True),
                collection_list=["input_tensor", "output_tensor"],
                saver_def=None)
            self.assertEqual({}, var_list)

        # Create a clean graph and import the MetaGraphDef nodes.
        new_graph = tf.Graph()
        with self.test_session(graph=new_graph) as sess:
            # Import the previously export meta graph.
            meta_graph.import_scoped_meta_graph(filename)

            # Re-exports the current graph state for comparison to the original.
            new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                filename + "_new")
            self.assertProtoEquals(meta_graph_def, new_meta_graph_def)

            # Ensures that we can still get a reference to our graph collections.
            new_input_tensor = tf.get_collection("input_tensor")[0]
            new_output_tensor = tf.get_collection("output_tensor")[0]
            # Verifies that the new graph computes the same result as the original.
            new_output_value = sess.run(new_output_tensor,
                                        {new_input_tensor: input_feed_value})
            self.assertEqual(new_output_value, output_value)
Exemplo n.º 19
0
    def _testExportImportAcrossScopes(self, graph_fn, use_resource):
        """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
        with ops.Graph().as_default() as original_graph:
            with variable_scope.variable_scope("dropA/dropB/keepA"):
                graph_fn(use_resource=use_resource)
        exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
            graph=original_graph, export_scope="dropA/dropB")[0]

        with ops.Graph().as_default() as imported_graph:
            meta_graph.import_scoped_meta_graph(exported_meta_graph_def,
                                                import_scope="importA")

        with ops.Graph().as_default() as expected_graph:
            with variable_scope.variable_scope("importA/keepA"):
                graph_fn(use_resource=use_resource)

            if use_resource:
                # Bringing in a collection that contains ResourceVariables adds ops
                # to the graph, so mimic the same behavior.
                for collection_key in sorted([
                        ops.GraphKeys.GLOBAL_VARIABLES,
                        ops.GraphKeys.TRAINABLE_VARIABLES,
                ]):
                    for var in expected_graph.get_collection(collection_key):
                        var._read_variable_op()

        result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
        expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

        if use_resource:
            # Clear all shared_name attributes before comparing, since they are
            # supposed to be orthogonal to scopes.
            for meta_graph_def in [result, expected]:
                for node in meta_graph_def.graph_def.node:
                    shared_name_attr = "shared_name"
                    shared_name_value = node.attr.get(shared_name_attr, None)
                    if shared_name_value and shared_name_value.HasField("s"):
                        if shared_name_value.s:
                            node.attr[shared_name_attr].s = b""

        self.assertProtoEquals(expected, result)
Exemplo n.º 20
0
    def doTestExportNestedNames(self, use_resource=False):
        graph1 = ops.Graph()
        with graph1.as_default():
            with ops.name_scope("hidden1/hidden2/hidden3"):
                images = constant_op.constant(1.0,
                                              dtypes.float32,
                                              shape=[3, 2],
                                              name="images")
                if use_resource:
                    weights1 = variables.Variable(
                        [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
                    biases1 = resource_variable_ops.ResourceVariable(
                        [0.1] * 3, name="biases")
                else:
                    biases1 = variables.Variable([0.1] * 3, name="biases")
                    weights1 = variables.Variable(
                        [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
                nn_ops.relu(math_ops.matmul(images, weights1) + biases1,
                            name="relu")

        orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
            export_scope="hidden1/hidden2", graph=graph1)
        var_names = [v.name for _, v in var_list.items()]
        self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
                         sorted(var_list.keys()))
        self.assertEqual([
            "hidden1/hidden2/hidden3/biases:0",
            "hidden1/hidden2/hidden3/weights:0"
        ], sorted(var_names))
        for node in orig_meta_graph.graph_def.node:
            self.assertTrue(node.name.startswith("hidden3"))

        graph2 = ops.Graph()
        new_var_list = meta_graph.import_scoped_meta_graph(
            orig_meta_graph,
            import_scope="new_hidden1/new_hidden2",
            graph=graph2)
        self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
                         sorted(new_var_list.keys()))
        new_var_names = [v.name for _, v in new_var_list.items()]
        self.assertEqual([
            "new_hidden1/new_hidden2/hidden3/biases:0",
            "new_hidden1/new_hidden2/hidden3/weights:0"
        ], sorted(new_var_names))

        nodes = [
            "new_hidden1/new_hidden2/hidden3/biases/Assign",
            "new_hidden1/new_hidden2/hidden3/weights/Assign"
        ]
        expected = [
            b"loc:@new_hidden1/new_hidden2/hidden3/biases",
            b"loc:@new_hidden1/new_hidden2/hidden3/weights"
        ]
        for n, e in zip(nodes, expected):
            self.assertEqual(
                [e],
                graph2.get_operation_by_name(n).get_attr("_class"))
Exemplo n.º 21
0
 def testImportsUsingSameScopeName(self):
   with ops.Graph().as_default():
     variables.Variable(0, name="v")
     meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
   with ops.Graph().as_default():
     for suffix in ["", "_1"]:
       imported_variables = meta_graph.import_scoped_meta_graph(
           meta_graph_def, import_scope="s")
       self.assertEqual(len(imported_variables), 1)
       self.assertEqual(list(imported_variables.keys())[0], "v:0")
       self.assertEqual(list(imported_variables.values())[0].name,
                        "s" + suffix + "/v:0")
Exemplo n.º 22
0
 def testImportsUsingSameScopeName(self):
   with ops.Graph().as_default():
     variables.Variable(0, name="v")
     meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
   with ops.Graph().as_default():
     for suffix in ["", "_1"]:
       imported_variables = meta_graph.import_scoped_meta_graph(
           meta_graph_def, import_scope="s")
       self.assertEqual(len(imported_variables), 1)
       self.assertEqual(list(imported_variables.keys())[0], "v:0")
       self.assertEqual(list(imported_variables.values())[0].name,
                        "s" + suffix + "/v:0")
Exemplo n.º 23
0
  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # orthogonal to scopes and are not updated on export/import.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    test_util.assert_meta_graph_protos_equal(self, expected, result)
Exemplo n.º 24
0
  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # orthogonal to scopes and are not updated on export/import.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    test_util.assert_meta_graph_protos_equal(self, expected, result)
Exemplo n.º 25
0
    def testScopedImportUnderNameScopeNoVarScope(self):
        graph = ops.Graph()
        with graph.as_default():
            variables.Variable(initial_value=1.0, trainable=True, name="myvar")
        meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)

        graph = ops.Graph()
        with graph.as_default():
            with ops.name_scope("foo"):
                imported_variables = meta_graph.import_scoped_meta_graph(
                    meta_graph_def)
                self.assertEqual(len(imported_variables), 1)
                self.assertEqual(
                    list(imported_variables.values())[0].name, "foo/myvar:0")
Exemplo n.º 26
0
  def testScopedImportUnderNameScope(self):
    graph = ops.Graph()
    with graph.as_default():
      variables.Variable(initial_value=1.0, trainable=True, name="myvar")
    meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)

    graph = ops.Graph()
    with graph.as_default():
      with ops.name_scope("foo"):
        imported_variables = meta_graph.import_scoped_meta_graph(
            meta_graph_def, import_scope="bar")
        self.assertEqual(len(imported_variables), 1)
        self.assertEqual(list(imported_variables.values())[0].name,
                         "foo/bar/myvar:0")
Exemplo n.º 27
0
  def doTestExportNestedNames(self, use_resource=False):
    graph1 = ops.Graph()
    with graph1.as_default():
      with ops.name_scope("hidden1/hidden2/hidden3"):
        images = constant_op.constant(
            1.0, dtypes.float32, shape=[3, 2], name="images")
        if use_resource:
          weights1 = variables.Variable(
              [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
          biases1 = resource_variable_ops.ResourceVariable(
              [0.1] * 3, name="biases")
        else:
          biases1 = variables.Variable([0.1] * 3, name="biases")
          weights1 = variables.Variable(
              [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")

    orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
        export_scope="hidden1/hidden2", graph=graph1)
    var_names = [v.name for _, v in var_list.items()]
    self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
                     sorted(var_list.keys()))
    self.assertEqual([
        "hidden1/hidden2/hidden3/biases:0", "hidden1/hidden2/hidden3/weights:0"
    ], sorted(var_names))
    for node in orig_meta_graph.graph_def.node:
      self.assertTrue(node.name.startswith("hidden3"))

    graph2 = ops.Graph()
    new_var_list = meta_graph.import_scoped_meta_graph(
        orig_meta_graph, import_scope="new_hidden1/new_hidden2", graph=graph2)
    self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
                     sorted(new_var_list.keys()))
    new_var_names = [v.name for _, v in new_var_list.items()]
    self.assertEqual([
        "new_hidden1/new_hidden2/hidden3/biases:0",
        "new_hidden1/new_hidden2/hidden3/weights:0"
    ], sorted(new_var_names))

    nodes = [
        "new_hidden1/new_hidden2/hidden3/biases/Assign",
        "new_hidden1/new_hidden2/hidden3/weights/Assign"
    ]
    expected = [
        b"loc:@new_hidden1/new_hidden2/hidden3/biases",
        b"loc:@new_hidden1/new_hidden2/hidden3/weights"
    ]
    for n, e in zip(nodes, expected):
      self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
Exemplo n.º 28
0
  def _testScopedImportWithQueue(self, test_dir, exported_filename,
                                 new_exported_filename):
    graph = tf.Graph()
    meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filename),
        graph=graph,
        import_scope="new_queue1")
    graph.as_graph_element("new_queue1/dequeue:0")
    graph.as_graph_element("new_queue1/close")
    with graph.as_default():
      new_meta_graph, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, new_exported_filename),
          graph=graph, export_scope="new_queue1")

    return new_meta_graph
Exemplo n.º 29
0
  def _testScopedImportWithQueue(self, test_dir, exported_filename,
                                 new_exported_filename):
    graph = tf.Graph()
    meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filename),
        graph=graph,
        import_scope="new_queue1")
    graph.as_graph_element("new_queue1/dequeue:0")
    graph.as_graph_element("new_queue1/close")
    with graph.as_default():
      new_meta_graph, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, new_exported_filename),
          graph=graph, export_scope="new_queue1")

    return new_meta_graph
Exemplo n.º 30
0
  def _testScopedExportWithQueue(self, test_dir, exported_filename):
    graph = tf.Graph()
    with graph.as_default():
      with tf.name_scope("queue1"):
        input_queue = tf.FIFOQueue(10, tf.float32)
        enqueue = input_queue.enqueue((9876), name="enqueue")
        close = input_queue.close(name="close")
        qr = tf.train.QueueRunner(input_queue, [enqueue], close)
        tf.train.add_queue_runner(qr)
        input_queue.dequeue(name="dequeue")

      orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filename),
          graph=tf.get_default_graph(), export_scope="queue1")

    return orig_meta_graph
Exemplo n.º 31
0
  def _testScopedExportWithQueue(self, test_dir, exported_filename):
    graph = tf.Graph()
    with graph.as_default():
      with tf.name_scope("queue1"):
        input_queue = tf.FIFOQueue(10, tf.float32)
        enqueue = input_queue.enqueue((9876), name="enqueue")
        close = input_queue.close(name="close")
        qr = tf.train.QueueRunner(input_queue, [enqueue], close)
        tf.train.add_queue_runner(qr)
        input_queue.dequeue(name="dequeue")

      orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filename),
          graph=tf.get_default_graph(), export_scope="queue1")

    return orig_meta_graph
Exemplo n.º 32
0
  def testMetricVariablesCollectionLoadsBytesList(self):
    with ops.Graph().as_default() as graph1:
      v1 = variables.Variable(
          [1, 2, 3], shape=[3], dtype=dtypes.float64, name="v")

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)

    # Copy bytes list from global variables collection to metric variables.
    orig_meta_graph.collection_def[ops.GraphKeys.METRIC_VARIABLES].CopyFrom(
        orig_meta_graph.collection_def["variables"])

    with ops.Graph().as_default() as graph2:
      meta_graph.import_scoped_meta_graph(orig_meta_graph)
      var_list = graph2.get_collection(ops.GraphKeys.METRIC_VARIABLES)
      self.assertEqual(len(var_list), 1)
      v2 = var_list[0]
      self.assertIsInstance(v2, variables.Variable)
      self.assertEqual(v1.name, v2.name)
      self.assertEqual(v1.dtype, v2.dtype)
      self.assertEqual(v1.shape, v2.shape)
Exemplo n.º 33
0
  def testVariableObjectsAreSharedAmongCollections(self):
    with ops.Graph().as_default() as graph1:
      v = variables.Variable(3.0)
      # A single instance of Variable is shared among the collections:
      global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
      self.assertEqual(len(global_vars), 1)
      self.assertEqual(len(trainable_vars), 1)
      self.assertIs(global_vars[0], trainable_vars[0])
      self.assertIs(v, global_vars[0])

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
    del graph1  # To avoid accidental references in code involving graph2.

    with ops.Graph().as_default() as graph2:
      meta_graph.import_scoped_meta_graph(orig_meta_graph)
      global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
      self.assertEqual(len(global_vars), 1)
      self.assertEqual(len(trainable_vars), 1)
      # A single instance of Variable is shared among the collections:
      self.assertIs(global_vars[0], trainable_vars[0])
Exemplo n.º 34
0
  def testVariableObjectsAreSharedAmongCollections(self):
    with ops.Graph().as_default() as graph1:
      v = variables.Variable(3.0)
      # A single instance of Variable is shared among the collections:
      global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
      self.assertEqual(len(global_vars), 1)
      self.assertEqual(len(trainable_vars), 1)
      self.assertIs(global_vars[0], trainable_vars[0])
      self.assertIs(v, global_vars[0])

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
    del graph1  # To avoid accidental references in code involving graph2.

    with ops.Graph().as_default() as graph2:
      meta_graph.import_scoped_meta_graph(orig_meta_graph)
      global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
      self.assertEqual(len(global_vars), 1)
      self.assertEqual(len(trainable_vars), 1)
      # A single instance of Variable is shared among the collections:
      self.assertIs(global_vars[0], trainable_vars[0])
Exemplo n.º 35
0
  def testImportWhileLoopInWhileLoop(self):
    # Create a simple while loop.
    with ops.Graph().as_default():
      var = variables.Variable(0.0)
      _, output = control_flow_ops.while_loop(lambda i, x: i < 5,
                                              lambda i, x: (i + 1, x * 2.0),
                                              [0, var])
      output_name = output.name

      # Generate a MetaGraphDef containing the while loop with an export scope.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph()

    # Restore the MetaGraphDef in a while loop in a new graph.
    with ops.Graph().as_default():

      def body(i, _):
        meta_graph.import_scoped_meta_graph(meta_graph_def)
        return i + 1, ops.get_default_graph().get_tensor_by_name(output_name)

      _, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
                                         name="")
      with session.Session() as sess:
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(x)
Exemplo n.º 36
0
    def _testScopedExport(self, test_dir, exported_filenames):
        graph = ops.Graph()
        with graph.as_default():
            # Creates an inference graph.
            # Hidden 1
            colocate_constraint = constant_op.constant(1.2, name="constraint")
            images = constant_op.constant(1.2,
                                          dtypes.float32,
                                          shape=[100, 28],
                                          name="images")
            with ops.name_scope("hidden1"):
                with graph.colocate_with(colocate_constraint.op):
                    weights1 = variables.Variable(random_ops.truncated_normal(
                        [28, 128], stddev=1.0 / math.sqrt(float(28))),
                                                  name="weights")
                # The use of control_flow_ops.cond here is purely for adding test
                # coverage the save and restore of control flow context (which doesn't
                # make any sense here from a machine learning perspective).  The typical
                # biases is a simple Variable without the conditions.
                biases1 = variables.Variable(control_flow_ops.cond(
                    math_ops.less(random.random(),
                                  0.5), lambda: array_ops.ones([128]),
                    lambda: array_ops.zeros([128])),
                                             name="biases")
                hidden1 = nn_ops.relu(
                    math_ops.matmul(images, weights1) + biases1)

            # Hidden 2
            with ops.name_scope("hidden2"):
                weights2 = variables.Variable(random_ops.truncated_normal(
                    [128, 32], stddev=1.0 / math.sqrt(float(128))),
                                              name="weights")

                # The use of control_flow_ops.while_loop here is purely for adding test
                # coverage the save and restore of control flow context (which doesn't
                # make any sense here from a machine learning perspective).  The typical
                # biases is a simple Variable without the conditions.
                def loop_cond(it, _):
                    return it < 2

                def loop_body(it, biases2):
                    biases2 += constant_op.constant(0.1, shape=[32])
                    return it + 1, biases2

                _, biases2 = control_flow_ops.while_loop(
                    loop_cond, loop_body, [
                        constant_op.constant(0),
                        variables.Variable(array_ops.zeros([32]),
                                           name="biases")
                    ])
                hidden2 = nn_ops.relu(
                    math_ops.matmul(hidden1, weights2) + biases2)
            # Linear
            with ops.name_scope("softmax_linear"):
                weights3 = variables.Variable(random_ops.truncated_normal(
                    [32, 10], stddev=1.0 / math.sqrt(float(32))),
                                              name="weights")
                biases3 = variables.Variable(array_ops.zeros([10]),
                                             name="biases")
                logits = math_ops.matmul(hidden2, weights3) + biases3
                ops.add_to_collection("logits", logits)

            # Exports each sub-graph.
            # Exports the first one with unbound_inputs_col_name set to default.
            orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filenames[0]),
                graph=ops.get_default_graph(),
                export_scope="hidden1")
            self.assertEqual(["biases:0", "weights:0"],
                             sorted(var_list.keys()))
            var_names = [v.name for _, v in var_list.items()]
            self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
                             sorted(var_names))

            # Exports the rest with no unbound_inputs_col_name.
            orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filenames[1]),
                graph=ops.get_default_graph(),
                export_scope="hidden2",
                unbound_inputs_col_name=None)
            orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filenames[2]),
                graph=ops.get_default_graph(),
                export_scope="softmax_linear",
                unbound_inputs_col_name=None)

        return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3]
Exemplo n.º 37
0
  def _testScopedImport(self, test_dir, exported_filename,
                        new_exported_filename, ckpt_filename):
    graph = tf.Graph()
    # Create all the missing inputs.
    with graph.as_default():
      new_image = tf.constant(1.2, tf.float32, shape=[100, 28],
                              name="images")

    with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
      meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filename), graph=graph,
          import_scope="new_hidden1")

    with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
      meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filename), graph=graph,
          input_map={"image:0": new_image},
          import_scope="new_hidden1")

    var_list = meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filename), graph=graph,
        input_map={"$unbound_inputs_images": new_image},
        import_scope="new_hidden1")
    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    new_var_names = [v.name for _, v in var_list.items()]
    self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
                     sorted(new_var_names))
    hidden1 = graph.as_graph_element("new_hidden1/Relu:0")

    with graph.as_default():
      # Hidden 2
      with tf.name_scope("hidden2"):
        weights = tf.Variable(
            tf.truncated_normal([128, 32],
                                stddev=1.0 / math.sqrt(float(128))),
            name="weights")
        # The use of control_flow_ops.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2
        def loop_body(it, biases):
          biases += tf.constant(0.1, shape=[32])
          return it + 1, biases
        _, biases = control_flow_ops.while_loop(
            loop_cond, loop_body,
            [tf.constant(0), tf.Variable(tf.zeros([32]))])
        hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
      # Linear
      with tf.name_scope("softmax_linear"):
        weights = tf.Variable(
            tf.truncated_normal([32, 10],
                                stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases = tf.Variable(tf.zeros([10]), name="biases")
        logits = tf.matmul(hidden2, weights) + biases
        tf.add_to_collection("logits", logits)

      new_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, new_exported_filename),
          graph=graph, export_scope="new_hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

    return new_meta_graph
Exemplo n.º 38
0
  def _testScopedExport(self, test_dir, exported_filename, ckpt_filename):
    graph = tf.Graph()
    with graph.as_default():
      # Creates an inference graph.
      # Hidden 1
      colocate_constraint = tf.constant(1.2, name="constraint")
      images = tf.constant(1.2, tf.float32, shape=[100, 28], name="images")
      with tf.name_scope("hidden1"):
        with graph.colocate_with(colocate_constraint.op):
          weights1 = tf.Variable(
              tf.truncated_normal([28, 128],
                                  stddev=1.0 / math.sqrt(float(28))),
              name="weights")
        # The use of control_flow_ops.cond here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        biases1 = tf.Variable(
            control_flow_ops.cond(tf.less(random.random(), 0.5),
                                  lambda: tf.ones([128]),
                                  lambda: tf.zeros([128])),
            name="biases")
        hidden1 = tf.nn.relu(tf.matmul(images, weights1) + biases1)

      # Hidden 2
      with tf.name_scope("hidden2"):
        weights2 = tf.Variable(
            tf.truncated_normal([128, 32],
                                stddev=1.0 / math.sqrt(float(128))),
            name="weights")
        # The use of control_flow_ops.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2
        def loop_body(it, biases2):
          biases2 += tf.constant(0.1, shape=[32])
          return it + 1, biases2
        _, biases2 = control_flow_ops.while_loop(
            loop_cond, loop_body,
            [tf.constant(0), tf.Variable(tf.zeros([32]))])
        hidden2 = tf.nn.relu(tf.matmul(hidden1, weights2) + biases2)
      # Linear
      with tf.name_scope("softmax_linear"):
        weights3 = tf.Variable(
            tf.truncated_normal([32, 10],
                                stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases3 = tf.Variable(tf.zeros([10]), name="biases")
        logits = tf.matmul(hidden2, weights3) + biases3
        tf.add_to_collection("logits", logits)
      orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filename),
          graph=tf.get_default_graph(), export_scope="hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
      var_names = [v.name for _, v in var_list.items()]
      self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
                       sorted(var_names))

    return orig_meta_graph
Exemplo n.º 39
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("package_names", type=str, nargs='*')
    parser.add_argument(
        "--root",
        metavar='DIR',
        type=str,
        help="""Specify root directory to search for imports from""")
    parser.add_argument(
        "--source",
        type=str,
        help="""Specify source code instead of reading from file""")

    parser.add_argument("--reopen-stderr",
                        metavar='FILE',
                        type=argparse.FileType('a', encoding='UTF-8'))
    parser.add_argument("--reopen-stdout",
                        metavar='FILE',
                        type=argparse.FileType('a', encoding='UTF-8'))

    parser.add_argument("--assets-fetch",
                        default=False,
                        action='store_const',
                        const=True,
                        help="""Fetch any assets we don't already have.""")
    parser.add_argument("--assets-root",
                        metavar='DIR',
                        type=str,
                        help="""Specify root directory for assets.""")

    parser.add_argument("--metagraphdef",
                        metavar='FILE',
                        type=str,
                        help="""Graph file to load.""")
    parser.add_argument("--binary-metagraphdef",
                        default=False,
                        action='store_const',
                        const=True,
                        help="""Whether or not input is binary.""")
    parser.add_argument(
        "--feed-constants",
        metavar='FILE',
        type=str,
        help="""Path to GraphDef protobuf with constants to feed""")
    parser.add_argument(
        "--feed-constants-strip",
        metavar='PREFIX',
        type=str,
        default="",
        help="""Prefix to filter for (and strip from) constants""")
    parser.add_argument("--feed-constants-prefix",
                        metavar='PREFIX',
                        type=str,
                        help="""Prefix to add to constant names in feed""")
    parser.add_argument(
        "--feed-constants-binary",
        default=False,
        action='store_const',
        const=True,
        help="""Whether or not feed constant protobuf is binary""")

    parser.add_argument(
        "--run",
        default=False,
        action='store_const',
        const=True,
        help=
        """Run the graph with given (or default) --result* and --feed-* options"""
    )
    parser.add_argument("--run-result-pattern",
                        metavar='PATTERN',
                        type=str,
                        default="^(${package}/Main)/outputs/(.*)$",
                        help="""Pattern to discover run results.""")
    parser.add_argument("--result-binary",
                        default=False,
                        action='store_const',
                        const=True,
                        help="""Whether or not to result in binary.""")
    parser.add_argument("--result",
                        metavar='FILE',
                        type=str,
                        default="/dev/stdout")

    parser.add_argument(
        "--test",
        default=False,
        action='store_const',
        const=True,
        help="""Run the tests graphs with given (or default) --test-* options"""
    )
    parser.add_argument("--test-result-pattern",
                        metavar='PATTERN',
                        type=str,
                        default="^(${package}/Test[^/]*)/outputs/(.*)$",
                        help="""Pattern to discover test graph results.""")

    parser.add_argument("--repl",
                        default=False,
                        action='store_const',
                        const=True,
                        help="""Start REPL""")

    parser.add_argument(
        "--tensorboard",
        nargs='?',
        default="",
        metavar="IP:PORT",
        help=
        """Start tensorboard server on the given address, with the given --log-root or --log-dir"""
    )

    parser.add_argument(
        "--jupyter-kernel",
        nargs='?',
        default="",
        metavar="CONFIG_FILE",
        help="""Start Jupyter kernel with the given configuration file""")

    parser.add_argument(
        "--train",
        default=False,
        action='store_const',
        const=True,
        help="""Run train graphs with given (or default) --train-* options""")
    parser.add_argument("--train-result-pattern",
                        metavar='PATTERN',
                        type=str,
                        default="^(${package}/Train[^/]*)/outputs/(.*)$",
                        help="""Pattern to discover train graph results.""")

    parser.add_argument("--workspace",
                        metavar='DIR',
                        type=str,
                        help="""Default value for workspace""")
    parser.add_argument(
        "--log-root",
        metavar='DIR',
        type=str,
        help="""Which directory to calculate default log dir from.""")
    parser.add_argument("--log-dir",
                        metavar='DIR',
                        type=str,
                        help="""Which directory to put logs in.""")

    parser.add_argument("--output",
                        default=False,
                        action='store_const',
                        const=True,
                        help="""Output graph""")
    parser.add_argument(
        "--output-root",
        metavar='DIR',
        type=str,
        help="""When automatically constructing output path, use this as base"""
    )
    parser.add_argument(
        "--output-name",
        metavar='NAME',
        type=str,
        help=
        """Base name to use for output file name. Defaults to ${package} if there's only one."""
    )
    parser.add_argument(
        "--output-result-pattern",
        metavar='PATTERN',
        type=str,
        default="^(${package}/[^/]*)(/outputs/[^/]*)?$",
        help="""Pattern to discover outputs of graph to output.""")
    parser.add_argument("--output-format",
                        metavar='FORMAT',
                        type=str,
                        default="metagraph",
                        help="""Defaults to metagraph""")
    parser.add_argument("--output-binary",
                        default=False,
                        action='store_const',
                        const=True,
                        help="""Whether or not to output in binary.""")
    parser.add_argument(
        "--output-file",
        metavar='FILE',
        type=str,
        help=
        """Path to write output to. Defaults to ${output-name}.${output-format}"""
    )

    FLAGS = parser.parse_args()

    if FLAGS.reopen_stderr:
        os.close(sys.stderr.fileno())
        os.dup2(FLAGS.reopen_stderr.fileno(), sys.stderr.fileno())

    if FLAGS.reopen_stdout:
        os.close(sys.stdout.fileno())
        os.dup2(FLAGS.reopen_stdout.fileno(), sys.stdout.fileno())

    package_names = FLAGS.package_names

    should_parse = len(package_names) > 0 or FLAGS.source
    if not (should_parse or FLAGS.run or FLAGS.test or FLAGS.output):
        if os.isatty(1):
            FLAGS.repl = True

    if should_parse and not (FLAGS.repl or FLAGS.run or FLAGS.test
                             or FLAGS.output):
        FLAGS.output = True

    def search_upwards(startdir, filename):
        curdir = startdir
        while True:
            if path.exists(path.join(curdir, filename)):
                return curdir
            lastdir = curdir
            curdir = path.dirname(curdir)
            if curdir == lastdir:
                return None

    if not FLAGS.workspace:
        FLAGS.workspace = os.environ.get("NAOPATH", "")
        if not FLAGS.workspace:
            FLAGS.workspace = search_upwards(os.getcwd(), ".naoconfig")
        if not FLAGS.workspace:
            FLAGS.workspace = "."

    if FLAGS.assets_root is None:
        FLAGS.assets_root = path.join(FLAGS.workspace, "assets")

    if FLAGS.output_root is None:
        FLAGS.output_root = path.join(FLAGS.workspace, "pkg")

    if FLAGS.root is None:
        FLAGS.root = path.join(FLAGS.workspace, "src")

    if FLAGS.log_root is None:
        FLAGS.log_root = path.join(FLAGS.workspace, "log")

    if FLAGS.tensorboard is None:
        FLAGS.tensorboard = "127.0.0.1:6006"

    def log_dir_fn_fn(pkg_names):
        if FLAGS.log_dir:
            return lambda: FLAGS.log_dir

        session_name = datetime.datetime.utcnow().strftime("%F_%H-%M-%S")
        base_log_dir = path.join(FLAGS.log_root, pkg_names[0], session_name)

        def log_dir_fn(run_id=None):
            log_dir = base_log_dir
            if run_id is not None:
                log_dir = path.join(log_dir, "%04d" % run_id)
            return log_dir

        return log_dir_fn

    def new_compiler():
        return Compiler(FLAGS.root, FLAGS.output_root, FLAGS.assets_root)

    meta_graph_def = None

    output_package_names = None

    if should_parse:
        p = new_compiler()
        if FLAGS.source:
            package_name = "main"
            package_names = [package_name]
            p.put_source(package_name + ".nao", FLAGS.source)
            p.resolve_import_path(package_name)
        else:
            # Look for matching packages _train
            if FLAGS.train:
                output_package_names = package_names[:]
                package_names.extend([pkg + "_train" for pkg in package_names])

            for package_name in package_names:
                p.resolve_import_path(package_name)

        meta_graph_def = p.meta_graph_def()
        p = None
        # print("parsed", expressions)
        # We need to do this so we clean up references to py_funcs. LAME.
        gc.collect()

    # Sometimes we want to output different packages than we're testing, training, etc.
    if output_package_names == None:
        output_package_names = package_names

    if not FLAGS.output_name and len(output_package_names) == 1:
        FLAGS.output_name = output_package_names[0]
        if FLAGS.train:
            FLAGS.output_name += "_trained"

    if FLAGS.metagraphdef:
        package_names = ["[^/]+"]
        meta_graph_def = graph_io.read_meta_graph_def(
            FLAGS.metagraphdef, FLAGS.binary_metagraphdef)

    if FLAGS.output and FLAGS.output_name and not FLAGS.output_file:
        output_suffix = "." + FLAGS.output_format + ".pb"
        if not FLAGS.output_binary:
            output_suffix += "txt"
        FLAGS.output_file = FLAGS.output_root + "/" + FLAGS.output_name + output_suffix

    # Now that we know our package names, use them to target the proper results.
    package_pattern = "(?:" + str.join("|", package_names) + ")"
    FLAGS.test_result_pattern = FLAGS.test_result_pattern.replace(
        "${package}", package_pattern)
    FLAGS.train_result_pattern = FLAGS.train_result_pattern.replace(
        "${package}", package_pattern)
    FLAGS.run_result_pattern = FLAGS.run_result_pattern.replace(
        "${package}", package_pattern)

    output_package_pattern = "(?:" + str.join("|", output_package_names) + ")"
    FLAGS.output_result_pattern = FLAGS.output_result_pattern.replace(
        "${package}", output_package_pattern)
    eprint("FLAGS", FLAGS)
    eprint("package_names", package_names)

    if FLAGS.tensorboard != "":
        tb_host, tb_port = FLAGS.tensorboard.split(':', 1)
        tb_logdir = FLAGS.log_dir or FLAGS.log_root
        if tb_port is not None:
            tb_port = int(tb_port)

        from nao.tool import tensorboard_server
        sys.exit(
            tensorboard_server.main(tb_logdir,
                                    tb_host=tb_host,
                                    tb_port=tb_port))

    if FLAGS.jupyter_kernel != "":
        jupyter_config_file = FLAGS.jupyter_kernel
        from nao.tool import jupyter_kernel, jupyter_kernel_driver

        if jupyter_config_file:
            eprint("Reading jupyter_config file '%s'..." % jupyter_config_file)
            jupyter_config = json.loads("".join(
                open(jupyter_config_file).readlines()))
        else:
            import uuid
            jupyter_config = {
                'control_port': 0,
                'hb_port': 0,
                'iopub_port': 0,
                'ip': '127.0.0.1',
                'key': str(uuid.uuid4()),
                'shell_port': 0,
                'signature_scheme': 'hmac-sha256',
                'stdin_port': 0,
                'transport': 'tcp'
            }

        pallet_parser = new_compiler()
        repl_session = graph_repl.ReplSession(pallet_parser,
                                              log_dir_fn_fn(["jupyter"]))
        driver = jupyter_kernel_driver.Driver(repl_session)
        sys.exit(
            jupyter_kernel.Kernel(jupyter_config, driver.info(),
                                  driver.do).run())

    def feed_dict_fn():
        feed_dict = {}
        # Properly find and strip prefix of constants, loading them with given prefix to feed_dict
        if FLAGS.feed_constants:
            feed_graph_def = graph_io.read_graph_def(
                FLAGS.feed_constants, FLAGS.feed_constants_binary)
            constants = graph_query.find_nodes_with_prefix(
                feed_graph_def, FLAGS.feed_constants_strip)
            constants_dict = graph_xform.constants_as_dict(constants)
            strip_prefix = FLAGS.feed_constants_strip
            add_prefix = FLAGS.feed_constants_prefix
            for name, value in constants_dict.items():
                if strip_prefix != None:
                    if name.startswith(strip_prefix):
                        name = name[len(strip_prefix):]
                    else:
                        continue
                feed_dict[add_prefix + name + ":0"] = value

        asset_map = graph_assets.load_asset_map(tf.get_default_graph())
        eprint("asset_map", asset_map)

        assets_by_path = {}
        missing_assets = {}
        for asset_name, asset in asset_map.items():
            asset_path = path.join(FLAGS.assets_root, asset_name)
            assets_by_path[asset_path] = asset
            feed_dict[asset["placeholder"]] = asset_path
            if not os.path.exists(asset_path):
                missing_assets[asset_path] = asset

        if len(missing_assets) > 0:
            if not FLAGS.assets_fetch:
                raise Exception("Missing assets: %s" % missing_assets)

            for asset_path, asset in missing_assets.items():
                graph_assets.maybe_download(asset_path, asset["url"])

        eprint("feed_dict", feed_dict)
        return feed_dict

    if FLAGS.train:

        def post_train(session, result_scope_prefixes):
            graph = session.graph

            trained_var_name_bs = set()
            for result_scope_prefix in result_scope_prefixes:
                collection_name = "%s:variable_names" % result_scope_prefix
                eprint("collection_name", collection_name)
                for var_name_b in graph.get_collection_ref(collection_name):
                    trained_var_name_bs.add(var_name_b)

            var_names = [b.decode('utf-8') for b in trained_var_name_bs]
            vars = graph_query.find_variables_by_name(
                graph.get_collection_ref("variables"), var_names)
            eprint("saving vars", var_names, vars)
            graph_xform.replace_variable_initializers_with_current_values(
                graph, vars, "Trained")

        graph_execution.import_and_run_meta_graph(
            meta_graph_def=meta_graph_def,
            feed_dict_fn=feed_dict_fn,
            result_pattern=re.compile(FLAGS.train_result_pattern),
            finish_session_fn=post_train,
            log_dir_fn=lambda x: log_dir_fn(x)(),
        )
        meta_graph_def, _ = meta_graph.export_scoped_meta_graph()

    if FLAGS.test:
        graph_execution.import_and_run_meta_graph(
            meta_graph_def=meta_graph_def,
            feed_dict_fn=feed_dict_fn,
            log_dir_fn=lambda x: log_dir_fn(x)(),
            result_pattern=re.compile(FLAGS.test_result_pattern),
        )

    if meta_graph_def and FLAGS.output_file:
        eprint("meta_graph_def",
               [n.name for n in meta_graph_def.graph_def.node])
        graph_def = meta_graph_def.graph_def
        output_re = re.compile(FLAGS.output_result_pattern)
        output_node_names = ['py_funcs_json'
                             ]  # HACK(adamb) So that pyfuncs still work.
        var_names = set()
        for n in graph_def.node:
            m = output_re.match(n.name)
            if not m:
                continue
            output_node_names.append(n.name)

            # If this isn't a function, then we're covered. Otherwise pick up needed
            # variables.
            if not m.group(2):
                continue

            # Look for collection of variable names referenced by this function.
            collection_name = "%s:variable_names" % m.group(1)
            eprint("collection_name", collection_name)
            function_var_name_bs = meta_graph_def.collection_def[
                collection_name].bytes_list.value
            for var_name_b in function_var_name_bs:
                # Remember the name of each variable referenced.
                var_names.add(var_name_b.decode('utf-8'))

        eprint("var_names", var_names)
        eprint("output_node_names", output_node_names)
        graph_xform.strip_meta_graph(meta_graph_def, output_node_names,
                                     var_names)

    if FLAGS.output_file:
        output_dirname = os.path.dirname(FLAGS.output_file)
        if not os.path.exists(output_dirname):
            os.makedirs(output_dirname)

        if FLAGS.output_format == "metagraph":
            graph_io.write_meta_graph_def(meta_graph_def=meta_graph_def,
                                          file=FLAGS.output_file,
                                          binary=FLAGS.output_binary)
        elif FLAGS.output_format == "graph":
            # If we trained and we're outputting a graph_def, we'll need to modify it.
            # We'll need to replace all the trained variables with the *constants* that
            # their initializers refer to.
            if FLAGS.train:
                pass
            graph_io.write_graph_def(graph_def=meta_graph_def.graph_def,
                                     file=FLAGS.output_file,
                                     binary=FLAGS.output_binary)

    if FLAGS.run:
        results = graph_execution.import_and_run_meta_graph(
            meta_graph_def=meta_graph_def,
            feed_dict_fn=feed_dict_fn,
            log_dir_fn=lambda x: log_dir_fn(x)(),
            result_pattern=re.compile(FLAGS.run_result_pattern),
        )

        graph_def = graph_xform.dict_as_graph_def(results)
        graph_io.write_graph_def(
            graph_def,
            file=FLAGS.result,
            binary=FLAGS.result_binary,
        )

    if FLAGS.repl:
        graph_repl.run(new_compiler(), log_dir_fn_fn(["repl"]))
Exemplo n.º 40
0
  def _testScopedImport(self, test_dir, exported_filenames):
    graph = ops.Graph()
    # Create all the missing inputs.
    with graph.as_default():
      new_image = constant_op.constant(
          1.2, dtypes.float32, shape=[100, 28], name="images")

    with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
      meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filenames[0]),
          graph=graph,
          import_scope="new_hidden1")

    with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
      meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filenames[0]),
          graph=graph,
          input_map={"image:0": new_image},
          import_scope="new_hidden1")

    # Verifies we can import the original "hidden1" into "new_hidden1".
    var_list = meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filenames[0]),
        graph=graph,
        input_map={"$unbound_inputs_images": new_image},
        import_scope="new_hidden1")

    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    new_var_names = [v.name for _, v in var_list.items()]
    self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
                     sorted(new_var_names))

    # Verifies we can import the original "hidden2" into "new_hidden2".
    hidden1 = array_ops.identity(
        graph.as_graph_element("new_hidden1/Relu:0"), name="hidden1/Relu")
    var_list = meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filenames[1]),
        graph=graph,
        input_map={"$unbound_inputs_hidden1/Relu": hidden1},
        import_scope="new_hidden2",
        unbound_inputs_col_name=None)

    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    new_var_names = [v.name for _, v in var_list.items()]
    self.assertEqual(["new_hidden2/biases:0", "new_hidden2/weights:0"],
                     sorted(new_var_names))

    # Verifies we can import the original "softmax_linear" into
    # "new_softmax_linear".
    hidden2 = array_ops.identity(
        graph.as_graph_element("new_hidden2/Relu:0"), name="hidden2/Relu")
    var_list = meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filenames[2]),
        graph=graph,
        input_map={"$unbound_inputs_hidden2/Relu": hidden2},
        import_scope="new_softmax_linear",
        unbound_inputs_col_name=None)
    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    new_var_names = [v.name for _, v in var_list.items()]
    self.assertEqual(
        ["new_softmax_linear/biases:0", "new_softmax_linear/weights:0"],
        sorted(new_var_names))

    # Exports the scoped meta graphs again.
    new_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
        graph=graph, export_scope="new_hidden1")
    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

    new_meta_graph2, var_list = meta_graph.export_scoped_meta_graph(
        graph=graph, export_scope="new_hidden2", unbound_inputs_col_name=None)
    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

    new_meta_graph3, var_list = meta_graph.export_scoped_meta_graph(
        graph=graph,
        export_scope="new_softmax_linear",
        unbound_inputs_col_name=None)
    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

    return [new_meta_graph1, new_meta_graph2, new_meta_graph3]
Exemplo n.º 41
0
  def _testScopedExport(self, test_dir, exported_filenames):
    graph = ops.Graph()
    with graph.as_default():
      # Creates an inference graph.
      # Hidden 1
      colocate_constraint = constant_op.constant(1.2, name="constraint")
      images = constant_op.constant(
          1.2, dtypes.float32, shape=[100, 28], name="images")
      with ops.name_scope("hidden1"):
        with graph.colocate_with(colocate_constraint.op):
          weights1 = variables.Variable(
              random_ops.truncated_normal(
                  [28, 128], stddev=1.0 / math.sqrt(float(28))),
              name="weights")
        # The use of control_flow_ops.cond here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        biases1 = variables.Variable(
            control_flow_ops.cond(
                math_ops.less(random.random(), 0.5),
                lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
            name="biases")
        hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)

      # Hidden 2
      with ops.name_scope("hidden2"):
        weights2 = variables.Variable(
            random_ops.truncated_normal(
                [128, 32], stddev=1.0 / math.sqrt(float(128))),
            name="weights")

        # The use of control_flow_ops.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2

        def loop_body(it, biases2):
          biases2 += constant_op.constant(0.1, shape=[32])
          return it + 1, biases2

        _, biases2 = control_flow_ops.while_loop(
            loop_cond,
            loop_body, [
                constant_op.constant(0), variables.Variable(
                    array_ops.zeros([32]), name="biases")
            ])
        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
      # Linear
      with ops.name_scope("softmax_linear"):
        weights3 = variables.Variable(
            random_ops.truncated_normal(
                [32, 10], stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
        logits = math_ops.matmul(hidden2, weights3) + biases3
        ops.add_to_collection("logits", logits)

      # Exports each sub-graph.
      # Exports the first one with unbound_inputs_col_name set to default.
      orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filenames[0]),
          graph=ops.get_default_graph(),
          export_scope="hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
      var_names = [v.name for _, v in var_list.items()]
      self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
                       sorted(var_names))

      # Exports the rest with no unbound_inputs_col_name.
      orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filenames[1]),
          graph=ops.get_default_graph(),
          export_scope="hidden2",
          unbound_inputs_col_name=None)
      orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filenames[2]),
          graph=ops.get_default_graph(),
          export_scope="softmax_linear",
          unbound_inputs_col_name=None)

    return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3]
Exemplo n.º 42
0
  def _testScopedImport(self, test_dir, exported_filename,
                        new_exported_filename, ckpt_filename):
    graph = tf.Graph()
    # Create all the missing inputs.
    with graph.as_default():
      new_image = tf.constant(1.2, tf.float32, shape=[100, 28],
                              name="images")

    with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
      meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filename), graph=graph,
          import_scope="new_hidden1")

    with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
      meta_graph.import_scoped_meta_graph(
          os.path.join(test_dir, exported_filename), graph=graph,
          input_map={"image:0": new_image},
          import_scope="new_hidden1")

    var_list = meta_graph.import_scoped_meta_graph(
        os.path.join(test_dir, exported_filename), graph=graph,
        input_map={"$unbound_inputs_images": new_image},
        import_scope="new_hidden1")
    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    new_var_names = [v.name for _, v in var_list.items()]
    self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
                     sorted(new_var_names))
    hidden1 = graph.as_graph_element("new_hidden1/Relu:0")

    with graph.as_default():
      # Hidden 2
      with tf.name_scope("hidden2"):
        weights = tf.Variable(
            tf.truncated_normal([128, 32],
                                stddev=1.0 / math.sqrt(float(128))),
            name="weights")
        # The use of control_flow_ops.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2
        def loop_body(it, biases):
          biases += tf.constant(0.1, shape=[32])
          return it + 1, biases
        _, biases = control_flow_ops.while_loop(
            loop_cond, loop_body,
            [tf.constant(0), tf.Variable(tf.zeros([32]))])
        hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
      # Linear
      with tf.name_scope("softmax_linear"):
        weights = tf.Variable(
            tf.truncated_normal([32, 10],
                                stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases = tf.Variable(tf.zeros([10]), name="biases")
        logits = tf.matmul(hidden2, weights) + biases
        tf.add_to_collection("logits", logits)

      new_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, new_exported_filename),
          graph=graph, export_scope="new_hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

    return new_meta_graph
Exemplo n.º 43
0
  def _testScopedExport(self, test_dir, exported_filename, ckpt_filename):
    graph = tf.Graph()
    with graph.as_default():
      # Creates an inference graph.
      # Hidden 1
      colocate_constraint = tf.constant(1.2, name="constraint")
      images = tf.constant(1.2, tf.float32, shape=[100, 28], name="images")
      with tf.name_scope("hidden1"):
        with graph.colocate_with(colocate_constraint.op):
          weights1 = tf.Variable(
              tf.truncated_normal([28, 128],
                                  stddev=1.0 / math.sqrt(float(28))),
              name="weights")
        # The use of control_flow_ops.cond here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        biases1 = tf.Variable(
            control_flow_ops.cond(tf.less(random.random(), 0.5),
                                  lambda: tf.ones([128]),
                                  lambda: tf.zeros([128])),
            name="biases")
        hidden1 = tf.nn.relu(tf.matmul(images, weights1) + biases1)

      # Hidden 2
      with tf.name_scope("hidden2"):
        weights2 = tf.Variable(
            tf.truncated_normal([128, 32],
                                stddev=1.0 / math.sqrt(float(128))),
            name="weights")
        # The use of control_flow_ops.while_loop here is purely for adding test
        # coverage the save and restore of control flow context (which doesn't
        # make any sense here from a machine learning perspective).  The typical
        # biases is a simple Variable without the conditions.
        def loop_cond(it, _):
          return it < 2
        def loop_body(it, biases2):
          biases2 += tf.constant(0.1, shape=[32])
          return it + 1, biases2
        _, biases2 = control_flow_ops.while_loop(
            loop_cond, loop_body,
            [tf.constant(0), tf.Variable(tf.zeros([32]))])
        hidden2 = tf.nn.relu(tf.matmul(hidden1, weights2) + biases2)
      # Linear
      with tf.name_scope("softmax_linear"):
        weights3 = tf.Variable(
            tf.truncated_normal([32, 10],
                                stddev=1.0 / math.sqrt(float(32))),
            name="weights")
        biases3 = tf.Variable(tf.zeros([10]), name="biases")
        logits = tf.matmul(hidden2, weights3) + biases3
        tf.add_to_collection("logits", logits)
      orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
          filename=os.path.join(test_dir, exported_filename),
          graph=tf.get_default_graph(), export_scope="hidden1")
      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
      var_names = [v.name for _, v in var_list.items()]
      self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
                       sorted(var_names))

    return orig_meta_graph
Exemplo n.º 44
0
    def _testScopedImport(self, test_dir, exported_filenames):
        graph = ops.Graph()
        # Create all the missing inputs.
        with graph.as_default():
            new_image = constant_op.constant(1.2,
                                             dtypes.float32,
                                             shape=[100, 28],
                                             name="images")

        with self.assertRaisesRegexp(ValueError,
                                     "Graph contains unbound inputs"):
            meta_graph.import_scoped_meta_graph(os.path.join(
                test_dir, exported_filenames[0]),
                                                graph=graph,
                                                import_scope="new_hidden1")

        with self.assertRaisesRegexp(ValueError,
                                     "Graph contains unbound inputs"):
            meta_graph.import_scoped_meta_graph(
                os.path.join(test_dir, exported_filenames[0]),
                graph=graph,
                input_map={"image:0": new_image},
                import_scope="new_hidden1")

        # Verifies we can import the original "hidden1" into "new_hidden1".
        var_list = meta_graph.import_scoped_meta_graph(
            os.path.join(test_dir, exported_filenames[0]),
            graph=graph,
            input_map={"$unbound_inputs_images": new_image},
            import_scope="new_hidden1")

        self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
        new_var_names = [v.name for _, v in var_list.items()]
        self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
                         sorted(new_var_names))

        # Verifies we can import the original "hidden2" into "new_hidden2".
        hidden1 = array_ops.identity(
            graph.as_graph_element("new_hidden1/Relu:0"), name="hidden1/Relu")
        var_list = meta_graph.import_scoped_meta_graph(
            os.path.join(test_dir, exported_filenames[1]),
            graph=graph,
            input_map={"$unbound_inputs_hidden1/Relu": hidden1},
            import_scope="new_hidden2",
            unbound_inputs_col_name=None)

        self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
        new_var_names = [v.name for _, v in var_list.items()]
        self.assertEqual(["new_hidden2/biases:0", "new_hidden2/weights:0"],
                         sorted(new_var_names))

        # Verifies we can import the original "softmax_linear" into
        # "new_softmax_linear".
        hidden2 = array_ops.identity(
            graph.as_graph_element("new_hidden2/Relu:0"), name="hidden2/Relu")
        var_list = meta_graph.import_scoped_meta_graph(
            os.path.join(test_dir, exported_filenames[2]),
            graph=graph,
            input_map={"$unbound_inputs_hidden2/Relu": hidden2},
            import_scope="new_softmax_linear",
            unbound_inputs_col_name=None)
        self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
        new_var_names = [v.name for _, v in var_list.items()]
        self.assertEqual(
            ["new_softmax_linear/biases:0", "new_softmax_linear/weights:0"],
            sorted(new_var_names))

        # Exports the scoped meta graphs again.
        new_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
            graph=graph, export_scope="new_hidden1")
        self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

        new_meta_graph2, var_list = meta_graph.export_scoped_meta_graph(
            graph=graph,
            export_scope="new_hidden2",
            unbound_inputs_col_name=None)
        self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

        new_meta_graph3, var_list = meta_graph.export_scoped_meta_graph(
            graph=graph,
            export_scope="new_softmax_linear",
            unbound_inputs_col_name=None)
        self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))

        return [new_meta_graph1, new_meta_graph2, new_meta_graph3]
Exemplo n.º 45
0
    def testDefaultAttrStripping(self):
        """Verifies that default attributes are stripped from a graph def."""

        # Complex Op has 2 attributes with defaults:
        #   o "T"    : float32.
        #   o "Tout" : complex64.

        # When inputs to the Complex Op are float32 instances, "T" maps to float32
        # and "Tout" maps to complex64. Since these attr values map to their
        # defaults, they must be stripped unless stripping of default attrs is
        # disabled.
        with self.test_session():
            real_num = constant_op.constant(1.0,
                                            dtype=dtypes.float32,
                                            name="real")
            imag_num = constant_op.constant(2.0,
                                            dtype=dtypes.float32,
                                            name="imag")
            math_ops.complex(real_num, imag_num, name="complex")

            # strip_default_attrs is enabled.
            meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                graph_def=ops.get_default_graph().as_graph_def(),
                strip_default_attrs=True)
            node_def = test_util.get_node_def_from_graph(
                "complex", meta_graph_def.graph_def)
            self.assertNotIn("T", node_def.attr)
            self.assertNotIn("Tout", node_def.attr)
            self.assertTrue(
                meta_graph_def.meta_info_def.stripped_default_attrs)

            # strip_default_attrs is disabled.
            meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                graph_def=ops.get_default_graph().as_graph_def(),
                strip_default_attrs=False)
            node_def = test_util.get_node_def_from_graph(
                "complex", meta_graph_def.graph_def)
            self.assertIn("T", node_def.attr)
            self.assertIn("Tout", node_def.attr)
            self.assertFalse(
                meta_graph_def.meta_info_def.stripped_default_attrs)

        # When inputs to the Complex Op are float64 instances, "T" maps to float64
        # and "Tout" maps to complex128. Since these attr values don't map to their
        # defaults, they must not be stripped.
        with self.session(graph=ops.Graph()):
            real_num = constant_op.constant(1.0,
                                            dtype=dtypes.float64,
                                            name="real")
            imag_num = constant_op.constant(2.0,
                                            dtype=dtypes.float64,
                                            name="imag")
            math_ops.complex(real_num, imag_num, name="complex")
            meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
                graph_def=ops.get_default_graph().as_graph_def(),
                strip_default_attrs=True)
            node_def = test_util.get_node_def_from_graph(
                "complex", meta_graph_def.graph_def)
            self.assertEqual(node_def.attr["T"].type, dtypes.float64)
            self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
            self.assertTrue(
                meta_graph_def.meta_info_def.stripped_default_attrs)