Esempio n. 1
0
 def test_do_transform(self):
     bf16_converter = BF16Convert(self.test_graph, ["conv3"], ["conv2"])
     new_graph = bf16_converter.do_transformation()
     new_conv1 = bf16_converter.cur_graph.node_name_details["conv1"].node
     new_relu2 = bf16_converter.cur_graph.node_name_details["relu2"].node
     new_conv3 = bf16_converter.cur_graph.node_name_details["conv3"].node
     self.assertEqual(new_relu2.attr["T"].type, dtypes.bfloat16)
     self.assertTrue("relu2_BF16toFP32" in new_conv3.input)
Esempio n. 2
0
 def test_rn50_convert(self):
     bf16_nodes = [node.name for node in self.input_graph.node if node.op in ["Conv2D", "AvgPool", "MatMul"]]
     bf16_nodes.remove("v0/resnet_v13/conv14/conv2d/Conv2D")
     rn50_bf16_converter = BF16Convert(self.input_graph, ["v0/resnet_v13/conv14/conv2d/Conv2D"], bf16_nodes)
     rn50_bf16_converter.do_transformation()
     new_conv11 = rn50_bf16_converter.cur_graph.node_name_details["v0/resnet_v13/conv11/conv2d/Conv2D"].node
     new_conv14 = rn50_bf16_converter.cur_graph.node_name_details["v0/resnet_v13/conv14/conv2d/Conv2D"].node
     new_conv52 = rn50_bf16_converter.cur_graph.node_name_details["v0/resnet_v115/conv52/conv2d/Conv2D"].node
     self.assertEqual(new_conv11.attr["T"].type, new_conv52.attr["T"].type)
     self.assertNotEqual(new_conv11.attr["T"].type, new_conv14.attr["T"].type)