Пример #1
0
 def test_rewrite_subgraph(self):
     model_proto = self.sample_net()
     nodes = model_proto.node
     g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = tf2onnx.utils.make_name("ReplacedOp")
         out_name = tf2onnx.utils.port_name(op_name)
         new_node = g.make_node("Sub",
                                inputs=input_node.input,
                                outputs=[out_name],
                                name=op_name)
         ops = g.replace_subgraph(ops, match, [], [output_node], [],
                                  [new_node])
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \
                'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \
                'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
     self.assertEqual(expected, result)
Пример #2
0
    def test_match_flipped(self):
        n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
        n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
        n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")

        graph_proto = helper.make_graph(
            nodes=[n1, n2, n3],
            name="test",
            inputs=[
                helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
                helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])
            ],
            outputs=[
                helper.make_tensor_value_info("n2:0", TensorProto.FLOAT,
                                              [2, 2])
            ],
            initializer=[])
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        pattern = OpTypePattern(
            'Mul', inputs=[OpTypePattern('Add'),
                           OpTypePattern('Sub')])
        ops = g.get_nodes()
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        self.assertEqual(1, len(match_results))
Пример #3
0
def rewrite_thresholded_relu(g, ops):
    if g.opset < 10:
        return ops

    pattern = \
        OpTypePattern('Mul', name='mul', inputs=[
            OpTypePattern('Cast', name='cast', inputs=[
                OpTypePattern('Greater', name='greater', inputs=[
                    OpTypePattern('*', name='greater_input'),
                    OpTypePattern('Const', name='theta')
                ])
            ]),
            OpTypePattern('*', name='mul_input')
        ])
    matcher = GraphMatcher(pattern, allow_reorder=True)
    match_results = list(matcher.match_ops(ops))

    for match in match_results:
        greater_node = match.get_op('greater')
        greater_input_node = match.get_op('greater_input')
        mul_node = match.get_op("mul")
        mul_input_node = match.get_op('mul_input')
        cast_node = match.get_op('cast')

        greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
        mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
        if greater_input_edge_name == mul_input_edge_name:
            theta = match.get_op('theta').get_tensor_value()
            thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
                                           shapes=[g.get_shape(mul_node.output[0])],
                                           dtypes=[g.get_dtype(mul_node.output[0])])
            g.replace_all_inputs(mul_node.output[0], thresholded_relu.output[0], ops=ops)
            to_delete = [cast_node, mul_node]
            g.safe_remove_nodes(to_delete)
    return ops
Пример #4
0
    def run(self, unit_type):
        """
        main procedures:
        1 use cell op pattern to find cell >> the found cell is the start pointer of the procedures below
        2 find needed info from tensorflow graph:
            1 rnn scope name
            2 input_x
            3 weight
            4 sequence node
            5 initializer
            6 state output & hidden output
        3 process found info according to ONNX requirement

        remember: op pattern and scope name are useful
                  they are used to get needed info from tensorflow graph
                  raw found info need to be formatted according to ONNX requirement
        """
        # allow_reorder must be true. because LSTMCell and BasicLSTMCell's call function
        # are defining the calculation with different orders. Then we can share the same
        # pattern.
        cell_pattern = get_pattern(unit_type)
        matcher = GraphMatcher(cell_pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(self.g.get_nodes()))

        if match_results:
            for match in match_results:
                self.run_single_match(match)

            self.g.delete_unused_nodes(self.g.outputs)
            self.print_step("finish handling")

        return self.g.get_nodes()
