def create_dummy_inputs(onnx_model_path, batch_size, sequence_length, samples): from onnx import TensorProto from onnx_model import OnnxModel onnx_model = OnnxModel(onnx.load(onnx_model_path)) dummy_inputs = {} for input in onnx_model.get_graph_inputs_excluding_initializers(): shape = get_shape_from_type_proto(input.type) symbol_dims = [] for i, dim in enumerate(shape): if type(dim) == str: symbol_dims.append(i) # allowed symbolic dimensions: batch_size and sequence_length if len(symbol_dims) > 2: return None if len(symbol_dims) > 0: shape[symbol_dims[0]] = batch_size if len(symbol_dims) > 1: shape[symbol_dims[1]] = sequence_length elem_type = input.type.tensor_type.elem_type assert elem_type in [ TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64 ] data_type = numpy.float32 if elem_type == TensorProto.FLOAT else ( numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) data = numpy.ones(shape, dtype=data_type) dummy_inputs[input.name] = data all_inputs = [dummy_inputs for _ in range(samples)] return all_inputs
def create_gpt2_inputs(onnx_model_path, batch_size, sequence_length, past_sequence_length, samples): from onnx import TensorProto from onnx_model import OnnxModel onnx_model = OnnxModel(onnx.load(onnx_model_path)) # The symbolic name shall be same as those used in Gpt2Helper.export_onnx(...) function. symbols = { 'batch_size': batch_size, 'seq_len': sequence_length, 'past_seq_len': past_sequence_length, 'total_seq_len': sequence_length + past_sequence_length } dummy_inputs = {} for input in onnx_model.get_graph_inputs_excluding_initializers(): shape = get_shape_from_type_proto(input.type) for i, dim in enumerate(shape): if type(dim) == str and dim not in symbols.keys(): raise RuntimeError(f"symbol is not supported: {dim}") else: shape[i] = symbols[dim] elem_type = input.type.tensor_type.elem_type assert elem_type in [ TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64 ] data_type = numpy.float32 if elem_type == TensorProto.FLOAT else ( numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) data = numpy.ones(shape, dtype=data_type) dummy_inputs[input.name] = data all_inputs = [dummy_inputs for _ in range(samples)] return all_inputs
def create_longformer_inputs(onnx_model_path, batch_size, sequence_length, global_length, samples): from onnx import TensorProto from onnx_model import OnnxModel onnx_model = OnnxModel(onnx.load(onnx_model_path)) symbols = {'batch_size': batch_size, 'sequence_length': sequence_length} dummy_inputs = {} for input in onnx_model.get_graph_inputs_excluding_initializers(): shape = get_shape_from_type_proto(input.type) for i, dim in enumerate(shape): if type(dim) == str and dim not in symbols.keys(): raise RuntimeError(f"symbol is not supported: {dim}") else: shape[i] = symbols[dim] elem_type = input.type.tensor_type.elem_type assert elem_type in [ TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64 ] data_type = numpy.float32 if elem_type == TensorProto.FLOAT else ( numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) if "global" in input.name: data = numpy.zeros(shape, dtype=data_type) data[:, :global_length] = 1 else: data = numpy.ones(shape, dtype=data_type) dummy_inputs[input.name] = data all_inputs = [dummy_inputs for _ in range(samples)] return all_inputs
def verify_fusion(self, optimized_model, expected_model_filename): optimized_model.topological_sort() expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) expected_model = OnnxModel(onnx.load(expected_model_path)) expected_model.topological_sort() self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
def get_last_matmul_node_name(raw_onnx_model: str): model = onnx.load(raw_onnx_model) onnx_model = OnnxModel(model) output_name_to_node = onnx_model.output_name_to_node() assert model.graph.output[0].name in output_name_to_node node = output_name_to_node[model.graph.output[0].name] if node.op_type == "MatMul": logger.info(f"Found last MatMul node for logits: {node.name}") return node.name logger.warning( f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}" ) return None
def get_bert_inputs( onnx_file: str, input_ids_name: Optional[str] = None, segment_ids_name: Optional[str] = None, input_mask_name: Optional[str] = None, ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """Find graph inputs for BERT model. First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. Args: onnx_file (str): onnx model path input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None. Returns: Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids, segment_ids and input_mask """ model = ModelProto() with open(onnx_file, "rb") as file: model.ParseFromString(file.read()) onnx_model = OnnxModel(model) return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
def test_pytorch_model_0_gpu_onnxruntime(self): if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers( ): print( "skip test_pytorch_model_0_gpu_onnxruntime since no gpu found") return input = _get_test_model_path('bert_pytorch_0') output = 'temp.onnx' optimize_by_onnxruntime(input, use_gpu=True, optimized_model_path=output) model = ModelProto() with open(output, "rb") as f: model.ParseFromString(f.read()) os.remove(output) bert_model = OnnxModel(model) expected_node_count = { 'EmbedLayerNormalization': 1, 'Attention': 12, 'SkipLayerNormalization': 24, 'Gelu': 0, 'FastGelu': 12, 'BiasGelu': 0 } self.verify_node_count(bert_model, expected_node_count, 'test_pytorch_model_0_gpu_onnxruntime')
def main(): args = parse_arguments() setup_logging(args.verbose) output_names = None if args.output_names is None else args.output_names.split( ";") model = ModelProto() with open(args.input, "rb") as input_file: model.ParseFromString(input_file.read()) onnx_model = OnnxModel(model) optimizer = BertOnnxModelShapeOptimizer(onnx_model) optimizer.optimize( args.output, args.input_ids, args.segment_ids, args.input_mask, args.enable_shape_opt, args.enable_reshape_opt, output_names, args.batch_size, args.sequence_length, args.verbose, )
def run(args): num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count( logical=False) # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package. if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = str(num_threads) from onnx import load from onnx_model import OnnxModel onnx_model = OnnxModel(load(args.model)) all_inputs = None if args.dummy_inputs == "bert": all_inputs = create_bert_inputs( onnx_model, args.batch_size, args.sequence_length, args.samples, args.input_ids_name, args.segment_ids_name, args.input_mask_name, ) elif args.dummy_inputs == "gpt2": all_inputs = create_gpt2_inputs( onnx_model, args.batch_size, args.sequence_length, args.past_sequence_length, args.samples, ) elif args.dummy_inputs == "longformer": all_inputs = create_longformer_inputs( onnx_model, args.batch_size, args.sequence_length, args.global_length, args.samples, ) else: # default all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples) profile_file = run_profile( args.model, args.use_gpu, args.provider, args.basic_optimization, args.thread_num, all_inputs, ) return profile_file
def test_gpt2_past_fp16(self): input_model_path = _get_test_model_path('gpt2_past') model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True)) model.convert_model_float32_to_float16(cast_input_output=False, use_symbolic_shape_infer=False) for input in model.graph().input[1:]: self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16) for output in model.graph().output: self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)
def run(args): num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count( logical=False) # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package. if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = str(num_threads) from onnx import load from onnx_model import OnnxModel onnx_model = OnnxModel(load(args.model)) all_inputs = None if args.dummy_inputs == 'bert': all_inputs = create_bert_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples, args.input_ids_name, args.segment_ids_name, args.input_mask_name) elif args.dummy_inputs == 'gpt2': all_inputs = create_gpt2_inputs(onnx_model, args.batch_size, args.sequence_length, args.past_sequence_length, args.samples) elif args.dummy_inputs == 'longformer': all_inputs = create_longformer_inputs(onnx_model, args.batch_size, args.sequence_length, args.global_length, args.samples) else: # default all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples) profile_file = run_profile(args.model, args.use_gpu, args.basic_optimization, args.thread_num, all_inputs) profile_records = load_profile_json(profile_file) lines = parse_profile_results(profile_records, args.kernel_time_only, args.threshold) lines.append("\nGrouped by operator type:") lines.append("-" * 64) lines += group_profile_results(profile_records, args.kernel_time_only, args.use_gpu) return lines
def test_pytorch_model_0_cpu_onnxruntime(self): input = _get_test_model_path('bert_pytorch_0') output = 'temp.onnx' optimize_by_onnxruntime(input, use_gpu=False, optimized_model_path=output) model = ModelProto() with open(output, "rb") as f: model.ParseFromString(f.read()) os.remove(output) bert_model = OnnxModel(model) expected_node_count = { 'EmbedLayerNormalization': 1, 'Attention': 12, 'SkipLayerNormalization': 24, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 12 } self.verify_node_count(bert_model, expected_node_count, 'test_pytorch_model_0_cpu_onnxruntime')
def get_bert_inputs(onnx_file, input_ids_name=None, segment_ids_name=None, input_mask_name=None): """Find graph inputs for BERT model. First, we will deduce from EmbedLayerNormalization node. If not found, we will guess based on naming. Args: onnx_file (str): onnx model path input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None. """ model = ModelProto() with open(onnx_file, "rb") as f: model.ParseFromString(f.read()) onnx_model = OnnxModel(model) return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
def remove_useless_reshape_nodes(model: OnnxModel): """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape """ shape_infer = model.infer_runtime_shape(update=True) if shape_infer is None: return nodes_to_remove = [] for node in model.nodes(): if node.op_type == 'Reshape': input_shape = shape_infer.get_edge_shape(node.input[0]) output_shape = shape_infer.get_edge_shape(node.output[0]) if input_shape and output_shape and input_shape == output_shape: logger.info( f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}" ) nodes_to_remove.append(node) if nodes_to_remove: for node in nodes_to_remove: model.replace_input_of_all_nodes(node.output[0], node.input[0]) model.remove_node(node) model.prune_graph()
def export_onnx( encoder: T5Encoder, device: torch.device, onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, use_int32_inputs: bool = False, ): """Export encoder to ONNX Args: encoder (T5Encoder): encoder object device (torch.device): device of encoder object onnx_model_path (str): onnx path verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. """ config = encoder.config encoder_inputs = T5EncoderInputs.create_dummy( batch_size=2, sequence_length=4, vocab_size=config.vocab_size, device=device, use_int32_inputs=use_int32_inputs, ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( encoder, args=tuple(encoder_inputs.to_list()), f=temp_onnx_model_path if use_external_data_format else onnx_model_path, export_params=True, input_names=["input_ids", "attention_mask"], output_names=["hidden_states"], dynamic_axes={ "input_ids": { 0: "batch_size", 1: "sequence_length" }, "attention_mask": { 0: "batch_size", 1: "sequence_length" }, "hidden_states": { 0: "batch_size", 1: "sequence_length" }, }, opset_version=12, do_constant_folding=True, use_external_data_format=use_external_data_format, verbose=verbose, ) if use_external_data_format: model = onnx.load_model(temp_onnx_model_path, load_external_data=True) OnnxModel.save( model, onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True, )
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) nodes_to_remove = onnx_model.nodes() node_to_add = onnx.helper.make_node( "LayerNormalization", ["input", "layer_norm.weight", "layer_norm.bias"], ["output"], "layer_norm", epsilon=epsilon) onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_node(node_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)
def get_bert_inputs(onnx_file, input_ids_name=None, segment_ids_name=None, input_mask_name=None): """ Get graph inputs for bert model. First, we will deduce from EmbedLayerNormalization node. If not found, we will guess based on naming. """ model = ModelProto() with open(onnx_file, "rb") as f: model.ParseFromString(f.read()) onnx_model = OnnxModel(model) graph_inputs = onnx_model.get_graph_inputs_excluding_initializers() if input_ids_name is not None: input_ids = onnx_model.find_graph_input(input_ids_name) if input_ids is None: raise ValueError( f"Graph does not have input named {input_ids_name}") segment_ids = None if segment_ids_name: segment_ids = onnx_model.find_graph_input(segment_ids_name) if segment_ids is None: raise ValueError( f"Graph does not have input named {segment_ids_name}") input_mask = None if input_mask_name: input_mask = onnx_model.find_graph_input(input_mask_name) if input_mask is None: raise ValueError( f"Graph does not have input named {input_mask_name}") expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0) if len(graph_inputs) != expected_inputs: raise ValueError( f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}" ) return input_ids, segment_ids, input_mask if len(graph_inputs) != 3: raise ValueError("Expect the graph to have 3 inputs. Got {}".format( len(graph_inputs))) embed_nodes = onnx_model.get_nodes_by_op_type('EmbedLayerNormalization') if len(embed_nodes) == 1: embed_node = embed_nodes[0] input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0) segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1) input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7) return input_ids, segment_ids, input_mask # Try guess the inputs based on naming. input_ids = None segment_ids = None input_mask = None for input in graph_inputs: input_name_lower = input.name.lower() if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask" input_mask = input elif "token" in input_name_lower or "segment" in input_name_lower: # matches input with name like "segment_ids" or "token_type_ids" segment_ids = input else: input_ids = input if input_ids and segment_ids and input_mask: return input_ids, segment_ids, input_mask raise ValueError( "Fail to assign 3 inputs. You might try rename the graph inputs.")
def export_onnx( model, device, onnx_model_path: str, verbose: bool = False, use_external_data_format: bool = False, has_position_ids: bool = True, has_attention_mask: bool = True, input_ids_dtype: torch.dtype = torch.int32, position_ids_dtype: torch.dtype = torch.int32, attention_mask_dtype: torch.dtype = torch.int32, ): """Export GPT-2 model with past state to ONNX model.""" config: GPT2Config = model.config num_layer = config.n_layer dummy_inputs = Gpt2Helper.get_dummy_inputs( batch_size=1, past_sequence_length=1, sequence_length=1, num_attention_heads=config.num_attention_heads, hidden_size=config.hidden_size, num_layer=num_layer, vocab_size=config.vocab_size, device=device, float16=False, has_position_ids=has_position_ids, has_attention_mask=has_attention_mask, input_ids_dtype=input_ids_dtype, position_ids_dtype=position_ids_dtype, attention_mask_dtype=attention_mask_dtype, ) input_list = dummy_inputs.to_list() with torch.no_grad(): outputs = model(*input_list) past_names = [f"past_{i}" for i in range(num_layer)] present_names = [f"present_{i}" for i in range(num_layer)] # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores) assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[ 2] == config.hidden_size output_names = [ "logits" if outputs[0].shape[2] == config.vocab_size else "last_state" ] + present_names # Shape of input tensors: # input_ids: (batch_size, seq_len) # past_{i}: (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads) # attention_mask: (batch_size, past_seq_len + seq_len) # Shape of output tensors: # last_state: (batch_size, seq_len, hidden_size) # or logits: (batch_size, seq_len, vocab_size) # present_{i}: (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads) dynamic_axes = { "input_ids": { 0: "batch_size", 1: "seq_len" }, output_names[0]: { 0: "batch_size", 1: "seq_len" }, } for name in past_names: dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"} for name in present_names: dynamic_axes[name] = {1: "batch_size", 3: "total_seq_len"} input_names = ["input_ids"] if has_position_ids: dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} input_names.append("position_ids") if has_attention_mask: dynamic_axes["attention_mask"] = { 0: "batch_size", 1: "total_seq_len" } input_names.append("attention_mask") input_names.extend(past_names) assert len(outputs) == 2 and len(outputs[1]) == num_layer logger.info( f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}" ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) if use_external_data_format: # We let PyTorch export onnx to a temp directory first, then convert external data to one file. with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "gpt2.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( model, args=tuple(input_list), f=temp_onnx_model_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11, do_constant_folding=True, use_external_data_format=True, verbose=verbose, ) model = onnx.load_model(temp_onnx_model_path, load_external_data=True) OnnxModel.save( model, onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True, ) else: torch_onnx_export( model, args=tuple(input_list), f=onnx_model_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11, do_constant_folding=True, use_external_data_format=False, verbose=verbose, )
def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None): """ Get graph inputs for longformer model. """ model = ModelProto() with open(onnx_file, "rb") as f: model.ParseFromString(f.read()) onnx_model = OnnxModel(model) graph_inputs = onnx_model.get_graph_inputs_excluding_initializers() if input_ids_name is not None: input_ids = onnx_model.find_graph_input(input_ids_name) if input_ids is None: raise ValueError( f"Graph does not have input named {input_ids_name}") input_mask = None if input_mask_name: input_mask = onnx_model.find_graph_input(input_mask_name) if input_mask is None: raise ValueError( f"Graph does not have input named {input_mask_name}") global_mask = None if global_mask_name: global_mask = onnx_model.find_graph_input(global_mask_name) if global_mask is None: raise ValueError( f"Graph does not have input named {global_mask_name}") expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0) if len(graph_inputs) != expected_inputs: raise ValueError( f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}" ) return input_ids, input_mask, global_mask if len(graph_inputs) != 3: raise ValueError("Expect the graph to have 3 inputs. Got {}".format( len(graph_inputs))) # Try guess the inputs based on naming. input_ids = None input_mask = None global_mask = None for input in graph_inputs: input_name_lower = input.name.lower() if "global" in input_name_lower: global_mask = input elif "mask" in input_name_lower: input_mask = input else: input_ids = input if input_ids and input_mask and global_mask: return input_ids, input_mask, global_mask raise ValueError( "Fail to assign 3 inputs. You might try rename the graph inputs.")
def change_graph_input_type( self, graph: GraphProto, graph_input: ValueInfoProto, new_type: int = TensorProto.INT32, ): """Change graph input type, and add Cast node if needed. Args: graph (GraphProto): graph graph_input (TensorProto): input of the graph new_type (int, optional): new data type. Defaults to TensorProto.INT32. Returns: NodeProto: a new Cast node that added. None if Cast node is not added. List[NodeProto]: Cast nodes that have been removed. """ assert isinstance(graph, GraphProto) assert isinstance(graph_input, ValueInfoProto) assert self.find_graph_input(graph_input.name) if graph_input.type.tensor_type.elem_type == int(new_type): return None, [] new_cast_node = None nodes_to_remove = [] input_name_to_nodes = self.input_name_to_nodes() if graph_input.name in input_name_to_nodes: nodes = input_name_to_nodes[graph_input.name] # For children that is not Cast node, insert a Cast node to convert int32 to original data type. nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] if nodes_not_cast: node_name = self.create_node_name("Cast") output_name = node_name + "_" + graph_input.name new_value_info = graph.value_info.add() new_value_info.CopyFrom(graph_input) new_value_info.name = output_name new_cast_node = helper.make_node( "Cast", [graph_input.name], [output_name], to=int(graph_input.type.tensor_type.elem_type), name=node_name, ) graph.node.extend([new_cast_node]) for node in nodes_not_cast: OnnxModel.replace_node_input(node, graph_input.name, output_name) # For children that is Cast node, no need to insert Cast. # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. nodes_cast = [node for node in nodes if node.op_type == "Cast"] for node in nodes_cast: if OnnxModel.get_node_attribute(node, "to") == int(new_type): self.replace_input_of_all_nodes(node.output[0], graph_input.name) if not self.find_graph_output(node.output[0]): nodes_to_remove.append(node) if nodes_to_remove: self.remove_nodes(nodes_to_remove) graph_input.type.tensor_type.elem_type = int(new_type) return new_cast_node, nodes_to_remove
def auto_mixed_precision(onnx_model: OnnxModel, op_block_list: List[str] = [ 'Add', 'LayerNormalization', 'FastGelu' ]): """Convert GPT-2 model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: onnx_model (OnnxModel): optimized ONNX model op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu'] Returns: parameters(dict): a dictionary of parameters used in float16 conversion """ op_full_set = set([node.op_type for node in onnx_model.nodes()]) fp32_op_set = set(op_block_list) fp16_op_set = op_full_set.difference(fp32_op_set) logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}") # logits is the first output logits_output_name = onnx_model.graph().output[0].name # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training. is_weight_fp16_precision = False output_name_to_node = onnx_model.output_name_to_node() assert logits_output_name in output_name_to_node node = output_name_to_node[logits_output_name] last_matmul_node = None if node.op_type == "MatMul": last_matmul_node = node logger.info(f"Found last MatMul node for logits: {node.name}") initializer = None for input in node.input: initializer = onnx_model.get_initializer(input) if initializer is not None: break # when the max difference of value after converting float to float16 is lower than a threshold (1e-6), # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug( f"max diff of converting weights in last MatMul node {node.name}: {max_diff}" ) is_weight_fp16_precision = (max_diff < 1E-6) else: logger.warning( f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}" ) if is_weight_fp16_precision: keep_io_types = [] node_block_list = [] else: # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision. keep_io_types = [logits_output_name] node_block_list = [last_matmul_node.name] parameters = { "keep_io_types": keep_io_types, "op_block_list": op_block_list, "node_block_list": node_block_list, "force_fp16_initializers": is_weight_fp16_precision } logger.info(f"auto_mixed_precision parameters: {parameters}") onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) fusion_utils = FusionUtils(onnx_model) fusion_utils.remove_cascaded_cast_nodes() fusion_utils.remove_useless_cast_nodes() return parameters
def export_onnx( decoder: Union[T5Decoder, T5DecoderInit], device: torch.device, onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, use_int32_inputs: bool = False, ): """Export decoder to ONNX Args: decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object device (torch.device): device of decoder object onnx_model_path (str): onnx path verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. use_int32_inputs (bool, optional): use int32 inputs """ assert isinstance(decoder, (T5Decoder, T5DecoderInit)) inputs = T5DecoderInputs.create_dummy( decoder.config, batch_size=2, encode_sequence_length=3, past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0, device=device, use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False) present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True) present_self_names = present_names[: 2 * decoder.config.num_layers] input_past_names = past_names if isinstance(decoder, T5Decoder) else [] output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names output_names = ["logits"] + output_present_names # Shape of input tensors (sequence_length==1): # input_ids: (batch_size, sequence_length) # encoder_attention_mask: (batch_size, encode_sequence_length) # encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size) # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size) # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size) # Shape of output tensors: # logits: (batch_size, sequence_length, vocab_size) # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size) # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size) input_names = ["input_ids"] input_names.append("encoder_attention_mask") input_names.append("encoder_hidden_states") input_names.extend(input_past_names) dynamic_axes = { "input_ids": { 0: "batch_size", # 1: 'sequence_length' }, "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"}, "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"}, "logits": { 0: "batch_size", # 1: 'sequence_length' }, } for name in input_past_names: dynamic_axes[name] = { 0: "batch_size", 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length", } for name in output_present_names: if "cross" in name: dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"} else: # self attention past state if isinstance(decoder, T5Decoder): dynamic_axes[name] = { 0: "batch_size", 2: "past_decode_sequence_length + 1", } else: dynamic_axes[name] = { 0: "batch_size", # 2: 'sequence_length' } Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( decoder, args=tuple(input_list), f=temp_onnx_model_path if use_external_data_format else onnx_model_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12, do_constant_folding=True, use_external_data_format=use_external_data_format, verbose=verbose, ) if use_external_data_format: model = onnx.load_model(temp_onnx_model_path, load_external_data=True) OnnxModel.save( model, onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True, )
def find_bert_inputs( onnx_model: OnnxModel, input_ids_name: Optional[str] = None, segment_ids_name: Optional[str] = None, input_mask_name: Optional[str] = None, ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """Find graph inputs for BERT model. First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. Args: onnx_model (OnnxModel): onnx model object input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None. segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None. input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None. Raises: ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name and input_mask_name Returns: Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids, segment_ids and input_mask """ graph_inputs = onnx_model.get_graph_inputs_excluding_initializers() if input_ids_name is not None: input_ids = onnx_model.find_graph_input(input_ids_name) if input_ids is None: raise ValueError( f"Graph does not have input named {input_ids_name}") segment_ids = None if segment_ids_name: segment_ids = onnx_model.find_graph_input(segment_ids_name) if segment_ids is None: raise ValueError( f"Graph does not have input named {segment_ids_name}") input_mask = None if input_mask_name: input_mask = onnx_model.find_graph_input(input_mask_name) if input_mask is None: raise ValueError( f"Graph does not have input named {input_mask_name}") expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0) if len(graph_inputs) != expected_inputs: raise ValueError( f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}" ) return input_ids, segment_ids, input_mask if len(graph_inputs) != 3: raise ValueError("Expect the graph to have 3 inputs. Got {}".format( len(graph_inputs))) embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization") if len(embed_nodes) == 1: embed_node = embed_nodes[0] input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0) segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1) input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7) if input_mask is None: for input in graph_inputs: input_name_lower = input.name.lower() if "mask" in input_name_lower: input_mask = input if input_mask is None: raise ValueError(f"Failed to find attention mask input") return input_ids, segment_ids, input_mask # Try guess the inputs based on naming. input_ids = None segment_ids = None input_mask = None for input in graph_inputs: input_name_lower = input.name.lower() if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask" input_mask = input elif ( "token" in input_name_lower or "segment" in input_name_lower ): # matches input with name like "segment_ids" or "token_type_ids" segment_ids = input else: input_ids = input if input_ids and segment_ids and input_mask: return input_ids, segment_ids, input_mask raise ValueError( "Fail to assign 3 inputs. You might try rename the graph inputs.")
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) weight_name = get_weight(onnx_model) bias_name = get_bias(onnx_model) nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name] nodes_to_remove = onnx_model.nodes() node_to_add = onnx.helper.make_node("LayerNormalization", ["input", weight_name, bias_name], ["output"], "layer_norm", epsilon=epsilon) onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_node(node_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)
def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) weight_name = get_weight(onnx_model) bias_name = get_bias(onnx_model) nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name] nodes_to_add = [ onnx.helper.make_node("Cast", ["input"], ["fp32_input"], "cast_input", to=1), onnx.helper.make_node("Cast", [weight_name], ["fp32_layer_norm.weight"], "cast_weight", to=1), onnx.helper.make_node("Cast", [bias_name], ["fp32_layer_norm.bias"], "cast_bias", to=1), onnx.helper.make_node("LayerNormalization", ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"], ["fp32_output"], "layer_norm", epsilon=epsilon), # use fp32 epsilon onnx.helper.make_node("Cast", ["fp32_output"], ["output"], "cast_output", to=10) ] onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_nodes(nodes_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)