コード例 #1
0
    def test_select_graph_output_by_index_single_level_of_nesting(self):
        computation_arg_type = computation_types.NamedTupleType(
            [tf.int32, tf.float32])

        foo = _create_compiled_computation(lambda x: x, computation_arg_type)

        first_element_selected = compiled_computation_transforms.select_graph_output(
            foo, index=0)
        second_element_selected = compiled_computation_transforms.select_graph_output(
            foo, index=1)

        self.assertEqual(first_element_selected.type_signature.result,
                         foo.type_signature.result[0])
        self.assertEqual(foo.proto.tensorflow.graph_def,
                         first_element_selected.proto.tensorflow.graph_def)
        self.assertEqual(foo.proto.tensorflow.parameter,
                         first_element_selected.proto.tensorflow.parameter)
        self.assertEqual(foo.proto.tensorflow.initialize_op,
                         first_element_selected.proto.tensorflow.initialize_op)
        self.assertEqual(foo.proto.tensorflow.result.tuple.element[0].tensor,
                         first_element_selected.proto.tensorflow.result.tensor)

        self.assertEqual(second_element_selected.type_signature.result,
                         foo.type_signature.result[1])
        self.assertEqual(foo.proto.tensorflow.graph_def,
                         second_element_selected.proto.tensorflow.graph_def)
        self.assertEqual(foo.proto.tensorflow.parameter,
                         second_element_selected.proto.tensorflow.parameter)
        self.assertEqual(
            foo.proto.tensorflow.initialize_op,
            second_element_selected.proto.tensorflow.initialize_op)
        self.assertEqual(
            foo.proto.tensorflow.result.tuple.element[1].tensor,
            second_element_selected.proto.tensorflow.result.tensor)
コード例 #2
0
    def test_select_graph_output_with_wrong_return_type_raises_type_error(
            self):
        computation_arg_type = computation_types.to_type(tf.int32)

        foo = _create_compiled_computation(lambda x: x, computation_arg_type)

        with self.assertRaises(TypeError):
            compiled_computation_transforms.select_graph_output(foo, index=0)
コード例 #3
0
    def test_select_graph_output_by_name_bad_name_raises_value_error(self):
        computation_arg_type = computation_types.NamedTupleType([
            ('a', tf.int32), ('b', tf.float32)
        ])

        foo = _create_compiled_computation(lambda x: x, computation_arg_type)

        with self.assertRaises(ValueError):
            compiled_computation_transforms.select_graph_output(foo, name='x')
コード例 #4
0
    def test_select_graph_output_by_index_two_nested_levels_keeps_nested_type(
            self):
        nested_type1 = computation_types.NamedTupleType([('a', tf.int32),
                                                         ('b', tf.float32)])
        nested_type2 = computation_types.NamedTupleType([('c', tf.int32),
                                                         ('d', tf.float32)])

        computation_arg_type = computation_types.NamedTupleType([
            ('x', nested_type1), ('y', nested_type2)
        ])

        foo = _create_compiled_computation(lambda x: x, computation_arg_type)

        first_element_selected = compiled_computation_transforms.select_graph_output(
            foo, index=0)
        self.assertEqual(first_element_selected.type_signature.result,
                         nested_type1)

        second_element_selected = compiled_computation_transforms.select_graph_output(
            foo, index=1)
        self.assertEqual(second_element_selected.type_signature.result,
                         nested_type2)

        self.assertEqual(foo.proto.tensorflow.graph_def,
                         first_element_selected.proto.tensorflow.graph_def)
        self.assertEqual(foo.proto.tensorflow.parameter,
                         first_element_selected.proto.tensorflow.parameter)
        self.assertEqual(foo.proto.tensorflow.initialize_op,
                         first_element_selected.proto.tensorflow.initialize_op)
        self.assertEqual(foo.proto.tensorflow.result.tuple.element[0].tuple,
                         first_element_selected.proto.tensorflow.result.tuple)

        self.assertEqual(second_element_selected.type_signature.result,
                         foo.type_signature.result[1])
        self.assertEqual(foo.proto.tensorflow.graph_def,
                         second_element_selected.proto.tensorflow.graph_def)
        self.assertEqual(foo.proto.tensorflow.parameter,
                         second_element_selected.proto.tensorflow.parameter)
        self.assertEqual(
            foo.proto.tensorflow.initialize_op,
            second_element_selected.proto.tensorflow.initialize_op)
        self.assertEqual(foo.proto.tensorflow.result.tuple.element[1].tuple,
                         second_element_selected.proto.tensorflow.result.tuple)
コード例 #5
0
 def test_select_graph_output_with_none_comp_raises_type_error(self):
     with self.assertRaises(TypeError):
         compiled_computation_transforms.select_graph_output(None, index=0)