def connect_unit_rnn_output_to_graph(self, context):
        outputs = context.loop_properties.scan_outputs_exits
        if not outputs:
            logger.debug("no one consume output")
            return

        gb = GraphBuilder(self.g)
        gather_output_id = outputs[0].id
        logger.debug("found output for rnn: %s", gather_output_id)

        # in tf batch major mode, output shape is : [batch, time, hidden]
        # in time major mode, output shape is: [time, batch, hidden]
        # in onnx, output shape is : [time, num_directions, batch, hidden]

        rnn_node = context.rnn_node[len(context.rnn_node) - 1]
        output_id = rnn_node.output[0]
        rnn_output_shape = self.g.get_shape(output_id)
        squeeze_output_shape = [
            rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3]
        ]
        squeeze_node = gb.make_squeeze({
            'data': output_id,
            "axes": [1]
        },
                                       shapes=[squeeze_output_shape],
                                       dtypes=[self.g.get_dtype(output_id)],
                                       return_node=True)
        self.g.replace_all_inputs(
            gather_output_id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Exemplo n.º 2
0
    def _connect_lstm_ych_to_graph(self, context, i):
        # in tf, concat of y_c and y_h output shape is: [batch, hidden *2]
        # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden]
        gb = GraphBuilder(self.g)
        exit_output = context.state_variables["ct_ht" + str(i)].exit_output
        lstm_node = context.rnn_node[i]
        yc_shape = self.g.get_shape(lstm_node.output[2])
        concat_output_shape = [yc_shape[0], yc_shape[1], yc_shape[2] * 2]
        concat = self.g.make_node(
            "Concat", [lstm_node.output[2], lstm_node.output[1]],
            attr={"axis": 2},
            shapes=[concat_output_shape],
            dtypes=[self.g.get_dtype(lstm_node.output[2])])

        squeeze_output_shape = [concat_output_shape[1], concat_output_shape[2]]
        squeeze_node = gb.make_squeeze(
            {
                'data': concat.output[0],
                "axes": [0]
            },
            shapes=[squeeze_output_shape],
            dtypes=[self.g.get_dtype(concat.output[0])],
            return_node=True)

        self.g.replace_all_inputs(
            exit_output.id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Exemplo n.º 3
0
    def create_rnn_node(self, context):
        gb = GraphBuilder(self.g)
        rnn_nodes = list()
        outputs = context.loop_properties.scan_outputs_exits
        logger.debug("number of rnn node outputs: %s", len(outputs))

        for i in range(self.num_lstm_layers):
            logger.debug("creating rnn node for layer: %s", i)
            rnn_nodes.append(self.create_single_rnn_node(context, i))
            output_id = rnn_nodes[i].output[0]
            rnn_output_shape = self.g.get_shape(output_id)
            squeeze_output_shape = [
                rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3]
            ]
            squeeze_node = gb.make_squeeze(
                {
                    "data": output_id,
                    "axes": [1]
                },
                shapes=[squeeze_output_shape],
                dtypes=[self.g.get_dtype(output_id)],
                return_node=True)
            if i + 1 < self.num_lstm_layers:
                logger.debug("setting input for layer: %s", i + 1)
                context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0]
        return rnn_nodes
Exemplo n.º 4
0
 def any_version(cls, opset, ctx, node, **kwargs):
     node.domain = constants.CONTRIB_OPS_DOMAIN
     separator = node.get_attr_value("separator")
     if separator is None:
         separator = b''
     separator = separator.decode('UTF-8')
     separator_node = ctx.make_const(utils.make_name("separator"),
                                     np.array([separator], np.object))
     axis_node = ctx.make_const(utils.make_name("axis"),
                                np.array([0], np.int64))
     inps_with_shapes = [i for i in node.input if ctx.get_shape(i) != []]
     shape_node = None
     if 0 < len(inps_with_shapes) < len(node.input):
         shape_node = ctx.make_node("Shape", [inps_with_shapes[0]])
     unsqueezes = []
     for inp in node.input:
         if ctx.get_shape(inp) == [] and shape_node is not None:
             expand_node = ctx.make_node("Expand",
                                         [inp, shape_node.output[0]])
             inp = expand_node.output[0]
         unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({
             'data': inp,
             'axes': [0]
         })
         unsqueezes.append(unsqueeze_node)
     stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
     ctx.replace_inputs(node, [
         stack_node.output[0], separator_node.output[0], axis_node.output[0]
     ])
Exemplo n.º 5
0
    def version_6(cls, ctx, node, **kwargs):
        # T output = All(T x, list(int) reduce_indices, @bool keepdims)
        # T output = Any(T x, list(int) reduce_indices, @bool keepdims)
        reduce_dim = node.inputs[1].get_tensor_value()

        # for Any, the reduce_indices can be scalar as observed.
        if np.isscalar(reduce_dim):
            reduce_dim = [reduce_dim]

        if ctx.opset < 11:
            utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now")

        cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
        keepdims = helper.get_attribute_value(node.get_attr("keep_dims"))
        op_type = "ReduceMin" if node.type == "All" else "ReduceSum"

        if op_type == "ReduceSum":
            reduce_node_output = GraphBuilder(ctx).make_reduce_sum(
                {"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1})
        else:
            reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output,
                                               attr={"axes": reduce_dim, "keepdims": keepdims}).output[0]

        zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32))

        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]],
                      name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
