Example #1
0
def rewrite_random_uniform(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
            ]), None
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        input2 = match.get_op('input2')
        output = match.get_op('output')
        ru_op = match.get_op('input1')
        # max is on input 0
        tmax = input2.inputs[0].get_tensor_value()
        tmin = input2.inputs[1].get_tensor_value()
        to_delete = list(set(match.get_nodes()))
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output,
                                                 to_delete)
        g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
        g.safe_remove_nodes(to_delete)

    return ops
Example #2
0
    def _parse_input_ta(self, context):
        graph_inputs = [
            v.switch_true_identity_output.id
            for v in context.loop_properties.all_variables.values()
            if v.switch_true_identity_output.id
        ]
        matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False)
        match_results = matcher.match_ops(self.g.get_nodes())
        match_results = [
            r for r in match_results
            if r.get_op("ta_index").output[0] in graph_inputs
        ]
        for match in match_results:
            ta_input_scatter = match.get_op("ta_input_scatter")
            # the 3rd input of scatter is the value
            data_input_id = ta_input_scatter.input[2]
            ta_read_node = match.get_op("ta_read")

            # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
            # then we can be sure this is equivalent to scan input behavior.
            index_input_id = ta_read_node.input[1]
            unstacked_ta_consumer = match.get_op("ta_read").output[0]
            ta = InputTensorArray(data_input_id, index_input_id,
                                  unstacked_ta_consumer, self.g)
            context.loop_properties.add_scan_input(ta)
Example #3
0
    def test_match_flipped(self):
        n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
        n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
        n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")

        graph_proto = helper.make_graph(
            nodes=[n1, n2, n3],
            name="test",
            inputs=[
                helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
                helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])
            ],
            outputs=[
                helper.make_tensor_value_info("n2:0", TensorProto.FLOAT,
                                              [2, 2])
            ],
            initializer=[])
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        pattern = OpTypePattern(
            'Mul', inputs=[OpTypePattern('Add'),
                           OpTypePattern('Sub')])
        ops = g.get_nodes()
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        self.assertEqual(1, len(match_results))
Example #4
0
 def test_rewrite_subgraph(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = utils.make_name("ReplacedOp")
         out_name = utils.port_name(op_name)
         new_node = g.make_node("Sub",
                                inputs=input_node.input,
                                outputs=[out_name],
                                name=op_name)
         g.replace_all_inputs(output_node.output[0],
                              new_node.output[0])  # ops=ops
         for n in set(match.get_nodes()):
             g.remove_node(n.name)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] ' \
                'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__6 [op_type=Sub] ' \
                'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__6 n3:0 -> ReplacedOp__6 ' \
                'ReplacedOp__6:0 -> n6 ReplacedOp__6:0 -> n5_graph_outputs_Identity__4 }'
     self.assertEqual(expected, result)
Example #5
0
def rewrite_quantize_and_dequantize(g, ops):

    pattern_for_qdq_v2 = \
        OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[
            OpTypePattern("*"),
            OpTypePattern(None),
            OpTypePattern(None),
        ])
    pattern_for_qdq_v3 = \
        OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[
            OpTypePattern("*"),
            OpTypePattern(None),
            OpTypePattern(None),
            OpTypePattern(None),
        ])

    # Match all the patterns for QDQ ops
    patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
    match_results = []
    for pattern in patterns:
        matcher = GraphMatcher(pattern)
        results = list(matcher.match_ops(ops))
        match_results.extend(results)

    return create_qdq_nodes(g, match_results)
Example #6
0
def rewrite_random_uniform_fold_const(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                None,
            ]),
            None,
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mul = match.get_op('mul')
        ru_op = match.get_op('input1')

        tmax_minus_tmin = mul.inputs[1].get_tensor_value()
        tmin = output.inputs[1].get_tensor_value()
        tmax = tmin + tmax_minus_tmin
        to_delete = list(set(match.get_nodes()))
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output,
                                                 to_delete)
        g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
        g.safe_remove_nodes(to_delete)

    return ops
