def test_concatenate_inputs_and_outputs_no_arg_graphs(self): graph1 = tf.Graph() with graph1.as_default(): out1 = tf.constant(1.0) init_op_name_1 = tf.compat.v1.global_variables_initializer().name graph2 = tf.Graph() with graph2.as_default(): out2 = tf.constant(2.0) init_op_name_2 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [], [out1.name]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [], [out2.name]) arg_list = [graph_spec_1, graph_spec_2] merged_graph, init_op_name, _, out_name_maps = graph_merge.concatenate_inputs_and_outputs( arg_list) with merged_graph.as_default(): with tf.compat.v1.Session() as sess: sess.run(init_op_name) outputs = sess.run( [out_name_maps[0][out1.name], out_name_maps[1][out2.name]]) self.assertAllClose(outputs, np.array([1., 2.]))
def test_concatenate_inputs_and_outputs_two_add_one_graphs(self): graph1, input_name_1, output_name_1 = _make_add_one_graph() graph2, input_name_2, output_name_2 = _make_add_one_graph() with graph1.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with graph2.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [input_name_1], [output_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [input_name_2], [output_name_2]) arg_list = [graph_spec_1, graph_spec_2] merged_graph, init_op_name, in_name_maps, out_name_maps = graph_merge.concatenate_inputs_and_outputs( arg_list) with merged_graph.as_default(): with tf.compat.v1.Session() as sess: sess.run(init_op_name) outputs = sess.run( [ out_name_maps[0][output_name_1], out_name_maps[1][output_name_2] ], feed_dict={ in_name_maps[0][input_name_1]: 1.0, in_name_maps[1][input_name_2]: 2.0 }) self.assertAllClose(outputs, np.array([2., 3.]))
def test_concatenate_inputs_and_outputs_with_dataset_wires_correctly(self): dataset_graph, _, dataset_out_name = _make_dataset_constructing_graph() graph_1, _, out_name_1 = _make_manual_reduce_graph( dataset_graph, dataset_out_name) graph_2, _, out_name_2 = _make_manual_reduce_graph( dataset_graph, dataset_out_name) with graph_1.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with graph_2.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph_1.as_graph_def(), init_op_name_1, [], [out_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph_2.as_graph_def(), init_op_name_2, [], [out_name_2]) arg_list = [graph_spec_1, graph_spec_2] merged_graph, init_op_name, _, out_name_maps = graph_merge.concatenate_inputs_and_outputs( arg_list) with merged_graph.as_default(): with tf.compat.v1.Session() as sess: sess.run(init_op_name) tens = sess.run([ out_name_maps[0][out_name_1], out_name_maps[1][out_name_2] ]) self.assertEqual(tens, [10, 10])
def test_compose_three_add_one_graphs_adds_three(self): graph1, input_name_1, output_name_1 = _make_add_one_graph() graph2, input_name_2, output_name_2 = _make_add_one_graph() graph3, input_name_3, output_name_3 = _make_add_one_graph() with graph1.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with graph2.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name with graph3.as_default(): init_op_name_3 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [input_name_1], [output_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [input_name_2], [output_name_2]) graph_spec_3 = graph_spec.GraphSpec(graph3.as_graph_def(), init_op_name_3, [input_name_3], [output_name_3]) arg_list = [graph_spec_1, graph_spec_2, graph_spec_3] composed_graph, init_op_name, in_name_map, out_name_map = graph_merge.compose_graph_specs( arg_list) with composed_graph.as_default(): with tf.compat.v1.Session() as sess: sess.run(init_op_name) outputs = sess.run(out_name_map[output_name_3], feed_dict={ in_name_map[input_name_1]: 0.0, }) self.assertAllClose(outputs, np.array(3.))
def optimize_graph_spec(graph_spec_obj, config_proto): """Applies Grappler with given options to a `graph_spec.GraphSpec`. For more information on Grappler, see https://www.tensorflow.org/guide/graph_optimization Args: graph_spec_obj: Instance of `graph_spec.GraphSpec` representing the TensorFlow computation to optimize. config_proto: Instance of `tf.compat.v1.ConfigProto` specifying optimization options for Grappler. Returns: An instance of `graph_spec_obj` which has been passed through Grappler and optimized if possible. """ meta_graph_def = graph_spec_obj.to_meta_graph_def() try: # Grappler raises if it fails to find feeds and fetches, but can handle # *some* no-arg graphs, so we try/except here. optimized_graph_def = tf_optimizer.OptimizeGraph( config_proto, meta_graph_def) except ValueError as error: logging.info( 'Grappler has raised the error %s; falling back to using ' 'non-optimized graph.', error) optimized_graph_def = graph_spec_obj.graph_def return graph_spec.GraphSpec(optimized_graph_def, init_op=graph_spec_obj.init_op, in_names=graph_spec_obj.in_names, out_names=graph_spec_obj.out_names)
def test_semantic_equivalence_for_simple_graphdef(self): graph, in_name, out_name = _make_redundant_add_one_graph() graph_def = graph.as_graph_def() init_op = None in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec( gs, config_proto) with tf.Graph().as_default() as orig_graph: tf.graph_util.import_graph_def(gs.graph_def, name='') with tf.compat.v1.Session(graph=orig_graph) as sess: orig_out = sess.run(gs.out_names, feed_dict={x: 1 for x in gs.in_names}) with tf.Graph().as_default() as new_graph: tf.graph_util.import_graph_def(opt_graph_spec.graph_def, name='') with tf.compat.v1.Session(graph=new_graph) as sess: new_out = sess.run( opt_graph_spec.out_names, feed_dict={x: 1 for x in opt_graph_spec.in_names}) self.assertEqual(new_out, orig_out)
def test_reduces_graph_size_in_function_lib(self): class StateHolder: pass obj = StateHolder() obj.variable = None @tf.function def foo(x): if obj.variable is None: obj.variable = tf.Variable(initial_value=0.) obj.variable.assign_add(x) return obj.variable.read_value() with tf.Graph().as_default() as g: x = tf.compat.v1.placeholder(shape=[], dtype=tf.float32) y = foo(x) init_op = tf.compat.v1.global_variables_initializer() graph_def = g.as_graph_def() in_name = x.name out_name = y.name init_op_name = init_op.name in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op_name, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec( gs, config_proto) self.assertIsInstance(opt_graph_spec, graph_spec.GraphSpec) self.assertLess(opt_graph_spec.graph_def.ByteSize(), graph_def.ByteSize())
def test_semantic_equivalence_for_graphdef_with_variables(self): graph, in_name, out_name = _make_foldable_add_variable_number_graph() with graph.as_default(): init_op = tf.compat.v1.global_variables_initializer().name graph_def = graph.as_graph_def() in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec( gs, config_proto) with tf.Graph().as_default() as orig_graph: tf.graph_util.import_graph_def(gs.graph_def, name='') with tf.compat.v1.Session(graph=orig_graph) as sess: sess.run(gs.init_op) orig_out = sess.run(gs.out_names, feed_dict={x: 1 for x in gs.in_names}) with tf.Graph().as_default() as new_graph: tf.graph_util.import_graph_def(opt_graph_spec.graph_def, name='') with tf.compat.v1.Session(graph=new_graph) as new_sess: new_sess.run(opt_graph_spec.init_op) new_out = new_sess.run( opt_graph_spec.out_names, feed_dict={x: 1 for x in opt_graph_spec.in_names}) self.assertEqual(new_out, orig_out)
def test_reduces_bytesize_for_foldable_graphdef_with_variables(self): graph, in_name, out_name = _make_foldable_add_variable_number_graph() with graph.as_default(): init_op = tf.compat.v1.global_variables_initializer().name graph_def = graph.as_graph_def() orig_constants_1 = [] for node in graph_def.node: if node.op == 'Const': for float_val in node.attr['value'].tensor.float_val: if float_val == 1.: orig_constants_1.append(node) in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec( gs, config_proto) opt_constants_1 = [] for node in opt_graph_spec.graph_def.node: if node.op == 'Const': for float_val in node.attr['value'].tensor.float_val: if float_val == 1.: opt_constants_1.append(node) self.assertIsInstance(opt_graph_spec, graph_spec.GraphSpec) self.assertLess(opt_graph_spec.graph_def.ByteSize(), graph_def.ByteSize()) self.assertGreater(len(orig_constants_1), 1) self.assertLess(len(opt_constants_1), len(orig_constants_1))
def test_meta_graph_def_restores_and_runs_with_variables(self): graph, in_name, out_name = _make_add_variable_number_graph() with graph.as_default(): init_op = tf.compat.v1.global_variables_initializer().name graph_def = graph.as_graph_def() in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) metagraphdef = gs.to_meta_graph_def() with tf.Graph().as_default() as g: tf.compat.v1.train.import_meta_graph(metagraphdef) restored_init_op = tf.group( *tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.INIT_OP)).name with tf.compat.v1.Session(graph=g) as sess: sess.run(restored_init_op) should_be_one = sess.run(out_name, feed_dict={in_name: 0}) should_be_two = sess.run(out_name, feed_dict={in_name: 0}) should_be_three = sess.run(out_name, feed_dict={in_name: 0}) self.assertEqual(should_be_one, 1.) self.assertEqual(should_be_two, 2.) self.assertEqual(should_be_three, 3.)
def _unpack_proto_into_graph_spec(tf_block_proto): """Packs a TF proto into a `graph_spec.GraphSpec`. Args: tf_block_proto: Instance of `computation_pb2.Computation` with `tensorflow` `computation` attribute. Returns: Instance of `graph_spec.GraphSpec` containing Python representations of the information present in `tf_block_proto`. """ graph = serialization_utils.unpack_graph_def( tf_block_proto.tensorflow.graph_def) graph_init_op_name = tf_block_proto.tensorflow.initialize_op if not graph_init_op_name: graph_init_op_name = None graph_parameter_binding = tf_block_proto.tensorflow.parameter graph_result_binding = tf_block_proto.tensorflow.result if graph_parameter_binding.WhichOneof('binding') is not None: graph_parameter_list = tensorflow_utils.extract_tensor_names_from_binding( graph_parameter_binding) else: graph_parameter_list = [] graph_result_list = tensorflow_utils.extract_tensor_names_from_binding( graph_result_binding) return graph_spec.GraphSpec(graph, graph_init_op_name, graph_parameter_list, graph_result_list)
def test_compose_no_input_graphs_raises(self): graph1 = tf.Graph() with graph1.as_default(): out1 = tf.constant(1.0) init_op_name_1 = tf.compat.v1.global_variables_initializer().name graph2 = tf.Graph() with graph2.as_default(): out2 = tf.constant(2.0) init_op_name_2 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [], [out1.name]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [], [out2.name]) arg_list = [graph_spec_1, graph_spec_2] with self.assertRaisesRegex(ValueError, 'mismatch'): graph_merge.compose_graph_specs(arg_list)
def test_graph_spec_to_meta_graph_def_simplest_case(self): graph, in_name, out_name = _make_add_one_graph() graph_def = graph.as_graph_def() init_op = None in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) metagraphdef = gs.to_meta_graph_def() self.assertIsInstance(metagraphdef, tf.compat.v1.MetaGraphDef)
def optimize_graph_spec(graph_spec_obj): meta_graph_def = graph_spec_obj.to_meta_graph_def() config_proto = tf.compat.v1.ConfigProto() # TODO(b/154367032): Determine a set of optimizer configurations for TFF. optimized_graph_def = tf_optimizer.OptimizeGraph(config_proto, meta_graph_def) return graph_spec.GraphSpec(optimized_graph_def, init_op=graph_spec_obj.init_op, in_names=graph_spec_obj.in_names, out_names=graph_spec_obj.out_names)
def test_graph_spec_constructs_whimsy_data(self): graph_def = _make_add_one_graph()[0].as_graph_def() init_op = 'init' in_names = ['in'] out_names = ['out'] x = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) self.assertIs(x.graph_def, graph_def) self.assertIs(x.init_op, init_op) self.assertIs(x.in_names, in_names) self.assertIs(x.out_names, out_names)
def test_reduces_bytesize_for_simple_graphdef(self): graph, in_name, out_name = _make_redundant_add_one_graph() graph_def = graph.as_graph_def() init_op = None in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec(gs, config_proto) self.assertIsInstance(opt_graph_spec, graph_spec.GraphSpec) self.assertLess(opt_graph_spec.graph_def.ByteSize(), graph_def.ByteSize())
def test_nested_composition_three_add_one_graphs_adds_three(self): graph1, input_name_1, output_name_1 = _make_add_variable_number_graph() graph2, input_name_2, output_name_2 = _make_add_variable_number_graph() graph3, input_name_3, output_name_3 = _make_add_variable_number_graph() with graph1.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with graph2.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name with graph3.as_default(): init_op_name_3 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [input_name_1], [output_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [input_name_2], [output_name_2]) arg_list = [graph_spec_1, graph_spec_2] (partial_merge_graph, partial_merge_init_op_name, partial_merge_in_name_map, partial_merge_out_name_map ) = graph_merge.compose_graph_specs(arg_list) partial_graph_spec = graph_spec.GraphSpec( partial_merge_graph.as_graph_def(), partial_merge_init_op_name, partial_merge_in_name_map.values(), partial_merge_out_name_map.values()) graph_spec_3 = graph_spec.GraphSpec(graph3.as_graph_def(), init_op_name_3, [input_name_3], [output_name_3]) composed_graph, init_op_name, in_name_map, out_name_map = graph_merge.compose_graph_specs( [graph_spec_3, partial_graph_spec]) with tf.compat.v1.Session(graph=composed_graph) as sess: sess.run(init_op_name) outputs = sess.run( out_name_map[output_name_3], feed_dict={ in_name_map[partial_merge_in_name_map[input_name_1]]: 0.0, }) self.assertAllClose(outputs, np.array(3.))
def test_reduces_bytesize_for_dataset_reduction(self): ds_graph, _, out = _make_dataset_constructing_graph() graph, _, out_name = _make_manual_reduce_graph(ds_graph, out) with graph.as_default(): init_op = tf.compat.v1.global_variables_initializer().name graph_def = graph.as_graph_def() in_names = [] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec(gs, config_proto) self.assertIsInstance(opt_graph_spec, graph_spec.GraphSpec) self.assertLess(opt_graph_spec.graph_def.ByteSize(), graph_def.ByteSize())
def test_semantic_equivalence_for_graphdef_with_function(self): class StateHolder: pass obj = StateHolder() obj.variable = None @tf.function def foo(x): if obj.variable is None: obj.variable = tf.Variable(initial_value=0.) obj.variable.assign_add(x) return obj.variable.read_value() with tf.Graph().as_default() as g: x = tf.compat.v1.placeholder(shape=[], dtype=tf.float32) y = foo(x) init_op = tf.compat.v1.global_variables_initializer() graph_def = g.as_graph_def() in_name = x.name out_name = y.name init_op_name = init_op.name in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op_name, in_names, out_names) config_proto = tf.compat.v1.ConfigProto() opt_graph_spec = graph_optimizations.optimize_graph_spec( gs, config_proto) with tf.Graph().as_default() as orig_graph: tf.graph_util.import_graph_def(gs.graph_def, name='') with tf.compat.v1.Session(graph=orig_graph) as sess: sess.run(gs.init_op) orig_out = sess.run(gs.out_names, feed_dict={x: 1 for x in gs.in_names}) with tf.Graph().as_default() as new_graph: tf.graph_util.import_graph_def(opt_graph_spec.graph_def, name='') with tf.compat.v1.Session(graph=new_graph) as new_sess: new_sess.run(opt_graph_spec.init_op) new_out = new_sess.run( opt_graph_spec.out_names, feed_dict={x: 1 for x in opt_graph_spec.in_names}) self.assertEqual(new_out, orig_out)
def test_compose_with_dataset_wires_correctly(self): with tf.Graph().as_default() as dataset_graph: d1 = tf.data.Dataset.range(5) v1 = tf.data.experimental.to_variant(d1) ds_out_name = v1.name variant_type = v1.dtype with tf.Graph().as_default() as reduce_graph: variant = tf.compat.v1.placeholder(variant_type) structure = tf.TensorSpec([], tf.int64) ds1 = tf.data.experimental.from_variant(variant, structure=structure) out = ds1.reduce(tf.constant(0, dtype=tf.int64), lambda x, y: x + y) ds_in_name = variant.name reduce_out_name = out.name with dataset_graph.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with reduce_graph.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name dataset_graph_spec = graph_spec.GraphSpec(dataset_graph.as_graph_def(), init_op_name_1, [], [ds_out_name]) reduce_graph_spec = graph_spec.GraphSpec(reduce_graph.as_graph_def(), init_op_name_2, [ds_in_name], [reduce_out_name]) arg_list = [reduce_graph_spec, dataset_graph_spec] composed_graph, _, _, out_name_map = graph_merge.compose_graph_specs( arg_list) with composed_graph.as_default(): with tf.compat.v1.Session() as sess: ten = sess.run(out_name_map[reduce_out_name]) self.assertEqual(ten, 10)
def test_composition_happens_in_mathematical_composition_order(self): graph1, input_name_1, output_name_1 = _make_add_one_graph() def _make_cast_to_int_graph(): with tf.Graph().as_default() as graph: input_val = tf.compat.v1.placeholder(tf.float32, name='input') out = tf.cast(input_val, tf.int32) return graph, input_val.name, out.name graph2, input_name_2, output_name_2 = _make_cast_to_int_graph() with graph1.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with graph2.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [input_name_1], [output_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [input_name_2], [output_name_2]) arg_list = [graph_spec_2, graph_spec_1] composed_graph, _, in_name_map, out_name_map = graph_merge.compose_graph_specs( arg_list) with composed_graph.as_default(): with tf.compat.v1.Session() as sess: outputs = sess.run(out_name_map[output_name_2], feed_dict={ in_name_map[input_name_1]: 0.0, }) self.assertEqual(outputs, 1) with self.assertRaises(ValueError): graph_merge.compose_graph_specs(list(reversed(arg_list)))
def test_compose_two_add_variable_number_graphs_executes_correctly(self): graph1, input_name_1, output_name_1 = _make_add_variable_number_graph() graph2, input_name_2, output_name_2 = _make_add_variable_number_graph() with graph1.as_default(): init_op_name_1 = tf.compat.v1.global_variables_initializer().name with graph2.as_default(): init_op_name_2 = tf.compat.v1.global_variables_initializer().name graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), init_op_name_1, [input_name_1], [output_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), init_op_name_2, [input_name_2], [output_name_2]) arg_list = [graph_spec_1, graph_spec_2] composed_graph, init_op_name, in_name_map, out_name_map = graph_merge.compose_graph_specs( arg_list) with tf.compat.v1.Session(graph=composed_graph) as sess: sess.run(init_op_name) output_one = sess.run(out_name_map[output_name_2], feed_dict={ in_name_map[input_name_1]: 0.0, }) sess.run(init_op_name) # TFF is functional, reset session state. output_two = sess.run(out_name_map[output_name_2], feed_dict={ in_name_map[input_name_1]: 0.0, }) sess.run(init_op_name) # TFF is functional, reset session state. output_three = sess.run(out_name_map[output_name_2], feed_dict={ in_name_map[input_name_1]: 0.0, }) self.assertAllClose(output_one, np.array(2.)) self.assertAllClose(output_two, np.array(2.)) self.assertAllClose(output_three, np.array(2.))
def test_concatenate_inputs_and_outputs_no_init_op_graphs(self): graph1, input_name_1, output_name_1 = _make_add_one_graph() graph2, input_name_2, output_name_2 = _make_add_one_graph() graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), None, [input_name_1], [output_name_1]) graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), None, [input_name_2], [output_name_2]) arg_list = [graph_spec_1, graph_spec_2] merged_graph, init_op_name, in_name_maps, out_name_maps = graph_merge.concatenate_inputs_and_outputs( arg_list) with tf.compat.v1.Session(graph=merged_graph) as sess: sess.run(init_op_name) outputs = sess.run( [ out_name_maps[0][output_name_1], out_name_maps[1][output_name_2] ], feed_dict={ in_name_maps[0][input_name_1]: 1.0, in_name_maps[1][input_name_2]: 2.0 }) self.assertAllClose(outputs, np.array([2., 3.]))
def test_meta_graph_def_runs_simplest_case(self): graph, in_name, out_name = _make_add_one_graph() graph_def = graph.as_graph_def() init_op = None in_names = [in_name] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) metagraphdef = gs.to_meta_graph_def() with tf.Graph().as_default() as g: tf.compat.v1.train.import_meta_graph(metagraphdef) with tf.compat.v1.Session(graph=g) as sess: should_be_one = sess.run(out_name, feed_dict={in_name: 0}) self.assertEqual(should_be_one, 1.)
def test_meta_graph_def_restores_and_runs_with_datasets(self): dataset_graph, _, dataset_out_name = _make_dataset_constructing_graph() graph, _, out_name = _make_manual_reduce_graph(dataset_graph, dataset_out_name) with graph.as_default(): init_op = tf.compat.v1.global_variables_initializer().name graph_def = graph.as_graph_def() in_names = [] out_names = [out_name] gs = graph_spec.GraphSpec(graph_def, init_op, in_names, out_names) metagraphdef = gs.to_meta_graph_def() with tf.Graph().as_default() as g: tf.compat.v1.train.import_meta_graph(metagraphdef) restored_init_op = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.INIT_OP) with tf.compat.v1.Session(graph=g) as sess: sess.run(restored_init_op) should_be_ten = sess.run(out_name) self.assertEqual(should_be_ten, 10)
def test_raises_on_graph_spec_set(self): graph1, input_name_1, output_name_1 = _make_add_one_graph() graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), '', [input_name_1], [output_name_1]) with self.assertRaises(TypeError): graph_merge.compose_graph_specs(set(graph_spec_1))
def test_graph_spec_fails_out_names_ints(self): graph_def = _make_add_one_graph()[0].as_graph_def() with self.assertRaises(TypeError): graph_spec.GraphSpec(graph_def, 'test', ['test'], [1])
def test_graph_spec_succeeds_empty_init_op(self): graph_def = _make_add_one_graph()[0].as_graph_def() graph_spec.GraphSpec(graph_def, '', ['test'], ['test'])
def test_graph_spec_fails_bad_init_op(self): graph_def = _make_add_one_graph()[0].as_graph_def() with self.assertRaises(TypeError): graph_spec.GraphSpec(graph_def, 1, ['test'], ['test'])
def test_graph_spec_fails_no_graph_def(self): with self.assertRaises(TypeError): graph_spec.GraphSpec(None, 'test', ['test'], ['test'])