Exemplo n.º 6
0
 def version_1(cls, ctx, node, **kwargs):
     # in tf-2.0 grappler optimizes the graph pretty well and our matching logic
     # in the rewriter does not trigger. grappler will send the random uniform
     # with shape as input so we need to pickup the input here and if the shape is
     # const we make it an attribute.
     seed = node.get_attr("seed")
     node.set_attr("seed", float(seed.f))
     utils.make_sure(node.inputs[0].is_const(),
                     "%s node with non-const shape requires opset >= 9")
     shape = node.inputs[0].get_tensor_value()
     ctx.remove_input(node, node.input[0], 0)
     if len(shape) == 0:
         # ORT can't take an empty shape (scalar)
         node.set_attr("shape", [1])
         ctx.set_shape(node.output[0], [1])
         squeeze_node = GraphBuilder(ctx).make_squeeze(
             {
                 'data': node.output[0],
                 'axes': [0]
             }, return_node=True)
         ctx.insert_node_on_output(squeeze_node, node.output[0])
         rand_out = squeeze_node.output[0]
     else:
         node.set_attr("shape", shape)
         ctx.set_shape(node.output[0], shape)
         rand_out = node.output[0]
     if node.type == "RandomUniformInt":
         cls.randuniform_int(ctx, node, rand_out, node.input[0],
                             node.input[1])
         node.type = "RandomUniform"
         ctx.replace_inputs(node, [])
Exemplo n.º 7
0
    def _connect_lstm_yc_to_graph(self, context, i):
        # in tf, y_c output shape is: [batch, hidden]
        # in onnx, output shape is: [number_directions, batch, hidden]
        gb = GraphBuilder(self.g)
        exit_output = context.state_variables["ct" + str(i)].exit_output
        output_id = context.rnn_node[i].output[2]
        lstm_yc_shape = self.g.get_shape(output_id)
        squeeze_node = gb.make_squeeze(
            {
                "data": output_id,
                "axes": [0]
            },
            shapes=[[lstm_yc_shape[1], lstm_yc_shape[2]]],
            dtypes=[self.g.get_dtype(output_id)],
            return_node=True)

        self.g.replace_all_inputs(
            exit_output.id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Exemplo n.º 8
0
    def _process_c_or_h_init_nodes(self, initializer_input_id, context):
        node = self.g.get_node_by_output(initializer_input_id)
        if node.is_const():
            val = node.get_tensor_value(as_list=False)
            initial_name = utils.make_name("Const")
            new_val = np.expand_dims(val, axis=0)
            const_node = self.g.make_const(initial_name, new_val)
            return const_node.output[0]

        gb = GraphBuilder(self.g)
        squeeze_node = gb.make_unsqueeze(
            {
                'data': initializer_input_id,
                "axes": [0]
            }, return_node=True)
        to_replace = [n for n in self.g.get_nodes() if n != squeeze_node]
        self.g.replace_all_inputs(initializer_input_id,
                                  squeeze_node.output[0],
                                  ops=to_replace)
        return squeeze_node.output[0]
Exemplo n.º 9
0
    def any_version(cls, opset, ctx, node, **kwargs):
        """
        Computes the modules of a complex.
        If the matrix dtype is not complex64 or complex128,
        it assumes the first dimension means real part (0)
        and imaginary part (1, :, :...).
        """
        supported_dtypes = [
            onnx_pb.TensorProto.FLOAT,
            onnx_pb.TensorProto.FLOAT16,
            onnx_pb.TensorProto.DOUBLE,
            onnx_pb.TensorProto.COMPLEX64,
            onnx_pb.TensorProto.COMPLEX128,
        ]
        onnx_dtype = ctx.get_dtype(node.input[0])
        utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
        shape = ctx.get_shape(node.input[0])
        np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
        utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape)

        ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64))
        ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64))
        p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype))

        real_part = ctx.make_node(
            'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0),
            name=utils.make_name('Real_' + node.name))
        imag_part = ctx.make_node(
            'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0),
            name=utils.make_name('Imag_' + node.name))

        real_part2 = ctx.make_node(
            'Pow', inputs=[real_part.output[0], p2.name],
            name=utils.make_name(real_part.name + 'p2p'))

        imag_part2 = ctx.make_node(
            'Pow', inputs=[imag_part.output[0], p2.name],
            name=utils.make_name(imag_part.name + 'p2p'))

        ctx.remove_node(node.name)
        add = ctx.make_node(
            "Add", inputs=[real_part2.output[0], imag_part2.output[0]],
            name=utils.make_name('ComplexAbs_' + node.name))

        squeezed = GraphBuilder(ctx).make_squeeze(
            {'data': add.output[0], 'axes': [0]}, name=utils.make_name('ComplexAbs' + node.name), return_node=True)

        last_node = ctx.make_node(
            "Sqrt", inputs=squeezed.output[:1],
            name=utils.make_name('ComplexAbs' + node.name),
            shapes=[shape[1:]], dtypes=[onnx_dtype])

        ctx.replace_all_inputs(node.output[0], last_node.output[0])  # ops=ctx.get_nodes()
Exemplo n.º 10
0
    def _process_non_tuple_ch_init_nodes(self, context, i):
        gb = GraphBuilder(self.g)
        input_id = context.state_variables["ct_ht" + str(i)].enter_input_id
        hidden_size = context.hidden_size[i]

        attr = {"axes": [1], "starts": [0], "ends": [hidden_size]}
        inputs_map = {"data": input_id, **attr}
        slice_node1 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_1 = gb.make_unsqueeze({
            'data': slice_node1,
            "axes": [0]
        },
                                             return_node=True)

        attr = {
            "axes": [1],
            "starts": [hidden_size],
            "ends": [hidden_size * 2]
        }
        inputs_map = {"data": input_id, **attr}
        slice_node2 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_2 = gb.make_unsqueeze({
            'data': slice_node2,
            "axes": [0]
        },
                                             return_node=True)

        return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
