def test_gets_some_explicit_some_none_placement(self): with tf.Graph().as_default() as g: with tf.device('/cpu:0'): a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(device_placements.keys()) self.assertLen(all_device_placements, 2) if all_device_placements[0]: self.assertIn('CPU', all_device_placements[0]) self.assertEqual('', all_device_placements[1]) else: self.assertIn('CPU', all_device_placements[1]) self.assertEqual('', all_device_placements[0]) self.assertGreater(device_placements[all_device_placements[0]], 0) self.assertGreater(device_placements[all_device_placements[1]], 0)
def test_gets_all_explicit_placement(self): with tf.Graph().as_default() as g: with tf.device('/cpu:0'): a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(sorted(device_placements.keys())) # Expect two placements, the explicit 'cpu' from above, and the empty # placement of the `tf.identity` op add to the captured result. self.assertLen(all_device_placements, 2) self.assertEqual('', sorted(all_device_placements)[0]) self.assertIn('CPU', sorted(all_device_placements)[1]) self.assertGreater(device_placements[all_device_placements[1]], 0)
def test_gets_none_placement(self): with tf.Graph().as_default() as g: a = tf.Variable(0, name='variable1') b = tf.Variable(1, name='variable2') c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(device_placements.keys()) self.assertLen(all_device_placements, 1) self.assertEqual(all_device_placements[0], '') self.assertGreater(device_placements[''], 0)
def test_raises_with_reference(self): ref = building_blocks.Reference('x', tf.int32) with self.assertRaisesRegex(ValueError, 'tensorflow'): building_block_analysis.get_device_placement_in(ref)