Пример #5
0
def rewrite_biasadd_with_conv2d(g, ops):
    pattern = \
        OpTypePattern('BiasAdd', name='biasadd', inputs=[
            OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        biasadd = match.get_op('biasadd')
        conv = match.get_op('conv')

        #backup the conv and biasadd values
        conv_type = conv.type
        conv_input = conv.input
        conv_attr = conv.attr
        dtype = g.get_dtype(conv.output[0])
        shape = g.get_shape(conv.output[0])
        conv_name = biasadd.name
        conv_output = biasadd.output
        conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

        # Remove the Conv and BiasAdd node
        g.remove_node(conv.name)
        g.remove_node(biasadd.name)

        g.make_node(conv_type,
                    conv_inputs,
                    attr=conv_attr,
                    name=conv_name,
                    outputs=conv_output,
                    shapes=[shape],
                    dtypes=[dtype],
                    skip_conversion=False)
    return ops
Пример #6
0
    def _parse_input_ta(self, context):
        graph_inputs = [
            v.switch_true_identity_output.id
            for v in context.loop_properties.all_variables.values()
            if v.switch_true_identity_output.id
        ]
        matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False)
        match_results = matcher.match_ops(self.g.get_nodes())
        match_results = [
            r for r in match_results
            if r.get_op("ta_index").output[0] in graph_inputs
        ]
        for match in match_results:
            ta_input_scatter = match.get_op("ta_input_scatter")
            # the 3rd input of scatter is the value
            data_input_id = ta_input_scatter.input[2]
            ta_read_node = match.get_op("ta_read")

            # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
            # then we can be sure this is equivalent to scan input behavior.
            index_input_id = ta_read_node.input[1]
            unstacked_ta_consumer = match.get_op("ta_read").output[0]
            ta = InputTensorArray(data_input_id, index_input_id,
                                  unstacked_ta_consumer, self.g)
            context.loop_properties.add_scan_input(ta)
Пример #7
0
def rewrite_random_uniform(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                OpTypePattern('Sub', name='input2', inputs=["Const|ConstV2", "Const|ConstV2"]),
            ]), None
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        input2 = match.get_op('input2')
        output = match.get_op('output')
        ru_op = match.get_op('input1')
        # max is on input 0
        tmax = input2.inputs[0].get_tensor_value()
        tmin = input2.inputs[1].get_tensor_value()
        to_delete = list(set(match.get_nodes()))
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output,
                                                 to_delete)
        g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
        g.safe_remove_nodes(to_delete)

    return ops
Пример #8
0
def rewrite_random_uniform_fold_const(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                "Const|ConstV2",
            ]),
            "Const|ConstV2",
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mul = match.get_op('mul')
        ru_op = match.get_op('input1')

        tmax_minus_tmin = mul.inputs[1].get_tensor_value()
        tmin = output.inputs[1].get_tensor_value()
        tmax = tmin + tmax_minus_tmin
        to_delete = list(set(match.get_nodes()))
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output,
                                                 to_delete)
        g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
        g.safe_remove_nodes(to_delete)

    return ops
Пример #9
0
 def test_rewrite_subgraph(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = utils.make_name("ReplacedOp")
         out_name = utils.port_name(op_name)
         new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
         g.replace_all_inputs(output_node.output[0], new_node.output[0])  # ops=ops
         for n in set(match.get_nodes()):
             g.remove_node(n.name)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
                'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
                'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \
                'ReplacedOp__5:0 -> n6 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 }'
     self.assertEqual(expected, result)
Пример #10
0
def rewrite_random_normal(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
            ]), "*"
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mean = output.inputs[1].get_tensor_value()
        dtype = g.get_dtype(output.output[0])
        op_name = utils.make_name("RandomNormal")
        out_name = port_name(op_name)

        rn_op = match.get_op('input1')
        if rn_op.inputs[0].type == "Shape":
            shape_node = rn_op.inputs[0]
            new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
                                   attr={"mean": mean, "scale": 1.0, "dtype": dtype})
        else:
            shape = g.get_shape(output.output[0])
            new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
                                   attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})

        g.replace_all_inputs(ops, output.output[0], new_node.output[0])
        for n in set(match.get_nodes()):
            g.remove_node(n.name)
    return ops
def rewrite_quantize_and_dequantize(g, ops):

    pattern_for_qdq_v2 = \
        OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[
            OpTypePattern("*"),
            OpTypePattern(None),
            OpTypePattern(None),
        ])
    pattern_for_qdq_v3 = \
        OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[
            OpTypePattern("*"),
            OpTypePattern(None),
            OpTypePattern(None),
            OpTypePattern(None),
        ])

    # Match all the patterns for QDQ ops
    patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
    match_results = []
    for pattern in patterns:
        matcher = GraphMatcher(pattern)
        results = list(matcher.match_ops(ops))
        match_results.extend(results)

    return create_qdq_nodes(g, match_results)