Exemplo n.º 11
0
def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn,
                                           rnn_output_index, all_nodes,
                                           to_remove):
    fw_consumers = g.find_output_consumers(rnn_fw.output[rnn_output_index])
    bw_consumers = g.find_output_consumers(rnn_bw.output[rnn_output_index])
    if not fw_consumers and not bw_consumers:
        return

    if rnn_output_index == 0:
        axis = 1
        # remove reverse op for rnn_bw
        reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw)

        for r_op in reverse_nodes:
            logger.debug("remove reverse op %s", r_op.name)
            g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
            to_remove.append(r_op.name)
    elif rnn_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("rnn only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr}
        slice_node_fw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_fw))
        g.replace_all_inputs(rnn_fw.output[rnn_output_index],
                             slice_node_fw,
                             ops=fw_consumers)

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr}
        slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_bw))
        g.replace_all_inputs(rnn_bw.output[rnn_output_index],
                             slice_node_bw,
                             ops=bw_consumers)
Exemplo n.º 12
0
    def any_version(cls, opset, ctx, node, **kwargs):
        if node.type == "StringSplit":
            skip_empty = node.get_attr_value('skip_empty', True)
        else:
            skip_empty = False
        node.type = "StringSplit"
        node.domain = constants.CONTRIB_OPS_DOMAIN
        for a in list(node.attr.keys()):
            del node.attr[a]
        unsqueeze_node = GraphBuilder(ctx).make_unsqueeze(
            {
                'data': node.input[1],
                'axes': [0]
            }, return_node=True)

        skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'),
                                          np.array([skip_empty], np.bool))
        ctx.replace_inputs(node, [
            node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0]
        ])
Exemplo n.º 13
0
    def version_1(cls, ctx, node, **kwargs):
        node.domain = constants.CONTRIB_OPS_DOMAIN
        input_node = node.inputs[0]
        utils.make_sure(input_node.type == "SentencepieceOp",
                        "Input 0 to node %s is not SentencepieceOp", node.name)
        ctx.remove_input(node, node.input[0], 0)

        nbest_size_cast = ctx.make_node("Cast", [node.input[1]],
                                        attr={
                                            'to': TensorProto.INT64
                                        }).output[0]
        ctx.replace_input(node, node.input[1], nbest_size_cast, 1)
        for i in range(1, len(node.input)):
            unsqueeze = GraphBuilder(ctx).make_unsqueeze({
                'data': node.input[i],
                'axes': [0]
            })
            ctx.replace_input(node, node.input[i], unsqueeze, i)
        node.set_attr("model", input_node.attr['model'].s)
        node.type = "SentencepieceTokenizer"
        if ctx.is_safe_to_remove_nodes([input_node]):
            ctx.remove_node(input_node.name)