def rewrite_biasadd_with_conv2d(g, ops):
    pattern = \
        OpTypePattern('BiasAdd', name='biasadd', inputs=[
            OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        biasadd = match.get_op('biasadd')
        conv = match.get_op('conv')

        #backup the conv and biasadd values
        conv_type = conv.type
        conv_input = conv.input
        conv_attr = conv.attr
        dtype = g.get_dtype(conv.output[0])
        shape = g.get_shape(conv.output[0])
        conv_name = biasadd.name
        conv_output = biasadd.output
        conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

        # Remove the Conv and BiasAdd node
        g.remove_node(conv.name)
        g.remove_node(biasadd.name)

        g.make_node(conv_type,
                    conv_inputs,
                    attr=conv_attr,
                    name=conv_name,
                    outputs=conv_output,
                    shapes=[shape],
                    dtypes=[dtype],
                    skip_conversion=False)
    return ops
Example #8
0
def rewrite_tfl_qdq(g, ops):
    pattern0 = \
        OpTypePattern('TFL_DEQUANTIZE', name='dequant', inputs=[
            OpTypePattern('TFL_QUANTIZE', name='quant'),
        ])

    matcher = GraphMatcher(pattern0, allow_reorder=False)
    match_results = list(matcher.match_ops(ops))
    if match_results:
        for match in match_results:
            dequant = match.get_op("dequant")
            quant = match.get_op("quant")
            inp_node = quant.inputs[0]
            for k in ["scale", "quantized_dimension", "zero_point"]:
                if dequant.get_attr_value(k) != quant.get_attr_value(k):
                    continue
            needed_relu = None
            if all(k in quant.attr and len(quant.get_attr_value(k)) == 1
                   for k in ["min", "max"]):
                min_val = quant.get_attr_value("min")[0]
                max_val = quant.get_attr_value("max")[0]
                if min_val == 0.0 and 5.999 <= max_val <= 6.0:
                    needed_relu = "TFL_RELU6"
                elif min_val == 0.0:
                    # This may introduce unneeded relu ops but will be correct.
                    # If the --dequantize feature is used a lot in the future we can optimize this.
                    needed_relu = "TFL_RELU"
                if inp_node.type == needed_relu:
                    # If it's really obviously unneeded, we skip it.
                    needed_relu = None
                elif "TFL_" + inp_node.get_attr_value(
                        "fused_activation_function",
                        b'').decode() == needed_relu:
                    needed_relu = None

            if needed_relu is not None:
                relu_name = inp_node.name + "_relu"

                relu6 = g.make_node(needed_relu, [quant.input[0]],
                                    op_name_scope=relu_name,
                                    skip_conversion=False,
                                    shapes=quant.output_shapes,
                                    dtypes=quant.output_dtypes)
                g.replace_all_inputs(dequant.output[0], relu6.output[0])
            else:
                g.replace_all_inputs(dequant.output[0], quant.input[0])

            g.remove_node(dequant.name)
            if len(g.find_output_consumers(quant.output[0])) == 0:
                g.remove_node(quant.name)

    return ops
def rewrite_conv2d_with_pad(g, ops):
    pattern = \
        OpTypePattern("Conv2D", name="conv", inputs=[
            OpTypePattern("Pad", name="pad"),
            OpTypePattern("*")
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        conv = match.get_op("conv")
        pad = match.get_op("pad")
        paddings = pad.inputs[1]

        if not paddings.is_const():
            continue
        mode = pad.get_attr("mode")
        if mode:
            mode = mode.s.decode("utf-8").lower()
        if mode not in [None, "constant"] or len(pad.input) >= 3:
            continue
        # Conv2D already has a pad
        if conv.get_attr("padding").s.decode("utf-8") == "SAME":
            continue

        logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name)
        paddings_val = np.array(paddings.get_tensor_value())
        # can't pad on batch or channel dimensions
        data_format = conv.get_attr("data_format").s.decode("utf-8")
        if data_format == "NHWC":
            if np.any(paddings_val[0]) or np.any(paddings_val[3]):
                continue
            paddings_val = paddings_val[1:3]
        else:
            if np.any(paddings_val[0]) or np.any(paddings_val[1]):
                continue
            paddings_val = paddings_val[2:4]

        paddings_val = paddings_val.transpose().flatten()
        g.replace_input(conv, conv.input[0], pad.input[0], 0)
        # convert Conv2D
        conv.type = "Conv2D"
        func, _ = handler.tf_op.find_effective_op("Conv2D")
        func(g, conv)
        conv.skip_conversion = True
        conv.set_attr("auto_pad", "NOTSET")
        conv.set_attr("pads", paddings_val)
    return ops
Example #10
0
def rewrite_thresholded_relu(g, ops):
    if g.opset < 10:
        return ops

    pattern = \
        OpTypePattern('Mul', name='mul', inputs=[
            OpTypePattern('Cast', name='cast', inputs=[
                OpTypePattern('Greater', name='greater', inputs=[
                    OpTypePattern('*', name='greater_input'),
                    OpTypePattern('Const', name='theta')
                ])
            ]),
            OpTypePattern('*', name='mul_input')
        ])
    matcher = GraphMatcher(pattern, allow_reorder=True)
    match_results = list(matcher.match_ops(ops))

    for match in match_results:
        greater_node = match.get_op('greater')
        greater_input_node = match.get_op('greater_input')
        mul_node = match.get_op("mul")
        mul_input_node = match.get_op('mul_input')
        cast_node = match.get_op('cast')

        greater_input_edge_name = _find_edge_name_between_nodes(
            greater_input_node, greater_node)
        mul_input_edge_name = _find_edge_name_between_nodes(
            mul_input_node, mul_node)
        if greater_input_edge_name == mul_input_edge_name:
            theta = match.get_op('theta').get_tensor_value()
            thresholded_relu = g.make_node(
                "ThresholdedRelu",
                inputs=[mul_input_edge_name],
                attr={"alpha": theta},
                shapes=[g.get_shape(mul_node.output[0])],
                dtypes=[g.get_dtype(mul_node.output[0])])
            g.replace_all_inputs(mul_node.output[0],
                                 thresholded_relu.output[0],
                                 ops=ops)
            to_delete = [cast_node, mul_node]
            g.safe_remove_nodes(to_delete)
    return ops
Example #11
0
def rewrite_leakyrelu(g, ops):
    if g.opset < 6:
        return ops

    pattern = \
        OpTypePattern('Maximum', name='max', inputs=[
            OpTypePattern('Mul', name='mul', inputs=[
                OpTypePattern('Const', name='alpha'),
                OpTypePattern('*', name='mul_input'),
            ]),
            OpTypePattern('*', name='max_input'),
        ])

    matcher = GraphMatcher(pattern, allow_reorder=True)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        max_node = match.get_op('max')
        max_input_node = match.get_op('max_input')
        mul_node = match.get_op("mul")
        mul_input_node = match.get_op('mul_input')

        max_input_edge_name = _find_edge_name_between_nodes(
            max_input_node, max_node)
        mul_input_edge_name = _find_edge_name_between_nodes(
            mul_input_node, mul_node)
        if max_input_edge_name == mul_input_edge_name:
            alpha = match.get_op("alpha").get_tensor_value()
            if alpha >= 1:
                continue
            leakyrelu = g.make_node("LeakyRelu",
                                    inputs=[max_input_edge_name],
                                    attr={"alpha": alpha},
                                    shapes=[g.get_shape(max_node.output[0])],
                                    dtypes=[g.get_dtype(max_node.output[0])])
            ops.append(leakyrelu)
            g.replace_all_inputs(max_node.output[0],
                                 leakyrelu.output[0],
                                 ops=ops)
            to_delete = [max_node, mul_node]
            g.safe_remove_nodes(to_delete)

    return ops
def rewrite_transpose(g, ops):
    pattern = \
        OpTypePattern('Transpose', name='output', inputs=[
            OpTypePattern(None),
            OpTypePattern('Sub', inputs=[
                OpTypePattern('Sub', inputs=["*", "*"]),
                OpTypePattern('Range', inputs=["*", "*", "*"]),
            ]),
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        shape = g.get_shape(output.input[0])
        dims = range(len(shape) - 1, -1, -1)
        output.set_attr("perm", dims)
        g.remove_input(output, output.input[1], 1)
        to_delete = [n for n in match.get_nodes() if n != output]
        g.safe_remove_nodes(to_delete)
    return ops
    def _match_cell(self, context, unittype):
        """match unit cell"""
        for cell_pattern in get_pattern(unittype):
            matcher = GraphMatcher(cell_pattern, allow_reorder=True)

            loop_props = context.loop_properties
            inputs = loop_props.state_inputs + loop_props.scan_inputs
            input_ids = [
                input_tensor_value_info.id
                for input_tensor_value_info in inputs
            ]
            outputs = loop_props.state_outputs + loop_props.scan_outputs
            output_ids = [
                out_tensor_value_info.id for out_tensor_value_info in outputs
            ]
            body_graph_ops, _, _ = LoopRewriterBase.find_subgraph(
                set(input_ids), set(output_ids), self.g, merge_as_end=True)

            match_results = list(matcher.match_ops(body_graph_ops))
            if len(match_results) == 1:
                return match_results[0]
        return None
def rewrite_gemm(g, ops):
    if g.opset <= 6:
        return ops

    # pattern0: alpha*A*B + beta*C
    pattern0 = \
        OpTypePattern('Add|AddV2', name='add', inputs=[
            OpTypePattern('Mul', name='mul1', inputs=[
                OpTypePattern('Const', name='alpha'),
                OpTypePattern('MatMul', name='matmul')
            ]),
            OpTypePattern('Mul', name='mul2', inputs=[
                OpTypePattern('Const', name='beta'),
                OpTypePattern('*', name='C')
            ])
        ])

    # pattern1: alpha*A*B + C
    pattern1 = \
        OpTypePattern('Add|AddV2', name='add', inputs=[
            OpTypePattern('Mul', name='mul1', inputs=[
                OpTypePattern('MatMul', name='matmul'),
                OpTypePattern('Const', name='alpha')
            ]),
            OpTypePattern('*', name='C'),
        ])

    # pattern2: A*B + beta*C
    pattern2 = \
        OpTypePattern('Add|AddV2', name='add', inputs=[
            OpTypePattern('MatMul', name='matmul'),
            OpTypePattern('Mul', name='mul2', inputs=[
                OpTypePattern('Const', name='beta'),
                OpTypePattern('*', name='C')
            ])
        ])

    # pattern3: A*B + C
    pattern3 = \
        OpTypePattern('Add|AddV2', name='add', inputs=[
            OpTypePattern('MatMul', name='matmul'),
            OpTypePattern('*', name='C'),
        ])

    # pattern4: A*B + c
    pattern4 = \
        OpTypePattern('BiasAdd', name='add', inputs=[
            OpTypePattern('MatMul', name='matmul'),
            OpTypePattern('*', name='C'),
        ])

    pattern_list = [pattern0, pattern1, pattern2, pattern3, pattern4]

    for pattern in pattern_list:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        if match_results:
            for match in match_results:
                matmul_node = match.get_op("matmul")

                if g.get_dtype(matmul_node.input[0]) != onnx_pb.TensorProto.FLOAT:
                    logging.warning(u"For now, onnxruntime only support float32 type for Gemm rewriter")
                    continue

                attr, is_valid = get_gemm_attr(match)
                if not is_valid:
                    continue

                add_node = match.get_op('add')
                input_c_node = match.get_op("C")
                a_edge_name = matmul_node.input[0]
                b_edge_name = matmul_node.input[1]
                c_edge_name = input_c_node.output[0]

                a_mul_b_shape = g.get_shape(matmul_node.output[0])
                c_shape = g.get_shape(c_edge_name)
                if c_shape is None: continue
                if a_mul_b_shape is None: continue
                if -1 in c_shape + a_mul_b_shape: continue
                if g.get_rank(a_edge_name) != 2 or g.get_rank(b_edge_name) != 2: continue
                compatible = True
                for i in range(1, len(c_shape) + 1):
                    if c_shape[-i] not in [1, a_mul_b_shape[-i]]:
                        compatible = False
                if not compatible: continue

                gemm = g.make_node("Gemm", inputs=[a_edge_name, b_edge_name, c_edge_name],
                                   attr=attr,
                                   shapes=[g.get_shape(add_node.output[0])],
                                   dtypes=[g.get_dtype(add_node.output[0])], op_name_scope=matmul_node.name)

                ops.append(gemm)
                g.replace_all_inputs(add_node.output[0], gemm.output[0], ops=ops)
                to_delete = [add_node, matmul_node]
                g.safe_remove_nodes(to_delete)
    return ops
def rewrite_eye(g, ops):
    # schema of eye is eye(num_rows, num_columns=None), if num_columns not specified then it's equal to num_rows
    # tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while
    # these two ops are un-supported directly in onnx
    # but onnx op EyeLike can be used to map the sub_graph
    # "rewrite_eye" supports tf.eye(non_const) and tf.eye(non_const1, non_const2).
    # tf.eye(const) and tf.eye(const1, const2) are not supported in this rewriter

    # ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing
    if g.opset < 9:
        return g.get_nodes()

    pattern1 = \
        OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Const", name="fill_value"),
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    "*",
                    OpTypePattern("Pack", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast")
                    ])
                ])
            ])
        ])
    pattern2 = \
        OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Const", name="fill_value"),
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    "*",
                    OpTypePattern("Pack", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast")
                    ])
                ])
            ])
        ])
    pattern3 = \
        OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ])
        ])
    pattern4 = \
        OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ]),
        ])
    pattern5 = \
        OpTypePattern("MatrixDiagV3", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ]),
            "*", "*", "*", "*",
        ])
    pattern6 = \
        OpTypePattern("MatrixSetDiagV3", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ]), "*"
        ])
    pattern7 = \
        OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Reshape", inputs=[
                    OpTypePattern("Minimum|Cast", name="min_or_cast"),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ])
        ])
    pattern8 = \
        OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Reshape", inputs=[
                    OpTypePattern("Minimum|Cast", name="min_or_cast"),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ])
        ])

    for pattern in [
            pattern1, pattern2, pattern3, pattern4, pattern5, pattern6,
            pattern7, pattern8
    ]:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        for match_result in match_results:
            if match_result.get_op("fill_value").get_tensor_value() != 1:
                continue

            min_or_cast = match_result.get_op("min_or_cast")
            if min_or_cast.type == "Minimum":
                min_node = min_or_cast
            elif min_or_cast.type == "Cast" and min_or_cast.inputs[
                    0].type == "Minimum":
                min_node = min_or_cast.inputs[0]
            else:
                continue

            num_rows = min_node.inputs[0]
            num_columns = min_node.inputs[1]

            old_output = match_result.get_op("output_eye_matrix")
            output_dtypes = [g.get_dtype(old_output.output[0])]
            output_shapes = [g.get_shape(old_output.output[0])]
            g.remove_node(old_output.name)

            # onnx op "EyeLike" need a 2D tensor, so generate it

            num_rows = GraphBuilder(g).make_unsqueeze(
                {
                    "axes": [0],
                    "data": num_rows.output[0]
                }, return_node=True)
            num_columns = GraphBuilder(g).make_unsqueeze(
                {
                    "axes": [0],
                    "data": num_columns.output[0]
                }, return_node=True)
            matrix_shape = g.make_node(
                "Concat", [num_rows.output[0], num_columns.output[0]],
                attr={"axis": 0})
            # cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data.
            matrix_shape_int64 = g.make_node(
                "Cast",
                matrix_shape.output,
                attr={"to": onnx_pb.TensorProto.INT64})
            zero_matrix = g.make_node("ConstantOfShape",
                                      matrix_shape_int64.output)

            g.make_node("EyeLike",
                        zero_matrix.output,
                        attr={"dtype": output_dtypes[0]},
                        name=old_output.name,
                        shapes=output_shapes,
                        dtypes=output_dtypes,
                        outputs=old_output.output)

    return g.get_nodes()
