コード例 #1
0
ファイル: export_utils.py プロジェクト: vinayphadnis/NeMo
def attach_onnx_to_onnx(model1: onnx.ModelProto, model2: onnx.ModelProto, prefix2: str):

    if len(model1.graph.output) < 1 or len(model1.graph.output) != len(model2.graph.output):
        raise ValueError(
            'Incompatible input/output dimensions: {} != {}'.format(len(model1.graph.output), len(model2.graph.output))
        )
    for i in range(len(model2.graph.initializer)):
        model2.graph.initializer[i].name = prefix2 + model2.graph.initializer[i].name

    for o in range(len(model1.graph.output)):
        for i in range(len(model2.graph.node)):
            for j in range(len(model2.graph.node[i].input)):
                if model2.graph.node[i].input[j] == model2.graph.input[o].name:
                    model2.graph.node[i].input[j] = model1.graph.output[o].name
                else:
                    model2.graph.node[i].input[j] = prefix2 + model2.graph.node[i].input[j]
            for j in range(len(model2.graph.node[i].output)):
                if model2.graph.node[i].output[j] != model2.graph.output[o].name:
                    model2.graph.node[i].output[j] = prefix2 + model2.graph.node[i].output[j]

    graph = onnx.GraphProto()
    graph.node.extend(model1.graph.node)
    graph.node.extend(model2.graph.node)
    graph.name = model1.graph.name + " + " + model2.graph.name
    graph.input.extend(model1.graph.input)
    graph.output.extend(model2.graph.output)
    graph.initializer.extend(model1.graph.initializer)
    graph.initializer.extend(model2.graph.initializer)
    graph.value_info.extend(model2.graph.value_info)
    if model1.graph.doc_string:
        graph.doc_string = model1.graph.doc_string
    output_model = onnx.helper.make_model(graph)
    onnx.checker.check_model(output_model, full_check=True)
    return output_model
コード例 #2
0
ファイル: parser.py プロジェクト: kaydoh/onnx
def parse_graph(graph_text: Text) -> onnx.GraphProto:
    (success, msg, graph_proto_str) = C.parse_graph(graph_text)
    if success:
        G = onnx.GraphProto()
        G.ParseFromString(graph_proto_str)
        return G
    else:
        raise ParseError(msg)
コード例 #3
0
ファイル: __init__.py プロジェクト: fumihwh/onnx-pytorch
 def reset_model(self, opset_ver=None):
   if opset_ver is not None:
     opset_imports = [make_opsetid("", opset_ver)]
   else:
     opset_imports = [self.opset_import]
   self.model = make_model_gen_version(onnx.GraphProto(),
                                       opset_imports=opset_imports)
   self.op_counter = collections.Counter()
コード例 #4
0
ファイル: __init__.py プロジェクト: fumihwh/onnx-pytorch
 def __init__(self, opset_ver=OPSET_VER):
   self.opset_import = make_opsetid("", opset_ver)
   self.model = make_model_gen_version(onnx.GraphProto(),
                                       opset_imports=[self.opset_import])
   self.op_counter = collections.Counter()
   self.ctx = C.CheckerContext()
   self.ctx.ir_version = self.model.ir_version
   self.ctx.opset_imports = {'': opset_ver}
コード例 #5
0
def parse_graph(graph_text: str) -> onnx.GraphProto:
    """Parse a string to build a GraphProto.

    Arguments:
        graph_text (string): formatted string
    Returns:
        GraphProto
    """
    (success, msg, graph_proto_str) = C.parse_graph(graph_text)
    if success:
        G = onnx.GraphProto()
        G.ParseFromString(graph_proto_str)
        return G
    else:
        raise ParseError(msg)
コード例 #6
0
def convert_rnn_to_scan(node, out_main_graph):
    assert node.op_type == 'RNN'
    nf = NodeFactory(out_main_graph)
    with nf.scoped_prefix(node.output[0]) as scoped_prefix:
        X = node.input[0]
        Wa = nf.get_initializer(node.input[1])
        Ra = nf.get_initializer(node.input[2])
        num_inputs = len(node.input)
        Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
        seq_len = node.input[4] if num_inputs > 4 else None
        InitHa = node.input[5] if num_inputs > 5 else None

        direction, num_directions, activations = handle_common_attributes(
            node, ['Tanh'])

        hidden_size = NodeFactory.get_attribute(node, 'hidden_size')

        InitHa = handle_init_state(InitHa, nf, num_directions)

        batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
        if InitHa is None:
            zero_init_state = default_init_state(X, batch_size, batch_node,
                                                 hidden_size, nf)

        scan_outputs = []
        scan_h_outputs = []
        for direction_index in range(num_directions):
            # for each direction
            # X [seq_len, batch_size, input_size]
            # W [hidden_size, input_size]
            # R [hidden_size, hidden_size]
            # B [2*hidden_size]
            # seq_len [batch_size]
            # init_h [batch_size, hidden_size]

            name_prefix = node.output[0] + '_' + str(direction_index) + '_'

            if InitHa is None:
                init_h = zero_init_state
            else:
                init_h = InitHa[direction_index]

            input_size = Wa.shape[len(Wa.shape) - 1]
            W_t = np.transpose(
                Wa[direction_index])  # [input_size, hidden_size]
            R_t = np.transpose(
                Ra[direction_index])  # [hidden_size, hidden_size]
            B = Ba[direction_index].reshape(2, hidden_size).sum(
                axis=0)  # [hidden_size]
            X_proj = nf.make_node('Add', [nf.make_node(
                'MatMul', [X, W_t]), B])  #[seq_len, batch_size, hidden_size]
            if num_directions == 1:
                is_backward = 0 if direction == 'forward' else 1
            else:
                is_backward = direction_index

            scan_body = onnx.GraphProto()
            scan_body.name = name_prefix + '_subgraph'

            nf_body = NodeFactory(out_main_graph, scan_body)
            with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
                # subgraph inputs
                X_proj_subgraph = X_proj.name + '_subgraph'
                prev_h_subgraph = name_prefix + '_h_subgraph'

                seq_len_subgraph = declare_seq_len_in_subgraph(
                    seq_len, nf_body, X_proj.name, batch_size)

                nf_body.make_value_info(prev_h_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)

                nf_body.make_value_info(X_proj_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)
                # subgraph nodes
                # Ht = f(Xt*(W^T) + Ht-1*(R^T) + Wb + Rb)

                activation_f = activations[direction_index]
                Ht = nf_body.make_node(
                    activation_f,
                    nf_body.make_node('Add', [
                        nf_body.make_node('MatMul', [prev_h_subgraph, R_t]),
                        X_proj_subgraph
                    ]))

                subgraph_outputs = handle_subgraph_outputs(
                    nf_body, seq_len_subgraph, batch_size, hidden_size,
                    [(Ht, prev_h_subgraph)] + ([
                        (Ht, np.zeros(shape=(), dtype=np.float32))
                    ] if node.output[0] else []))

                scan = nf.make_node(
                    'Scan', ([seq_len] if seq_len else []) + [init_h, X_proj],
                    {
                        'body': scan_body,
                        'scan_input_directions': [is_backward],
                        'scan_output_directions': [is_backward],
                        'num_scan_inputs': 1
                    },
                    output_names=[
                        o.name
                        for o in subgraph_outputs[(0 if seq_len else 1):]
                    ])

                scan_h_outputs.append(subgraph_outputs[1])
                if node.output[0]:
                    scan_outputs.append(subgraph_outputs[2])

        handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs],
                                  num_directions)

    # remove old initializers
    nf.remove_initializer(node.input[1])
    nf.remove_initializer(node.input[2])
    if num_inputs > 3:
        nf.remove_initializer(node.input[3])
    if num_inputs > 5:
        nf.remove_initializer(node.input[5])
    return True