Пример #12
0
def rewrite_ragged_variant_shape(g, ops):
    pattern1 = \
        OpTypePattern('Shape', name='shape', inputs=[
            OpTypePattern('RaggedTensorToVariant', name='raggedtovariant')
        ])

    pattern_list = [pattern1]
    for pattern in pattern_list:
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            shape = match.get_op('shape')
            raggedtovariant = match.get_op('raggedtovariant')
            if raggedtovariant.get_attr_value("batched_input") != 1:
                continue
            if raggedtovariant.get_attr_value("RAGGED_RANK") != 1:
                continue
            # Shape of batched variant from ragged is same as number of splits minus 1
            g.replace_inputs(shape, [raggedtovariant.input[0]])
            np_dtype = utils.map_onnx_to_numpy_type(
                g.get_dtype(shape.output[0]))
            const_one = g.make_const(utils.make_name("const_one"),
                                     np.array(1, np_dtype)).output[0]
            g.insert_new_node_on_output("Sub",
                                        shape.output[0],
                                        inputs=[shape.output[0], const_one])

    return ops
Пример #13
0
def rewrite_leakyrelu(g, ops):
    if g.opset < 6:
        return ops

    pattern = \
        OpTypePattern('Maximum', name='max', inputs=[
            OpTypePattern('Mul', name='mul', inputs=[
                OpTypePattern('Const', name='alpha'),
                OpTypePattern('*', name='mul_input'),
            ]),
            OpTypePattern('*', name='max_input'),
        ])

    matcher = GraphMatcher(pattern, allow_reorder=True)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        max_node = match.get_op('max')
        mul_node = match.get_op("mul")

        max_input_edge_name = match.get_tensor('max_input')
        mul_input_edge_name = match.get_tensor('mul_input')
        if max_input_edge_name == mul_input_edge_name:
            alpha = match.get_op("alpha").get_tensor_value()
            if alpha >= 1:
                continue
            leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha},
                                    shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
            ops.append(leakyrelu)
            g.replace_all_inputs(max_node.output[0], leakyrelu.output[0], ops=ops)
            to_delete = [max_node, mul_node]
            g.safe_remove_nodes(to_delete)

    return ops
Пример #14
0
def rewrite_dropout(g, ops):
    pattern = \
        OpTypePattern('Mul', name='outputs', inputs=[
            OpTypePattern('RealDiv', name="input2"),
            OpTypePattern('Floor', inputs=[
                OpTypePattern('Add', inputs=[
                    OpTypePattern(None, name="input3"),
                    OpTypePattern('RandomUniform|RandomUniformLike'),
                ])
            ]),
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        inputs2 = match.get_op('input2')
        outputs = match.get_op('outputs')
        op_name = utils.make_name("Dropout")
        out_name = port_name(op_name)
        new_node = g.make_node("Dropout", [inputs2.input[0]],
                               outputs=[out_name],
                               name=op_name,
                               attr={"ratio": 1.0},
                               shapes=[g.get_shape(inputs2.input[0])],
                               dtypes=[g.get_dtype(inputs2.input[0])])
        g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
        g.safe_remove_nodes(match.get_nodes())

    # remove dropout if its ratio is 1.0
    for node in g.get_nodes():
        if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
            g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
            g.remove_node(node.name)

    return ops
    def _parse_input_ta(self, context):
        matcher = GraphMatcher(self.rnn_input_pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(self.g.get_nodes()))
        match_results = [
            r for r in match_results
            if r.get_op("ta_input_scatter").name.startswith(context.rnn_scope)
        ]
        for match in match_results:
            ta_input_scatter = match.get_op("ta_input_scatter")
            # the 3rd input of scatter is the value
            input_ta = TensorArrayProp()

            # dynamic_rnn specific approach.
            input_ta.data_input_id = ta_input_scatter.input[2]

            ta_read_node = match.get_op("ta_read")
            input_ta.index_input_id = ta_read_node.input[1]
            input_ta.output_id = match.get_op("ta_read").output[0]

            input_shape = self.g.get_shape(input_ta.data_input_id)
            output_shape = self.g.get_shape(input_ta.output_id)
            if output_shape is None and input_shape is not None:
                self.g.set_shape(input_ta.output_id, input_shape[1:])

            context.input_tas.append(input_ta)

            log.debug(
                "input ta %s - data input (%s) shape: %s, output (%s) shape: %s",
                ta_read_node.name, input_ta.data_input_id,
                self.g.get_shape(input_ta.data_input_id), input_ta.output_id,
                self.g.get_shape(input_ta.output_id))
def rewrite_random_normal(g, ops):
    pattern1 = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "Const|ConstV2"
            ]), "Const|ConstV2"
        ])

    pattern2 = \
        OpTypePattern('Identity', name='output', inputs=[
            OpTypePattern('Identity', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"])
            ])
        ])

    pattern_list = [pattern1, pattern2]
    for pattern in pattern_list:
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            output = match.get_op('output')
            if output.type == 'Add':
                # pattern 1
                mean = output.inputs[1].get_tensor_value()
            else:
                # pattern 2
                mean = 0.0
            input2 = match.get_op('input2')
            if input2.type == 'Mul':
                scale = input2.inputs[1].get_tensor_value()
            else:
                scale = 1.0
            dtype = g.get_dtype(output.output[0])
            op_name = utils.make_name("RandomNormal")
            out_name = utils.port_name(op_name)

            rn_op = match.get_op('input1')
            seed = float(rn_op.get_attr('seed2').i)

            attr = {"mean": mean, "scale": scale, "dtype": dtype, "seed": seed}
            if rn_op.inputs[0].type == "Shape":
                shape_node = rn_op.inputs[0]
                new_node = g.make_node("RandomNormalLike",
                                       [shape_node.input[0]],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr=attr)
            else:
                shape = g.get_shape(output.output[0])
                if shape is None or -1 in shape:
                    continue
                attr['shape'] = shape
                new_node = g.make_node("RandomNormal", [],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr=attr)

            g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
            g.safe_remove_nodes(match.get_nodes())
    return ops
