Пример #1
0
  def test_compose_with_dataset_wires_correctly(self):
    with tf.Graph().as_default() as dataset_graph:
      d1 = tf.data.Dataset.range(5)
      v1 = tf.data.experimental.to_variant(d1)

    ds_out_name = v1.name
    variant_type = v1.dtype

    with tf.Graph().as_default() as reduce_graph:
      variant = tf.compat.v1.placeholder(variant_type)
      structure = tf.data.experimental.TensorStructure(tf.int64, shape=[])
      ds1 = tf.data.experimental.from_variant(variant, structure=structure)
      out = ds1.reduce(tf.constant(0, dtype=tf.int64), lambda x, y: x + y)

    ds_in_name = variant.name
    reduce_out_name = out.name

    with dataset_graph.as_default():
      init_op_name_1 = tf.initializers.global_variables().name
    with reduce_graph.as_default():
      init_op_name_2 = tf.initializers.global_variables().name
    dataset_graph_spec = graph_merge.GraphSpec(dataset_graph.as_graph_def(),
                                               init_op_name_1, [],
                                               [ds_out_name])
    reduce_graph_spec = graph_merge.GraphSpec(reduce_graph.as_graph_def(),
                                              init_op_name_2, [ds_in_name],
                                              [reduce_out_name])
    arg_list = [reduce_graph_spec, dataset_graph_spec]
    composed_graph, _, _, out_name_map = graph_merge.compose_graph_specs(
        arg_list)

    with composed_graph.as_default():
      with tf.compat.v1.Session() as sess:
        ten = sess.run(out_name_map[reduce_out_name])
    self.assertEqual(ten, 10)
Пример #2
0
    def test_compose_three_add_one_graphs_adds_three(self):
        graph1, input_name_1, output_name_1 = _make_add_one_graph()
        graph2, input_name_2, output_name_2 = _make_add_one_graph()
        graph3, input_name_3, output_name_3 = _make_add_one_graph()
        with graph1.as_default():
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        with graph2.as_default():
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name
        with graph3.as_default():
            init_op_name_3 = tf.compat.v1.global_variables_initializer().name
        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [input_name_1],
                                            [output_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [input_name_2],
                                            [output_name_2])
        graph_spec_3 = graph_spec.GraphSpec(graph3.as_graph_def(),
                                            init_op_name_3, [input_name_3],
                                            [output_name_3])
        arg_list = [graph_spec_1, graph_spec_2, graph_spec_3]
        composed_graph, init_op_name, in_name_map, out_name_map = graph_merge.compose_graph_specs(
            arg_list)

        with composed_graph.as_default():
            with tf.compat.v1.Session() as sess:
                sess.run(init_op_name)
                outputs = sess.run(out_name_map[output_name_3],
                                   feed_dict={
                                       in_name_map[input_name_1]: 0.0,
                                   })

        self.assertAllClose(outputs, np.array(3.))
Пример #3
0
    def test_compose_no_input_graphs_raises(self):
        graph1 = tf.Graph()
        with graph1.as_default():
            out1 = tf.constant(1.0)
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        graph2 = tf.Graph()
        with graph2.as_default():
            out2 = tf.constant(2.0)
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name

        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [], [out1.name])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [], [out2.name])
        arg_list = [graph_spec_1, graph_spec_2]
        with self.assertRaisesRegex(ValueError, 'mismatch'):
            graph_merge.compose_graph_specs(arg_list)
