예제 #1
0
    def test_concatenate_inputs_and_outputs_two_add_one_graphs(self):
        graph1, input_name_1, output_name_1 = _make_add_one_graph()
        graph2, input_name_2, output_name_2 = _make_add_one_graph()
        with graph1.as_default():
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        with graph2.as_default():
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name
        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [input_name_1],
                                            [output_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [input_name_2],
                                            [output_name_2])
        arg_list = [graph_spec_1, graph_spec_2]
        merged_graph, init_op_name, in_name_maps, out_name_maps = graph_merge.concatenate_inputs_and_outputs(
            arg_list)

        with merged_graph.as_default():
            with tf.compat.v1.Session() as sess:
                sess.run(init_op_name)
                outputs = sess.run(
                    [
                        out_name_maps[0][output_name_1],
                        out_name_maps[1][output_name_2]
                    ],
                    feed_dict={
                        in_name_maps[0][input_name_1]: 1.0,
                        in_name_maps[1][input_name_2]: 2.0
                    })

        self.assertAllClose(outputs, np.array([2., 3.]))
예제 #2
0
    def test_concatenate_inputs_and_outputs_with_dataset_wires_correctly(self):
        dataset_graph, _, dataset_out_name = _make_dataset_constructing_graph()
        graph_1, _, out_name_1 = _make_manual_reduce_graph(
            dataset_graph, dataset_out_name)
        graph_2, _, out_name_2 = _make_manual_reduce_graph(
            dataset_graph, dataset_out_name)
        with graph_1.as_default():
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        with graph_2.as_default():
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name
        graph_spec_1 = graph_spec.GraphSpec(graph_1.as_graph_def(),
                                            init_op_name_1, [], [out_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph_2.as_graph_def(),
                                            init_op_name_2, [], [out_name_2])
        arg_list = [graph_spec_1, graph_spec_2]
        merged_graph, init_op_name, _, out_name_maps = graph_merge.concatenate_inputs_and_outputs(
            arg_list)

        with merged_graph.as_default():
            with tf.compat.v1.Session() as sess:
                sess.run(init_op_name)
                tens = sess.run([
                    out_name_maps[0][out_name_1], out_name_maps[1][out_name_2]
                ])
        self.assertEqual(tens, [10, 10])
예제 #3
0
    def test_concatenate_inputs_and_outputs_no_arg_graphs(self):
        graph1 = tf.Graph()
        with graph1.as_default():
            out1 = tf.constant(1.0)
            init_op_name_1 = tf.compat.v1.global_variables_initializer().name
        graph2 = tf.Graph()
        with graph2.as_default():
            out2 = tf.constant(2.0)
            init_op_name_2 = tf.compat.v1.global_variables_initializer().name

        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(),
                                            init_op_name_1, [], [out1.name])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(),
                                            init_op_name_2, [], [out2.name])
        arg_list = [graph_spec_1, graph_spec_2]
        merged_graph, init_op_name, _, out_name_maps = graph_merge.concatenate_inputs_and_outputs(
            arg_list)

        with merged_graph.as_default():
            with tf.compat.v1.Session() as sess:
                sess.run(init_op_name)
                outputs = sess.run(
                    [out_name_maps[0][out1.name], out_name_maps[1][out2.name]])

        self.assertAllClose(outputs, np.array([1., 2.]))
예제 #4
0
    def test_concatenate_inputs_and_outputs_no_init_op_graphs(self):
        graph1, input_name_1, output_name_1 = _make_add_one_graph()
        graph2, input_name_2, output_name_2 = _make_add_one_graph()
        graph_spec_1 = graph_spec.GraphSpec(graph1.as_graph_def(), None,
                                            [input_name_1], [output_name_1])
        graph_spec_2 = graph_spec.GraphSpec(graph2.as_graph_def(), None,
                                            [input_name_2], [output_name_2])
        arg_list = [graph_spec_1, graph_spec_2]
        merged_graph, init_op_name, in_name_maps, out_name_maps = graph_merge.concatenate_inputs_and_outputs(
            arg_list)

        with tf.compat.v1.Session(graph=merged_graph) as sess:
            sess.run(init_op_name)
            outputs = sess.run(
                [
                    out_name_maps[0][output_name_1],
                    out_name_maps[1][output_name_2]
                ],
                feed_dict={
                    in_name_maps[0][input_name_1]: 1.0,
                    in_name_maps[1][input_name_2]: 2.0
                })

        self.assertAllClose(outputs, np.array([2., 3.]))