def rewrite_depthwise_conv_dilations(g, ops):
    pattern1 = \
        OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
            OpTypePattern("DepthwiseConv2dNative", name="depthwise_conv", inputs=[
                OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[
                    OpTypePattern("*"),
                    OpTypePattern("Const|ConstV2"),
                    OpTypePattern("Const|ConstV2"),
                ]),
                OpTypePattern("*"),
            ]),
            OpTypePattern("Const|ConstV2"),
            OpTypePattern("Const|ConstV2"),
        ])

    for pattern in [pattern1]:
        matcher = GraphMatcher(pattern, allow_reorder=False)
        match_results = list(matcher.match_ops(ops))
        for match_result in match_results:
            space_to_batch = match_result.get_op("space_to_batch")
            depthwise_conv = match_result.get_op("depthwise_conv")
            batch_to_space = match_result.get_op("batch_to_space")

            block_shape1 = space_to_batch.inputs[1].get_tensor_value(
                as_list=True)
            paddings = space_to_batch.inputs[2].get_tensor_value(
                as_list=False).flatten().tolist()
            block_shape2 = batch_to_space.inputs[1].get_tensor_value(
                as_list=True)
            crops = batch_to_space.inputs[2].get_tensor_value(as_list=True)
            if block_shape1 != block_shape2:
                continue
            if depthwise_conv.get_attr_value("dilations",
                                             [1, 1, 1, 1]) != [1, 1, 1, 1]:
                continue
            if depthwise_conv.get_attr_value("strides",
                                             [1, 1, 1, 1]) != [1, 1, 1, 1]:
                continue
            if depthwise_conv.get_attr_value("data_format",
                                             b"NHWC") != b"NHWC":
                continue
            if depthwise_conv.get_attr_value("padding") != b"VALID":
                continue
            if crops != [[0, 0], [0, 0]]:
                continue

            inp = space_to_batch.input[0]
            kernel = depthwise_conv.input[1]

            g.replace_inputs(depthwise_conv, [inp, kernel])
            depthwise_conv.set_attr("dilations", [1] + block_shape1 + [1])
            depthwise_conv.set_attr("explicit_paddings",
                                    [0, 0] + paddings + [0, 0])
            depthwise_conv.set_attr("padding", "EXPLICIT")
            g.copy_shape(batch_to_space.output[0], depthwise_conv.output[0])
            g.replace_all_inputs(batch_to_space.output[0],
                                 depthwise_conv.output[0])

    return g.get_nodes()