コード例 #7
0
def convert_gru_to_scan(node, out_main_graph):
    assert node.op_type == 'GRU'
    nf = NodeFactory(out_main_graph)
    with nf.scoped_prefix(node.output[0]) as scoped_prefix:
        X = node.input[0]
        Wa = nf.get_initializer(node.input[1])
        Ra = nf.get_initializer(node.input[2])
        num_inputs = len(node.input)
        Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
        seq_len = node.input[4] if num_inputs > 4 else None
        InitHa = node.input[5] if num_inputs > 5 else None

        direction, num_directions, activations = handle_common_attributes(
            node, ['Sigmoid', 'Tanh'])

        hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
        linear_before_reset = NodeFactory.get_attribute(
            node, 'linear_before_reset')
        InitHa = handle_init_state(InitHa, nf, num_directions)

        batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
        if InitHa is None:
            zero_init_state = default_init_state(X, batch_size, batch_node,
                                                 hidden_size, nf)

        scan_outputs = []
        scan_h_outputs = []
        for direction_index in range(num_directions):
            # for each direction
            # X [seq_len, batch_size, input_size]
            # W [3*hidden_size, input_size]
            # R [3*hidden_size, hidden_size]
            # B [6*hidden_size]
            # seq_len [batch_size]
            # init_h [batch_size, hidden_size]

            name_prefix = node.output[0] + '_' + str(direction_index) + '_'

            if InitHa is None:
                init_h = zero_init_state
            else:
                init_h = InitHa[direction_index]

            input_size = Wa.shape[len(Wa.shape) - 1]
            W_t = np.transpose(
                Wa[direction_index])  # [input_size, 3*hidden_size]
            R_t = np.transpose(
                Ra[direction_index])  # [hidden_size, 3*hidden_size]
            Rzr_t, Rh_t = np.hsplit(R_t, [
                2 * hidden_size
            ])  # [hidden_size, 2*hidden_size] and [hidden_size, hidden_size]
            Bzr, Bh = np.hsplit(
                Ba[direction_index].reshape(2, 3 * hidden_size),
                [2 * hidden_size])
            Bzr = Bzr.sum(axis=0)  # [2*hidden_size]
            Wbh = Bh[0]
            Rbh = Bh[1]
            X_proj = nf.make_node(
                'Add',
                [nf.make_node('MatMul', [X, W_t]),
                 np.concatenate(
                     (Bzr, Wbh))])  #[seq_len, batch_size, 3*hidden_size]
            if num_directions == 1:
                is_backward = 0 if direction == 'forward' else 1
            else:
                is_backward = direction_index

            scan_body = onnx.GraphProto()
            scan_body.name = name_prefix + '_subgraph'

            nf_body = NodeFactory(out_main_graph, scan_body)
            with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
                # subgraph inputs
                X_proj_subgraph = X_proj.name + '_subgraph'
                prev_h_subgraph = name_prefix + '_h_subgraph'

                seq_len_subgraph = declare_seq_len_in_subgraph(
                    seq_len, nf_body, X_proj.name, batch_size)

                nf_body.make_value_info(prev_h_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)

                nf_body.make_value_info(X_proj_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, 3 * hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)

                # subgraph nodes
                # zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
                # rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
                # ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
                # ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
                # Ht = (1 - zt) (.) ht + zt (.) Ht-1

                split_X_outputs = ['split_Xzr', 'split_Xh']
                nf_body.make_node('Split',
                                  X_proj_subgraph, {
                                      "axis": 1,
                                      "split": [2 * hidden_size, hidden_size]
                                  },
                                  output_names=split_X_outputs)
                nf_body.make_value_info('split_Xzr',
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, 2 * hidden_size))
                nf_body.make_value_info('split_Xh',
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, hidden_size))

                activation_f, activation_g = activations[direction_index *
                                                         2:(direction_index +
                                                            1) * 2]

                if linear_before_reset:
                    prev_h_proj = nf_body.make_node('Add', [
                        nf_body.make_node('MatMul', [prev_h_subgraph, R_t]),
                        np.concatenate((np.zeros(2 * hidden_size).astype(
                            np.float32), Rbh))
                    ])
                    split_prev_h_outputs = ['split_Hzr', 'split_Hh']
                    nf_body.make_node(
                        'Split',
                        prev_h_proj, {
                            "axis": 1,
                            "split": [2 * hidden_size, hidden_size]
                        },
                        output_names=split_prev_h_outputs)
                    nf_body.make_value_info('split_Hzr',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size,
                                                   2 * hidden_size))
                    nf_body.make_value_info('split_Hh',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    ztrt = nf_body.make_node(
                        activation_f,
                        nf_body.make_node('Add', ['split_Hzr', 'split_Xzr']))
                    split_ztrt_outputs = ['split_zt', 'split_rt']
                    nf_body.make_node('Split',
                                      ztrt, {
                                          "axis": 1,
                                          "split": [hidden_size, hidden_size]
                                      },
                                      output_names=split_ztrt_outputs)
                    nf_body.make_value_info('split_zt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    nf_body.make_value_info('split_rt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    ht = nf_body.make_node(
                        activation_g,
                        nf_body.make_node('Add', [
                            nf_body.make_node('Mul', ['split_rt', 'split_Hh']),
                            'split_Xh'
                        ]))
                else:
                    ztrt = nf_body.make_node(
                        activation_f,
                        nf_body.make_node('Add', [
                            nf_body.make_node('MatMul',
                                              [prev_h_subgraph, Rzr_t]),
                            'split_Xzr'
                        ]))
                    split_ztrt_outputs = ['split_zt', 'split_rt']
                    nf_body.make_node('Split',
                                      ztrt, {
                                          "axis": 1,
                                          "split": [hidden_size, hidden_size]
                                      },
                                      output_names=split_ztrt_outputs)
                    nf_body.make_value_info('split_zt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    nf_body.make_value_info('split_rt',
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                    ht = nf_body.make_node(
                        activation_g,
                        nf_body.make_node('Add', [
                            nf_body.make_node('MatMul', [
                                nf_body.make_node(
                                    'Mul', [prev_h_subgraph, 'split_rt']), Rh_t
                            ]), 'split_Xh'
                        ]))

                Ht = nf_body.make_node('Add', [
                    nf_body.make_node('Mul', [
                        nf_body.make_node('Sub', [
                            np.asarray([1]).astype(np.float32), 'split_zt'
                        ]), ht
                    ]),
                    nf_body.make_node('Mul', ['split_zt', prev_h_subgraph])
                ])

                subgraph_outputs = handle_subgraph_outputs(
                    nf_body, seq_len_subgraph, batch_size, hidden_size,
                    [(Ht, prev_h_subgraph)] + ([
                        (Ht, np.zeros(shape=(), dtype=np.float32))
                    ] if node.output[0] else []))

                scan = nf.make_node(
                    'Scan', ([seq_len] if seq_len else []) + [init_h, X_proj],
                    {
                        'body': scan_body,
                        'scan_input_directions': [is_backward],
                        'scan_output_directions': [is_backward],
                        'num_scan_inputs': 1
                    },
                    output_names=[
                        o.name
                        for o in subgraph_outputs[(0 if seq_len else 1):]
                    ])

                scan_h_outputs.append(subgraph_outputs[1])
                if node.output[0]:
                    scan_outputs.append(subgraph_outputs[2])

        handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs],
                                  num_directions)

    # remove old initializers
    nf.remove_initializer(node.input[1])
    nf.remove_initializer(node.input[2])
    if num_inputs > 3:
        nf.remove_initializer(node.input[3])
    if num_inputs > 5:
        nf.remove_initializer(node.input[5], allow_empty=True)
    return True
コード例 #8
0
def convert_lstm_to_scan(node, out_main_graph):
    assert node.op_type == 'LSTM'
    nf = NodeFactory(out_main_graph)
    with nf.scoped_prefix(node.output[0]) as scoped_prefix:
        X = node.input[0]
        Wa = nf.get_initializer(node.input[1])
        Ra = nf.get_initializer(node.input[2])
        num_inputs = len(node.input)
        Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
        seq_len = node.input[4] if num_inputs > 4 else None
        InitHa = node.input[5] if num_inputs > 5 else None
        InitCa = node.input[6] if num_inputs > 6 else None
        PB = node.input[7] if num_inputs > 7 else None

        # TODO: support peephole
        assert not PB

        direction, num_directions, activations = handle_common_attributes(
            node, ['Sigmoid', 'Tanh', 'Tanh'])

        hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
        input_forget = NodeFactory.get_attribute(node, 'input_forget')

        # TODO: implement input_forget = 1
        assert not (input_forget != None and input_forget == 1)

        # split initializer if needed:
        is_same_init = InitHa == InitCa
        InitHa = handle_init_state(InitHa, nf, num_directions)
        if is_same_init:
            InitCa = InitHa
        else:
            InitCa = handle_init_state(InitCa, nf, num_directions)

        batch_size, batch_node = handle_batch_size(
            X, nf, InitHa is None or InitCa is None)

        scan_outputs = []
        scan_h_outputs = []
        scan_c_outputs = []
        for direction_index in range(num_directions):
            # for each direction
            # X [seq_len, batch_size, input_size]
            # W [4*hidden_size, input_size]
            # R [4*hidden_size, hidden_size]
            # B [8*hidden_size]
            # seq_len [batch_size]
            # init_h [batch_size, hidden_size]
            # init_c [batch_size, hidden_size]
            # PB [3*hidden_size]

            name_prefix = node.output[0] + '_' + str(direction_index) + '_'

            if InitHa is None:
                init_h = default_init_state(X, batch_size, batch_node,
                                            hidden_size, nf, '_H')
            else:
                init_h = InitHa[direction_index]

            if InitCa is None:
                init_c = default_init_state(X, batch_size, batch_node,
                                            hidden_size, nf, '_C')
            else:
                init_c = InitCa[direction_index]

            input_size = Wa.shape[len(Wa.shape) - 1]
            Wt = np.transpose(Wa[direction_index])
            Rt = np.transpose(Ra[direction_index])
            B = Ba[direction_index].reshape(2, 4 * hidden_size).sum(
                axis=0)  # [4*hidden_size]
            X_proj = nf.make_node(
                'MatMul', [X, Wt])  #[seq_len, batch_size, 4*hidden_size]
            X_proj = nf.make_node('Add', [X_proj, B])
            if num_directions == 1:
                is_backward = 0 if direction == 'forward' else 1
            else:
                is_backward = direction_index

            scan_body = onnx.GraphProto()
            scan_body.name = name_prefix + '_subgraph'

            nf_body = NodeFactory(out_main_graph, scan_body)
            with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
                # subgraph inputs
                X_proj_subgraph = X_proj.name + '_subgraph'
                prev_h_subgraph = name_prefix + '_h_subgraph'
                prev_c_subgraph = name_prefix + '_c_subgraph'

                seq_len_subgraph = declare_seq_len_in_subgraph(
                    seq_len, nf_body, X_proj.name, batch_size)

                for subgraph_i in [prev_h_subgraph, prev_c_subgraph]:
                    nf_body.make_value_info(
                        subgraph_i,
                        data_type=onnx.TensorProto.FLOAT,
                        shape=(batch_size, hidden_size),
                        usage=NodeFactory.ValueInfoType.input)

                nf_body.make_value_info(X_proj_subgraph,
                                        data_type=onnx.TensorProto.FLOAT,
                                        shape=(batch_size, 4 * hidden_size),
                                        usage=NodeFactory.ValueInfoType.input)
                # subgraph nodes
                # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
                # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
                # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
                # Ct = ft (.) Ct-1 + it (.) ct
                # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
                # Ht = ot (.) h(Ct)
                prev_h_proj = nf_body.make_node('MatMul',
                                                [prev_h_subgraph, Rt])
                sum_x_proj_h_proj_bias = nf_body.make_node(
                    'Add', [X_proj_subgraph, prev_h_proj])
                split_outputs = ['split_i', 'split_o', 'split_f', 'split_c']
                nf_body.make_node('Split',
                                  sum_x_proj_h_proj_bias, {
                                      "axis": 1,
                                      "split": [hidden_size] * 4
                                  },
                                  output_names=split_outputs)
                # manually add shape inference to split outputs
                for split_o in split_outputs:
                    nf_body.make_value_info(split_o,
                                            data_type=onnx.TensorProto.FLOAT,
                                            shape=(batch_size, hidden_size))
                activation_f, activation_g, activation_h = activations[
                    direction_index * 3:(direction_index + 1) * 3]
                it = nf_body.make_node(activation_f, 'split_i')
                ft = nf_body.make_node(activation_f, 'split_f')
                ct = nf_body.make_node(activation_g, 'split_c')
                c_subgraph = nf_body.make_node('Add', [
                    nf_body.make_node('Mul', [ft, prev_c_subgraph]),
                    nf_body.make_node('Mul', [it, ct])
                ])
                ot = nf_body.make_node(activation_f, 'split_o')
                h_subgraph = nf_body.make_node(
                    'Mul',
                    [ot, nf_body.make_node(activation_h, c_subgraph)])

                subgraph_outputs = handle_subgraph_outputs(
                    nf_body, seq_len_subgraph, batch_size, hidden_size,
                    [(h_subgraph, prev_h_subgraph),
                     (c_subgraph, prev_c_subgraph)] +
                    ([(h_subgraph,
                       np.zeros(shape=(), dtype=np.float32))] if node.output[0]
                     else []))  # skip scan output if node.output[0] is empty

                scan = nf.make_node(
                    'Scan',
                    ([seq_len] if seq_len else []) + [init_h, init_c, X_proj],
                    {
                        'body': scan_body,
                        'scan_input_directions': [is_backward],
                        'scan_output_directions': [is_backward],
                        'num_scan_inputs': 1
                    },
                    output_names=[
                        o.name
                        for o in subgraph_outputs[(0 if seq_len else 1):]
                    ])

                scan_h_outputs.append(subgraph_outputs[1])
                scan_c_outputs.append(subgraph_outputs[2])
                if node.output[0]:
                    scan_outputs.append(subgraph_outputs[3])

        handle_final_scan_outputs(node, nf, scan_outputs,
                                  [scan_h_outputs, scan_c_outputs],
                                  num_directions)

    # remove old initializers
    nf.remove_initializer(node.input[1])
    nf.remove_initializer(node.input[2])
    if num_inputs > 3:
        nf.remove_initializer(node.input[3])
    if num_inputs > 5:
        nf.remove_initializer(node.input[5], allow_empty=True)
    if num_inputs > 6:
        nf.remove_initializer(node.input[6], allow_empty=True)
    return True
コード例 #9
0
ファイル: native.py プロジェクト: ORG-MARS/dragon
    def graph_def_to_onnx_graph(
        cls,
        graph_def,
        input_names=None,
        output_names=None,
        input_shapes=None,
        constants=None,
        value_info=None,
        opset_version=None,
        workspace=None,
        verbose=True,
    ):
        input_names = [] if input_names is None else input_names
        output_names = [] if output_names is None else output_names
        constants = {} if constants is None else constants
        value_info = {} if value_info is None else value_info

        if not nest.is_sequence(input_names):
            raise ValueError('<input_names> should be a sequence.')
        if not nest.is_sequence(output_names):
            raise ValueError('<output_names> should be a sequence.')
        if not isinstance(constants, dict):
            raise ValueError('<constants> should be a dict with name -> value.')
        if not isinstance(value_info, dict):
            raise ValueError('<value_info> should be a dict with name -> (dtype, shape).')

        # Determine the opset version to select exporters.
        if opset_version is None:
            opset_version = cls._check_opset_version(opset_version)

        # Create aliases for blobs.
        blob_aliases = {}
        for i, alias in enumerate(output_names):
            blob_aliases[graph_def.output[i]] = alias
            workspace.RegisterAlias(graph_def.output[i], alias)
            if graph_def.output[i] in value_info:
                value_info[alias] = value_info[graph_def.output[i]]
        for i, alias in enumerate(input_names):
            blob_aliases[graph_def.input[i]] = alias
            workspace.RegisterAlias(graph_def.input[i], alias)
            if graph_def.input[i] in value_info:
                value_info[alias] = value_info[graph_def.input[i]]

        # Maybe rewrite the input shapes for future development.
        # A common case is that we should fill ``-1`` for dynamic dimension
        # in the inference runtime like TensorRT.
        if input_shapes is not None:
            if isinstance(input_shapes, dict):
                for k, v in input_shapes.items():
                    value_info[k] = (value_info[k][0], v)
            else:
                for k, v in zip(graph_def.input[:], input_shapes):
                    value_info[k] = (value_info[k][0], v)

        # Prepare to make the graph.
        onnx_graph = onnx.GraphProto(name=graph_def.name
                                     if len(graph_def.name) > 0
                                     else 'onnx-model')
        blob_shapes, blob_names = {}, {}
        blob_versions = collections.defaultdict(
            int, **dict((blob_aliases.get(k, k), 1)
                        for k in helper.collect_inputs(graph_def)))
        initializers, seen_initializers = [], set()

        # Build translator context.
        context = export_util.TranslatorContext(
            workspace=workspace,
            blob_names=blob_names,
            blob_shapes=blob_shapes,
            blob_versions=blob_versions,
            opset_version=opset_version,
        )

        # Add nodes.
        for op in graph_def.op:
            # Get the shape of inputs and outputs.
            for name in itertools.chain(op.input, op.output):
                impl = workspace.GetTensor(name)
                if impl is not None:
                    blob_shapes[name] = impl.dims
                else:
                    blob_shapes[name] = value_info[name][1]

            # Translate definition.
            nodes, const_tensors = cls._make_node(op, context)

            # Rewritten for names.
            for node in nodes:
                node.input[:] = [blob_aliases.get(e, e) for e in node.input]
                node.output[:] = [blob_aliases.get(e, e) for e in node.output]
                cls._rewrite_for_ssa(node, context)

            # Convert constant outputs if necessary.
            if None in nodes:
                const_tensors = [helper.from_tensor(name, workspace)
                                 for name in op.output]
            else:
                onnx_graph.node.extend(nodes)

            # Merge constant tensors.
            if const_tensors is not None:
                value_info = {**value_info,
                              **dict((e.name, (e.data_type, e.dims))
                                     for e in const_tensors)}
                for tensor in const_tensors:
                    if tensor.name not in seen_initializers:
                        initializers.append(tensor)
                        seen_initializers.add(tensor.name)

        # Add constants.
        if constants is not None:
            for k, v in constants.items():
                initializers.append(helper.from_array(v, name=k))

        # Add inputs.
        for name in helper.collect_inputs(onnx_graph):
            try:
                onnx_graph.input.extend([
                    helper.make_tensor_value_info(
                        name=name,
                        elem_type=value_info[name][0],
                        shape=value_info[name][1])])
            except KeyError:
                impl = workspace.GetTensor(name)
                if impl is not None:
                    initializer = helper.from_tensor(name, workspace)
                    onnx_graph.input.extend([
                        helper.make_tensor_value_info(
                            name=name,
                            elem_type=initializer.data_type,
                            shape=initializer.dims)])
                    if name not in seen_initializers:
                        initializers.append(initializer)
                        seen_initializers.add(initializer.name)
                else:
                    raise ValueError(
                        'Info of tensor `{}` is missing, '
                        'specify it in <value_info>.'.format(name))

        # Add initializers.
        onnx_graph.initializer.extend(initializers)

        # Add outputs.
        onnx_graph.output.extend(
            helper.make_tensor_value_info(
                name=blob_names.get(name_v2, name_v2),
                elem_type=value_info[name_v2][0],
                shape=value_info[name_v2][1])
            for name_v2 in [blob_aliases.get(name, name)
                            for name in set(graph_def.output)])

        if verbose:
            print(helper.printable_graph(onnx_graph))

        return onnx_graph
コード例 #10
0
def generate_gemm_scan_model(model_name, config1, config2):
    model = onnx.ModelProto()
    model.ir_version = onnx.IR_VERSION
    opset = model.opset_import.add()
    opset.version = 11

    # Based on the given configs, we would have a model like below:
    # Main graph, where C is an initializer and passed as the input state for the Scan:
    #      C  input_1A input_2A
    #       \     |    /
    #        \    |   /
    #           Scan
    #             |
    #           output
    #
    # Scan's subgraph, where out_C is the output state of the Scan
    # input_1A  B  C  input_2A B  C
    #     \     | /       \    | /
    #      \    |/         \   |/
    #      Gemm_1           Gemm_2
    #           \          /
    #            \        /
    #               Sub
    #              /   \
    #           out_C  output
    #
    # config1 and config2 configure alpha/beta/transA/transB for Gemm_1 and Gemm_2, respectively.

    scan_body = onnx.GraphProto()
    scan_body.name = 'gemm_subgraph'

    shape_c1 = [config1['M'], config1['N']]
    shape_c2 = [config2['M'], config2['N']]
    assert shape_c1 == shape_c2
    C1 = config1['C']
    C2 = config2['C']

    scan_node_inputs = []
    postfix = '_subgraph'
    states_cnt = 0
    # make sure we create state inputs first
    if config1['withC']:
        assert config1['initC']
        states_cnt = states_cnt + 1
        scan_node_inputs.append(C1)
        scan_body.input.add().CopyFrom(
            helper.make_tensor_value_info('in_' + C1 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c1))
    if config2['withC'] and C1 != C2:
        assert config2['initC']
        states_cnt = states_cnt + 1
        scan_node_inputs.append(C2)
        scan_body.input.add().CopyFrom(
            helper.make_tensor_value_info('in_' + C2 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c2))

    added_inputs_subgraph = {}
    generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config1,
                                added_inputs_subgraph)
    generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config2,
                                added_inputs_subgraph)

    sub_output = 'sub_output' + postfix
    # create a Sub op instead of Add to break the MatMul-to-Gemm rewriting rule
    # performed by the ort optimizer
    sub_node = helper.make_node(
        'Sub', [config1['Y'] + postfix, config2['Y'] + postfix], [sub_output],
        'sub_node')
    scan_body.node.add().CopyFrom(sub_node)

    scan_node_outputs = []
    # create state outputs
    if config1['withC']:
        id_node1 = onnx.helper.make_node('Identity', [sub_output],
                                         ['out_' + C1 + postfix], 'id_node1')
        scan_body.node.add().CopyFrom(id_node1)
        scan_body.output.add().CopyFrom(
            helper.make_tensor_value_info('out_' + C1 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c1))
        scan_node_outputs.append('out_' + C1)

    if config2['withC'] and C1 != C2:
        id_node2 = onnx.helper.make_node('Identity', [sub_output],
                                         ['out_' + C2 + postfix], 'id_node2')
        scan_body.node.add().CopyFrom(id_node2)
        scan_body.output.add().CopyFrom(
            helper.make_tensor_value_info('out_' + C2 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c2))
        scan_node_outputs.append('out_' + C2)

    # scan subgraph output
    scan_body.output.add().CopyFrom(
        helper.make_tensor_value_info(sub_output, onnx.TensorProto.FLOAT,
                                      shape_c1))
    scan_node_outputs.append('scan_output')

    # create scan node
    inputs_cnt = len(scan_node_inputs) - states_cnt
    assert inputs_cnt > 0

    scan_node = onnx.helper.make_node('Scan',
                                      scan_node_inputs,
                                      scan_node_outputs,
                                      'scan_node',
                                      num_scan_inputs=inputs_cnt,
                                      body=scan_body)
    model.graph.node.add().CopyFrom(scan_node)

    added_inputs_initializers = {}
    # main graph inputs and initializers
    (a1, b1, c1) = generate_gemm_inputs_initializers(model.graph,
                                                     config1,
                                                     added_inputs_initializers,
                                                     extend=True)
    (a2, b2, c2) = generate_gemm_inputs_initializers(model.graph,
                                                     config2,
                                                     added_inputs_initializers,
                                                     extend=True)

    shape_output = ['seq', config1['M'], config1['N']]
    # main graph outputs
    model.graph.output.add().CopyFrom(
        helper.make_tensor_value_info('scan_output', onnx.TensorProto.FLOAT,
                                      shape_output))
    onnx.save(model, model_name)
    return (a1, b1, c1, a2, b2, c2)