def rewrite_dropout(g, ops):
    patterns = [
        OpTypePattern(
            'Mul',
            name='outputs',
            inputs=[
                OpTypePattern('RealDiv', name="input2"),
                OpTypePattern(
                    'Floor',
                    inputs=[
                        OpTypePattern(
                            'Add',
                            inputs=[
                                OpTypePattern("*", name="input3"),
                                OpTypePattern(
                                    'RandomUniform|RandomUniformLike'),
                            ])
                    ]),
            ]),
        OpTypePattern(
            "Mul",
            name="outputs",
            inputs=[
                OpTypePattern("Mul", name="input2"),
                OpTypePattern(
                    "Cast",
                    inputs=[
                        OpTypePattern(
                            "GreaterEqual",
                            inputs=[
                                OpTypePattern(
                                    "RandomUniform|RandomUniformLike"),
                                OpTypePattern("*", name="input3")
                            ])
                    ])
            ]),
        # pattern for tf-2.0 tf.nn.dropout()
        OpTypePattern(
            "Mul",
            name="outputs",
            inputs=[
                OpTypePattern(
                    "Cast",
                    inputs=[
                        OpTypePattern(
                            "GreaterEqual",
                            inputs=[
                                OpTypePattern(
                                    "RandomUniform|RandomUniformLike"),
                                OpTypePattern("*", name="input3")
                            ])
                    ]),
                OpTypePattern("Mul", name="input2"),
            ]),
    ]
    for pattern in patterns:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            input2 = match.get_op('input2')
            input3 = match.get_op('input3')
            outputs = match.get_op('outputs')

            if not input3.is_scalar():
                logger.warning(
                    "Dropout pattern rooted at %s does not have a "
                    "constant ratio and cannot be replaced.", outputs.name)
                continue
            ratio = input3.get_tensor_value()

            if input2.inputs[0].is_scalar():
                data = input2.inputs[1]
                scaling_constant = input2.inputs[0].get_tensor_value()
            elif input2.inputs[1].is_scalar():
                data = input2.inputs[0]
                scaling_constant = input2.inputs[1].get_tensor_value()
            else:
                logger.warning(
                    "Could not find scaling constant for dropout pattern rooted at %s. "
                    "The pattern will not be replaced with an ONNX dropout node.",
                    outputs.name)
                continue

            #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node
            if not np.allclose([1], [scaling_constant * (1 - ratio)]):
                logger.warning(
                    "Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout "
                    "ratio %f. The pattern will not be replaced with an ONNX dropout node.",
                    scaling_constant, outputs.name, ratio)
                continue

            nodes_to_remove = [
                n for n in match.get_nodes() if n.name != input3.name
            ]
            if not g.is_safe_to_remove_nodes(nodes_to_remove,
                                             [outputs.output[0]]):
                logger.warning(
                    "Nodes in dropout pattern rooted at %s cannot be removed because intermediate results "
                    "of some nodes are referenced elsewhere in graph.",
                    outputs.name)
                continue

            op_name = utils.make_name("Dropout")
            out_name = utils.port_name(op_name)
            new_node = g.make_node("Dropout",
                                   inputs=[data.output[0]],
                                   outputs=[out_name],
                                   name=op_name,
                                   attr={"ratio": ratio},
                                   shapes=[g.get_shape(data.output[0])],
                                   dtypes=[g.get_dtype(data.output[0])])
            g.replace_all_inputs(outputs.output[0],
                                 new_node.output[0],
                                 ops=ops)
            for n in nodes_to_remove:
                g.remove_node(n.name)

    return ops
