示例#1
0
    def test_hsigmoid_with_clamp_different_tensors(self):
        graph = build_graph_with_edge_attrs({
            **regular_op('input', {'type': 'Parameter'}),
            **regular_op('input_2', {'type': 'Parameter'}),
            **regular_op('add', {'op': 'Add'}),
            **regular_op('relu6', {'op': 'Clamp'}),
            **regular_op('mul', {'op': 'Mul'}),
            **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
            **const('const_0', float_array([0.0])),
            **const('const_3', float_array([3.0])),
            **const('const_6', float_array([6.0])),
            **const('const_1_6', float_array([1.0 / 6.0])),
            **result('result'),
        }, [('input', 'mul', {'in': 0, 'out': 0}),
            ('input_2', 'add', {'in': 0, 'out': 0}),
            ('const_3', 'add', {'in': 1, 'out': 0}),
            ('add', 'relu6', {'in': 0, 'out': 0}),
            ('const_0', 'relu6', {'in': 1, 'out': 0}),
            ('const_6', 'relu6', {'in': 2, 'out': 0}),
            ('relu6', 'mul', {'in': 1, 'out': 0}),
            ('mul', 'mul_2', {'in': 0, 'out': 0}),
            ('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
            ('mul_2', 'result', {'in': 0, 'out': 0})])

        graph_ref = graph.copy()
        graph.stage = 'front'

        HSigmoidWithClamp().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
示例#2
0
    def test_hsigmoid_with_clamp_wrong_constant(self):
        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})

        graph_ref = graph.copy()
        graph.stage = 'front'

        HSigmoidWithClamp().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
示例#3
0
    def test_hsigmoid_with_clamp(self):
        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

        graph_ref = build_graph(ref_nodes, ref_edges)
        graph.stage = 'front'

        HSigmoidWithClamp().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
                        graph.get_op_nodes(name='final_mul')[0].op == 'HSigmoid')