Пример #18
0
 def rewrite_test(g, ops):
     pattern = \
         OpTypePattern('Add', name='op', inputs=["*", "*"])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         op = match.get_op('op')
         op.type = "Mul"
     return ops
 def find_sequence_length_node(self, context):
     # get any state variable
     state_variable = list(context.state_variables.values())[0]
     next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id)
     if not is_select_op(next_iter_input_node):
         log.debug("no sequence length node is given")
         return None
     matcher = GraphMatcher(seq_len_pattern)
     match_result = matcher.match_op(next_iter_input_node)
     if not match_result:
         raise RuntimeError("failed to find sequence length.")
     return match_result.get_op("seq_len_node")
Пример #20
0
def rewrite_conv2d_with_pad(g, ops):
    pattern = \
        OpTypePattern("Conv2D", name="conv", inputs=[
            OpTypePattern("Pad", name="pad"),
            OpTypePattern("*")
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        conv = match.get_op("conv")
        pad = match.get_op("pad")
        paddings = pad.inputs[1]

        if not paddings.is_const():
            continue
        mode = pad.get_attr("mode")
        if mode:
            mode = mode.s.decode("utf-8").lower()
        if mode not in [None, "constant"] or len(pad.input) >= 3:
            continue
        # Conv2D already has a pad
        if conv.get_attr("padding").s.decode("utf-8") == "SAME":
            continue

        logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name)
        paddings_val = np.array(paddings.get_tensor_value())
        # can't pad on batch or channel dimensions
        data_format = conv.get_attr("data_format").s.decode("utf-8")
        if data_format == "NHWC":
            if np.any(paddings_val[0]) or np.any(paddings_val[3]):
                continue
            paddings_val = paddings_val[1:3]
        else:
            if np.any(paddings_val[0]) or np.any(paddings_val[1]):
                continue
            paddings_val = paddings_val[2:4]

        paddings_val = paddings_val.transpose().flatten()
        g.replace_input(conv, conv.input[0], pad.input[0])
        # convert Conv2D
        conv.type = "Conv2D"
        func, _ = handler.tf_op.find_effective_op("Conv2D")
        func(g, conv)
        conv.skip_conversion = True
        conv.set_attr("auto_pad", "NOTSET")
        conv.set_attr("pads", paddings_val)
    return ops
def rewrite_tfl_qdq(g, ops):
    pattern0 = \
        OpTypePattern('TFL_DEQUANTIZE', name='dequant', inputs=[
            OpTypePattern('TFL_QUANTIZE', name='quant'),
        ])

    matcher = GraphMatcher(pattern0, allow_reorder=False)
    match_results = list(matcher.match_ops(ops))
    if match_results:
        for match in match_results:
            dequant = match.get_op("dequant")
            quant = match.get_op("quant")
            inp_node = quant.inputs[0]
            for k in ["scale", "quantized_dimension", "zero_point"]:
                if dequant.get_attr_value(k) != quant.get_attr_value(k):
                    continue
            needed_relu = None
            if all(k in quant.attr and len(quant.get_attr_value(k)) == 1 for k in ["min", "max"]):
                min_val = quant.get_attr_value("min")[0]
                max_val = quant.get_attr_value("max")[0]
                if min_val == 0.0 and 5.999 <= max_val <= 6.0:
                    needed_relu = "TFL_RELU6"
                elif min_val == 0.0:
                    # This may introduce unneeded relu ops but will be correct.
                    # If the --dequantize feature is used a lot in the future we can optimize this.
                    needed_relu = "TFL_RELU"
                if inp_node.type == needed_relu:
                    # If it's really obviously unneeded, we skip it.
                    needed_relu = None
                elif "TFL_" + inp_node.get_attr_value("fused_activation_function", b'').decode() == needed_relu:
                    needed_relu = None

            if needed_relu is not None:
                relu_name = inp_node.name + "_relu"

                relu6 = g.make_node(needed_relu, [quant.input[0]], op_name_scope=relu_name,
                                    skip_conversion=False, shapes=quant.output_shapes, dtypes=quant.output_dtypes)
                g.replace_all_inputs(dequant.output[0], relu6.output[0])
            else:
                g.replace_all_inputs(dequant.output[0], quant.input[0])

            g.remove_node(dequant.name)
            if len(g.find_output_consumers(quant.output[0])) == 0:
                g.remove_node(quant.name)

    return ops
Пример #22
0
def rewrite_tfl_rfft(g, ops):
    pattern0 = \
        OpTypePattern('TFL_COMPLEX_ABS', name='complex_abs', inputs=[
            OpTypePattern('TFL_RESHAPE', name='reshape', inputs=[
                OpTypePattern('TFL_RFFT2D', name='rfft2d', inputs=[
                    OpTypePattern('*'),
                    OpTypePattern('Const|ConstV2', name='length'),
                ]),
                OpTypePattern('Const|ConstV2', name='shape'),
            ], allow_reorder=True),
        ])

    matcher = GraphMatcher(pattern0, allow_reorder=False)
    match_results = list(matcher.match_ops(ops))
    if match_results:
        for match in match_results:
            length = match.get_op("length").get_tensor_value(as_list=True)
            rfft2d = match.get_op("rfft2d")
            complex_abs = match.get_op("complex_abs")
            reshape = match.get_op("reshape")
            shape = match.get_op("shape").get_tensor_value(as_list=True)
            output_shape = g.get_shape(rfft2d.output[0])

            if output_shape is None or output_shape != shape[:-1] + [
                    1, shape[-1]
            ]:
                continue
            if length[0] != 1:
                continue

            rfft2d.type = "RFFT"
            g.copy_shape(complex_abs.input[0], rfft2d.output[0])
            # Skip the Reshape
            g.replace_input(complex_abs, complex_abs.input[0],
                            rfft2d.output[0], 0)

            new_length = g.make_const(utils.make_name("rfft_length"),
                                      np.array([length[1]], np.int64))
            g.replace_input(rfft2d, rfft2d.input[1], new_length.output[0], 1)

            g.replace_all_inputs(complex_abs.output[0], reshape.output[0])
            # Move reshape below complex abs
            g.replace_input(reshape, reshape.input[0], complex_abs.output[0],
                            0)

    return ops
Пример #23
0
    def _match_cell(self, context, unittype):
        """match unit cell"""
        for cell_pattern in get_pattern(unittype):
            matcher = GraphMatcher(cell_pattern, allow_reorder=True)

            loop_props = context.loop_properties
            inputs = loop_props.state_inputs + loop_props.scan_inputs
            input_ids = [input_tensor_value_info.id for input_tensor_value_info in inputs]
            outputs = loop_props.state_outputs + loop_props.scan_outputs
            output_ids = [out_tensor_value_info.id for out_tensor_value_info in outputs]
            body_graph_ops, _, _ = LoopRewriterBase.find_subgraph(
                set(input_ids),
                set(output_ids),
                self.g, merge_as_end=True
            )

            match_results = list(matcher.match_ops(body_graph_ops))
            if len(match_results) == 1:
                return match_results[0]
        return None
Пример #24
0
def rewrite_transpose(g, ops):
    pattern = \
        OpTypePattern('Transpose', name='output', inputs=[
            OpTypePattern(None),
            OpTypePattern('Sub', inputs=[
                OpTypePattern('Sub', inputs=["*", "*"]),
                OpTypePattern('Range', inputs=["*", "*", "*"]),
            ]),
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        shape = g.get_shape(output.input[0])
        dims = [i for i in range(len(shape) - 1, -1, -1)]
        output.set_attr("perm", dims)
        g.remove_input(output, output.input[1])
        to_delete = [n for n in match.get_nodes() if n != output]
        g.safe_remove_nodes(to_delete)
    return ops
def rewrite_biasadd_with_conv2d(g, ops):
    pattern1 = \
        OpTypePattern('BiasAdd', name='biasadd', inputs=[
            OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
    pattern2 = \
        OpTypePattern('BiasAdd', name='biasadd', inputs=[
            OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=[
                '*', '*', '*']), '*'], allow_reorder=True)

    for pattern in [pattern1, pattern2]:
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            biasadd = match.get_op('biasadd')
            conv = match.get_op('conv')

            # Backup the conv and biasadd values
            conv_type = conv.type
            conv_input = conv.input
            conv_attr = conv.attr
            dtype = g.get_dtype(conv.output[0])
            shape = g.get_shape(conv.output[0])
            conv_name = biasadd.name
            conv_output = biasadd.output
            if pattern == pattern2:
                conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]]
            else:
                conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

            if len(g.find_output_consumers(conv.output[0])) > 1:
                continue
            # Remove the Conv and BiasAdd node
            g.remove_node(conv.name)
            g.remove_node(biasadd.name)

            g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
                        shapes=[shape], dtypes=[dtype], skip_conversion=False)
    return ops
Пример #26
0
def rewrite_random_uniform(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
            ]), None
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        input2 = match.get_op('input2')
        output = match.get_op('output')
        ru_op = match.get_op('input1')
        # max is on input 0
        tmax = input2.inputs[0].get_tensor_value()[0]
        tmin = input2.inputs[1].get_tensor_value()[0]

        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output)
        ops = g.replace_subgraph(ops, match, [], [output], [], [new_node])

    return ops