def rewrite_tfl_scan_outputs(g, ops):
    pattern0 = \
        OpTypePattern('TFL_CONCATENATION', name='concat', inputs=[
            OpTypePattern('TFL_SLICE', name='begin_slice'),
            OpTypePattern('*', name='middle'),
            OpTypePattern('TFL_SLICE', name='end_slice')
        ])

    matcher = GraphMatcher(pattern0, allow_reorder=False)
    match_results = list(matcher.match_ops(ops))
    if match_results:
        for match in match_results:
            concat = match.get_op("concat")
            begin_slice = match.get_op("begin_slice")
            middle = match.get_op("middle")
            end_slice = match.get_op("end_slice")
            middle_shape = g.get_shape(middle.output[0])

            # Both slices must be slicing the same tensor
            if begin_slice.input[0] != end_slice.input[0]:
                continue
            original_tensor = begin_slice.input[0]
            if concat.get_attr_int("axis") != 0:
                continue
            # The inserted slice must have length 1 (to be a single index)
            if middle_shape is None or len(
                    middle_shape) == 0 or middle_shape[0] != 1:
                continue
            rank = len(middle_shape)
            scan_output = middle.output[0]
            if not begin_slice.inputs[1].is_const(
            ) or not end_slice.inputs[2].is_const():
                continue
            # The first slice must start from the beginning (0) for all dims
            if not all(v == 0
                       for v in begin_slice.inputs[1].get_tensor_value()):
                continue
            # The second slice must slice to the end (-1) for all dims
            if not all(v == -1
                       for v in end_slice.inputs[2].get_tensor_value()):
                continue
            # The other slice dims are assembled by concatenation if rank > 1
            if rank > 1:
                begin_concat = begin_slice.inputs[2]
                end_concat = end_slice.inputs[1]
                if not begin_concat.type == "TFL_CONCATENATION":
                    continue
                if not end_concat.type == "TFL_CONCATENATION":
                    continue
                # Except for dim 0, slice from beginning to end
                if not all(
                        get_uniform_const_val(inp) == -1
                        for inp in begin_concat.inputs[1:]):
                    continue
                if not all(
                        get_uniform_const_val(inp) == 0
                        for inp in end_concat.inputs[1:]):
                    continue
                begin_idx = begin_concat.inputs[0]
                end_idx = end_concat.inputs[0]
            else:
                begin_idx = begin_slice.inputs[2]
                end_idx = end_slice.inputs[1]
            # For dim 0, slice to i for first part and from i+1 for second
            if not node_is_one_plus_node(begin_idx, end_idx):
                continue
            out1, _ = get_out_and_offset(begin_idx)
            graph_inps = [n.output[0] for n in g.inputs]
            # To be a scan output, i must be a graph input
            if out1 not in graph_inps:
                continue
            # The array being sliced must be a graph input
            if original_tensor not in graph_inps:
                continue
            # The input/output index of i
            idx = graph_inps.index(out1)
            # The input/output index of the array
            scan_output_idx = graph_inps.index(original_tensor)
            # For a scan output, i must be assigned to i+1 with each iteration
            if not node_is_one_plus_node(g.get_node_by_output(out1),
                                         g.get_node_by_output(g.outputs[idx])):
                continue
            if len(g.find_output_consumers(concat.output[0])) > 1:
                continue

            if g.opset < 10 and len(g.find_output_consumers(
                    concat.output[0])) <= 1:
                # If opset is < 10, conversion of the subgraph will fail unless we remove the slice nodes
                # We add a tmp node to replace them.
                shape = g.get_shape(concat.output[0])
                dtype = g.get_dtype(concat.output[0])
                tmp_node = g.make_node("TMP_SCAN_OUTPUT",
                                       [original_tensor, scan_output],
                                       shapes=[shape],
                                       dtypes=[dtype])
                g.replace_all_inputs(concat.output[0], tmp_node.output[0])

            to_remove = []
            out = g.outputs[scan_output_idx]
            node = g.get_node_by_output(out)
            to_remove.append(node)

            while len(node.input) > 0 and node != concat:
                out = node.input[0]
                node = g.get_node_by_output(out)
                to_remove.append(node)

            to_remove += [begin_slice, end_slice, concat]

            out = original_tensor
            node = g.get_node_by_output(out)
            to_remove.append(node)

            while len(node.input) > 0:
                out = node.input[0]
                node = g.get_node_by_output(out)
                to_remove.append(node)

            if not g.is_safe_to_remove_nodes(to_remove):
                continue

            g.scan_outputs.append((scan_output_idx, scan_output))
    return ops