Exemplo n.º 14
0
    def any_version(cls, opset, ctx, node, **kwargs):
        node_inputs = node.input
        num_segments_specified = False
        if node.type.endswith("WithNumSegments") or node.type.startswith("Unsorted"):
            num_segments_specified = True
            num_segments = node_inputs.pop()
            node.type = node.type.replace("WithNumSegments", "")
            node.type = node.type.replace("Unsorted", "")
        if node.type.startswith("Sparse"):
            data_inp, indices_inp, segment_inp = node_inputs
            gather_node = ctx.make_node("Gather", [data_inp, indices_inp], attr={'axis': 0})
            data_inp = gather_node.output[0]
            node.type = node.type.replace("Sparse", "")
        else:
            data_inp, segment_inp = node_inputs

        # Data has shape [n, a, b, ..., c]
        data_shape = ctx.get_shape(data_inp)
        data_rank = len(data_shape) if data_shape is not None else None
        data_dtype = ctx.get_dtype(data_inp)
        data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype)
        seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))

        if num_segments_specified and ctx.get_dtype(segment_inp) != ctx.get_dtype(num_segments):
            num_segments = ctx.make_node("Cast", [num_segments], attr={"to": ctx.get_dtype(segment_inp)}).output[0]

        data_is_float = np.dtype(data_np_dtype).kind == 'f'
        data_is_int = np.dtype(data_np_dtype).kind == 'i'
        utils.make_sure(data_is_float or data_is_int, "dtype for Segment ops must be float or int")

        if node.type in ["SegmentSum", "SegmentMean", "SegmentSqrtN"]:
            onnx_op = "ReduceSum"
            identity_value = np.array(0, dtype=data_np_dtype)
        elif node.type == "SegmentProd":
            onnx_op = "ReduceProd"
            identity_value = np.array(1, dtype=data_np_dtype)
        elif node.type == "SegmentMax":
            onnx_op = "ReduceMax"
            if data_is_float:
                identity_value = np.array('-inf', dtype=data_np_dtype)
            else:
                identity_value = np.iinfo(data_np_dtype).min
        elif node.type == "SegmentMin":
            onnx_op = "ReduceMin"
            if data_is_float:
                identity_value = np.array('inf', dtype=data_np_dtype)
            else:
                identity_value = np.iinfo(data_np_dtype).max

        if not num_segments_specified:
            max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0})
            one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype))
            num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]]).output[0]
        # ORT doesn't support bool for OneHot so we use float32 and cast to bool
        onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=np.float32))
        # one_hot_node has shape [s, n] (s is # segments)
        one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments, onehot_values.output[0]],
                                     attr={'axis': 0})
        if node.type == "SegmentMean":
            scaling_node_output = GraphBuilder(ctx).make_reduce_sum(
                {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1})
        elif node.type == "SegmentSqrtN":
            seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum(
                {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1})
            scaling_node_output = ctx.make_node("Sqrt", [seg_cnts_node_output]).output[0]
        else:
            scaling_node_output = None

        if scaling_node_output is not None and num_segments_specified:
            # If empty segments are possible, we must avoid division by zero
            const_one_float = ctx.make_const(utils.make_name("const_one_float"), np.array(1, dtype=np.float32))
            scaling_node_output = ctx.make_node("Max", [scaling_node_output, const_one_float.output[0]]).output[0]


        if onnx_op == "ReduceSum":
            # If the op is a summation, we can use MatMul instead of Where, which is faster

            # Data shape is [n, a, b, ..., c]
            data_shape_node = ctx.make_node("Shape", [data_inp])
            new_shape = ctx.make_const(utils.make_name("reshape_const"), np.array([0, -1], dtype=np.int64))
            # Reshape the data from [n, a, b, ..., c] to [n, P]
            data_reshape = ctx.make_node("Reshape", [data_inp, new_shape.output[0]])

            one_hot_cast = one_hot_node
            if data_dtype != onnx_pb.TensorProto.FLOAT:
                one_hot_cast = ctx.make_node("Cast", [one_hot_node.output[0]], attr={'to': data_dtype})

            # Shapes [s, n] * [n, P] => [s, P]
            product = ctx.make_node("MatMul", [one_hot_cast.output[0], data_reshape.output[0]], op_name_scope=node.name)
            if scaling_node_output is not None:
                scaling_node_unsqueeze = GraphBuilder(ctx).make_unsqueeze(
                    {'data': scaling_node_output, 'axes': [1]}, return_node=True)
                product = ctx.make_node("Div", [product.output[0], scaling_node_unsqueeze.output[0]])

            # Create new shape [0, a, b, ..., c]
            max_int64 = int(utils.get_max_value(np.int64))
            new_shape_slice = GraphBuilder(ctx).make_slice(
                {"data": data_shape_node.output[0], "ends": [max_int64], "starts": [1], "axes": [0]})
            zero_const = ctx.make_const(utils.make_name("zero_const"), np.array([0], dtype=np.int64))
            new_shape = ctx.make_node("Concat", [zero_const.output[0], new_shape_slice], attr={'axis': 0})

            shapes = node.output_shapes
            dtypes = node.output_dtypes
            ctx.remove_node(node.name)
            # Reshape result from [s, P] to [s, a, b, ..., c]
            ctx.make_node("Reshape", [product.output[0], new_shape.output[0]],
                          name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
            return

        identity_const = ctx.make_const(utils.make_name("const_identity"), identity_value)
        one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]], attr={"to": onnx_pb.TensorProto.BOOL})
        one_hot_unsqueeze = one_hot_bool

        # Make one_hot_unsqueeze have shape [s, n, 1, 1, ..., 1]
        if data_rank is None:
            # Unsqueeze requires known rank, but we can use Reshape if rank is unknown
            shape_node = ctx.make_node("Shape", [data_inp])
            rank_node = ctx.make_node("Shape", [shape_node.output[0]])
            one_const_int64 = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64))
            num_unsqueeze_dims = ctx.make_node("Sub", [rank_node.output[0], one_const_int64.output[0]])

            one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[1])
            unsqueeze_dims = ctx.make_node("ConstantOfShape", inputs=[num_unsqueeze_dims.output[0]],
                                           attr={"value": one_tensor})
            # Zero indicates a dimension should be unchanged
            double_zero_const = ctx.make_const(utils.make_name("double_zero"), np.array([0, 0], dtype=np.int64))
            expanded_shape = ctx.make_node("Concat", [double_zero_const.output[0], unsqueeze_dims.output[0]],
                                           attr={'axis': 0})
            one_hot_unsqueeze = ctx.make_node("Reshape", [one_hot_bool.output[0], expanded_shape.output[0]])
        elif data_rank > 1:
            new_dims = list(range(2, 2 + data_rank - 1))
            one_hot_unsqueeze = GraphBuilder(ctx).make_unsqueeze(
                {'data': one_hot_bool.output[0], 'axes': new_dims}, return_node=True)

        # Shape of data:       [n, a, b, ..., c]
        # Shape of one_hot: [s, n, 1, 1, ..., 1]
        # Broadcast left-pads shape with 1s, so result is shape: [s, n, a, b, ..., c]
        where_node = ctx.make_node("Where", [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]])

        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        # After reduction over axis 1, shape is: [s, a, b, ..., c]
        ctx.make_node(onnx_op, [where_node.output[0]], attr={'axes': [1], 'keepdims': 0},
                      name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
Exemplo n.º 15
0
    def _adapt_scan_sequence_input_or_output(self,
                                             target_name,
                                             input_id,
                                             handle_output=False):
        nodes_to_add = []
        shape_node = self.g.make_node("Shape", [input_id])
        nodes_to_add.append(shape_node)
        inferred_shape = self.g.get_shape(input_id)
        if handle_output is True:
            # handle output:
            # if required dim values don't contain more than one -1,
            # just use a const for Reshape's shape input.
            if inferred_shape is not None and inferred_shape[1:].count(
                    -1) <= 1:
                new_shape_node = self.g.make_const(
                    utils.make_name(target_name + "_target_shape"),
                    np.array(inferred_shape[1:], dtype=np.int64))
                nodes_to_add.append(new_shape_node)
            else:
                # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1)
                # from [1, time, real-batch, ...]
                origin_shape_node = self.g.make_node(
                    "Cast", [shape_node.output[0]],
                    {"to": onnx_pb.TensorProto.FLOAT})
                nodes_to_add.append(origin_shape_node)

                attr = {"axes": [0], "starts": [1], "ends": [sys.maxsize]}
                inputs_map = {"data": origin_shape_node.output[0], **attr}
                sliced_shape_node = GraphBuilder(self.g).make_slice(inputs_map)
                nodes_to_add.append(
                    self.g.get_node_by_output(sliced_shape_node))

                new_shape_node = self.g.make_node(
                    "Cast", [sliced_shape_node],
                    {"to": onnx_pb.TensorProto.INT64})
                nodes_to_add.append(new_shape_node)

            new_shape = inferred_shape[1:]
        else:
            # handle input:
            if inferred_shape is not None and inferred_shape.count(-1) <= 1:
                new_shape_node = self.g.make_const(
                    utils.make_name(target_name + "_target_shape"),
                    np.array([1] + inferred_shape, dtype=np.int64))
                nodes_to_add.append(new_shape_node)
            else:
                # add a fake batch size : 1
                fake_batch_size_node = self.g.make_const(
                    utils.make_name(target_name + "_target_shape"),
                    np.array([1], dtype=np.int64))
                nodes_to_add.append(fake_batch_size_node)
                new_shape_node = self.g.make_node(
                    "Concat",
                    [fake_batch_size_node.output[0], shape_node.output[0]],
                    attr={"axis": 0})
                nodes_to_add.append(new_shape_node)
            new_shape = [1] + inferred_shape

        reshape_node = self.g.make_node("Reshape",
                                        [input_id, new_shape_node.output[0]],
                                        shapes=[new_shape],
                                        dtypes=[self.g.get_dtype(input_id)],
                                        op_name_scope=target_name)
        nodes_to_add.append(reshape_node)
        logger.debug("create Reshape for scan output %s, with output shape %s",
                     reshape_node.output[0], new_shape)
        return nodes_to_add
Exemplo n.º 16
0
    def version_11(cls, ctx, node, **kwargs):
        # This ops is basically NMS with a little post-processing.
        # TFLite implementation:
        # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/kernels/detection_postprocess.cc

        # box_encodings.shape = [batch_dim, box_num, 4]
        # class_predictions.shape = [batch_dim, box_num, num_classes(+1)]
        # anchors.shape = [box_num, 4]
        box_encodings, class_predictions, anchors = node.input

        classes_dtype = ctx.get_dtype(node.output[1])
        box_cnt_dtype = ctx.get_dtype(node.output[3])

        num_classes = node.get_attr_value('num_classes')
        max_detections = node.get_attr_value('max_detections')

        # Remove 'other' class if present.
        max_int64 = int(utils.get_max_value(np.int64))
        class_predictions = GraphBuilder(ctx).make_slice(
            {'data': class_predictions, 'starts': [-num_classes], 'ends': [max_int64], 'axes': [2]})

        scaling_vector = [node.get_attr_value(a) for a in ['y_scale', 'x_scale', 'h_scale', 'w_scale']]
        scale_const = ctx.make_const(utils.make_name('scale_const'), np.array(scaling_vector, np.float32)).output[0]

        scaled_boxes = ctx.make_node('Div', [box_encodings, scale_const]).output[0]
        anchors_yx = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [0], 'ends': [2], 'axes': [1]})
        anchors_hw = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [2], 'ends': [4], 'axes': [1]})
        boxes_yx = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [0], 'ends': [2], 'axes': [2]})
        boxes_hw = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [2], 'ends': [4], 'axes': [2]})

        scaled_boxes_yx = ctx.make_node('Mul', [boxes_yx, anchors_hw]).output[0]
        boxes_hw_exp = ctx.make_node('Exp', [boxes_hw]).output[0]
        scaled_boxes_hw = ctx.make_node('Mul', [boxes_hw_exp, anchors_hw]).output[0]
        const_half = ctx.make_const(utils.make_name('const_half'), np.array(0.5, np.float32)).output[0]
        boxes_half_hw = ctx.make_node('Mul', [scaled_boxes_hw, const_half]).output[0]
        boxes_center_yx = ctx.make_node('Add', [scaled_boxes_yx, anchors_yx]).output[0]

        boxes_lower_left = ctx.make_node('Sub', [boxes_center_yx, boxes_half_hw]).output[0]
        boxes_upper_right = ctx.make_node('Add', [boxes_center_yx, boxes_half_hw]).output[0]
        adjusted_boxes = ctx.make_node('Concat', [boxes_lower_left, boxes_upper_right], attr={'axis': 2}).output[0]

        iou_threshold = np.array(node.get_attr_value('nms_iou_threshold'), np.float32)
        iou_threshold_const = ctx.make_const(utils.make_name('iou_threshold'), iou_threshold).output[0]

        score_threshold = np.array(node.get_attr_value('nms_score_threshold'), np.float32)
        score_threshold_const = ctx.make_const(utils.make_name('score_threshold'), score_threshold).output[0]

        if node.get_attr_value('use_regular_nms', False):
            boxes_per_class = np.array(node.get_attr_value('detections_per_class', 100), np.int64)
        else:
            # When tflite uses FastNMS, detections_per_class is ignored.
            logging.warning("NMS node %s uses fast NMS. ONNX will approximate with standard NMS.", node.name)
            boxes_per_class = np.array(max_detections, np.int64)
        max_boxes_per_class_const = ctx.make_const(utils.make_name('max_boxes_per_class'), boxes_per_class).output[0]

        # scores.shape = [batch_dim, classes_num, box_num]
        scores = ctx.make_node('Transpose', [class_predictions], attr={'perm': [0, 2, 1]}).output[0]

        nms_inputs = [adjusted_boxes, scores, max_boxes_per_class_const, iou_threshold_const, score_threshold_const]
        # shape: [-1, 3], elts of format [batch_index, class_index, box_index]
        selected_indices = ctx.make_node('NonMaxSuppression', nms_inputs, attr={'center_point_box': 0},
                                         op_name_scope=node.name).output[0]

        selected_boxes_idx = GraphBuilder(ctx).make_slice(
            {'data': selected_indices, 'starts': [2], 'ends': [3], 'axes': [1]})
        selected_boxes_idx_sq = GraphBuilder(ctx).make_squeeze({'data': selected_boxes_idx, 'axes': [1]})

        selected_classes = GraphBuilder(ctx).make_slice(
            {'data': selected_indices, 'starts': [1], 'ends': [2], 'axes': [1]})
        selected_classes_sq = GraphBuilder(ctx).make_squeeze({'data': selected_classes, 'axes': [1]})

        box_and_class_idx = ctx.make_node('Concat', [selected_boxes_idx, selected_classes], attr={'axis': 1}).output[0]

        box_cnt = ctx.make_node('Shape', [selected_classes_sq]).output[0]

        adjusted_boxes_sq = GraphBuilder(ctx).make_squeeze({'data': adjusted_boxes, 'axes': [0]})
        detection_boxes = ctx.make_node('Gather', [adjusted_boxes_sq, selected_boxes_idx_sq]).output[0]
        class_predictions_sq = GraphBuilder(ctx).make_squeeze({'data': class_predictions, 'axes': [0]})
        detection_scores = ctx.make_node('GatherND', [class_predictions_sq, box_and_class_idx]).output[0]

        k_const = ctx.make_const(utils.make_name('const_k'), np.array([max_detections], np.int64)).output[0]
        if ctx.opset >= 12:
            min_k = ctx.make_node('Min', [k_const, box_cnt]).output[0]
        else:
            # Lower opsets only support Min between floats
            box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': TensorProto.FLOAT}).output[0]
            k_const_float = ctx.make_node('Cast', [k_const], attr={'to': TensorProto.FLOAT}).output[0]
            min_k_float = ctx.make_node('Min', [k_const_float, box_cnt_float]).output[0]
            min_k = ctx.make_node('Cast', [min_k_float], attr={'to': TensorProto.INT64}).output[0]
        min_k_cast = ctx.make_node('Cast', [min_k], attr={'to': box_cnt_dtype}).output[0]

        scores_top_k, scores_top_k_idx = ctx.make_node('TopK', [detection_scores, min_k], output_count=2).output

        scores_top_k_idx_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k_idx, 'axes': [0]})
        scores_top_k_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k, 'axes': [0]})

        selected_classes_sort = ctx.make_node('Gather', [selected_classes_sq, scores_top_k_idx_unsq]).output[0]
        classes_sort_cast = ctx.make_node('Cast', [selected_classes_sort], attr={'to': classes_dtype}).output[0]
        detection_boxes_sorted = ctx.make_node('Gather', [detection_boxes, scores_top_k_idx_unsq]).output[0]

        pad_amount = ctx.make_node('Sub', [k_const, min_k]).output[0]

        quad_zero_const = ctx.make_const(utils.make_name('quad_zero_const'), np.array([0, 0, 0, 0], np.int64)).output[0]
        duo_zero_const = ctx.make_const(utils.make_name('duo_zero_const'), np.array([0, 0], np.int64)).output[0]
        zero_const = ctx.make_const(utils.make_name('zero_const'), np.array([0], np.int64)).output[0]

        pads_3d = ctx.make_node('Concat', [quad_zero_const, pad_amount, zero_const], attr={'axis': 0}).output[0]
        pads_2d = ctx.make_node('Concat', [duo_zero_const, zero_const, pad_amount], attr={'axis': 0}).output[0]

        detection_boxes_padded = ctx.make_node('Pad', [detection_boxes_sorted, pads_3d]).output[0]
        detection_classes_padded = ctx.make_node('Pad', [classes_sort_cast, pads_2d]).output[0]
        detection_scores_padded = ctx.make_node('Pad', [scores_top_k_unsq, pads_2d]).output[0]

        ctx.replace_all_inputs(node.output[0], detection_boxes_padded)
        ctx.replace_all_inputs(node.output[1], detection_classes_padded)
        ctx.replace_all_inputs(node.output[2], detection_scores_padded)
        ctx.replace_all_inputs(node.output[3], min_k_cast)

        ctx.remove_node(node.name)