コード例 #11
0
def optimize_input_projection(input_model, output_model):
    in_mp = onnx.load(input_model)
    out_mp = onnx.ModelProto()
    out_mp.CopyFrom(in_mp)
    out_mp.ir_version = 5  # update ir version to avoid requirement of initializer in graph input
    out_mp.graph.ClearField('node')
    nf = NodeFactory(out_mp.graph, prefix='opt_inproj_')
    initializers = dict([(i.name, i) for i in in_mp.graph.initializer])
    # first find possible fused SVD and do constant folding on MatMul of initializers
    const_matmuls = [
        n for n in in_mp.graph.node
        if n.op_type == 'MatMul' and all([i in initializers for i in n.input])
    ]
    for mm in const_matmuls:
        lhs = numpy_helper.to_array(initializers[mm.input[0]])
        rhs = numpy_helper.to_array(initializers[mm.input[1]])
        val = np.matmul(lhs, rhs)
        new_initializer = out_mp.graph.initializer.add()
        new_initializer.CopyFrom(numpy_helper.from_array(val, mm.output[0]))
        if not [
                n
                for n in in_mp.graph.node if n != mm and mm.input[0] in n.input
        ]:
            nf.remove_initializer(mm.input[0])
        if not [
                n
                for n in in_mp.graph.node if n != mm and mm.input[1] in n.input
        ]:
            nf.remove_initializer(mm.input[1])

    initializers = dict([(i.name, i) for i in out_mp.graph.initializer])

    # remove const_matmul output from graph outputs
    new_outputs = [
        i for i in out_mp.graph.output
        if not [m for m in const_matmuls if m.output[0] == i.name]
    ]
    out_mp.graph.ClearField('output')
    out_mp.graph.output.extend(new_outputs)

    for in_n in in_mp.graph.node:
        if in_n in const_matmuls:
            continue

        optimize_scan = False
        if in_n.op_type == 'Scan':
            in_sg = NodeFactory.get_attribute(in_n, 'body')
            num_scan_inputs = NodeFactory.get_attribute(
                in_n, 'num_scan_inputs')
            # only support 1 scan input
            if num_scan_inputs == 1:
                optimize_scan = True

        # copy the node if it's not the scan node that is supported at the moment
        if not optimize_scan:
            out_n = out_mp.graph.node.add()
            out_n.CopyFrom(in_n)
            continue

        scan_input_directions = NodeFactory.get_attribute(
            in_n, 'scan_input_directions')
        scan_output_directions = NodeFactory.get_attribute(
            in_n, 'scan_output_directions')
        out_sg = onnx.GraphProto()
        out_sg.CopyFrom(in_sg)
        out_sg.ClearField('node')
        nf_subgraph = NodeFactory(out_mp.graph,
                                  out_sg,
                                  prefix='opt_inproj_sg_' + in_n.name + '_')
        new_inputs = list(in_n.input)
        in_sg_inputs = [i.name for i in in_sg.input]
        replaced_matmul = None
        for in_sn in in_sg.node:
            if in_sn.op_type == 'Concat' and len(in_sn.input) == 2 and all(
                [i in in_sg_inputs for i in in_sn.input]):
                # make sure the concat's inputs are scan input and scan state
                if NodeFactory.get_attribute(in_sn, 'axis') != len(
                        in_sg.input[-1].type.tensor_type.shape.dim) - 1:
                    continue  # must concat last dim
                matmul_node = [
                    nn for nn in in_sg.node
                    if nn.op_type == 'MatMul' and in_sn.output[0] in nn.input
                ]
                if not matmul_node:
                    continue
                replaced_matmul = matmul_node[0]
                assert replaced_matmul.input[1] in initializers
                aa = nf.get_initializer(replaced_matmul.input[1])
                input_size = in_sg.input[-1].type.tensor_type.shape.dim[
                    -1].dim_value
                if in_sg_inputs[-1] == in_sn.input[0]:
                    hidden_idx = 1
                    input_proj_weights, hidden_proj_weights = np.vsplit(
                        aa, [input_size])
                else:
                    hidden_idx = 0
                    hidden_proj_weights, input_proj_weights = np.vsplit(
                        aa, [aa.shape[-1] - input_size])
                # add matmul for input_proj outside of Scan
                input_proj = nf.make_node('MatMul',
                                          [new_inputs[-1], input_proj_weights])
                input_proj.doc_string = replaced_matmul.doc_string
                new_inputs[-1] = input_proj.name
                out_sg.input[-1].type.tensor_type.shape.dim[
                    -1].dim_value = input_proj_weights.shape[-1]
                # add matmul for hidden_proj inside Scan
                hidden_proj = nf_subgraph.make_node(
                    'MatMul', [in_sn.input[hidden_idx], hidden_proj_weights])
                hidden_proj.doc_string = replaced_matmul.doc_string
                nf_subgraph.make_node('Add',
                                      [out_sg.input[-1].name, hidden_proj],
                                      output_names=replaced_matmul.output[0])
                # remove initializer of concat matmul
                if not [
                        n for n in in_mp.graph.node
                        if n != in_n and replaced_matmul.input[1] in n.input
                ]:
                    nf.remove_initializer(replaced_matmul.input[1])
            elif in_sn != replaced_matmul:
                out_sg.node.add().CopyFrom(in_sn)

        scan = nf.make_node(
            'Scan',
            new_inputs, {
                'body': out_sg,
                'scan_input_directions': scan_input_directions,
                'scan_output_directions': scan_output_directions,
                'num_scan_inputs': num_scan_inputs
            },
            output_names=list(in_n.output))
        scan.name = in_n.name
        scan.doc_string = in_n.doc_string

    onnx.save(out_mp, output_model)