Пример #27
0
def rewrite_random_uniform_fold_const(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                None,
            ]),
            None,
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mul = match.get_op('mul')
        ru_op = match.get_op('input1')

        tmax_minus_tmin = mul.inputs[1].get_tensor_value()[0]
        tmin = output.inputs[1].get_tensor_value()[0]
        tmax = tmin + tmax_minus_tmin
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output)
        ops = g.replace_subgraph(ops, match, [], [output], [], [new_node])

    return ops
Пример #28
0
def rewrite_tfl_select_zero_mul(g, ops):
    pattern0 = \
        OpTypePattern('TFL_SELECT_V2', name='select', inputs=[
            OpTypePattern('TFL_EQUAL', name='equal', inputs=[
                OpTypePattern('Const|ConstV2', name='const_eq'),
                OpTypePattern('*', name='term_eq'),
            ], allow_reorder=True),
            OpTypePattern('Const|ConstV2', name='const_select'),
            OpTypePattern('TFL_MUL', name='mul', inputs=[
                OpTypePattern('*', name='term_mul1'),
                OpTypePattern('*', name='term_mul2'),
            ]),
        ])

    matcher = GraphMatcher(pattern0, allow_reorder=False)
    match_results = list(matcher.match_ops(ops))
    if match_results:
        for match in match_results:
            select = match.get_op("select")
            term_eq = match.get_op("term_eq")
            const_select = match.get_op("const_select")
            const_eq = match.get_op("const_eq")
            term_mul1 = match.get_op("term_mul1")
            term_mul2 = match.get_op("term_mul2")
            if const_select.get_tensor_value(as_list=True) != 0:
                continue
            if const_eq.get_tensor_value(as_list=True) != 0:
                continue
            if term_mul1.name != term_eq.name:
                term_mul1, term_mul2 = term_mul2, term_mul1
            if term_mul1.name != term_eq.name:
                continue
            # Tell downstream conversion to avoid Mul/Add optimization
            select.set_attr("handles_nan", True)

    return ops
