def attach_onnx_to_onnx(model1: onnx.ModelProto, model2: onnx.ModelProto, prefix2: str): if len(model1.graph.output) < 1 or len(model1.graph.output) != len(model2.graph.output): raise ValueError( 'Incompatible input/output dimensions: {} != {}'.format(len(model1.graph.output), len(model2.graph.output)) ) for i in range(len(model2.graph.initializer)): model2.graph.initializer[i].name = prefix2 + model2.graph.initializer[i].name for o in range(len(model1.graph.output)): for i in range(len(model2.graph.node)): for j in range(len(model2.graph.node[i].input)): if model2.graph.node[i].input[j] == model2.graph.input[o].name: model2.graph.node[i].input[j] = model1.graph.output[o].name else: model2.graph.node[i].input[j] = prefix2 + model2.graph.node[i].input[j] for j in range(len(model2.graph.node[i].output)): if model2.graph.node[i].output[j] != model2.graph.output[o].name: model2.graph.node[i].output[j] = prefix2 + model2.graph.node[i].output[j] graph = onnx.GraphProto() graph.node.extend(model1.graph.node) graph.node.extend(model2.graph.node) graph.name = model1.graph.name + " + " + model2.graph.name graph.input.extend(model1.graph.input) graph.output.extend(model2.graph.output) graph.initializer.extend(model1.graph.initializer) graph.initializer.extend(model2.graph.initializer) graph.value_info.extend(model2.graph.value_info) if model1.graph.doc_string: graph.doc_string = model1.graph.doc_string output_model = onnx.helper.make_model(graph) onnx.checker.check_model(output_model, full_check=True) return output_model
def parse_graph(graph_text: Text) -> onnx.GraphProto: (success, msg, graph_proto_str) = C.parse_graph(graph_text) if success: G = onnx.GraphProto() G.ParseFromString(graph_proto_str) return G else: raise ParseError(msg)
def reset_model(self, opset_ver=None): if opset_ver is not None: opset_imports = [make_opsetid("", opset_ver)] else: opset_imports = [self.opset_import] self.model = make_model_gen_version(onnx.GraphProto(), opset_imports=opset_imports) self.op_counter = collections.Counter()
def __init__(self, opset_ver=OPSET_VER): self.opset_import = make_opsetid("", opset_ver) self.model = make_model_gen_version(onnx.GraphProto(), opset_imports=[self.opset_import]) self.op_counter = collections.Counter() self.ctx = C.CheckerContext() self.ctx.ir_version = self.model.ir_version self.ctx.opset_imports = {'': opset_ver}
def parse_graph(graph_text: str) -> onnx.GraphProto: """Parse a string to build a GraphProto. Arguments: graph_text (string): formatted string Returns: GraphProto """ (success, msg, graph_proto_str) = C.parse_graph(graph_text) if success: G = onnx.GraphProto() G.ParseFromString(graph_proto_str) return G else: raise ParseError(msg)
def convert_rnn_to_scan(node, out_main_graph): assert node.op_type == 'RNN' nf = NodeFactory(out_main_graph) with nf.scoped_prefix(node.output[0]) as scoped_prefix: X = node.input[0] Wa = nf.get_initializer(node.input[1]) Ra = nf.get_initializer(node.input[2]) num_inputs = len(node.input) Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None seq_len = node.input[4] if num_inputs > 4 else None InitHa = node.input[5] if num_inputs > 5 else None direction, num_directions, activations = handle_common_attributes( node, ['Tanh']) hidden_size = NodeFactory.get_attribute(node, 'hidden_size') InitHa = handle_init_state(InitHa, nf, num_directions) batch_size, batch_node = handle_batch_size(X, nf, InitHa is None) if InitHa is None: zero_init_state = default_init_state(X, batch_size, batch_node, hidden_size, nf) scan_outputs = [] scan_h_outputs = [] for direction_index in range(num_directions): # for each direction # X [seq_len, batch_size, input_size] # W [hidden_size, input_size] # R [hidden_size, hidden_size] # B [2*hidden_size] # seq_len [batch_size] # init_h [batch_size, hidden_size] name_prefix = node.output[0] + '_' + str(direction_index) + '_' if InitHa is None: init_h = zero_init_state else: init_h = InitHa[direction_index] input_size = Wa.shape[len(Wa.shape) - 1] W_t = np.transpose( Wa[direction_index]) # [input_size, hidden_size] R_t = np.transpose( Ra[direction_index]) # [hidden_size, hidden_size] B = Ba[direction_index].reshape(2, hidden_size).sum( axis=0) # [hidden_size] X_proj = nf.make_node('Add', [nf.make_node( 'MatMul', [X, W_t]), B]) #[seq_len, batch_size, hidden_size] if num_directions == 1: is_backward = 0 if direction == 'forward' else 1 else: is_backward = direction_index scan_body = onnx.GraphProto() scan_body.name = name_prefix + '_subgraph' nf_body = NodeFactory(out_main_graph, scan_body) with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix: # subgraph inputs X_proj_subgraph = X_proj.name + '_subgraph' prev_h_subgraph = name_prefix + '_h_subgraph' seq_len_subgraph = declare_seq_len_in_subgraph( seq_len, nf_body, X_proj.name, batch_size) nf_body.make_value_info(prev_h_subgraph, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size), usage=NodeFactory.ValueInfoType.input) nf_body.make_value_info(X_proj_subgraph, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size), usage=NodeFactory.ValueInfoType.input) # subgraph nodes # Ht = f(Xt*(W^T) + Ht-1*(R^T) + Wb + Rb) activation_f = activations[direction_index] Ht = nf_body.make_node( activation_f, nf_body.make_node('Add', [ nf_body.make_node('MatMul', [prev_h_subgraph, R_t]), X_proj_subgraph ])) subgraph_outputs = handle_subgraph_outputs( nf_body, seq_len_subgraph, batch_size, hidden_size, [(Ht, prev_h_subgraph)] + ([ (Ht, np.zeros(shape=(), dtype=np.float32)) ] if node.output[0] else [])) scan = nf.make_node( 'Scan', ([seq_len] if seq_len else []) + [init_h, X_proj], { 'body': scan_body, 'scan_input_directions': [is_backward], 'scan_output_directions': [is_backward], 'num_scan_inputs': 1 }, output_names=[ o.name for o in subgraph_outputs[(0 if seq_len else 1):] ]) scan_h_outputs.append(subgraph_outputs[1]) if node.output[0]: scan_outputs.append(subgraph_outputs[2]) handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs], num_directions) # remove old initializers nf.remove_initializer(node.input[1]) nf.remove_initializer(node.input[2]) if num_inputs > 3: nf.remove_initializer(node.input[3]) if num_inputs > 5: nf.remove_initializer(node.input[5]) return True
def convert_gru_to_scan(node, out_main_graph): assert node.op_type == 'GRU' nf = NodeFactory(out_main_graph) with nf.scoped_prefix(node.output[0]) as scoped_prefix: X = node.input[0] Wa = nf.get_initializer(node.input[1]) Ra = nf.get_initializer(node.input[2]) num_inputs = len(node.input) Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None seq_len = node.input[4] if num_inputs > 4 else None InitHa = node.input[5] if num_inputs > 5 else None direction, num_directions, activations = handle_common_attributes( node, ['Sigmoid', 'Tanh']) hidden_size = NodeFactory.get_attribute(node, 'hidden_size') linear_before_reset = NodeFactory.get_attribute( node, 'linear_before_reset') InitHa = handle_init_state(InitHa, nf, num_directions) batch_size, batch_node = handle_batch_size(X, nf, InitHa is None) if InitHa is None: zero_init_state = default_init_state(X, batch_size, batch_node, hidden_size, nf) scan_outputs = [] scan_h_outputs = [] for direction_index in range(num_directions): # for each direction # X [seq_len, batch_size, input_size] # W [3*hidden_size, input_size] # R [3*hidden_size, hidden_size] # B [6*hidden_size] # seq_len [batch_size] # init_h [batch_size, hidden_size] name_prefix = node.output[0] + '_' + str(direction_index) + '_' if InitHa is None: init_h = zero_init_state else: init_h = InitHa[direction_index] input_size = Wa.shape[len(Wa.shape) - 1] W_t = np.transpose( Wa[direction_index]) # [input_size, 3*hidden_size] R_t = np.transpose( Ra[direction_index]) # [hidden_size, 3*hidden_size] Rzr_t, Rh_t = np.hsplit(R_t, [ 2 * hidden_size ]) # [hidden_size, 2*hidden_size] and [hidden_size, hidden_size] Bzr, Bh = np.hsplit( Ba[direction_index].reshape(2, 3 * hidden_size), [2 * hidden_size]) Bzr = Bzr.sum(axis=0) # [2*hidden_size] Wbh = Bh[0] Rbh = Bh[1] X_proj = nf.make_node( 'Add', [nf.make_node('MatMul', [X, W_t]), np.concatenate( (Bzr, Wbh))]) #[seq_len, batch_size, 3*hidden_size] if num_directions == 1: is_backward = 0 if direction == 'forward' else 1 else: is_backward = direction_index scan_body = onnx.GraphProto() scan_body.name = name_prefix + '_subgraph' nf_body = NodeFactory(out_main_graph, scan_body) with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix: # subgraph inputs X_proj_subgraph = X_proj.name + '_subgraph' prev_h_subgraph = name_prefix + '_h_subgraph' seq_len_subgraph = declare_seq_len_in_subgraph( seq_len, nf_body, X_proj.name, batch_size) nf_body.make_value_info(prev_h_subgraph, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size), usage=NodeFactory.ValueInfoType.input) nf_body.make_value_info(X_proj_subgraph, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, 3 * hidden_size), usage=NodeFactory.ValueInfoType.input) # subgraph nodes # zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) # rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) # ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0 # ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0 # Ht = (1 - zt) (.) ht + zt (.) Ht-1 split_X_outputs = ['split_Xzr', 'split_Xh'] nf_body.make_node('Split', X_proj_subgraph, { "axis": 1, "split": [2 * hidden_size, hidden_size] }, output_names=split_X_outputs) nf_body.make_value_info('split_Xzr', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, 2 * hidden_size)) nf_body.make_value_info('split_Xh', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) activation_f, activation_g = activations[direction_index * 2:(direction_index + 1) * 2] if linear_before_reset: prev_h_proj = nf_body.make_node('Add', [ nf_body.make_node('MatMul', [prev_h_subgraph, R_t]), np.concatenate((np.zeros(2 * hidden_size).astype( np.float32), Rbh)) ]) split_prev_h_outputs = ['split_Hzr', 'split_Hh'] nf_body.make_node( 'Split', prev_h_proj, { "axis": 1, "split": [2 * hidden_size, hidden_size] }, output_names=split_prev_h_outputs) nf_body.make_value_info('split_Hzr', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, 2 * hidden_size)) nf_body.make_value_info('split_Hh', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) ztrt = nf_body.make_node( activation_f, nf_body.make_node('Add', ['split_Hzr', 'split_Xzr'])) split_ztrt_outputs = ['split_zt', 'split_rt'] nf_body.make_node('Split', ztrt, { "axis": 1, "split": [hidden_size, hidden_size] }, output_names=split_ztrt_outputs) nf_body.make_value_info('split_zt', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) nf_body.make_value_info('split_rt', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) ht = nf_body.make_node( activation_g, nf_body.make_node('Add', [ nf_body.make_node('Mul', ['split_rt', 'split_Hh']), 'split_Xh' ])) else: ztrt = nf_body.make_node( activation_f, nf_body.make_node('Add', [ nf_body.make_node('MatMul', [prev_h_subgraph, Rzr_t]), 'split_Xzr' ])) split_ztrt_outputs = ['split_zt', 'split_rt'] nf_body.make_node('Split', ztrt, { "axis": 1, "split": [hidden_size, hidden_size] }, output_names=split_ztrt_outputs) nf_body.make_value_info('split_zt', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) nf_body.make_value_info('split_rt', data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) ht = nf_body.make_node( activation_g, nf_body.make_node('Add', [ nf_body.make_node('MatMul', [ nf_body.make_node( 'Mul', [prev_h_subgraph, 'split_rt']), Rh_t ]), 'split_Xh' ])) Ht = nf_body.make_node('Add', [ nf_body.make_node('Mul', [ nf_body.make_node('Sub', [ np.asarray([1]).astype(np.float32), 'split_zt' ]), ht ]), nf_body.make_node('Mul', ['split_zt', prev_h_subgraph]) ]) subgraph_outputs = handle_subgraph_outputs( nf_body, seq_len_subgraph, batch_size, hidden_size, [(Ht, prev_h_subgraph)] + ([ (Ht, np.zeros(shape=(), dtype=np.float32)) ] if node.output[0] else [])) scan = nf.make_node( 'Scan', ([seq_len] if seq_len else []) + [init_h, X_proj], { 'body': scan_body, 'scan_input_directions': [is_backward], 'scan_output_directions': [is_backward], 'num_scan_inputs': 1 }, output_names=[ o.name for o in subgraph_outputs[(0 if seq_len else 1):] ]) scan_h_outputs.append(subgraph_outputs[1]) if node.output[0]: scan_outputs.append(subgraph_outputs[2]) handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs], num_directions) # remove old initializers nf.remove_initializer(node.input[1]) nf.remove_initializer(node.input[2]) if num_inputs > 3: nf.remove_initializer(node.input[3]) if num_inputs > 5: nf.remove_initializer(node.input[5], allow_empty=True) return True
def convert_lstm_to_scan(node, out_main_graph): assert node.op_type == 'LSTM' nf = NodeFactory(out_main_graph) with nf.scoped_prefix(node.output[0]) as scoped_prefix: X = node.input[0] Wa = nf.get_initializer(node.input[1]) Ra = nf.get_initializer(node.input[2]) num_inputs = len(node.input) Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None seq_len = node.input[4] if num_inputs > 4 else None InitHa = node.input[5] if num_inputs > 5 else None InitCa = node.input[6] if num_inputs > 6 else None PB = node.input[7] if num_inputs > 7 else None # TODO: support peephole assert not PB direction, num_directions, activations = handle_common_attributes( node, ['Sigmoid', 'Tanh', 'Tanh']) hidden_size = NodeFactory.get_attribute(node, 'hidden_size') input_forget = NodeFactory.get_attribute(node, 'input_forget') # TODO: implement input_forget = 1 assert not (input_forget != None and input_forget == 1) # split initializer if needed: is_same_init = InitHa == InitCa InitHa = handle_init_state(InitHa, nf, num_directions) if is_same_init: InitCa = InitHa else: InitCa = handle_init_state(InitCa, nf, num_directions) batch_size, batch_node = handle_batch_size( X, nf, InitHa is None or InitCa is None) scan_outputs = [] scan_h_outputs = [] scan_c_outputs = [] for direction_index in range(num_directions): # for each direction # X [seq_len, batch_size, input_size] # W [4*hidden_size, input_size] # R [4*hidden_size, hidden_size] # B [8*hidden_size] # seq_len [batch_size] # init_h [batch_size, hidden_size] # init_c [batch_size, hidden_size] # PB [3*hidden_size] name_prefix = node.output[0] + '_' + str(direction_index) + '_' if InitHa is None: init_h = default_init_state(X, batch_size, batch_node, hidden_size, nf, '_H') else: init_h = InitHa[direction_index] if InitCa is None: init_c = default_init_state(X, batch_size, batch_node, hidden_size, nf, '_C') else: init_c = InitCa[direction_index] input_size = Wa.shape[len(Wa.shape) - 1] Wt = np.transpose(Wa[direction_index]) Rt = np.transpose(Ra[direction_index]) B = Ba[direction_index].reshape(2, 4 * hidden_size).sum( axis=0) # [4*hidden_size] X_proj = nf.make_node( 'MatMul', [X, Wt]) #[seq_len, batch_size, 4*hidden_size] X_proj = nf.make_node('Add', [X_proj, B]) if num_directions == 1: is_backward = 0 if direction == 'forward' else 1 else: is_backward = direction_index scan_body = onnx.GraphProto() scan_body.name = name_prefix + '_subgraph' nf_body = NodeFactory(out_main_graph, scan_body) with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix: # subgraph inputs X_proj_subgraph = X_proj.name + '_subgraph' prev_h_subgraph = name_prefix + '_h_subgraph' prev_c_subgraph = name_prefix + '_c_subgraph' seq_len_subgraph = declare_seq_len_in_subgraph( seq_len, nf_body, X_proj.name, batch_size) for subgraph_i in [prev_h_subgraph, prev_c_subgraph]: nf_body.make_value_info( subgraph_i, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size), usage=NodeFactory.ValueInfoType.input) nf_body.make_value_info(X_proj_subgraph, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, 4 * hidden_size), usage=NodeFactory.ValueInfoType.input) # subgraph nodes # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) # Ct = ft (.) Ct-1 + it (.) ct # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) # Ht = ot (.) h(Ct) prev_h_proj = nf_body.make_node('MatMul', [prev_h_subgraph, Rt]) sum_x_proj_h_proj_bias = nf_body.make_node( 'Add', [X_proj_subgraph, prev_h_proj]) split_outputs = ['split_i', 'split_o', 'split_f', 'split_c'] nf_body.make_node('Split', sum_x_proj_h_proj_bias, { "axis": 1, "split": [hidden_size] * 4 }, output_names=split_outputs) # manually add shape inference to split outputs for split_o in split_outputs: nf_body.make_value_info(split_o, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) activation_f, activation_g, activation_h = activations[ direction_index * 3:(direction_index + 1) * 3] it = nf_body.make_node(activation_f, 'split_i') ft = nf_body.make_node(activation_f, 'split_f') ct = nf_body.make_node(activation_g, 'split_c') c_subgraph = nf_body.make_node('Add', [ nf_body.make_node('Mul', [ft, prev_c_subgraph]), nf_body.make_node('Mul', [it, ct]) ]) ot = nf_body.make_node(activation_f, 'split_o') h_subgraph = nf_body.make_node( 'Mul', [ot, nf_body.make_node(activation_h, c_subgraph)]) subgraph_outputs = handle_subgraph_outputs( nf_body, seq_len_subgraph, batch_size, hidden_size, [(h_subgraph, prev_h_subgraph), (c_subgraph, prev_c_subgraph)] + ([(h_subgraph, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else [])) # skip scan output if node.output[0] is empty scan = nf.make_node( 'Scan', ([seq_len] if seq_len else []) + [init_h, init_c, X_proj], { 'body': scan_body, 'scan_input_directions': [is_backward], 'scan_output_directions': [is_backward], 'num_scan_inputs': 1 }, output_names=[ o.name for o in subgraph_outputs[(0 if seq_len else 1):] ]) scan_h_outputs.append(subgraph_outputs[1]) scan_c_outputs.append(subgraph_outputs[2]) if node.output[0]: scan_outputs.append(subgraph_outputs[3]) handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs, scan_c_outputs], num_directions) # remove old initializers nf.remove_initializer(node.input[1]) nf.remove_initializer(node.input[2]) if num_inputs > 3: nf.remove_initializer(node.input[3]) if num_inputs > 5: nf.remove_initializer(node.input[5], allow_empty=True) if num_inputs > 6: nf.remove_initializer(node.input[6], allow_empty=True) return True
def graph_def_to_onnx_graph( cls, graph_def, input_names=None, output_names=None, input_shapes=None, constants=None, value_info=None, opset_version=None, workspace=None, verbose=True, ): input_names = [] if input_names is None else input_names output_names = [] if output_names is None else output_names constants = {} if constants is None else constants value_info = {} if value_info is None else value_info if not nest.is_sequence(input_names): raise ValueError('<input_names> should be a sequence.') if not nest.is_sequence(output_names): raise ValueError('<output_names> should be a sequence.') if not isinstance(constants, dict): raise ValueError('<constants> should be a dict with name -> value.') if not isinstance(value_info, dict): raise ValueError('<value_info> should be a dict with name -> (dtype, shape).') # Determine the opset version to select exporters. if opset_version is None: opset_version = cls._check_opset_version(opset_version) # Create aliases for blobs. blob_aliases = {} for i, alias in enumerate(output_names): blob_aliases[graph_def.output[i]] = alias workspace.RegisterAlias(graph_def.output[i], alias) if graph_def.output[i] in value_info: value_info[alias] = value_info[graph_def.output[i]] for i, alias in enumerate(input_names): blob_aliases[graph_def.input[i]] = alias workspace.RegisterAlias(graph_def.input[i], alias) if graph_def.input[i] in value_info: value_info[alias] = value_info[graph_def.input[i]] # Maybe rewrite the input shapes for future development. # A common case is that we should fill ``-1`` for dynamic dimension # in the inference runtime like TensorRT. if input_shapes is not None: if isinstance(input_shapes, dict): for k, v in input_shapes.items(): value_info[k] = (value_info[k][0], v) else: for k, v in zip(graph_def.input[:], input_shapes): value_info[k] = (value_info[k][0], v) # Prepare to make the graph. onnx_graph = onnx.GraphProto(name=graph_def.name if len(graph_def.name) > 0 else 'onnx-model') blob_shapes, blob_names = {}, {} blob_versions = collections.defaultdict( int, **dict((blob_aliases.get(k, k), 1) for k in helper.collect_inputs(graph_def))) initializers, seen_initializers = [], set() # Build translator context. context = export_util.TranslatorContext( workspace=workspace, blob_names=blob_names, blob_shapes=blob_shapes, blob_versions=blob_versions, opset_version=opset_version, ) # Add nodes. for op in graph_def.op: # Get the shape of inputs and outputs. for name in itertools.chain(op.input, op.output): impl = workspace.GetTensor(name) if impl is not None: blob_shapes[name] = impl.dims else: blob_shapes[name] = value_info[name][1] # Translate definition. nodes, const_tensors = cls._make_node(op, context) # Rewritten for names. for node in nodes: node.input[:] = [blob_aliases.get(e, e) for e in node.input] node.output[:] = [blob_aliases.get(e, e) for e in node.output] cls._rewrite_for_ssa(node, context) # Convert constant outputs if necessary. if None in nodes: const_tensors = [helper.from_tensor(name, workspace) for name in op.output] else: onnx_graph.node.extend(nodes) # Merge constant tensors. if const_tensors is not None: value_info = {**value_info, **dict((e.name, (e.data_type, e.dims)) for e in const_tensors)} for tensor in const_tensors: if tensor.name not in seen_initializers: initializers.append(tensor) seen_initializers.add(tensor.name) # Add constants. if constants is not None: for k, v in constants.items(): initializers.append(helper.from_array(v, name=k)) # Add inputs. for name in helper.collect_inputs(onnx_graph): try: onnx_graph.input.extend([ helper.make_tensor_value_info( name=name, elem_type=value_info[name][0], shape=value_info[name][1])]) except KeyError: impl = workspace.GetTensor(name) if impl is not None: initializer = helper.from_tensor(name, workspace) onnx_graph.input.extend([ helper.make_tensor_value_info( name=name, elem_type=initializer.data_type, shape=initializer.dims)]) if name not in seen_initializers: initializers.append(initializer) seen_initializers.add(initializer.name) else: raise ValueError( 'Info of tensor `{}` is missing, ' 'specify it in <value_info>.'.format(name)) # Add initializers. onnx_graph.initializer.extend(initializers) # Add outputs. onnx_graph.output.extend( helper.make_tensor_value_info( name=blob_names.get(name_v2, name_v2), elem_type=value_info[name_v2][0], shape=value_info[name_v2][1]) for name_v2 in [blob_aliases.get(name, name) for name in set(graph_def.output)]) if verbose: print(helper.printable_graph(onnx_graph)) return onnx_graph
def generate_gemm_scan_model(model_name, config1, config2): model = onnx.ModelProto() model.ir_version = onnx.IR_VERSION opset = model.opset_import.add() opset.version = 11 # Based on the given configs, we would have a model like below: # Main graph, where C is an initializer and passed as the input state for the Scan: # C input_1A input_2A # \ | / # \ | / # Scan # | # output # # Scan's subgraph, where out_C is the output state of the Scan # input_1A B C input_2A B C # \ | / \ | / # \ |/ \ |/ # Gemm_1 Gemm_2 # \ / # \ / # Sub # / \ # out_C output # # config1 and config2 configure alpha/beta/transA/transB for Gemm_1 and Gemm_2, respectively. scan_body = onnx.GraphProto() scan_body.name = 'gemm_subgraph' shape_c1 = [config1['M'], config1['N']] shape_c2 = [config2['M'], config2['N']] assert shape_c1 == shape_c2 C1 = config1['C'] C2 = config2['C'] scan_node_inputs = [] postfix = '_subgraph' states_cnt = 0 # make sure we create state inputs first if config1['withC']: assert config1['initC'] states_cnt = states_cnt + 1 scan_node_inputs.append(C1) scan_body.input.add().CopyFrom( helper.make_tensor_value_info('in_' + C1 + postfix, onnx.TensorProto.FLOAT, shape_c1)) if config2['withC'] and C1 != C2: assert config2['initC'] states_cnt = states_cnt + 1 scan_node_inputs.append(C2) scan_body.input.add().CopyFrom( helper.make_tensor_value_info('in_' + C2 + postfix, onnx.TensorProto.FLOAT, shape_c2)) added_inputs_subgraph = {} generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config1, added_inputs_subgraph) generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config2, added_inputs_subgraph) sub_output = 'sub_output' + postfix # create a Sub op instead of Add to break the MatMul-to-Gemm rewriting rule # performed by the ort optimizer sub_node = helper.make_node( 'Sub', [config1['Y'] + postfix, config2['Y'] + postfix], [sub_output], 'sub_node') scan_body.node.add().CopyFrom(sub_node) scan_node_outputs = [] # create state outputs if config1['withC']: id_node1 = onnx.helper.make_node('Identity', [sub_output], ['out_' + C1 + postfix], 'id_node1') scan_body.node.add().CopyFrom(id_node1) scan_body.output.add().CopyFrom( helper.make_tensor_value_info('out_' + C1 + postfix, onnx.TensorProto.FLOAT, shape_c1)) scan_node_outputs.append('out_' + C1) if config2['withC'] and C1 != C2: id_node2 = onnx.helper.make_node('Identity', [sub_output], ['out_' + C2 + postfix], 'id_node2') scan_body.node.add().CopyFrom(id_node2) scan_body.output.add().CopyFrom( helper.make_tensor_value_info('out_' + C2 + postfix, onnx.TensorProto.FLOAT, shape_c2)) scan_node_outputs.append('out_' + C2) # scan subgraph output scan_body.output.add().CopyFrom( helper.make_tensor_value_info(sub_output, onnx.TensorProto.FLOAT, shape_c1)) scan_node_outputs.append('scan_output') # create scan node inputs_cnt = len(scan_node_inputs) - states_cnt assert inputs_cnt > 0 scan_node = onnx.helper.make_node('Scan', scan_node_inputs, scan_node_outputs, 'scan_node', num_scan_inputs=inputs_cnt, body=scan_body) model.graph.node.add().CopyFrom(scan_node) added_inputs_initializers = {} # main graph inputs and initializers (a1, b1, c1) = generate_gemm_inputs_initializers(model.graph, config1, added_inputs_initializers, extend=True) (a2, b2, c2) = generate_gemm_inputs_initializers(model.graph, config2, added_inputs_initializers, extend=True) shape_output = ['seq', config1['M'], config1['N']] # main graph outputs model.graph.output.add().CopyFrom( helper.make_tensor_value_info('scan_output', onnx.TensorProto.FLOAT, shape_output)) onnx.save(model, model_name) return (a1, b1, c1, a2, b2, c2)
def optimize_input_projection(input_model, output_model): in_mp = onnx.load(input_model) out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input out_mp.graph.ClearField('node') nf = NodeFactory(out_mp.graph, prefix='opt_inproj_') initializers = dict([(i.name, i) for i in in_mp.graph.initializer]) # first find possible fused SVD and do constant folding on MatMul of initializers const_matmuls = [ n for n in in_mp.graph.node if n.op_type == 'MatMul' and all([i in initializers for i in n.input]) ] for mm in const_matmuls: lhs = numpy_helper.to_array(initializers[mm.input[0]]) rhs = numpy_helper.to_array(initializers[mm.input[1]]) val = np.matmul(lhs, rhs) new_initializer = out_mp.graph.initializer.add() new_initializer.CopyFrom(numpy_helper.from_array(val, mm.output[0])) if not [ n for n in in_mp.graph.node if n != mm and mm.input[0] in n.input ]: nf.remove_initializer(mm.input[0]) if not [ n for n in in_mp.graph.node if n != mm and mm.input[1] in n.input ]: nf.remove_initializer(mm.input[1]) initializers = dict([(i.name, i) for i in out_mp.graph.initializer]) # remove const_matmul output from graph outputs new_outputs = [ i for i in out_mp.graph.output if not [m for m in const_matmuls if m.output[0] == i.name] ] out_mp.graph.ClearField('output') out_mp.graph.output.extend(new_outputs) for in_n in in_mp.graph.node: if in_n in const_matmuls: continue optimize_scan = False if in_n.op_type == 'Scan': in_sg = NodeFactory.get_attribute(in_n, 'body') num_scan_inputs = NodeFactory.get_attribute( in_n, 'num_scan_inputs') # only support 1 scan input if num_scan_inputs == 1: optimize_scan = True # copy the node if it's not the scan node that is supported at the moment if not optimize_scan: out_n = out_mp.graph.node.add() out_n.CopyFrom(in_n) continue scan_input_directions = NodeFactory.get_attribute( in_n, 'scan_input_directions') scan_output_directions = NodeFactory.get_attribute( in_n, 'scan_output_directions') out_sg = onnx.GraphProto() out_sg.CopyFrom(in_sg) out_sg.ClearField('node') nf_subgraph = NodeFactory(out_mp.graph, out_sg, prefix='opt_inproj_sg_' + in_n.name + '_') new_inputs = list(in_n.input) in_sg_inputs = [i.name for i in in_sg.input] replaced_matmul = None for in_sn in in_sg.node: if in_sn.op_type == 'Concat' and len(in_sn.input) == 2 and all( [i in in_sg_inputs for i in in_sn.input]): # make sure the concat's inputs are scan input and scan state if NodeFactory.get_attribute(in_sn, 'axis') != len( in_sg.input[-1].type.tensor_type.shape.dim) - 1: continue # must concat last dim matmul_node = [ nn for nn in in_sg.node if nn.op_type == 'MatMul' and in_sn.output[0] in nn.input ] if not matmul_node: continue replaced_matmul = matmul_node[0] assert replaced_matmul.input[1] in initializers aa = nf.get_initializer(replaced_matmul.input[1]) input_size = in_sg.input[-1].type.tensor_type.shape.dim[ -1].dim_value if in_sg_inputs[-1] == in_sn.input[0]: hidden_idx = 1 input_proj_weights, hidden_proj_weights = np.vsplit( aa, [input_size]) else: hidden_idx = 0 hidden_proj_weights, input_proj_weights = np.vsplit( aa, [aa.shape[-1] - input_size]) # add matmul for input_proj outside of Scan input_proj = nf.make_node('MatMul', [new_inputs[-1], input_proj_weights]) input_proj.doc_string = replaced_matmul.doc_string new_inputs[-1] = input_proj.name out_sg.input[-1].type.tensor_type.shape.dim[ -1].dim_value = input_proj_weights.shape[-1] # add matmul for hidden_proj inside Scan hidden_proj = nf_subgraph.make_node( 'MatMul', [in_sn.input[hidden_idx], hidden_proj_weights]) hidden_proj.doc_string = replaced_matmul.doc_string nf_subgraph.make_node('Add', [out_sg.input[-1].name, hidden_proj], output_names=replaced_matmul.output[0]) # remove initializer of concat matmul if not [ n for n in in_mp.graph.node if n != in_n and replaced_matmul.input[1] in n.input ]: nf.remove_initializer(replaced_matmul.input[1]) elif in_sn != replaced_matmul: out_sg.node.add().CopyFrom(in_sn) scan = nf.make_node( 'Scan', new_inputs, { 'body': out_sg, 'scan_input_directions': scan_input_directions, 'scan_output_directions': scan_output_directions, 'num_scan_inputs': num_scan_inputs }, output_names=list(in_n.output)) scan.name = in_n.name scan.doc_string = in_n.doc_string onnx.save(out_mp, output_model)
def attach_onnx_to_onnx_2(model1: onnx.ModelProto, model2: onnx.ModelProto, model3: onnx.ModelProto, prefix2: str, prefix3: str): if len(model1.graph.output) < 1 or (len(model1.graph.output) != len( model2.graph.input)) or (len(model1.graph.output) != len( model3.graph.input)): raise ValueError( 'Incompatible input/output dimensions: {} != {} or {}'.format( len(model1.graph.output), len(model2.graph.input), len(model3.graph.input))) for i in range(len(model2.graph.initializer)): model2.graph.initializer[ i].name = prefix2 + model2.graph.initializer[i].name for i in range(len(model2.graph.node)): model2.graph.node[i].name = prefix2 + model2.graph.node[i].name for i in range(len(model3.graph.initializer)): model3.graph.initializer[ i].name = prefix3 + model3.graph.initializer[i].name for i in range(len(model3.graph.node)): model3.graph.node[i].name = prefix3 + model3.graph.node[i].name for o in range(len(model1.graph.output)): for i in range(len(model2.graph.node)): for j in range(len(model2.graph.node[i].input)): if model2.graph.node[i].input[j] == model2.graph.input[o].name: model2.graph.node[i].input[j] = model1.graph.output[o].name else: model2.graph.node[i].input[ j] = prefix2 + model2.graph.node[i].input[j] for j in range(len(model2.graph.node[i].output)): if model2.graph.node[i].output[j] != model2.graph.output[ o].name: model2.graph.node[i].output[ j] = prefix2 + model2.graph.node[i].output[j] else: model2.graph.output[ o].name = prefix2 + model2.graph.output[o].name model2.graph.node[i].output[j] = model2.graph.output[ o].name for o in range(len(model1.graph.output)): for i in range(len(model3.graph.node)): for j in range(len(model3.graph.node[i].input)): if model3.graph.node[i].input[j] == model3.graph.input[o].name: model3.graph.node[i].input[j] = model1.graph.output[o].name else: model3.graph.node[i].input[ j] = prefix3 + model3.graph.node[i].input[j] for j in range(len(model3.graph.node[i].output)): if model3.graph.node[i].output[j] != model3.graph.output[ o].name: model3.graph.node[i].output[ j] = prefix3 + model3.graph.node[i].output[j] else: model3.graph.output[ o].name = prefix3 + model3.graph.output[o].name model3.graph.node[i].output[j] = model3.graph.output[ o].name # for i in range(len(model2.graph.node)): # for j in range(len(model2.graph.node[i].input)): # for o in range(len(model1.graph.output)): # if model2.graph.node[i].input[j] == model2.graph.input[o].name: # model2.graph.node[i].input[j] = model1.graph.output[o].name # else: # model2.graph.node[i].input[j] = prefix2 + model2.graph.node[i].input[j] # for j in range(len(model2.graph.node[i].output)): # inner_output = True # for p in range(len(model2.graph.output)): # if model2.graph.node[i].output[j] == model2.graph.output[p].name: # inner_output = False # break # if inner_output: # model2.graph.node[i].output[j] = prefix2 + model2.graph.node[i].output[j] # for i in range(len(model3.graph.node)): # for j in range(len(model3.graph.node[i].input)): # for o in range(len(model1.graph.output)): # if model3.graph.node[i].input[j] == model3.graph.input[o].name: # model3.graph.node[i].input[j] = model1.graph.output[o].name # else: # model3.graph.node[i].input[j] = prefix3 + model3.graph.node[i].input[j] # for j in range(len(model3.graph.node[i].output)): # inner_output = True # for p in range(len(model3.graph.output)): # if model3.graph.node[i].output[j] == model3.graph.output[p].name: # inner_output = False # break # if inner_output: # model3.graph.node[i].output[j] = prefix3 + model3.graph.node[i].output[j] graph = onnx.GraphProto() graph.node.extend(model1.graph.node) graph.node.extend(model2.graph.node) graph.node.extend(model3.graph.node) graph.name = model1.graph.name + " + " + model2.graph.name + " + " + model3.graph.name # torch-jit-export + torch-jit-export + torch-jit-export graph.input.extend(model1.graph.input) graph.output.extend(model2.graph.output) graph.output.extend(model3.graph.output) graph.initializer.extend(model1.graph.initializer) graph.initializer.extend(model2.graph.initializer) graph.initializer.extend(model3.graph.initializer) graph.value_info.extend(model2.graph.value_info) graph.value_info.extend(model3.graph.value_info) if model1.graph.doc_string: graph.doc_string = model1.graph.doc_string output_model = onnx.helper.make_model(graph, opset_imports=model1.opset_import) onnx.checker.check_model(output_model, full_check=True) return output_model
def generate_gemm_scan_model(model_name, config1, config2): model = onnx.ModelProto() model.ir_version = 7 # use stable onnx ir version opset = model.opset_import.add() opset.version = 11 # Based on the given configs, we would have a model like below: # Main graph, where C is an initializer and passed as the input state for the Scan: # C input_1A input_2A # \ | / # \ | / # Scan # | # output # # Scan's subgraph, where out_C is the output state of the Scan # input_1A B C input_2A B C # \ | / \ | / # \ |/ \ |/ # Gemm_1 Gemm_2 # \ / # \ / # Sub # / \ # out_C output # # config1 and config2 configure alpha/beta/transA/transB for Gemm_1 and Gemm_2, respectively. scan_body = onnx.GraphProto() scan_body.name = "gemm_subgraph" shape_c1 = [config1["M"], config1["N"]] shape_c2 = [config2["M"], config2["N"]] assert shape_c1 == shape_c2 C1 = config1["C"] C2 = config2["C"] scan_node_inputs = [] postfix = "_subgraph" states_cnt = 0 # make sure we create state inputs first if config1["withC"]: assert config1["initC"] states_cnt = states_cnt + 1 scan_node_inputs.append(C1) scan_body.input.add().CopyFrom( helper.make_tensor_value_info("in_" + C1 + postfix, onnx.TensorProto.FLOAT, shape_c1)) if config2["withC"] and C1 != C2: assert config2["initC"] states_cnt = states_cnt + 1 scan_node_inputs.append(C2) scan_body.input.add().CopyFrom( helper.make_tensor_value_info("in_" + C2 + postfix, onnx.TensorProto.FLOAT, shape_c2)) added_inputs_subgraph = {} generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config1, added_inputs_subgraph) generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config2, added_inputs_subgraph) sub_output = "sub_output" + postfix # create a Sub op instead of Add to break the MatMul-to-Gemm rewriting rule # performed by the ort optimizer sub_node = helper.make_node( "Sub", [config1["Y"] + postfix, config2["Y"] + postfix], [sub_output], "sub_node", ) scan_body.node.add().CopyFrom(sub_node) scan_node_outputs = [] # create state outputs if config1["withC"]: id_node1 = onnx.helper.make_node("Identity", [sub_output], ["out_" + C1 + postfix], "id_node1") scan_body.node.add().CopyFrom(id_node1) scan_body.output.add().CopyFrom( helper.make_tensor_value_info("out_" + C1 + postfix, onnx.TensorProto.FLOAT, shape_c1)) scan_node_outputs.append("out_" + C1) if config2["withC"] and C1 != C2: id_node2 = onnx.helper.make_node("Identity", [sub_output], ["out_" + C2 + postfix], "id_node2") scan_body.node.add().CopyFrom(id_node2) scan_body.output.add().CopyFrom( helper.make_tensor_value_info("out_" + C2 + postfix, onnx.TensorProto.FLOAT, shape_c2)) scan_node_outputs.append("out_" + C2) # scan subgraph output scan_body.output.add().CopyFrom( helper.make_tensor_value_info(sub_output, onnx.TensorProto.FLOAT, shape_c1)) scan_node_outputs.append("scan_output") # create scan node inputs_cnt = len(scan_node_inputs) - states_cnt assert inputs_cnt > 0 scan_node = onnx.helper.make_node( "Scan", scan_node_inputs, scan_node_outputs, "scan_node", num_scan_inputs=inputs_cnt, body=scan_body, ) model.graph.node.add().CopyFrom(scan_node) added_inputs_initializers = {} # main graph inputs and initializers (a1, b1, c1) = generate_gemm_inputs_initializers(model.graph, config1, added_inputs_initializers, extend=True) (a2, b2, c2) = generate_gemm_inputs_initializers(model.graph, config2, added_inputs_initializers, extend=True) shape_output = ["seq", config1["M"], config1["N"]] # main graph outputs model.graph.output.add().CopyFrom( helper.make_tensor_value_info("scan_output", onnx.TensorProto.FLOAT, shape_output)) onnx.save(model, model_name) return (a1, b1, c1, a2, b2, c2)
def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs): updated_graphs = kwargs["updated_graphs"] node_to_consumers = kwargs["node_to_consumers"] validate_updates = kwargs["validate_updates"] nodes_to_update = [] for node in filter(lambda node: node.op_type == "DequantizeLinear", graph.node): # node providing graph output won't have consumer nodes consumers = node_to_consumers[node] if node in node_to_consumers else [] if len(consumers) > 1: if not all(consumer in graph.node for consumer in consumers): # TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that # value as is (no need to handle recursing into the subgraph) and update the consumers in this # graph only raise IndexError( "DequantizeLinear node output is consumed by a subgraph. " "This is not currently supported." ) nodes_to_update.append(node) if validate_updates: if nodes_to_update: # internal error. we somehow missed an update in the first pass when validate_upates was false raise ValueError("Graph still has DequantizeLinear nodes with multiple consumers.") return if nodes_to_update: dup_idx = 0 new_graph = onnx.GraphProto() graph_outputs = set([output.name for output in graph.output]) for node in graph.node: new_graph.node.append(node) if node in nodes_to_update: is_graph_output = node.output[0] in graph_outputs # create duplicate DQ nodes as needed so that there is one consumer per node. # this allows us to cleanly create a QDQ node group with no DQ nodes shared with other QDQ node groups. # if the node produces a graph output we need a duplicate DQ node for every consumer node. # if not, we can leave the first consumer as is and create duplicate nodes for the other consumers. start_idx = 0 if is_graph_output else 1 consumers = list(node_to_consumers[node])[start_idx:] for idx, consumer in enumerate(consumers): # create duplicate DQ node duplicate = onnx.NodeProto() duplicate.CopyFrom(node) # update node name for debugging. use the global dup idx for node duplication duplicate.name += f"/qdq_utils_dup_{dup_idx}" # update output. use the local idx for value duplication orig_output = node.output[0] new_output = f"{orig_output}/qdq_utils_dup_{idx}" duplicate.output[0] = new_output # update input on the consumer node. for input_idx, input_name in enumerate(consumer.input): if input_name == orig_output: consumer.input[input_idx] = new_output new_graph.node.append(duplicate) dup_idx += 1 # replace nodes del graph.node[:] graph.node.extend(new_graph.node) updated_graphs.append(graph)