def test_classic_i1_positive_case(self): graph = build_graph( nodes, i1_edges, { 'const_1_d': { 'value': np.zeros([1, 2, 3, 4], dtype=np.float32) }, 'quantize': { 'levels': 2 } }, nodes_with_edges_only=True) graph.graph['layout'] = 'NHWC' graph.stage = 'middle' graph_ref = build_graph(nodes, ref_i1_edges, nodes_with_edges_only=True) ReluFakeQuantizeMark().find_and_replace_pattern(graph) ReluQuantizeFuse().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_const_extra_outputs_i1_case(self): graph = build_graph( nodes, const_extra, { 'const_1_d': { 'value': np.full([1, 2, 3, 4], -1, dtype=np.float32) }, 'quantize': { 'levels': 2 }, 'quantize_1': { 'levels': 2 } }, nodes_with_edges_only=True) graph.graph['layout'] = 'NHWC' graph.stage = 'middle' graph_ref = build_graph(nodes, ref_const_extra, nodes_with_edges_only=True) ReluFakeQuantizeMark().find_and_replace_pattern(graph) ReluQuantizeFuse().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'relu', check_op_attrs=True) self.assertTrue(flag, resp) np.array_equal(np.full([1, 2, 3, 4], float('-inf'), dtype=np.float32), graph_ref.node['const_1_d']['value'])