Пример #29
0
def rewrite_dropout(g, ops):
    patterns = [
        OpTypePattern(
            'Mul',
            name='outputs',
            inputs=[
                OpTypePattern('RealDiv', name="input2"),
                OpTypePattern(
                    'Floor',
                    inputs=[
                        OpTypePattern(
                            'Add',
                            inputs=[
                                OpTypePattern("*", name="input3"),
                                OpTypePattern(
                                    'RandomUniform|RandomUniformLike'),
                            ])
                    ]),
            ]),
        OpTypePattern(
            "Mul",
            name="outputs",
            inputs=[
                OpTypePattern("Mul", name="input2"),
                OpTypePattern(
                    "Cast",
                    inputs=[
                        OpTypePattern(
                            "GreaterEqual",
                            inputs=[
                                OpTypePattern(
                                    "RandomUniform|RandomUniformLike"),
                                OpTypePattern("*", name="input3")
                            ])
                    ])
            ]),
        # pattern for tf-2.0 tf.nn.dropout()
        OpTypePattern(
            "Mul",
            name="outputs",
            inputs=[
                OpTypePattern(
                    "Cast",
                    inputs=[
                        OpTypePattern(
                            "GreaterEqual",
                            inputs=[
                                OpTypePattern(
                                    "RandomUniform|RandomUniformLike"),
                                OpTypePattern("*", name="input3")
                            ])
                    ]),
                OpTypePattern("Mul", name="input2"),
            ]),
    ]
    for pattern in patterns:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            input2 = match.get_op('input2')
            input3 = match.get_op('input3')
            outputs = match.get_op('outputs')

            if not input3.is_scalar():
                logger.warning(
                    "Dropout pattern rooted at %s does not have a "
                    "constant ratio and cannot be replaced.", outputs.name)
                continue
            ratio = input3.get_tensor_value()

            if input2.inputs[0].is_scalar():
                data = input2.inputs[1]
                scaling_constant = input2.inputs[0].get_tensor_value()
            elif input2.inputs[1].is_scalar():
                data = input2.inputs[0]
                scaling_constant = input2.inputs[1].get_tensor_value()
            else:
                logger.warning(
                    "Could not find scaling constant for dropout pattern rooted at %s. "
                    "The pattern will not be replaced with an ONNX dropout node.",
                    outputs.name)
                continue

            #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node
            if not np.allclose([1], [scaling_constant * (1 - ratio)]):
                logger.warning(
                    "Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout "
                    "ratio %f. The pattern will not be replaced with an ONNX dropout node.",
                    scaling_constant, outputs.name, ratio)
                continue

            nodes_to_remove = [
                n for n in match.get_nodes() if n.name != input3.name
            ]
            if not g.is_safe_to_remove_nodes(nodes_to_remove,
                                             [outputs.output[0]]):
                logger.warning(
                    "Nodes in dropout pattern rooted at %s cannot be removed because intermediate results "
                    "of some nodes are referenced elsewhere in graph.",
                    outputs.name)
                continue

            op_name = utils.make_name("Dropout")
            out_name = utils.port_name(op_name)
            new_node = g.make_node("Dropout",
                                   inputs=[data.output[0]],
                                   outputs=[out_name],
                                   name=op_name,
                                   attr={"ratio": ratio},
                                   shapes=[g.get_shape(data.output[0])],
                                   dtypes=[g.get_dtype(data.output[0])])
            g.replace_all_inputs(outputs.output[0],
                                 new_node.output[0],
                                 ops=ops)
            for n in nodes_to_remove:
                g.remove_node(n.name)

    return ops