コード例 #12
0
ファイル: export_utils.py プロジェクト: gvskalyan/NeMo
def attach_onnx_to_onnx_2(model1: onnx.ModelProto, model2: onnx.ModelProto,
                          model3: onnx.ModelProto, prefix2: str, prefix3: str):

    if len(model1.graph.output) < 1 or (len(model1.graph.output) != len(
            model2.graph.input)) or (len(model1.graph.output) != len(
                model3.graph.input)):
        raise ValueError(
            'Incompatible input/output dimensions: {} != {} or {}'.format(
                len(model1.graph.output), len(model2.graph.input),
                len(model3.graph.input)))
    for i in range(len(model2.graph.initializer)):
        model2.graph.initializer[
            i].name = prefix2 + model2.graph.initializer[i].name
    for i in range(len(model2.graph.node)):
        model2.graph.node[i].name = prefix2 + model2.graph.node[i].name

    for i in range(len(model3.graph.initializer)):
        model3.graph.initializer[
            i].name = prefix3 + model3.graph.initializer[i].name
    for i in range(len(model3.graph.node)):
        model3.graph.node[i].name = prefix3 + model3.graph.node[i].name

    for o in range(len(model1.graph.output)):
        for i in range(len(model2.graph.node)):
            for j in range(len(model2.graph.node[i].input)):
                if model2.graph.node[i].input[j] == model2.graph.input[o].name:
                    model2.graph.node[i].input[j] = model1.graph.output[o].name
                else:
                    model2.graph.node[i].input[
                        j] = prefix2 + model2.graph.node[i].input[j]
            for j in range(len(model2.graph.node[i].output)):
                if model2.graph.node[i].output[j] != model2.graph.output[
                        o].name:
                    model2.graph.node[i].output[
                        j] = prefix2 + model2.graph.node[i].output[j]
                else:
                    model2.graph.output[
                        o].name = prefix2 + model2.graph.output[o].name
                    model2.graph.node[i].output[j] = model2.graph.output[
                        o].name

    for o in range(len(model1.graph.output)):
        for i in range(len(model3.graph.node)):
            for j in range(len(model3.graph.node[i].input)):
                if model3.graph.node[i].input[j] == model3.graph.input[o].name:
                    model3.graph.node[i].input[j] = model1.graph.output[o].name
                else:
                    model3.graph.node[i].input[
                        j] = prefix3 + model3.graph.node[i].input[j]
            for j in range(len(model3.graph.node[i].output)):
                if model3.graph.node[i].output[j] != model3.graph.output[
                        o].name:
                    model3.graph.node[i].output[
                        j] = prefix3 + model3.graph.node[i].output[j]
                else:
                    model3.graph.output[
                        o].name = prefix3 + model3.graph.output[o].name
                    model3.graph.node[i].output[j] = model3.graph.output[
                        o].name

    # for i in range(len(model2.graph.node)):
    #     for j in range(len(model2.graph.node[i].input)):
    #         for o in range(len(model1.graph.output)):
    #             if model2.graph.node[i].input[j] == model2.graph.input[o].name:
    #                 model2.graph.node[i].input[j] = model1.graph.output[o].name
    #             else:
    #                 model2.graph.node[i].input[j] = prefix2 + model2.graph.node[i].input[j]
    #     for j in range(len(model2.graph.node[i].output)):
    #         inner_output = True
    #         for p in range(len(model2.graph.output)):
    #             if model2.graph.node[i].output[j] == model2.graph.output[p].name:
    #                 inner_output = False
    #                 break
    #         if inner_output:
    #             model2.graph.node[i].output[j] = prefix2 + model2.graph.node[i].output[j]

    # for i in range(len(model3.graph.node)):
    #     for j in range(len(model3.graph.node[i].input)):
    #         for o in range(len(model1.graph.output)):
    #             if model3.graph.node[i].input[j] == model3.graph.input[o].name:
    #                 model3.graph.node[i].input[j] = model1.graph.output[o].name
    #             else:
    #                 model3.graph.node[i].input[j] = prefix3 + model3.graph.node[i].input[j]
    #     for j in range(len(model3.graph.node[i].output)):
    #         inner_output = True
    #         for p in range(len(model3.graph.output)):
    #             if model3.graph.node[i].output[j] == model3.graph.output[p].name:
    #                 inner_output = False
    #                 break
    #         if inner_output:
    #             model3.graph.node[i].output[j] = prefix3 + model3.graph.node[i].output[j]

    graph = onnx.GraphProto()
    graph.node.extend(model1.graph.node)
    graph.node.extend(model2.graph.node)
    graph.node.extend(model3.graph.node)
    graph.name = model1.graph.name + " + " + model2.graph.name + " + " + model3.graph.name  # torch-jit-export + torch-jit-export + torch-jit-export
    graph.input.extend(model1.graph.input)
    graph.output.extend(model2.graph.output)
    graph.output.extend(model3.graph.output)
    graph.initializer.extend(model1.graph.initializer)
    graph.initializer.extend(model2.graph.initializer)
    graph.initializer.extend(model3.graph.initializer)
    graph.value_info.extend(model2.graph.value_info)
    graph.value_info.extend(model3.graph.value_info)
    if model1.graph.doc_string:
        graph.doc_string = model1.graph.doc_string
    output_model = onnx.helper.make_model(graph,
                                          opset_imports=model1.opset_import)
    onnx.checker.check_model(output_model, full_check=True)
    return output_model
