def test_cast_out_of_fp16_min(self): input_data = np.array([0, -100000, 4, 9, 0]) graph, graph_ref = build_cast_test_graphs(input_data, dst_type_str='FP16') self.assertRaises( Error, ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
def test_range_different_values(self): graph, graph_ref = build_range_test_graphs(start=0, limit=50000, delta=1, dst_type_str='FP16') self.assertRaises( Error, ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
def test_range_out_of_fp16_min(self): graph, graph_ref = build_range_test_graphs(start=0, limit=-100000, delta=-1, dst_type_str='FP16') self.assertRaises( Error, ChangeOutputTypeAttributes().find_and_replace_pattern, graph)
def test_cast_correct_case(self): input_data = np.array([0, 1000, 4, 9, 0]) graph, graph_ref = build_cast_test_graphs(input_data, dst_type_str='FP16') ChangeOutputTypeAttributes().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True) self.assertTrue(flag, resp)
def test_range_correct_case(self): graph, graph_ref = build_range_test_graphs(start=0, limit=10, delta=1, dst_type_str='FP16') ChangeOutputTypeAttributes().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True) self.assertTrue(flag, resp)