def convert_gru_to_scan(node, out_main_graph):
    assert node.op_type == 'GRU'
    nf = NodeFactory(out_main_graph)
    with nf.scoped_prefix(node.output[0]) as scoped_prefix:
        X = node.input[0]
        Wa = nf.get_initializer(node.input[1])
        Ra = nf.get_initializer(node.input[2])
        num_inputs = len(node.input)
        Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
        seq_len = node.input[4] if num_inputs > 4 else None
        InitHa = node.input[5] if num_inputs > 5 else None

        direction, num_directions, activations = handle_common_attributes(
            node, ['Sigmoid', 'Tanh'])

        hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
        linear_before_reset = NodeFactory.get_attribute(
            node, 'linear_before_reset')
        InitHa = handle_init_state(InitHa, nf, num_directions)

        batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
        if InitHa is None:
            zero_init_state = default_init_state(X, batch_size, batch_node,
                                                 hidden_size, nf)

        scan_outputs = []
        scan_h_outputs = []
        for direction_index in range(num_directions):
            # for each direction
            # X [seq_len, batch_size, input_size]
            # W [3*hidden_size, input_size]
            # R [3*hidden_size, hidden_size]
            # B [6*hidden_size]
            # seq_len [batch_size]
            # init_h [batch_size, hidden_size]

            name_prefix = node.output[0] + '_' + str(direction_index) + '_'

            if InitHa is None:
                init_h = zero_init_state
            else:
                init_h = InitHa[direction_index]

            input_size = Wa.shape[len(Wa.shape) - 1]
            W_t = np.transpose(
                Wa[direction_index])  # [input_size, 3*hidden_size]
            R_t = np.transpose(
                Ra[direction_index])  # [hidden_size, 3*hidden_size]
            Rzr_t, Rh_t = np.hsplit(R_t, [
                2 * hidden_size
            ])  # [hidden_size, 2*hidden_size] and [hidden_size, hidden_size]
            Bzr, Bh = np.hsplit(
                Ba[direction_index].reshape(2, 3 * hidden_size),
                [2 * hidden_size])
            Bzr = Bzr.sum(axis=0)  # [2*hidden_size]
            Wbh = Bh[0]
            Rbh = Bh[1]
            X_proj = nf.make_node(
                'Add',
                [nf.make_node('MatMul', [X, W_t]),
                 np.concatenate(
                     (Bzr, Wbh))])  #[seq_len, batch_size, 3*hidden_size]
            if num_directions == 1:
                is_backward = 0 if direction == 'forward' else 1
            else:
                is_backward = direction_index

            scan_body = onnx.GraphProto()
            scan_body.name = name_prefix + '_subgraph'

            nf_body = NodeFactory(out_main_graph, scan_body)
            with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
                # subgraph inputs
                X_proj_subgraph = X_proj.name + '_subgraph'
                prev_h_subgraph = name_prefix + '_h_subgraph'

                seq_len_subgraph = declare_seq_len_in_subgraph(
                    seq_len, nf_body, X_proj.name, batch_size)

                nf_body.make_value_info(prev_h_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)

                nf_body.make_value_info(X_proj_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, 3 * hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)

                # subgraph nodes
                # zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
                # rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
                # ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
                # ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
                # Ht = (1 - zt) (.) ht + zt (.) Ht-1

                split_X_outputs = ['split_Xzr', 'split_Xh']
                nf_body.make_node('Split',
                                  X_proj_subgraph, {
                                      "axis": 1,
                                      "split": [2 * hidden_size, hidden_size]
                                  },
                                  output_names=split_X_outputs)
                nf_body.make_value_info('split_Xzr',
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, 2 * hidden_size))
                nf_body.make_value_info('split_Xh',
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size))

                activation_f, activation_g = activations[direction_index *
                                                         2:(direction_index +
                                                            1) * 2]

                if linear_before_reset:
                    prev_h_proj = nf_body.make_node('Add', [
                        nf_body.make_node('MatMul', [prev_h_subgraph, R_t]),
                        np.concatenate((np.zeros(2 * hidden_size).astype(
                            np.float32), Rbh))
                    ])
                    split_prev_h_outputs = ['split_Hzr', 'split_Hh']
                    nf_body.make_node(
                        'Split',
                        prev_h_proj, {
                            "axis": 1,
                            "split": [2 * hidden_size, hidden_size]
                        },
                        output_names=split_prev_h_outputs)
                    nf_body.make_value_info('split_Hzr',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size,
                                                   2 * hidden_size))
                    nf_body.make_value_info('split_Hh',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    ztrt = nf_body.make_node(
                        activation_f,
                        nf_body.make_node('Add', ['split_Hzr', 'split_Xzr']))
                    split_ztrt_outputs = ['split_zt', 'split_rt']
                    nf_body.make_node('Split',
                                      ztrt, {
                                          "axis": 1,
                                          "split": [hidden_size, hidden_size]
                                      },
                                      output_names=split_ztrt_outputs)
                    nf_body.make_value_info('split_zt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    nf_body.make_value_info('split_rt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    ht = nf_body.make_node(
                        activation_g,
                        nf_body.make_node('Add', [
                            nf_body.make_node('Mul', ['split_rt', 'split_Hh']),
                            'split_Xh'
                        ]))
                else:
                    ztrt = nf_body.make_node(
                        activation_f,
                        nf_body.make_node('Add', [
                            nf_body.make_node('MatMul',
                                              [prev_h_subgraph, Rzr_t]),
                            'split_Xzr'
                        ]))
                    split_ztrt_outputs = ['split_zt', 'split_rt']
                    nf_body.make_node('Split',
                                      ztrt, {
                                          "axis": 1,
                                          "split": [hidden_size, hidden_size]
                                      },
                                      output_names=split_ztrt_outputs)
                    nf_body.make_value_info('split_zt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    nf_body.make_value_info('split_rt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    ht = nf_body.make_node(
                        activation_g,
                        nf_body.make_node('Add', [
                            nf_body.make_node('MatMul', [
                                nf_body.make_node(
                                    'Mul', [prev_h_subgraph, 'split_rt']), Rh_t
                            ]), 'split_Xh'
                        ]))

                Ht = nf_body.make_node('Add', [
                    nf_body.make_node('Mul', [
                        nf_body.make_node('Sub', [
                            np.asarray([1]).astype(np.float32), 'split_zt'
                        ]), ht
                    ]),
                    nf_body.make_node('Mul', ['split_zt', prev_h_subgraph])
                ])

                subgraph_outputs = handle_subgraph_outputs(
                    nf_body, seq_len_subgraph, batch_size, hidden_size,
                    [(Ht, prev_h_subgraph)] + ([
                        (Ht, np.zeros(shape=(), dtype=np.float32))
                    ] if node.output[0] else []))

                scan = nf.make_node(
                    'Scan', ([seq_len] if seq_len else []) + [init_h, X_proj],
                    {
                        'body': scan_body,
                        'scan_input_directions': [is_backward],
                        'scan_output_directions': [is_backward],
                        'num_scan_inputs': 1
                    },
                    output_names=[
                        o.name
                        for o in subgraph_outputs[(0 if seq_len else 1):]
                    ])

                scan_h_outputs.append(subgraph_outputs[1])
                if node.output[0]:
                    scan_outputs.append(subgraph_outputs[2])

        handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs],
                                  num_directions)

    # remove old initializers
    nf.remove_initializer(node.input[1])
    nf.remove_initializer(node.input[2])
    if num_inputs > 3:
        nf.remove_initializer(node.input[3])
    if num_inputs > 5:
        nf.remove_initializer(node.input[5], allow_empty=True)
    return True
def convert_rnn_to_scan(node, out_main_graph):
    assert node.op_type == 'RNN'
    nf = NodeFactory(out_main_graph)
    with nf.scoped_prefix(node.output[0]) as scoped_prefix:
        X = node.input[0]
        Wa = nf.get_initializer(node.input[1])
        Ra = nf.get_initializer(node.input[2])
        num_inputs = len(node.input)
        Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
        seq_len = node.input[4] if num_inputs > 4 else None
        InitHa = node.input[5] if num_inputs > 5 else None

        direction, num_directions, activations = handle_common_attributes(
            node, ['Tanh'])

        hidden_size = NodeFactory.get_attribute(node, 'hidden_size')

        InitHa = handle_init_state(InitHa, nf, num_directions)

        batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
        if InitHa is None:
            zero_init_state = default_init_state(X, batch_size, batch_node,
                                                 hidden_size, nf)

        scan_outputs = []
        scan_h_outputs = []
        for direction_index in range(num_directions):
            # for each direction
            # X [seq_len, batch_size, input_size]
            # W [hidden_size, input_size]
            # R [hidden_size, hidden_size]
            # B [2*hidden_size]
            # seq_len [batch_size]
            # init_h [batch_size, hidden_size]

            name_prefix = node.output[0] + '_' + str(direction_index) + '_'

            if InitHa is None:
                init_h = zero_init_state
            else:
                init_h = InitHa[direction_index]

            input_size = Wa.shape[len(Wa.shape) - 1]
            W_t = np.transpose(
                Wa[direction_index])  # [input_size, hidden_size]
            R_t = np.transpose(
                Ra[direction_index])  # [hidden_size, hidden_size]
            B = Ba[direction_index].reshape(2, hidden_size).sum(
                axis=0)  # [hidden_size]
            X_proj = nf.make_node('Add', [nf.make_node(
                'MatMul', [X, W_t]), B])  #[seq_len, batch_size, hidden_size]
            if num_directions == 1:
                is_backward = 0 if direction == 'forward' else 1
            else:
                is_backward = direction_index

            scan_body = onnx.GraphProto()
            scan_body.name = name_prefix + '_subgraph'

            nf_body = NodeFactory(out_main_graph, scan_body)
            with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
                # subgraph inputs
                X_proj_subgraph = X_proj.name + '_subgraph'
                prev_h_subgraph = name_prefix + '_h_subgraph'

                seq_len_subgraph = declare_seq_len_in_subgraph(
                    seq_len, nf_body, X_proj.name, batch_size)

                nf_body.make_value_info(prev_h_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)

                nf_body.make_value_info(X_proj_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)
                # subgraph nodes
                # Ht = f(Xt*(W^T) + Ht-1*(R^T) + Wb + Rb)

                activation_f = activations[direction_index]
                Ht = nf_body.make_node(
                    activation_f,
                    nf_body.make_node('Add', [
                        nf_body.make_node('MatMul', [prev_h_subgraph, R_t]),
                        X_proj_subgraph
                    ]))

                subgraph_outputs = handle_subgraph_outputs(
                    nf_body, seq_len_subgraph, batch_size, hidden_size,
                    [(Ht, prev_h_subgraph)] + ([
                        (Ht, np.zeros(shape=(), dtype=np.float32))
                    ] if node.output[0] else []))

                scan = nf.make_node(
                    'Scan', ([seq_len] if seq_len else []) + [init_h, X_proj],
                    {
                        'body': scan_body,
                        'scan_input_directions': [is_backward],
                        'scan_output_directions': [is_backward],
                        'num_scan_inputs': 1
                    },
                    output_names=[
                        o.name
                        for o in subgraph_outputs[(0 if seq_len else 1):]
                    ])

                scan_h_outputs.append(subgraph_outputs[1])
                if node.output[0]:
                    scan_outputs.append(subgraph_outputs[2])

        handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs],
                                  num_directions)

    # remove old initializers
    nf.remove_initializer(node.input[1])
    nf.remove_initializer(node.input[2])
    if num_inputs > 3:
        nf.remove_initializer(node.input[3])
    if num_inputs > 5:
        nf.remove_initializer(node.input[5])
    return True