コード例 #13
0
def generate_gemm_scan_model(model_name, config1, config2):
    model = onnx.ModelProto()
    model.ir_version = 7  # use stable onnx ir version
    opset = model.opset_import.add()
    opset.version = 11

    # Based on the given configs, we would have a model like below:
    # Main graph, where C is an initializer and passed as the input state for the Scan:
    #      C  input_1A input_2A
    #       \     |    /
    #        \    |   /
    #           Scan
    #             |
    #           output
    #
    # Scan's subgraph, where out_C is the output state of the Scan
    # input_1A  B  C  input_2A B  C
    #     \     | /       \    | /
    #      \    |/         \   |/
    #      Gemm_1           Gemm_2
    #           \          /
    #            \        /
    #               Sub
    #              /   \
    #           out_C  output
    #
    # config1 and config2 configure alpha/beta/transA/transB for Gemm_1 and Gemm_2, respectively.

    scan_body = onnx.GraphProto()
    scan_body.name = "gemm_subgraph"

    shape_c1 = [config1["M"], config1["N"]]
    shape_c2 = [config2["M"], config2["N"]]
    assert shape_c1 == shape_c2
    C1 = config1["C"]
    C2 = config2["C"]

    scan_node_inputs = []
    postfix = "_subgraph"
    states_cnt = 0
    # make sure we create state inputs first
    if config1["withC"]:
        assert config1["initC"]
        states_cnt = states_cnt + 1
        scan_node_inputs.append(C1)
        scan_body.input.add().CopyFrom(
            helper.make_tensor_value_info("in_" + C1 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c1))
    if config2["withC"] and C1 != C2:
        assert config2["initC"]
        states_cnt = states_cnt + 1
        scan_node_inputs.append(C2)
        scan_body.input.add().CopyFrom(
            helper.make_tensor_value_info("in_" + C2 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c2))

    added_inputs_subgraph = {}
    generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config1,
                                added_inputs_subgraph)
    generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config2,
                                added_inputs_subgraph)

    sub_output = "sub_output" + postfix
    # create a Sub op instead of Add to break the MatMul-to-Gemm rewriting rule
    # performed by the ort optimizer
    sub_node = helper.make_node(
        "Sub",
        [config1["Y"] + postfix, config2["Y"] + postfix],
        [sub_output],
        "sub_node",
    )
    scan_body.node.add().CopyFrom(sub_node)

    scan_node_outputs = []
    # create state outputs
    if config1["withC"]:
        id_node1 = onnx.helper.make_node("Identity", [sub_output],
                                         ["out_" + C1 + postfix], "id_node1")
        scan_body.node.add().CopyFrom(id_node1)
        scan_body.output.add().CopyFrom(
            helper.make_tensor_value_info("out_" + C1 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c1))
        scan_node_outputs.append("out_" + C1)

    if config2["withC"] and C1 != C2:
        id_node2 = onnx.helper.make_node("Identity", [sub_output],
                                         ["out_" + C2 + postfix], "id_node2")
        scan_body.node.add().CopyFrom(id_node2)
        scan_body.output.add().CopyFrom(
            helper.make_tensor_value_info("out_" + C2 + postfix,
                                          onnx.TensorProto.FLOAT, shape_c2))
        scan_node_outputs.append("out_" + C2)

    # scan subgraph output
    scan_body.output.add().CopyFrom(
        helper.make_tensor_value_info(sub_output, onnx.TensorProto.FLOAT,
                                      shape_c1))
    scan_node_outputs.append("scan_output")

    # create scan node
    inputs_cnt = len(scan_node_inputs) - states_cnt
    assert inputs_cnt > 0

    scan_node = onnx.helper.make_node(
        "Scan",
        scan_node_inputs,
        scan_node_outputs,
        "scan_node",
        num_scan_inputs=inputs_cnt,
        body=scan_body,
    )
    model.graph.node.add().CopyFrom(scan_node)

    added_inputs_initializers = {}
    # main graph inputs and initializers
    (a1, b1, c1) = generate_gemm_inputs_initializers(model.graph,
                                                     config1,
                                                     added_inputs_initializers,
                                                     extend=True)
    (a2, b2, c2) = generate_gemm_inputs_initializers(model.graph,
                                                     config2,
                                                     added_inputs_initializers,
                                                     extend=True)

    shape_output = ["seq", config1["M"], config1["N"]]
    # main graph outputs
    model.graph.output.add().CopyFrom(
        helper.make_tensor_value_info("scan_output", onnx.TensorProto.FLOAT,
                                      shape_output))
    onnx.save(model, model_name)
    return (a1, b1, c1, a2, b2, c2)
