示例#1
0
 def test_cast_out_of_fp16_min(self):
     input_data = np.array([0, -100000, 4, 9, 0])
     graph, graph_ref = build_cast_test_graphs(input_data,
                                               dst_type_str='FP16')
     self.assertRaises(
         Error,
         ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
示例#2
0
 def test_range_different_values(self):
     graph, graph_ref = build_range_test_graphs(start=0,
                                                limit=50000,
                                                delta=1,
                                                dst_type_str='FP16')
     self.assertRaises(
         Error,
         ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
示例#3
0
 def test_range_out_of_fp16_min(self):
     graph, graph_ref = build_range_test_graphs(start=0,
                                                limit=-100000,
                                                delta=-1,
                                                dst_type_str='FP16')
     self.assertRaises(
         Error,
         ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
示例#4
0
 def test_cast_correct_case(self):
     input_data = np.array([0, 1000, 4, 9, 0])
     graph, graph_ref = build_cast_test_graphs(input_data,
                                               dst_type_str='FP16')
     ChangeOutputTypeAttributes().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   'res',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
示例#5
0
 def test_range_correct_case(self):
     graph, graph_ref = build_range_test_graphs(start=0,
                                                limit=10,
                                                delta=1,
                                                dst_type_str='FP16')
     ChangeOutputTypeAttributes().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   'res',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)