Пример #30
0
def rewrite_flatten(g, ops):
    pattern_fixed_shape_input = \
        OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    "*", "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    pattern_non_fixed_shape_input = \
        OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    OpTypePattern('Shape', inputs=[
                        OpTypePattern("*", name="input2")
                    ]),
                    "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    matcher = GraphMatcher(pattern_fixed_shape_input)
    match_results_1 = list(matcher.match_ops(ops))

    matcher = GraphMatcher(pattern_non_fixed_shape_input)
    match_results_2 = list(matcher.match_ops(ops))

    match_results = [(match_results_1, True), (match_results_2, False)]
    for match_results, check_fixed_input_shape in match_results:
        for match in match_results:
            input_node = match.get_op('input')
            reshape_node = match.get_op('reshape')
            pack_node = match.get_op('pack')
            slice_node = match.get_op('slice')
            need_rewrite = pack_node.inputs[1].is_const(
            ) and pack_node.inputs[1].get_tensor_value() == -1
            if not need_rewrite:
                continue

            input_shape = g.get_shape(reshape_node.input[0])
            need_rewrite = input_shape is not None
            if not need_rewrite:
                continue

            if check_fixed_input_shape:
                need_rewrite = slice_node.inputs[0].is_const() and \
                               np.array_equal(list(input_shape), list(slice_node.inputs[0].get_tensor_value()))
                if not need_rewrite:
                    continue

            begin = slice_node.inputs[1].get_tensor_value(as_list=False)
            end = slice_node.inputs[2].get_tensor_value(as_list=False)
            strides = slice_node.inputs[3].get_tensor_value(as_list=False)
            need_rewrite = np.array_equal(begin, [0]) and len(end) == 1 and \
                           np.array_equal(strides, [1]) and end[0] - begin[0] == 1
            if not need_rewrite:
                continue

            op_name = utils.make_name("Flatten")
            out_name = port_name(op_name)
            new_node = g.make_node("Flatten", [reshape_node.input[0]],
                                   outputs=[out_name],
                                   name=op_name)

            last_dim = input_shape[-1]
            sec_last_dim = input_shape[-2]
            new_dim = None
            if last_dim > 0 and sec_last_dim > 0:
                new_dim = last_dim * sec_last_dim
            else:
                new_dim = -1

            g.set_shape(out_name, input_shape[:-2] + [new_dim])
            g.replace_all_inputs(ops, reshape_node.output[0], out_name)
            to_delete = [n for n in match.get_nodes() if n != input_node]
            g.safe_remove_nodes(to_delete)

    return ops