コード例 #14
0
def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs):
    updated_graphs = kwargs["updated_graphs"]
    node_to_consumers = kwargs["node_to_consumers"]
    validate_updates = kwargs["validate_updates"]

    nodes_to_update = []
    for node in filter(lambda node: node.op_type == "DequantizeLinear", graph.node):
        # node providing graph output won't have consumer nodes
        consumers = node_to_consumers[node] if node in node_to_consumers else []
        if len(consumers) > 1:
            if not all(consumer in graph.node for consumer in consumers):
                # TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that
                # value as is (no need to handle recursing into the subgraph) and update the consumers in this
                # graph only
                raise IndexError(
                    "DequantizeLinear node output is consumed by a subgraph. " "This is not currently supported."
                )

            nodes_to_update.append(node)

    if validate_updates:
        if nodes_to_update:
            # internal error. we somehow missed an update in the first pass when validate_upates was false
            raise ValueError("Graph still has DequantizeLinear nodes with multiple consumers.")

        return

    if nodes_to_update:
        dup_idx = 0
        new_graph = onnx.GraphProto()
        graph_outputs = set([output.name for output in graph.output])
        for node in graph.node:
            new_graph.node.append(node)
            if node in nodes_to_update:
                is_graph_output = node.output[0] in graph_outputs
                # create duplicate DQ nodes as needed so that there is one consumer per node.
                # this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups.
                # if the node produces a graph output we need a duplicate DQ node for every consumer node.
                # if not, we can leave the first consumer as is and create duplicate nodes for the other consumers.
                start_idx = 0 if is_graph_output else 1
                consumers = list(node_to_consumers[node])[start_idx:]

                for idx, consumer in enumerate(consumers):
                    # create duplicate DQ node
                    duplicate = onnx.NodeProto()
                    duplicate.CopyFrom(node)
                    # update node name for debugging. use the global dup idx for node duplication
                    duplicate.name += f"/qdq_utils_dup_{dup_idx}"

                    # update output. use the local idx for value duplication
                    orig_output = node.output[0]
                    new_output = f"{orig_output}/qdq_utils_dup_{idx}"
                    duplicate.output[0] = new_output

                    # update input on the consumer node.
                    for input_idx, input_name in enumerate(consumer.input):
                        if input_name == orig_output:
                            consumer.input[input_idx] = new_output

                    new_graph.node.append(duplicate)
                    dup_idx += 1

        # replace nodes
        del graph.node[:]
        graph.node.extend(new_graph.node)
        updated_graphs.append(graph)