def convert_lstm_to_scan(node, out_main_graph):
    assert node.op_type == 'LSTM'
    nf = NodeFactory(out_main_graph)
    with nf.scoped_prefix(node.output[0]) as scoped_prefix:
        X = node.input[0]
        Wa = nf.get_initializer(node.input[1])
        Ra = nf.get_initializer(node.input[2])
        num_inputs = len(node.input)
        Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
        seq_len = node.input[4] if num_inputs > 4 else None
        InitHa = node.input[5] if num_inputs > 5 else None
        InitCa = node.input[6] if num_inputs > 6 else None
        PB = node.input[7] if num_inputs > 7 else None

        # TODO: support peephole
        assert not PB

        direction, num_directions, activations = handle_common_attributes(
            node, ['Sigmoid', 'Tanh', 'Tanh'])

        hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
        input_forget = NodeFactory.get_attribute(node, 'input_forget')

        # TODO: implement input_forget = 1
        assert not (input_forget != None and input_forget == 1)

        # split initializer if needed:
        is_same_init = InitHa == InitCa
        InitHa = handle_init_state(InitHa, nf, num_directions)
        if is_same_init:
            InitCa = InitHa
        else:
            InitCa = handle_init_state(InitCa, nf, num_directions)

        batch_size, batch_node = handle_batch_size(
            X, nf, InitHa is None or InitCa is None)

        scan_outputs = []
        scan_h_outputs = []
        scan_c_outputs = []
        for direction_index in range(num_directions):
            # for each direction
            # X [seq_len, batch_size, input_size]
            # W [4*hidden_size, input_size]
            # R [4*hidden_size, hidden_size]
            # B [8*hidden_size]
            # seq_len [batch_size]
            # init_h [batch_size, hidden_size]
            # init_c [batch_size, hidden_size]
            # PB [3*hidden_size]

            name_prefix = node.output[0] + '_' + str(direction_index) + '_'

            if InitHa is None:
                init_h = default_init_state(X, batch_size, batch_node,
                                            hidden_size, nf, '_H')
            else:
                init_h = InitHa[direction_index]

            if InitCa is None:
                init_c = default_init_state(X, batch_size, batch_node,
                                            hidden_size, nf, '_C')
            else:
                init_c = InitCa[direction_index]

            input_size = Wa.shape[len(Wa.shape) - 1]
            Wt = np.transpose(Wa[direction_index])
            Rt = np.transpose(Ra[direction_index])
            B = Ba[direction_index].reshape(2, 4 * hidden_size).sum(
                axis=0)  # [4*hidden_size]
            X_proj = nf.make_node(
                'MatMul', [X, Wt])  #[seq_len, batch_size, 4*hidden_size]
            X_proj = nf.make_node('Add', [X_proj, B])
            if num_directions == 1:
                is_backward = 0 if direction == 'forward' else 1
            else:
                is_backward = direction_index

            scan_body = onnx.GraphProto()
            scan_body.name = name_prefix + '_subgraph'

            nf_body = NodeFactory(out_main_graph, scan_body)
            with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
                # subgraph inputs
                X_proj_subgraph = X_proj.name + '_subgraph'
                prev_h_subgraph = name_prefix + '_h_subgraph'
                prev_c_subgraph = name_prefix + '_c_subgraph'

                seq_len_subgraph = declare_seq_len_in_subgraph(
                    seq_len, nf_body, X_proj.name, batch_size)

                for subgraph_i in [prev_h_subgraph, prev_c_subgraph]:
                    nf_body.make_value_info(
                        subgraph_i,
                        data_type=onnx.TensorProto.FLOAT,
                        shape=(batch_size, hidden_size),
                        usage=NodeFactory.ValueInfoType.input)

                nf_body.make_value_info(X_proj_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, 4 * hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)
                # subgraph nodes
                # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
                # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
                # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
                # Ct = ft (.) Ct-1 + it (.) ct
                # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
                # Ht = ot (.) h(Ct)
                prev_h_proj = nf_body.make_node('MatMul',
                                                [prev_h_subgraph, Rt])
                sum_x_proj_h_proj_bias = nf_body.make_node(
                    'Add', [X_proj_subgraph, prev_h_proj])
                split_outputs = ['split_i', 'split_o', 'split_f', 'split_c']
                nf_body.make_node('Split',
                                  sum_x_proj_h_proj_bias, {
                                      "axis": 1,
                                      "split": [hidden_size] * 4
                                  },
                                  output_names=split_outputs)
                # manually add shape inference to split outputs
                for split_o in split_outputs:
                    nf_body.make_value_info(split_o,
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                activation_f, activation_g, activation_h = activations[
                    direction_index * 3:(direction_index + 1) * 3]
                it = nf_body.make_node(activation_f, 'split_i')
                ft = nf_body.make_node(activation_f, 'split_f')
                ct = nf_body.make_node(activation_g, 'split_c')
                c_subgraph = nf_body.make_node('Add', [
                    nf_body.make_node('Mul', [ft, prev_c_subgraph]),
                    nf_body.make_node('Mul', [it, ct])
                ])
                ot = nf_body.make_node(activation_f, 'split_o')
                h_subgraph = nf_body.make_node(
                    'Mul',
                    [ot, nf_body.make_node(activation_h, c_subgraph)])

                subgraph_outputs = handle_subgraph_outputs(
                    nf_body, seq_len_subgraph, batch_size, hidden_size,
                    [(h_subgraph, prev_h_subgraph),
                     (c_subgraph, prev_c_subgraph)] +
                    ([(h_subgraph,
                       np.zeros(shape=(), dtype=np.float32))] if node.output[0]
                     else []))  # skip scan output if node.output[0] is empty

                scan = nf.make_node(
                    'Scan',
                    ([seq_len] if seq_len else []) + [init_h, init_c, X_proj],
                    {
                        'body': scan_body,
                        'scan_input_directions': [is_backward],
                        'scan_output_directions': [is_backward],
                        'num_scan_inputs': 1
                    },
                    output_names=[
                        o.name
                        for o in subgraph_outputs[(0 if seq_len else 1):]
                    ])

                scan_h_outputs.append(subgraph_outputs[1])
                scan_c_outputs.append(subgraph_outputs[2])
                if node.output[0]:
                    scan_outputs.append(subgraph_outputs[3])

        handle_final_scan_outputs(node, nf, scan_outputs,
                                  [scan_h_outputs, scan_c_outputs],
                                  num_directions)

    # remove old initializers
    nf.remove_initializer(node.input[1])
    nf.remove_initializer(node.input[2])
    if num_inputs > 3:
        nf.remove_initializer(node.input[3])
    if num_inputs > 5:
        nf.remove_initializer(node.input[5], allow_empty=True)
    if num_inputs > 6:
        nf.remove_initializer(node.input[6], allow_empty=True)
    return True