Example #18
0
def rewrite_random_normal(g, ops):
    pattern1 = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
            ]), "*"
        ])

    pattern2 = \
        OpTypePattern('Identity', name='output', inputs=[
            OpTypePattern('Identity', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"])
            ])
        ])

    pattern_list = [pattern1, pattern2]
    for pattern in pattern_list:
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            output = match.get_op('output')
            if output.type == 'Add':
                # pattern 1
                mean = output.inputs[1].get_tensor_value()
            else:
                # pattern 2
                mean = 0.0
            dtype = g.get_dtype(output.output[0])
            op_name = utils.make_name("RandomNormal")
            out_name = utils.port_name(op_name)

            rn_op = match.get_op('input1')
            seed = rn_op.get_attr('seed2').i

            if rn_op.inputs[0].type == "Shape":
                shape_node = rn_op.inputs[0]
                new_node = g.make_node("RandomNormalLike",
                                       [shape_node.input[0]],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr={
                                           "mean": mean,
                                           "scale": 1.0,
                                           "dtype": dtype,
                                           "seed": float(seed)
                                       })
            else:
                shape = g.get_shape(output.output[0])
                new_node = g.make_node("RandomNormal", [],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr={
                                           "shape": shape,
                                           "mean": mean,
                                           "scale": 1.0,
                                           "dtype": dtype,
                                           "seed": seed
                                       })

            g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
            g.safe_remove_nodes(match.get_nodes())
    return ops