예제 #5
0
 def test_raises_on_non_iterable(self):
     with self.assertRaises(TypeError):
         graph_merge.concatenate_inputs_and_outputs(1)
예제 #6
0
def concatenate_tensorflow_blocks(tf_comp_list):
    """Concatenates inputs and outputs of its argument to a single TF block.

  Takes a Python `list` or `tuple` of instances of
  `computation_building_blocks.CompiledComputation`, and constructs a single
  instance of the same building block representing the computations present
  in this list concatenated side-by-side.

  There is one important convention here for callers to be aware of.
  `concatenate_tensorflow_blocks` does not perform any more packing into tuples
  than necessary. That is, if `tf_comp_list` contains only a single TF
  computation which declares a parameter, the parameter type of the resulting
  computation is exactly this single parameter type. Since all TF blocks declare
  a result, this is only of concern for parameters, and we will always return a
  function with a tuple for its result value.

  Args:
    tf_comp_list: Python `list` or `tuple` of
      `computation_building_blocks.CompiledComputation`s, whose inputs and
      outputs we wish to concatenate.

  Returns:
    A single instance of `computation_building_blocks.CompiledComputation`,
    representing all the computations in `tf_comp_list` concatenated
    side-by-side.

  Raises:
    ValueError: If we are passed less than 2 computations in `tf_comp_list`. In
      this case, the caller is likely using the wrong function.
    TypeError: If `tf_comp_list` is not a `list` or `tuple`, or if it
      contains anything other than TF blocks.
  """
    py_typecheck.check_type(tf_comp_list, (list, tuple))
    if len(tf_comp_list) < 2:
        raise ValueError(
            'We expect to concatenate at least two blocks of '
            'TensorFlow; otherwise the transformation you seek '
            'represents simply type manipulation, and you will find '
            'your desired function elsewhere in '
            '`compiled_computation_transforms`. You passed a tuple of '
            'length {}'.format(len(tf_comp_list)))
    tf_proto_list = []
    for comp in tf_comp_list:
        py_typecheck.check_type(
            comp, computation_building_blocks.CompiledComputation)
        tf_proto_list.append(comp.proto)

    (merged_graph, init_op_name, parameter_name_maps,
     result_name_maps) = graph_merge.concatenate_inputs_and_outputs(
         [_unpack_proto_into_graph_spec(x) for x in tf_proto_list])

    concatenated_parameter_bindings = _pack_concatenated_bindings(
        [x.tensorflow.parameter for x in tf_proto_list], parameter_name_maps)
    concatenated_result_bindings = _pack_concatenated_bindings(
        [x.tensorflow.result for x in tf_proto_list], result_name_maps)

    if concatenated_parameter_bindings:
        tf_result_proto = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(
                merged_graph.as_graph_def()),
            initialize_op=init_op_name,
            parameter=concatenated_parameter_bindings,
            result=concatenated_result_bindings)
    else:
        tf_result_proto = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(
                merged_graph.as_graph_def()),
            initialize_op=init_op_name,
            result=concatenated_result_bindings)

    parameter_type = _construct_concatenated_type(
        [x.type_signature.parameter for x in tf_comp_list])
    return_type = _construct_concatenated_type(
        [x.type_signature.result for x in tf_comp_list])
    function_type = computation_types.FunctionType(parameter_type, return_type)
    serialized_function_type = type_serialization.serialize_type(function_type)

    constructed_proto = pb.Computation(type=serialized_function_type,
                                       tensorflow=tf_result_proto)
    return computation_building_blocks.CompiledComputation(constructed_proto)