Exemple #1
0
    def test_gets_some_explicit_some_none_placement(self):
        with tf.Graph().as_default() as g:
            with tf.device('/cpu:0'):
                a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        device_placements = building_block_analysis.get_device_placement_in(
            building_block)
        all_device_placements = list(device_placements.keys())
        self.assertLen(all_device_placements, 2)
        if all_device_placements[0]:
            self.assertIn('CPU', all_device_placements[0])
            self.assertEqual('', all_device_placements[1])
        else:
            self.assertIn('CPU', all_device_placements[1])
            self.assertEqual('', all_device_placements[0])
        self.assertGreater(device_placements[all_device_placements[0]], 0)
        self.assertGreater(device_placements[all_device_placements[1]], 0)
Exemple #2
0
    def test_gets_all_explicit_placement(self):

        with tf.Graph().as_default() as g:
            with tf.device('/cpu:0'):
                a = tf.constant(0)
                b = tf.constant(1)
                c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        device_placements = building_block_analysis.get_device_placement_in(
            building_block)
        all_device_placements = list(sorted(device_placements.keys()))
        # Expect two placements, the explicit 'cpu' from above, and the empty
        # placement of the `tf.identity` op add to the captured result.
        self.assertLen(all_device_placements, 2)
        self.assertEqual('', sorted(all_device_placements)[0])
        self.assertIn('CPU', sorted(all_device_placements)[1])
        self.assertGreater(device_placements[all_device_placements[1]], 0)
  def test_gets_none_placement(self):

    with tf.Graph().as_default() as g:
      a = tf.Variable(0, name='variable1')
      b = tf.Variable(1, name='variable2')
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    device_placements = building_block_analysis.get_device_placement_in(
        building_block)
    all_device_placements = list(device_placements.keys())
    self.assertLen(all_device_placements, 1)
    self.assertEqual(all_device_placements[0], '')
    self.assertGreater(device_placements[''], 0)
Exemple #4
0
 def test_raises_with_reference(self):
     ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaisesRegex(ValueError, 'tensorflow'):
         building_block_analysis.get_device_placement_in(ref)