Exemplo n.º 17
0
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # create a dummy loop to calculate the init condition
            init_cond_output = self._create_subgraph_initial_cond(cond_g_info)

            ## create Loop body graph with existing nodes

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                shape = out_tensor_value_info.shape
                utils.make_sure(
                    shape is not None,
                    "Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id)
                )
                out_tensor_value_info.shape = utils.create_vague_shape_like(shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                gb = GraphBuilder(loop_body_g)
                index_node = gb.make_unsqueeze({'data': input_ta.index_input_id, "axes": [0]}, return_node=True)
                gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = gb.make_squeeze({'data': gather_node.output[0], "axes": [0]}, return_node=True)
                loop_body_g.replace_all_inputs(input_ta.consumer.id, data_node.output[0])  # ops=loop_body_g.get_nodes()

            ## create Loop node
            branches = {"body": loop_body_g}
            loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL
Exemplo n.º 18
0
    def version_10(cls, ctx, node, **kwargs):
        x = node.input[0]
        x_shape = ctx.get_shape(x)
        h = node.input[1]
        h_shape = ctx.get_shape(h)
        p = node.input[3]
        utils.make_sure(node.attr["rnn_mode"].s == b"gru",
                        "rnn mode other than gru are not supported yet")
        utils.make_sure(node.attr["dropout"].f == 0,
                        "dropout not supported yet")
        utils.make_sure(node.attr["input_mode"].s == b"linear_input",
                        "input mode must be linear input")
        num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2
        num_layers = int(h_shape[0] / num_dirs)
        num_units = hidden_size = h_shape[2]
        input_size = x_shape[2]
        w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size]
        w_shape_const = ctx.make_const(utils.make_name("w_shape"),
                                       np.array(w_shape, dtype=np.int64))
        r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size]
        r_shape_const = ctx.make_const(utils.make_name("r_shape"),
                                       np.array(r_shape, dtype=np.int64))
        b_shape = [num_layers * num_dirs, 6 * hidden_size]
        b_shape_const = ctx.make_const(utils.make_name("b_shape"),
                                       np.array(b_shape, dtype=np.int64))
        zero_const = ctx.make_const(utils.make_name("zero"),
                                    np.array([0], dtype=np.int64))
        w_end = np.prod(w_shape)
        w_end_const = ctx.make_const(utils.make_name("w_end"),
                                     np.array([w_end], dtype=np.int64))
        r_end = w_end + np.prod(r_shape)
        r_end_const = ctx.make_const(utils.make_name("r_end"),
                                     np.array([r_end], dtype=np.int64))
        b_end = r_end + np.prod(b_shape)
        b_end_const = ctx.make_const(utils.make_name("b_end"),
                                     np.array([b_end], dtype=np.int64))

        def name(nm):
            return node.name + "_" + nm

        ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)]
        rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)]
        bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)]
        hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)]
        yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)]
        w_flattened = ctx.make_node(
            'Slice', [p, zero_const.output[0], w_end_const.output[0]])
        r_flattened = ctx.make_node(
            'Slice', [p, w_end_const.output[0], r_end_const.output[0]])
        b_flattened = ctx.make_node(
            'Slice', [p, r_end_const.output[0], b_end_const.output[0]])
        w = utils.make_name('W')
        r = utils.make_name('R')
        b = utils.make_name('B')
        ctx.make_node('Reshape',
                      [w_flattened.output[0], w_shape_const.output[0]],
                      outputs=[w])
        ctx.make_node('Reshape',
                      [r_flattened.output[0], r_shape_const.output[0]],
                      outputs=[r])
        ctx.make_node('Reshape',
                      [b_flattened.output[0], b_shape_const.output[0]],
                      outputs=[b])
        ctx.make_node('Split', [w], outputs=ws)
        ctx.make_node('Split', [r], outputs=rs)
        ctx.make_node('Split', [b], outputs=bs)
        ctx.make_node('Split', [h], outputs=hs)

        builder = GraphBuilder(ctx)

        xnf = xnb = x
        for i in range(num_layers):
            suffix = '_' + str(i * num_dirs)
            ctx.make_node('GRU', [
                xnf,
                name('W' + suffix),
                name('R' + suffix),
                name('B' + suffix), '',
                name('H' + suffix)
            ],
                          outputs=[name('Y' + suffix),
                                   name('YH' + suffix)],
                          attr={
                              'direction': 'forward',
                              'hidden_size': num_units
                          })
            xnf = name(x + suffix)
            builder.make_squeeze({
                'data': name('Y' + suffix),
                'outputs': [xnf],
                'axes': [1]
            })
            if num_dirs == 2:
                suffix = '_' + str(i * 2 + 1)
                ctx.make_node(
                    'GRU', [
                        xnb,
                        name('W' + suffix),
                        name('R' + suffix),
                        name('B' + suffix), '',
                        name('H' + suffix)
                    ],
                    outputs=[name('Y' + suffix),
                             name('YH' + suffix)],
                    attr={
                        'direction': 'reverse',
                        'hidden_size': num_units
                    })
                xnb = name(x + suffix)
                builder.make_squeeze({
                    'data': name('Y' + suffix),
                    'outputs': [xnb],
                    'axes': [1]
                })
        ctx.remove_node(node.name)
        if num_dirs == 2:
            ctx.make_node('Concat', [xnf, xnb],
                          outputs=[node.output[0]],
                          attr={'axis': -1})
        else:
            ctx.make_node('Identity', [xnf], outputs=[node.output[0]])
        ctx.make_node('Concat',
                      yhs,
                      outputs=[node.output[1]],
                      attr={'axis': 0})