def rewrite_flatten(g, ops):
    pattern_fixed_shape_input = \
        OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    "*", "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    pattern_non_fixed_shape_input = \
        OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    OpTypePattern('Shape', inputs=[
                        OpTypePattern("*", name="input2")
                    ]),
                    "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    matcher = GraphMatcher(pattern_fixed_shape_input)
    match_results_1 = list(matcher.match_ops(ops))

    matcher = GraphMatcher(pattern_non_fixed_shape_input)
    match_results_2 = list(matcher.match_ops(ops))

    match_results = [(match_results_1, True), (match_results_2, False)]
    for match_results, check_fixed_input_shape in match_results:
        for match in match_results:
            input_node = match.get_op('input')
            reshape_node = match.get_op('reshape')
            pack_node = match.get_op('pack')
            slice_node = match.get_op('slice')
            need_rewrite = pack_node.inputs[1].is_const() and pack_node.inputs[1].get_tensor_value() == -1
            if not need_rewrite:
                continue

            input_shape = g.get_shape(reshape_node.input[0])
            need_rewrite = input_shape is not None
            if not need_rewrite:
                continue

            if check_fixed_input_shape:
                need_rewrite = slice_node.inputs[0].is_const() and \
                               np.array_equal(list(input_shape), list(slice_node.inputs[0].get_tensor_value()))
                if not need_rewrite:
                    continue

            begin = slice_node.inputs[1].get_tensor_value(as_list=False)
            end = slice_node.inputs[2].get_tensor_value(as_list=False)
            strides = slice_node.inputs[3].get_tensor_value(as_list=False)
            need_rewrite = np.array_equal(begin, [0]) and len(end) == 1 and \
                           np.array_equal(strides, [1]) and end[0] - begin[0] == 1
            if not need_rewrite:
                continue

            to_remove = [n for n in match.get_nodes() if n != input_node]
            safe = g.safe_to_remove_nodes(to_remove)

            # Ok if reshape_node is not safe. Will make it safe later.
            if len(to_remove) - len(safe) > 1:
                continue

            op_name = utils.make_name("Flatten")
            out_name = utils.port_name(op_name)
            g.make_node("Flatten", [reshape_node.input[0]], outputs=[out_name], name=op_name)

            last_dim = input_shape[-1]
            sec_last_dim = input_shape[-2]
            new_dim = None
            if last_dim > 0 and sec_last_dim > 0:
                new_dim = last_dim * sec_last_dim
            else:
                new_dim = -1

            g.set_shape(out_name, input_shape[:-2] + [new_dim])
            g.replace_all_inputs(reshape_node.output[0], out_name, ops=ops)
            for n in to_remove:
                g.remove_node(n.name)

    return ops