Пример #4
0
    def test_nested_composition_three_add_one_graphs_adds_three(self):
        graph1, input_name_1, output_name_1 = _make_add_variable_number_graph()
        graph2, input_name_2, output_name_2 = _make_add_variable_number_graph()
        graph3, input_name_3, output_name_3 = _make_add_variable_number_graph()
        with graph1.as_default():
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        with graph2.as_default():
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name
        with graph3.as_default():
            init_op_name_3 = tf.compat.v1.global_variables_initializer().name
        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [input_name_1],
                                            [output_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [input_name_2],
                                            [output_name_2])
        arg_list = [graph_spec_1, graph_spec_2]
        (partial_merge_graph, partial_merge_init_op_name,
         partial_merge_in_name_map, partial_merge_out_name_map
         ) = graph_merge.compose_graph_specs(arg_list)

        partial_graph_spec = graph_spec.GraphSpec(
            partial_merge_graph.as_graph_def(), partial_merge_init_op_name,
            partial_merge_in_name_map.values(),
            partial_merge_out_name_map.values())
        graph_spec_3 = graph_spec.GraphSpec(graph3.as_graph_def(),
                                            init_op_name_3, [input_name_3],
                                            [output_name_3])
        composed_graph, init_op_name, in_name_map, out_name_map = graph_merge.compose_graph_specs(
            [graph_spec_3, partial_graph_spec])

        with tf.compat.v1.Session(graph=composed_graph) as sess:
            sess.run(init_op_name)
            outputs = sess.run(
                out_name_map[output_name_3],
                feed_dict={
                    in_name_map[partial_merge_in_name_map[input_name_1]]: 0.0,
                })

        self.assertAllClose(outputs, np.array(3.))
Пример #5
0
    def test_composition_happens_in_mathematical_composition_order(self):
        graph1, input_name_1, output_name_1 = _make_add_one_graph()

        def _make_cast_to_int_graph():
            with tf.Graph().as_default() as graph:
                input_val = tf.compat.v1.placeholder(tf.float32, name='input')
                out = tf.cast(input_val, tf.int32)
            return graph, input_val.name, out.name

        graph2, input_name_2, output_name_2 = _make_cast_to_int_graph()

        with graph1.as_default():
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        with graph2.as_default():
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name
        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [input_name_1],
                                            [output_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [input_name_2],
                                            [output_name_2])
        arg_list = [graph_spec_2, graph_spec_1]

        composed_graph, _, in_name_map, out_name_map = graph_merge.compose_graph_specs(
            arg_list)

        with composed_graph.as_default():
            with tf.compat.v1.Session() as sess:
                outputs = sess.run(out_name_map[output_name_2],
                                   feed_dict={
                                       in_name_map[input_name_1]: 0.0,
                                   })

        self.assertEqual(outputs, 1)

        with self.assertRaises(ValueError):
            graph_merge.compose_graph_specs(list(reversed(arg_list)))
Пример #6
0
    def test_compose_two_add_variable_number_graphs_executes_correctly(self):
        graph1, input_name_1, output_name_1 = _make_add_variable_number_graph()
        graph2, input_name_2, output_name_2 = _make_add_variable_number_graph()
        with graph1.as_default():
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        with graph2.as_default():
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name
        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [input_name_1],
                                            [output_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [input_name_2],
                                            [output_name_2])
        arg_list = [graph_spec_1, graph_spec_2]
        composed_graph, init_op_name, in_name_map, out_name_map = graph_merge.compose_graph_specs(
            arg_list)

        with tf.compat.v1.Session(graph=composed_graph) as sess:
            sess.run(init_op_name)
            output_one = sess.run(out_name_map[output_name_2],
                                  feed_dict={
                                      in_name_map[input_name_1]: 0.0,
                                  })
            sess.run(init_op_name)  # TFF is functional, reset session state.
            output_two = sess.run(out_name_map[output_name_2],
                                  feed_dict={
                                      in_name_map[input_name_1]: 0.0,
                                  })
            sess.run(init_op_name)  # TFF is functional, reset session state.
            output_three = sess.run(out_name_map[output_name_2],
                                    feed_dict={
                                        in_name_map[input_name_1]: 0.0,
                                    })

        self.assertAllClose(output_one, np.array(2.))
        self.assertAllClose(output_two, np.array(2.))
        self.assertAllClose(output_three, np.array(2.))
Пример #7
0
 def test_raises_on_list_of_ints(self):
     with self.assertRaises(TypeError):
         graph_merge.compose_graph_specs([0, 1])
Пример #8
0
 def test_raises_on_graph_spec_set(self):
     graph1, input_name_1, output_name_1 = _make_add_one_graph()
     graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), '',
                                         [input_name_1], [output_name_1])
     with self.assertRaises(TypeError):
         graph_merge.compose_graph_specs(set(graph_spec_1))
Пример #9
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         graph_merge.compose_graph_specs(None)