Exemplo n.º 19
0
    def version_7(cls, ctx, node, **kwargs):
        tfl_while_inputs = node.input
        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        output_names = node.output

        cond_name = node.get_attr_str("cond_subgraph_index")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body_subgraph_index")
        body = find_function(body_name)
        body.parent_graph = ctx

        ctx.remove_node(node.name)

        cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name,
                                       cond_binding)

        # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter.
        # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph.
        scan_outputs = sorted(body.scan_outputs, reverse=True)

        def input_is_unused(g, index):
            return len(g.find_output_consumers(g.inputs[index])) == 0

        scan_outputs = [(i, out) for i, out in scan_outputs
                        if input_is_unused(cond_graph, i)]

        for idx, _ in scan_outputs:
            del tfl_while_inputs[idx]
            output_shapes.append(output_shapes.pop(idx))
            output_dtypes.append(output_dtypes.pop(idx))
            output_names.append(output_names.pop(idx))

        max_iterations = ctx.make_const(utils.make_name("max_iterations"),
                                        np.array(np.iinfo(np.int64).max))

        loop_node = ctx.make_node("Loop",
                                  [max_iterations.output[0], cond_outputs[0]] +
                                  tfl_while_inputs,
                                  output_count=len(output_shapes),
                                  name=node.name + "_loop",
                                  shapes=output_shapes,
                                  dtypes=output_dtypes,
                                  skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        body = wire_tfl_while_body(body, loop_node.inputs, output_shapes,
                                   output_dtypes, cond_graph, scan_outputs)

        for i in range(len(scan_outputs)):
            squeeze_node = GraphBuilder(body).make_squeeze(
                {
                    'data': body.outputs[-1 - i],
                    "axes": [0]
                }, return_node=True)
            body.outputs[-1 - i] = squeeze_node.output[0]

        loop_node.set_body_graph_as_attr("body", body)
Exemplo n.º 20
0
def rewrite_eye(g, ops):
    # schema of eye is eye(num_rows, num_columns=None), if num_columns not specified then it's equal to num_rows
    # tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while
    # these two ops are un-supported directly in onnx
    # but onnx op EyeLike can be used to map the sub_graph
    # "rewrite_eye" supports tf.eye(non_const) and tf.eye(non_const1, non_const2).
    # tf.eye(const) and tf.eye(const1, const2) are not supported in this rewriter

    # ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing
    if g.opset < 9:
        return g.get_nodes()

    pattern1 = \
        OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Const", name="fill_value"),
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    "*",
                    OpTypePattern("Pack", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast")
                    ])
                ])
            ])
        ])
    pattern2 = \
        OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Const", name="fill_value"),
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    "*",
                    OpTypePattern("Pack", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast")
                    ])
                ])
            ])
        ])
    pattern3 = \
        OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ])
        ])
    pattern4 = \
        OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ]),
        ])
    pattern5 = \
        OpTypePattern("MatrixDiagV3", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ]),
            "*", "*", "*", "*",
        ])
    pattern6 = \
        OpTypePattern("MatrixSetDiagV3", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("ConcatV2", inputs=[
                    "*",
                    OpTypePattern("ExpandDims", inputs=[
                        OpTypePattern("Minimum|Cast", name="min_or_cast"),
                        "*"
                    ]),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ]), "*"
        ])
    pattern7 = \
        OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Reshape", inputs=[
                    OpTypePattern("Minimum|Cast", name="min_or_cast"),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ])
        ])
    pattern8 = \
        OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
            OpTypePattern("Fill"),
            OpTypePattern("Fill", inputs=[
                OpTypePattern("Reshape", inputs=[
                    OpTypePattern("Minimum|Cast", name="min_or_cast"),
                    "*",
                ]),
                OpTypePattern("Const", name="fill_value"),
            ])
        ])

    for pattern in [
            pattern1, pattern2, pattern3, pattern4, pattern5, pattern6,
            pattern7, pattern8
    ]:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        for match_result in match_results:
            if match_result.get_op("fill_value").get_tensor_value() != 1:
                continue

            min_or_cast = match_result.get_op("min_or_cast")
            if min_or_cast.type == "Minimum":
                min_node = min_or_cast
            elif min_or_cast.type == "Cast" and min_or_cast.inputs[
                    0].type == "Minimum":
                min_node = min_or_cast.inputs[0]
            else:
                continue

            num_rows = min_node.inputs[0]
            num_columns = min_node.inputs[1]

            old_output = match_result.get_op("output_eye_matrix")
            output_dtypes = [g.get_dtype(old_output.output[0])]
            output_shapes = [g.get_shape(old_output.output[0])]
            g.remove_node(old_output.name)

            # onnx op "EyeLike" need a 2D tensor, so generate it

            num_rows = GraphBuilder(g).make_unsqueeze(
                {
                    "axes": [0],
                    "data": num_rows.output[0]
                }, return_node=True)
            num_columns = GraphBuilder(g).make_unsqueeze(
                {
                    "axes": [0],
                    "data": num_columns.output[0]
                }, return_node=True)
            matrix_shape = g.make_node(
                "Concat", [num_rows.output[0], num_columns.output[0]],
                attr={"axis": 0})
            # cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data.
            matrix_shape_int64 = g.make_node(
                "Cast",
                matrix_shape.output,
                attr={"to": onnx_pb.TensorProto.INT64})
            zero_matrix = g.make_node("ConstantOfShape",
                                      matrix_shape_int64.output)

            g.make_node("EyeLike",
                        zero_matrix.output,
                        attr={"dtype": output_dtypes[0]},
                        name=old_output.name,
                        shapes=output_shapes,
                        dtypes=output_dtypes,
                        outputs=old_output.output)

    return g.get_nodes()