def __init__(self, model: OnnxModel): self.model = model # A lookup table with mask input as key, and mask index output as value self.mask_indice = {} # A lookup table with mask input as key, and cast (to int32) output as value self.mask_casted = {} self.utils = FusionUtils(model)
def change_input_to_int32(self): original_opset_version = self.model.opset_import[0].version graph = self.graph() new_graph_inputs = [] casted_bert_graph_inputs = self.get_graph_inputs_from_embed_nodes(casted=True) utils = FusionUtils(self) for input in graph.input: if input.name in casted_bert_graph_inputs: utils.remove_cast_int32(input.name) int32_input = helper.make_tensor_value_info(input.name, TensorProto.INT32, self.tensor_shape_to_list(input.type.tensor_type)) new_graph_inputs.append(int32_input) else: new_graph_inputs.append(input) graph_def = helper.make_graph(graph.node, 'int32 inputs', new_graph_inputs, graph.output, initializer=graph.initializer, value_info=graph.value_info) self.model = helper.make_model(graph_def, producer_name='onnxruntime-tools') # restore opset version self.model.opset_import[0].version = original_opset_version
def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") # TODO: detect num_heads from graph like FusionAttention self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32
class AttentionMask(): """ Fuse Attention subgraph into one Attention node. """ def __init__(self, model: OnnxModel): self.model = model # A lookup table with mask input as key, and mask index output as value self.mask_indice = {} # A lookup table with mask input as key, and cast (to int32) output as value self.mask_casted = {} self.utils = FusionUtils(model) self.mask_format = AttentionMaskFormat.MaskIndexEnd def set_mask_format(self, mask_format: AttentionMaskFormat): self.mask_format = mask_format def set_mask_indice(self, mask, mask_index): if mask in self.mask_indice: assert mask_index == self.mask_indice[mask] self.mask_indice[mask] = mask_index def get_first_mask(self): assert len(self.mask_indice) > 0 return next(iter(self.mask_indice)) def process_mask(self, input): if input in self.mask_indice: return self.mask_indice[input] # Add cast to convert int64 to int32 if self.model.find_graph_input(input): casted, input_name = self.utils.cast_graph_input_to_int32(input) else: input_name, cast_node = self.utils.cast_input_to_int32(input) casted = True if casted: self.mask_casted[input] = input_name # Attention supports int32 attention mask (2D) since 1.4.0 if self.mask_format == AttentionMaskFormat.AttentionMask: self.mask_indice[input] = input_name return input_name # Add a mask processing node to convert attention mask to mask index (1D) output_name = self.model.create_node_name('mask_index') mask_index_node = helper.make_node('ReduceSum', inputs=[input_name], outputs=[output_name], name=self.model.create_node_name( 'ReduceSum', 'MaskReduceSum')) mask_index_node.attribute.extend([ helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0) ]) self.model.add_node(mask_index_node) self.mask_indice[input] = output_name return output_name
def fuse(self, concat_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto]): """ Smplify subgraph like (2d_input) / \ Shape shape / \ Gather(indices=0) Gather(indices=1) | | Unsqueeze(axes=0) Unsqueeze(axes=0) \ / Concat | into (2d_input) --> Shape --> """ opset_version = self.model.get_opset_version() inputs = len(concat_node.input) root = None shape_output = None for i in range(inputs): path = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [i, 0, 0], output_name_to_node) if path is None: return unsqueeze, gather, shape = path if i == 0: shape_output = shape.output[0] if root is None: root = shape.input[0] if self.get_dimensions(root) != inputs: return elif shape.input[0] != root: return if not FusionUtils.check_node_attribute(unsqueeze, 'axis', 0, default_value=0): return if opset_version < 13: if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]): return else: if not self.utils.check_node_input_value(unsqueeze, 1, [0]): return value = self.model.get_constant_value(gather.input[1]) from numpy import ndarray, array_equal if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i): return if self.model.find_graph_output(concat_node.output[0]) is None: self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output) self.fused_count += 1 self.prune_graph = True
def __init__(self, model: OnnxModel, description: str = 'no mask'): super().__init__(model, "EmbedLayerNormalization", ["LayerNormalization", "SkipLayerNormalization"], description) self.utils = FusionUtils(model) self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None
def __init__(self, model, num_heads, hidden_size): assert num_heads > 0 assert hidden_size % num_heads == 0 super().__init__(model) self.num_heads = num_heads self.hidden_size = hidden_size self.attention_mask = AttentionMask(self) self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.utils = FusionUtils(self)
def optimize(self, options: FusionOptions = None, add_dynamic_axes=False): if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() self.preprocess() self.fuse_reshape() if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() if (options is None) or options.enable_attention: if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) self.fuse_attention() if (options is None) or options.enable_embed_layer_norm: self.fuse_embed_layer() # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. FusionUtils.remove_useless_reshape_nodes(self) self.postprocess() # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization if (options is None) or options.enable_bias_gelu: # Fuse Gelu and Add Bias before it. self.fuse_bias_gelu(is_fastgelu=True) self.fuse_bias_gelu(is_fastgelu=False) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() if (options is not None and options.enable_gelu_approximation): self.gelu_approximation() self.remove_unused_constant() # Use symbolic batch dimension in input and output. if add_dynamic_axes: self.use_dynamic_axes() logger.info(f"opset verion: {self.model.opset_import[0].version}")
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): """Initialize BERT ONNX Model. Args: model (ModelProto): the ONNX model num_heads (int, optional): number of attentioin heads. Defaults to 0, and we will detect the parameter automatically. hidden_size (int, optional): hidden dimension. Defaults to 0, and we will detect the parameter automatically. """ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model) self.num_heads = num_heads self.hidden_size = hidden_size self.attention_mask = AttentionMask(self) self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.utils = FusionUtils(self)
def adjust_reshape_and_expand(self): # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. FusionUtils.remove_useless_reshape_nodes(self) nodes_to_remove = [] for node in self.nodes(): if node.op_type == 'Reshape': # Clean up unneccessary reshape nodes. # Find reshape nodes with no actually data in "shape" attribute and remove. reshape_shape = self.get_constant_value(node.input[1]) if reshape_shape is not None and reshape_shape.size == 0: nodes_to_remove.extend([node]) self.replace_input_of_all_nodes(node.output[0], node.input[0]) continue # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by # changing current reshape's input to output of slice. reshape_path = self.match_parent_path( node, ['Expand', 'Expand', 'Reshape', 'Slice'], [0, 0, 0, 0], self.output_name_to_node()) if reshape_path is not None: expand_node = reshape_path[-3] expand_shape_value = self.get_constant_value( expand_node.input[1]) reshape_before_expand = reshape_path[-2] shape_value = self.get_constant_value( reshape_before_expand.input[1]) slice_node = reshape_path[-1] if expand_shape_value is not None and shape_value is not None and len( expand_shape_value) is 2 and len( shape_value ) is 1 and expand_shape_value[1] == shape_value[0]: node.input[0] = slice_node.output[0] if nodes_to_remove: self.remove_nodes(nodes_to_remove) logger.info( f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
def change_input_to_int32(self): original_opset_version = self.model.opset_import[0].version graph = self.graph() batch_size, sequence_length = self.get_bert_input_shape() new_graph_inputs = [] bert_inputs = self.get_bert_inputs() utils = FusionUtils(self) for input in graph.input: if input.name in bert_inputs: utils.remove_cast_int32(input.name) input_shape = [ batch_size if isinstance(batch_size, int) else 1, sequence_length if isinstance(sequence_length, int) else 128 ] int32_input = helper.make_tensor_value_info( input.name, TensorProto.INT32, input_shape) new_graph_inputs.append(int32_input) else: new_graph_inputs.append(input) graph_def = helper.make_graph(graph.node, 'int32 inputs', new_graph_inputs, graph.output, initializer=graph.initializer, value_info=graph.value_info) self.model = helper.make_model(graph_def, producer_name='bert model optimizer') if isinstance(batch_size, str) or isinstance(sequence_length, str): self.use_dynamic_axes( batch_size if isinstance(batch_size, str) else None, sequence_length if isinstance(sequence_length, str) else None) # restore opset version self.model.opset_import[0].version = original_opset_version
class BertOnnxModel(OnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): """Initialize BERT ONNX Model. Args: model (ModelProto): the ONNX model num_heads (int, optional): number of attentioin heads. Defaults to 0, and we will detect the parameter automatically. hidden_size (int, optional): hidden dimension. Defaults to 0, and we will detect the parameter automatically. """ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model) self.num_heads = num_heads self.hidden_size = hidden_size self.attention_mask = AttentionMask(self) self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.utils = FusionUtils(self) def fuse_attention(self): self.attention_fusion.apply() def fuse_gelu(self): fusion = FusionGelu(self) fusion.apply() fusion = FusionFastGelu(self) fusion.apply() def fuse_bias_gelu(self, is_fastgelu): fusion = FusionBiasGelu(self, is_fastgelu) fusion.apply() def gelu_approximation(self): fusion = FusionGeluApproximation(self) fusion.apply() def fuse_add_bias_skip_layer_norm(self): fusion = FusionBiasSkipLayerNormalization(self) fusion.apply() def fuse_reshape(self): fusion = FusionReshape(self) fusion.apply() def fuse_embed_layer(self): fusion = FusionEmbedLayerNormalization(self) fusion.apply() def fuse_layer_norm(self): fusion = FusionLayerNormalization(self) fusion.apply() fusion = FusionLayerNormalizationTF(self) fusion.apply() def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool): """ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention). Returns a list of the graph input names based on the filter whether it is casted or not. """ graph_inputs = [] output_name_to_node = self.output_name_to_node() nodes = self.get_nodes_by_op_type(op_type) for node in nodes: bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)] for bert_input in bert_inputs: if self.find_graph_input(bert_input): if not casted: graph_inputs.append(bert_input) elif bert_input in output_name_to_node: parent = output_name_to_node[bert_input] if parent.op_type == 'Cast' and self.find_graph_input(parent.input[0]) is not None: if casted: graph_inputs.append(parent.input[0]) return graph_inputs def get_graph_inputs_from_fused_nodes(self, casted: bool): inputs = self.get_graph_inputs_from_node_type('EmbedLayerNormalization', [0, 1, 7], casted) inputs += self.get_graph_inputs_from_node_type('Attention', [3], casted) return inputs def change_input_to_int32(self): original_opset_version = self.model.opset_import[0].version graph = self.graph() new_graph_inputs = [] casted_bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(casted=True) for input in graph.input: if input.name in casted_bert_graph_inputs: self.utils.remove_cast_int32(input.name) int32_input = helper.make_tensor_value_info(input.name, TensorProto.INT32, self.tensor_shape_to_list(input.type.tensor_type)) new_graph_inputs.append(int32_input) else: new_graph_inputs.append(input) graph_def = helper.make_graph(graph.node, 'int32 inputs', new_graph_inputs, graph.output, initializer=graph.initializer, value_info=graph.value_info) self.model = helper.make_model(graph_def, producer_name='onnxruntime-tools') # restore opset version self.model.opset_import[0].version = original_opset_version def use_dynamic_axes(self, dynamic_batch_dim='batch_size', dynamic_seq_len='max_seq_len'): """ Update input and output shape to use dynamic axes. """ bert_graph_inputs = self.get_graph_inputs_from_fused_nodes( casted=True) + self.get_graph_inputs_from_fused_nodes(casted=False) dynamic_batch_inputs = {} for input in self.model.graph.input: if input.name in bert_graph_inputs: dim_proto = input.type.tensor_type.shape.dim[0] dim_proto.dim_param = dynamic_batch_dim if dynamic_seq_len is not None: dim_proto = input.type.tensor_type.shape.dim[1] dim_proto.dim_param = dynamic_seq_len for output in self.model.graph.output: dim_proto = output.type.tensor_type.shape.dim[0] dim_proto.dim_param = dynamic_batch_dim def preprocess(self): self.adjust_reshape_and_expand() return def adjust_reshape_and_expand(self): nodes_to_remove = [] for node in self.nodes(): if node.op_type == 'Reshape': # Clean up unneccessary reshape nodes. # Find reshape nodes with no actually data in "shape" attribute and remove. reshape_shape = self.get_constant_value(node.input[1]) if reshape_shape is not None and reshape_shape.size == 0: nodes_to_remove.extend([node]) self.replace_input_of_all_nodes(node.output[0], node.input[0]) continue # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by # changing current reshape's input to output of slice. reshape_path = self.match_parent_path(node, ['Expand', 'Expand', 'Reshape', 'Slice'], [0, 0, 0, 0], self.output_name_to_node()) if reshape_path is not None: expand_node = reshape_path[-3] expand_shape_value = self.get_constant_value(expand_node.input[1]) reshape_before_expand = reshape_path[-2] shape_value = self.get_constant_value(reshape_before_expand.input[1]) slice_node = reshape_path[-1] if expand_shape_value is not None and shape_value is not None and len( expand_shape_value) is 2 and len( shape_value) is 1 and expand_shape_value[1] == shape_value[0]: node.input[0] = slice_node.output[0] self.remove_nodes(nodes_to_remove) logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}") def clean_graph(self): output_name_to_node = self.output_name_to_node() nodes_to_add = [] nodes_to_remove = [] for node in self.nodes(): # Before: # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+ # | | # | v # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum # After: # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value) op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3} if node.op_type in op_input_id: i = op_input_id[node.op_type] parent_nodes = self.match_parent_path( node, ['Cast', 'ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [i, 0, 0, 0, 0, 0], output_name_to_node) if parent_nodes is not None: cast, constantOfShape, concat, unsqueeze, gather, shape = parent_nodes if shape.input[0] == self.graph().input[0].name: constantOfShape.input[0] = shape.output[0] output_name_to_node = self.output_name_to_node() if node.op_type == 'Attention': # Before: # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention # After: # remove this path, and remove the optional mask_index input of Attention node. parent_nodes = self.match_parent_path(node, ['ReduceSum', 'Cast', 'ConstantOfShape', 'Shape'], [3, 0, 0, 0], output_name_to_node) if parent_nodes is not None: if parent_nodes[-1].input[0] == self.graph().input[0].name: attention_node = helper.make_node('Attention', inputs=node.input[0:len(node.input) - 1], outputs=node.output, name=node.name + "_remove_mask") attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) nodes_to_add.append(attention_node) nodes_to_remove.append(node) self.remove_nodes(nodes_to_remove) self.add_nodes(nodes_to_add) def postprocess(self): self.clean_graph() self.prune_graph() def optimize(self, options: BertOptimizationOptions = None, add_dynamic_axes=False): if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() self.preprocess() self.fuse_reshape() if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() if (options is None) or options.enable_attention: if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) self.fuse_attention() if (options is None) or options.enable_embed_layer_norm: self.fuse_embed_layer() # Post-processing like removing extra reshape nodes. self.postprocess() # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization if (options is None) or options.enable_bias_gelu: # Fuse Gelu and Add Bias before it. self.fuse_bias_gelu(is_fastgelu=True) self.fuse_bias_gelu(is_fastgelu=False) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() if (options is not None and options.enable_gelu_approximation): self.gelu_approximation() self.remove_unused_constant() # Use symbolic batch dimension in input and output. if add_dynamic_axes: self.use_dynamic_axes() logger.info(f"opset verion: {self.model.opset_import[0].version}") def get_fused_operator_statistics(self): """ Returns node count of fused operators. """ op_count = {} ops = [ 'EmbedLayerNormalization', 'Attention', 'Gelu', 'FastGelu', 'BiasGelu', 'LayerNormalization', 'SkipLayerNormalization' ] for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) logger.info(f"Optimized operators:{op_count}") return op_count def is_fully_optimized(self): """ Returns True when the model is fully optimized. """ op_count = self.get_fused_operator_statistics() embed = op_count['EmbedLayerNormalization'] attention = op_count['Attention'] gelu = op_count['Gelu'] + op_count['BiasGelu'] + op_count['FastGelu'] layer_norm = op_count['LayerNormalization'] + op_count['SkipLayerNormalization'] is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) if layer_norm == 0: logger.debug("Layer Normalization not fused") if gelu == 0: logger.debug("Gelu/FastGelu not fused") if embed == 0: logger.debug("Embed Layer not fused") if attention == 0: logger.warning("Attention not fused") return is_perfect
def __init__(self, model: OnnxModel, description='no mask'): super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description) self.utils = FusionUtils(model)
class FusionGptAttentionPastBase(Fusion): """Base class for GPT Attention Fusion with past state """ def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32 def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node): # Pattern 1: # {past} # / \ # / \ # Gather(axes=0, indices=0) Gather(indices=1) # | | # Transpose (perm=0,1,3,2) | # | | # Concat_k Concat_v # | / # Transpose (perm=0,1,3,2) / # | / # Unsqueeze Unsqueeze # \ / # \ / # Concat # | # {present} gather = self.model.get_parent(concat_v, 0, output_name_to_node) if gather.op_type != 'Gather': logger.debug("match_past_pattern_1: expect Gather for past") return None if not self.model.find_constant_input(gather, 1) == 1: logger.debug( "match_past_pattern_1: expect indices=1 for Gather of past") return None past = gather.input[0] parent = self.model.get_parent(concat_k, 0, output_name_to_node) if parent.op_type == 'Gather': gather_past_k = parent else: past_k_nodes = self.model.match_parent_path( concat_k, ['Transpose', 'Gather'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_1: failed match Transpose and Gather") return None gather_past_k = past_k_nodes[-1] if not self.model.find_constant_input(gather_past_k, 0) == 1: logger.debug( "match_past_pattern_1: expect indices=0 for Gather k of past") return None past_k = gather_past_k.input[0] if past != past_k: logger.debug("match_past_pattern_1: expect past to be same") return None return past def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): # Pattern 2: # Split (QKV) # / | | # / | +----------------------+ # | | # | {past} | # | | | # Reshape Split Reshape # | / \ | # Transpose_k Squeeze Squeeze Transpose_v # | | \ / # +------|---+ \ / # | | \ / # Concat_k Concat_v # | | # Unsqueeze Unsqueeze # \ / # Concat # | # {present} # squeeze = self.model.get_parent(concat_v, 0, output_name_to_node) if squeeze.op_type != 'Squeeze': logger.debug( "match_past_pattern_2: expect Squeeze as parent of concat_v") return None split = self.model.get_parent(squeeze, 0, output_name_to_node) if split.op_type != "Split": logger.debug("match_past_pattern_2: expect Split for past path") return None opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not FusionUtils.check_node_attribute(split, 'split', [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None else: if not self.utils.check_node_input_value(squeeze, 1, [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not self.utils.check_node_input_value(split, 1, [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None if not FusionUtils.check_node_attribute( split, 'axis', 0, default_value=0): logger.debug( "match_past_pattern_2: attribute axis of Split are not expected in past path" ) return None past = split.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Squeeze', 'Split'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_2: failed to match past_k_nodes path") return None past_k = past_k_nodes[-1].input[0] if past != past_k: logger.info("match_past_pattern_2: expect past to be same") return None return past def match_present(self, concat_v, input_name_to_nodes): unsqueeze_present_v = self.model.find_first_child_by_type( concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return None concat_present = self.model.find_first_child_by_type( unsqueeze_present_v, 'Concat', input_name_to_nodes, recursive=False) if not concat_present: logger.info("expect concat for present") return None present = concat_present.output[0] return present def cast_attention_mask(self, input_name): if input_name in self.casted_attention_mask: attention_mask_input_name = self.casted_attention_mask[input_name] elif self.model.find_graph_input(input_name): casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32( input_name) self.casted_attention_mask[input_name] = attention_mask_input_name else: attention_mask_input_name, cast_node = self.utils.cast_input_to_int32( input_name) self.casted_attention_mask[input_name] = attention_mask_input_name return attention_mask_input_name
def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): # Pattern 2: # Split (QKV) # / | | # / | +----------------------+ # | | # | {past} | # | | | # Reshape Split Reshape # | / \ | # Transpose_k Squeeze Squeeze Transpose_v # | | \ / # +------|---+ \ / # | | \ / # Concat_k Concat_v # | | # Unsqueeze Unsqueeze # \ / # Concat # | # {present} # squeeze = self.model.get_parent(concat_v, 0, output_name_to_node) if squeeze.op_type != 'Squeeze': logger.debug( "match_past_pattern_2: expect Squeeze as parent of concat_v") return None split = self.model.get_parent(squeeze, 0, output_name_to_node) if split.op_type != "Split": logger.debug("match_past_pattern_2: expect Split for past path") return None opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not FusionUtils.check_node_attribute(split, 'split', [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None else: if not self.utils.check_node_input_value(squeeze, 1, [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not self.utils.check_node_input_value(split, 1, [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None if not FusionUtils.check_node_attribute( split, 'axis', 0, default_value=0): logger.debug( "match_past_pattern_2: attribute axis of Split are not expected in past path" ) return None past = split.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Squeeze', 'Split'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_2: failed to match past_k_nodes path") return None past_k = past_k_nodes[-1].input[0] if past != past_k: logger.info("match_past_pattern_2: expect past to be same") return None return past
class BertOnnxModel(OnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): """Initialize BERT ONNX Model. Args: model (ModelProto): the ONNX model num_heads (int, optional): number of attentioin heads. Defaults to 0, and we will detect the parameter automatically. hidden_size (int, optional): hidden dimension. Defaults to 0, and we will detect the parameter automatically. """ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model) self.num_heads = num_heads self.hidden_size = hidden_size self.attention_mask = AttentionMask(self) self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.utils = FusionUtils(self) def fuse_attention(self): self.attention_fusion.apply() def fuse_gelu(self): fusion = FusionGelu(self) fusion.apply() fusion = FusionFastGelu(self) fusion.apply() def fuse_bias_gelu(self, is_fastgelu): fusion = FusionBiasGelu(self, is_fastgelu) fusion.apply() def gelu_approximation(self): fusion = FusionGeluApproximation(self) fusion.apply() def fuse_add_bias_skip_layer_norm(self): fusion = FusionBiasSkipLayerNormalization(self) fusion.apply() def fuse_reshape(self): fusion = FusionReshape(self) fusion.apply() def fuse_shape(self): fusion = FusionShape(self) fusion.apply() def fuse_embed_layer(self): fusion = FusionEmbedLayerNormalization(self) fusion.apply() def fuse_layer_norm(self): fusion = FusionLayerNormalization(self) fusion.apply() fusion = FusionLayerNormalizationTF(self) fusion.apply() def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool): """ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention). Returns a list of the graph input names based on the filter whether it is casted or not. """ graph_inputs = [] output_name_to_node = self.output_name_to_node() nodes = self.get_nodes_by_op_type(op_type) for node in nodes: bert_inputs = [ node.input[i] for i in input_indices if i < len(node.input) ] for bert_input in bert_inputs: if self.find_graph_input(bert_input): if not casted: graph_inputs.append(bert_input) elif bert_input in output_name_to_node: parent = output_name_to_node[bert_input] if parent.op_type == "Cast" and self.find_graph_input( parent.input[0]) is not None: if casted: graph_inputs.append(parent.input[0]) return graph_inputs def get_graph_inputs_from_fused_nodes(self, casted: bool): inputs = self.get_graph_inputs_from_node_type( "EmbedLayerNormalization", [0, 1, 7], casted) inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted) return inputs def change_graph_input_type( self, graph: GraphProto, graph_input: ValueInfoProto, new_type: int = TensorProto.INT32, ): """Change graph input type, and add Cast node if needed. Args: graph (GraphProto): graph graph_input (TensorProto): input of the graph new_type (int, optional): new data type. Defaults to TensorProto.INT32. Returns: NodeProto: a new Cast node that added. None if Cast node is not added. List[NodeProto]: Cast nodes that have been removed. """ assert isinstance(graph, GraphProto) assert isinstance(graph_input, ValueInfoProto) assert self.find_graph_input(graph_input.name) if graph_input.type.tensor_type.elem_type == int(new_type): return None, [] new_cast_node = None nodes_to_remove = [] input_name_to_nodes = self.input_name_to_nodes() if graph_input.name in input_name_to_nodes: nodes = input_name_to_nodes[graph_input.name] # For children that is not Cast node, insert a Cast node to convert int32 to original data type. nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] if nodes_not_cast: node_name = self.create_node_name("Cast") output_name = node_name + "_" + graph_input.name new_value_info = graph.value_info.add() new_value_info.CopyFrom(graph_input) new_value_info.name = output_name new_cast_node = helper.make_node( "Cast", [graph_input.name], [output_name], to=int(graph_input.type.tensor_type.elem_type), name=node_name, ) graph.node.extend([new_cast_node]) for node in nodes_not_cast: OnnxModel.replace_node_input(node, graph_input.name, output_name) # For children that is Cast node, no need to insert Cast. # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. nodes_cast = [node for node in nodes if node.op_type == "Cast"] for node in nodes_cast: if OnnxModel.get_node_attribute(node, "to") == int(new_type): self.replace_input_of_all_nodes(node.output[0], graph_input.name) if not self.find_graph_output(node.output[0]): nodes_to_remove.append(node) if nodes_to_remove: self.remove_nodes(nodes_to_remove) graph_input.type.tensor_type.elem_type = int(new_type) return new_cast_node, nodes_to_remove def change_graph_inputs_to_int32(self): """Change data type of all graph inputs to int32 type, and add Cast node if needed.""" graph = self.graph() add_cast_count = 0 remove_cast_count = 0 for graph_input in graph.input: new_node, removed_nodes = self.change_graph_input_type( graph, graph_input, TensorProto.INT32) if new_node: add_cast_count += 1 remove_cast_count += len(removed_nodes) logger.info( f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes." ) def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"): """ Update input and output shape to use dynamic axes. """ bert_graph_inputs = self.get_graph_inputs_from_fused_nodes( casted=True) + self.get_graph_inputs_from_fused_nodes(casted=False) dynamic_batch_inputs = {} for input in self.model.graph.input: if input.name in bert_graph_inputs: dim_proto = input.type.tensor_type.shape.dim[0] dim_proto.dim_param = dynamic_batch_dim if dynamic_seq_len is not None: dim_proto = input.type.tensor_type.shape.dim[1] dim_proto.dim_param = dynamic_seq_len for output in self.model.graph.output: dim_proto = output.type.tensor_type.shape.dim[0] dim_proto.dim_param = dynamic_batch_dim def preprocess(self): self.adjust_reshape_and_expand() return def adjust_reshape_and_expand(self): nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Reshape": # Clean up unneccessary reshape nodes. # Find reshape nodes with no actually data in "shape" attribute and remove. reshape_shape = self.get_constant_value(node.input[1]) if reshape_shape is not None and reshape_shape.size == 0: nodes_to_remove.extend([node]) self.replace_input_of_all_nodes(node.output[0], node.input[0]) continue # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by # changing current reshape's input to output of slice. reshape_path = self.match_parent_path( node, ["Expand", "Expand", "Reshape", "Slice"], [0, 0, 0, 0], self.output_name_to_node(), ) if reshape_path is not None: expand_node = reshape_path[-3] expand_shape_value = self.get_constant_value( expand_node.input[1]) reshape_before_expand = reshape_path[-2] shape_value = self.get_constant_value( reshape_before_expand.input[1]) slice_node = reshape_path[-1] if (expand_shape_value is not None and shape_value is not None and len(expand_shape_value) == 2 and len(shape_value) == 1 and expand_shape_value[1] == shape_value[0]): node.input[0] = slice_node.output[0] if nodes_to_remove: self.remove_nodes(nodes_to_remove) logger.info( f"Removed Reshape and Expand count: {len(nodes_to_remove)}") def clean_graph(self): output_name_to_node = self.output_name_to_node() nodes_to_remove = [] for node in self.nodes(): # Before: # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+ # | | # | v # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum # After: # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value) op_input_id = { "EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3 } if node.op_type in op_input_id: i = op_input_id[node.op_type] parent_nodes = self.match_parent_path( node, [ "Cast", "ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape", ], [i, 0, 0, 0, 0, 0], output_name_to_node, ) if parent_nodes is not None: ( cast, constantOfShape, concat, unsqueeze, gather, shape, ) = parent_nodes if shape.input[0] == self.graph().input[0].name: constantOfShape.input[0] = shape.output[0] output_name_to_node = self.output_name_to_node() if node.op_type == "Attention": # Before: # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention # After: # remove this path, and remove the optional mask_index input of Attention node. parent_nodes = self.match_parent_path( node, ["ReduceSum", "Cast", "ConstantOfShape", "Shape"], [3, 0, 0, 0], output_name_to_node, ) if parent_nodes is not None: if parent_nodes[-1].input[0] == self.graph().input[0].name: attention_node = helper.make_node( "Attention", inputs=node.input[0:len(node.input) - 1], outputs=node.output, name=node.name + "_remove_mask", ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([ helper.make_attribute("num_heads", self.num_heads) ]) self.add_node( attention_node, self.get_graph_by_node(attention_node).name) nodes_to_remove.append(node) self.remove_nodes(nodes_to_remove) def postprocess(self): self.clean_graph() self.prune_graph() def optimize(self, options: FusionOptions = None, add_dynamic_axes=False): # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() self.preprocess() self.fuse_reshape() if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() if (options is None) or options.enable_attention: if options is not None: self.attention_mask.set_mask_format( options.attention_mask_format) self.fuse_attention() self.fuse_shape() if (options is None) or options.enable_embed_layer_norm: self.fuse_embed_layer() # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() self.postprocess() # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization if (options is None) or options.enable_bias_gelu: # Fuse Gelu and Add Bias before it. self.fuse_bias_gelu(is_fastgelu=True) self.fuse_bias_gelu(is_fastgelu=False) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() if options is not None and options.enable_gelu_approximation: self.gelu_approximation() self.remove_unused_constant() # Use symbolic batch dimension in input and output. if add_dynamic_axes: self.use_dynamic_axes() logger.info(f"opset version: {self.get_opset_version()}") def get_fused_operator_statistics(self): """ Returns node count of fused operators. """ op_count = {} ops = [ "EmbedLayerNormalization", "Attention", "Gelu", "FastGelu", "BiasGelu", "LayerNormalization", "SkipLayerNormalization", ] for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) logger.info(f"Optimized operators:{op_count}") return op_count def is_fully_optimized(self): """ Returns True when the model is fully optimized. """ op_count = self.get_fused_operator_statistics() embed = op_count["EmbedLayerNormalization"] attention = op_count["Attention"] gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] layer_norm = op_count["LayerNormalization"] + op_count[ "SkipLayerNormalization"] is_perfect = (embed > 0) and (attention > 0) and ( attention == gelu) and (layer_norm >= 2 * attention) if layer_norm == 0: logger.debug("Layer Normalization not fused") if gelu == 0: logger.debug("Gelu/FastGelu not fused") if embed == 0: logger.debug("Embed Layer not fused") if attention == 0: logger.warning("Attention not fused") return is_perfect
def __init__(self, model: OnnxModel, name: str = "EmbedLayerNormalization(no mask)", search_op_types="SkipLayerNormalization"): super().__init__(model, name, search_op_types) self.utils = FusionUtils(model)
def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node): """ Match position embedding path from input_ids to Gather for BERT. BERT Embedding Layer Pattern: (input_ids) / \ / Shape / | / Gather (indices=1) / | / Add (optional, B=0) / | Gather (segment_ids) Unsqueeze (axes=0) \ | | \ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1) \ / | Add Gather \ / Add | LayerNormalization """ path = self.model.match_parent_path(position_embedding_gather, ['Slice', 'Unsqueeze'], [1, 2], output_name_to_node) if path is None: return False slice, unsqueeze = path slice_weight = self.model.get_constant_value(slice.input[0]) if not (slice_weight is not None and len(slice_weight.shape) == 2 and slice_weight.shape[0] == 1 \ and self.utils.check_node_input_value(slice, 1, [0]) \ and self.utils.check_node_input_value(slice, 3, [1]) \ and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))): return False opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]): return False else: if not self.utils.check_node_input_value(unsqueeze, 1, [0]): return False node = self.model.get_parent(unsqueeze, 0, output_name_to_node) if node is None: return False if node.op_type == "Add": if not self.utils.check_node_input_value(node, 1, 0): return False gather = self.model.get_parent(node, 0, output_name_to_node) else: gather = node if gather is None or gather.op_type != "Gather": return False if not (self.utils.check_node_input_value(gather, 1, 1)): return False shape = self.model.get_parent(gather, 0, output_name_to_node) if shape is None or shape.op_type != "Shape": return False return input_ids == shape.input[0]
class FusionEmbedLayerNoMask(Fusion): """ Embed Layer Normalization will fuse embeddings and mask processing into one node. The embeddings before conversion: (input_ids) --------> Gather ----------+ (segment_ids) | | | | v v +--> Shape --> Expand -> Gather---->Add Gather | ^ | | | | v v +---(optional graph) SkipLayerNormalization Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model. (input_ids) --> Gather -----+ Slice | | v v (segment_ids)--> Gather --->Add Reshape | | v v SkipLayerNormalization """ def __init__(self, model: OnnxModel, description='no mask'): super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description) self.utils = FusionUtils(model) self.attention = None def match_segment_path(self, normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node): segment_ids = None segment_embedding_gather = None segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1]) if segment_embedding_path is None: segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1]) if segment_embedding_path is None: logger.info("Segment embedding is not found. Embed layer cannot be fused.") return _, segment_embedding_gather = segment_embedding_path else: segment_embedding_gather = segment_embedding_path[0] segment_ids = segment_embedding_gather.input[1] self.nodes_to_remove.extend(segment_embedding_path) if self.model.find_graph_input(segment_ids): casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids) else: segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids) # Cast might be removed by OnnxRuntime. _, segment_id_path, _ = self.model.match_parent_paths( segment_ids_cast_node, [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]), (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node) if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]: logger.debug("Simplify semgent id path...") self.model.add_node( helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"])) self.model.add_node( helper.make_node('ConstantOfShape', inputs=["input_shape"], outputs=["zeros_for_input_shape"], value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1]))) segment_ids = "zeros_for_input_shape" return segment_ids, segment_embedding_gather def fuse(self, node, input_name_to_nodes, output_name_to_node): is_distill = False; if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None and self.model.match_parent_path(node, ['Gather'], [0]) is None: logger.debug("Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather") return self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) if self.attention is None: # In case user disables attention fusion, check whether subgraph looks like Attention. if node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[node.output[0]] children_types = sorted([child.op_type for child in children]) if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization'] and children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape', 'SkipLayerNormalization']: logger.debug("No Attention like subgraph in children of SkipLayerNormalization") return # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding normalize_node = node add_node = None word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0]) if word_embedding_path is not None: add_node, word_embedding_gather = word_embedding_path else: word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0]) if word_embedding_path is not None: word_embedding_gather = word_embedding_path[0] is_distill = True; else: logger.info("Word embedding path is not found. Embed layer cannot be fused.") return input_ids = word_embedding_gather.input[1] position_embedding_expand = None position_embedding_shape = None position_embedding_path = self.model.match_parent_path(normalize_node, ['Gather', 'Expand'], [1, 1]) # for distill-bert if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand = position_embedding_path else: position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0]) if position_embedding_path is not None: _, position_embedding_weight_node = position_embedding_path else: position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path else: position_embedding_path = self.model.match_parent_path( add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path else: # Here we will not try to get exact match. Instead, we only try identify position embedding weights. position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand = position_embedding_path else: logger.info("Position embedding path is not found. Embed layer cannot be fused.") return if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids: logger.info("position and word embedding is expected to be applied on same input") return if position_embedding_expand and position_embedding_shape: input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node) subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand, [input_parent] if input_parent else [], output_name_to_node) self.nodes_to_remove.extend(subgraph_nodes) self.nodes_to_remove.extend(word_embedding_path) self.nodes_to_remove.extend(position_embedding_path) self.nodes_to_remove.extend([normalize_node]) # Cast input_ids and segment_ids to int32. input_ids_cast_node = None if self.model.find_graph_input(input_ids): casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids) else: input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids) node_name = self.model.create_node_name('EmbedLayerNormalization') output_name = node_name + "_output" embed_node_inputs = None if is_distill == False: segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node) if segment_path is None: return else: from packaging.version import Version import onnxruntime if Version(onnxruntime.__version__) <= Version("1.4.0"): logger.warning('Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert') return segment_ids, segment_embedding_gather = segment_path embed_node_inputs=[ input_ids, segment_ids, word_embedding_gather.input[0], position_embedding_weight_node.input[0], segment_embedding_gather.input[0], normalize_node.input[2], normalize_node.input[3] # gamma and beta ] else: embed_node_inputs=[ input_ids, '', word_embedding_gather.input[0], position_embedding_weight_node.input[0], '', normalize_node.input[2], normalize_node.input[3] # gamma and beta ] embed_node = helper.make_node( 'EmbedLayerNormalization', embed_node_inputs, outputs=[node_name + "_output", node_name + "_dummy_mask_index"], name=node_name) embed_node.domain = "com.microsoft" # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization. for att in normalize_node.attribute: if att.name == 'epsilon': embed_node.attribute.extend([att]) # Set default value to 1e-12 if no attribute is found. # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later. if len(embed_node.attribute) == 0: embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)]) self.model.replace_input_of_all_nodes(normalize_node.output[0], output_name) self.nodes_to_add.append(embed_node)
class FusionShape(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "Shape", "Concat") self.utils = FusionUtils(model) self.shape_infer = None self.shape_infer_done = False def get_dimensions_from_tensor_proto( self, tensor_proto: TensorProto) -> Union[int, None]: if tensor_proto.type.tensor_type.HasField("shape"): return len(tensor_proto.type.tensor_type.shape.dim) else: return None def get_dimensions(self, input_name: str) -> Union[int, None]: graph_input = self.model.find_graph_input(input_name) if graph_input: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: self.shape_infer = self.model.infer_runtime_shape({}, update=True) self.shape_infer_done = True if self.shape_infer is not None: return self.get_dimensions_from_tensor_proto( self.shape_infer.known_vi_[input_name]) return None def fuse( self, concat_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto], ): """ Smplify subgraph like (2d_input) / \ Shape shape / \ Gather(indices=0) Gather(indices=1) | | Unsqueeze(axes=0) Unsqueeze(axes=0) \ / Concat | into (2d_input) --> Shape --> """ opset_version = self.model.get_opset_version() inputs = len(concat_node.input) root = None shape_output = None for i in range(inputs): path = self.model.match_parent_path( concat_node, ["Unsqueeze", "Gather", "Shape"], [i, 0, 0], output_name_to_node, ) if path is None: return unsqueeze, gather, shape = path if i == 0: shape_output = shape.output[0] if root is None: root = shape.input[0] if self.get_dimensions(root) != inputs: return elif shape.input[0] != root: return if not FusionUtils.check_node_attribute( unsqueeze, "axis", 0, default_value=0): return if opset_version < 13: if not FusionUtils.check_node_attribute( unsqueeze, "axes", [0]): return else: if not self.utils.check_node_input_value(unsqueeze, 1, [0]): return value = self.model.get_constant_value(gather.input[1]) from numpy import array_equal, ndarray if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i): return if self.model.find_graph_output(concat_node.output[0]) is None: self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output) self.fused_count += 1 self.prune_graph = True
class FusionEmbedLayerNoMask(Fusion): """ Fuse embedding layer into one node (EmbedLayerNormalization). It supports the following model types: BERT, DistilBert, ALBert. """ def __init__(self, model: OnnxModel, description: str = 'no mask'): super().__init__(model, "EmbedLayerNormalization", ["LayerNormalization", "SkipLayerNormalization"], description) self.utils = FusionUtils(model) self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None def match_two_gather( self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]: gather_0_path = self.model.match_parent_path(add, ['Gather'], [0]) if gather_0_path is None: return None gather_1_path = self.model.match_parent_path(add, ['Gather'], [1]) if gather_1_path is None: return None return gather_0_path[0], gather_1_path[0] def check_attention_subgraph(self, layernorm: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], is_distil_bert: bool) -> bool: """Check that LayerNormalization has a child of Attention node or subgraph like Attention. Args: layernorm (NodeProto): LayerNormalization node input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes is_distil_bert (bool): whether it is DistilBert or not Returns: bool: whether there is Attention node or subgraph like Attention """ self.attention = self.model.find_first_child_by_type( layernorm, 'Attention', input_name_to_nodes, recursive=False) if self.attention is None: # In case user disables attention fusion, check whether subgraph looks like Attention. if layernorm.output[0] not in input_name_to_nodes: return False children = input_name_to_nodes[layernorm.output[0]] # For Albert, there is MatMul+Add after embedding layer before attention. if len(children ) == 1 and children[0].op_type == "MatMul" and children[ 0].output[0] in input_name_to_nodes: grandchildren = input_name_to_nodes[children[0].output[0]] if len(grandchildren) == 1 and grandchildren[ 0].op_type == "Add" and grandchildren[0].output[ 0] in input_name_to_nodes: nodes = input_name_to_nodes[grandchildren[0].output[0]] for node in nodes: if node.op_type == "Attention": self.attention = node return True children_types = sorted([child.op_type for child in nodes]) else: children_types = sorted([child.op_type for child in children]) # Two Shape nodes might be merged by ORT if is_distil_bert: # SkipLayerNormailization might exist when model has been optimized by ORT first. if children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'SkipLayerNormalization'] and \ children_types != ['Add', 'MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape'] and \ children_types != ['Add', 'MatMul', 'MatMul', 'MatMul', 'Shape']: logger.debug( "No Attention like subgraph in children of LayerNormalization" ) return False else: if children_types != ['Add', 'MatMul', 'MatMul', 'MatMul'] and \ children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']: logger.debug( "No Attention like subgraph in children of LayerNormalization" ) return False return True def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node): """ Match position embedding path from input_ids to Gather for DistilBert. Pattern is like the following: (input_ids) | Shape | \ | Gather (indices=1) | | | Cast (optional) | | | Range (start=0, end=*, delta=1) | | | Unsqueeze | / Expand | Gather """ path1 = self.model.match_parent_path(position_embedding_gather, ['Expand', 'Shape'], [1, 1]) if path1 is None: return False expand, shape = path1 if shape.input[0] != input_ids: return False _, path2, _ = self.model.match_parent_paths(expand, [(['Unsqueeze', 'Range', 'Cast', 'Gather', 'Shape'], [0, 0, 1, 0, 0]), \ (['Unsqueeze', 'Range', 'Gather', 'Shape'], [0, 0, 1, 0])], output_name_to_node) if path2 is None: return False range_node = path2[1] if not (self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)): return False gather_node = path2[-2] if not (self.utils.check_node_input_value(gather_node, 1, 1)): return False shape_node = path2[-1] if shape_node.input[0] != input_ids: return False return True def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node): """ Match position embedding path from input_ids to Gather for Roberta. Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id): (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather | ^ V | +------------------------------+ Roberta new pattern from transformers v4.9: (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather | ^ V | +-------------------------------------------+ start_node = position_embedding_gather start_index = 1 # match optional Cast node. parent = self.model.get_parent(start_node, start_index, output_name_to_node) if parent is None: return if parent.op_type == "Cast": if OnnxModel.get_node_attribute(parent, "to") != 7: return start_node = parent start_index = 0 i, path, return_indices = self.model.match_parent_paths( start_node, [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]), (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])], output_name_to_node) if path is not None: # constant input of Add shall be 1. i, value = self.model.get_constant_input(path[0]) if value != 1: return False _, self.padding_word_id = self.model.get_constant_input(path[-1]) return input_ids == path[-1].input[0] """ return False def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node): """ Match position embedding path from input_ids to Gather for BERT. BERT Embedding Layer Pattern: (input_ids) / \ / Shape / | / Gather (indices=1) / | / Add (optional, B=0) / | Gather (segment_ids) Unsqueeze (axes=0) \ | | \ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1) \ / | Add Gather \ / Add | LayerNormalization """ path = self.model.match_parent_path(position_embedding_gather, ['Slice', 'Unsqueeze'], [1, 2], output_name_to_node) if path is None: return False slice, unsqueeze = path slice_weight = self.model.get_constant_value(slice.input[0]) if not (slice_weight is not None and len(slice_weight.shape) == 2 and slice_weight.shape[0] == 1 \ and self.utils.check_node_input_value(slice, 1, [0]) \ and self.utils.check_node_input_value(slice, 3, [1]) \ and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))): return False opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]): return False else: if not self.utils.check_node_input_value(unsqueeze, 1, [0]): return False node = self.model.get_parent(unsqueeze, 0, output_name_to_node) if node is None: return False if node.op_type == "Add": if not self.utils.check_node_input_value(node, 1, 0): return False gather = self.model.get_parent(node, 0, output_name_to_node) else: gather = node if gather is None or gather.op_type != "Gather": return False if not (self.utils.check_node_input_value(gather, 1, 1)): return False shape = self.model.get_parent(gather, 0, output_name_to_node) if shape is None or shape.op_type != "Shape": return False return input_ids == shape.input[0] def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node): if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node): return True # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel # related: https://github.com/huggingface/transformers/issues/10736 #if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node): # return True if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node): return True return False def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather): """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs. """ input_ids = word_embedding_gather.input[1] segment_ids = segment_embedding_gather.input[ 1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] if self.shape_infer_helper is not None: input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) position_ids_shape = self.shape_infer_helper.get_edge_shape( position_ids) assert input_ids_shape and position_ids_shape if not (len(input_ids_shape) == 2 and len(position_ids_shape) == 2 and input_ids_shape[1] == position_ids_shape[1]): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}" .format(input_ids_shape, position_ids_shape)) return False if segment_ids and not self.shape_infer_helper.compare_shape( input_ids, segment_ids): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}" .format( input_ids_shape, self.shape_infer_helper.get_edge_shape(segment_ids))) return False word_embedding_table = self.model.get_constant_value( word_embedding_gather.input[0]) if word_embedding_table is None or len( word_embedding_table.shape) != 2: logger.info( "Cannot fuse EmbedLayerNormalization: word embedding table is not expected" ) return False position_embedding_table = self.model.get_constant_value( position_embedding_gather.input[0]) if position_embedding_table is None or len( position_embedding_table.shape) != 2 or ( word_embedding_table.shape[1] != position_embedding_table.shape[1]): logger.info( "Cannot fuse EmbedLayerNormalization: position embedding table is not expected" ) return False if segment_ids: segment_embedding_table = self.model.get_constant_value( segment_embedding_gather.input[0]) if segment_embedding_table is None or len( segment_embedding_table.shape) != 2 or ( word_embedding_table.shape[1] != segment_embedding_table.shape[1]): logger.info( "Cannot fuse EmbedLayerNormalization: segment embedding table is not expected" ) return False # In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between. # TODO: use other information (like initializer names) to identify different embedding weights automatically. if word_embedding_table.shape[0] <= position_embedding_table.shape[0]: logger.warning( f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}" ) if segment_ids: if word_embedding_table.shape[0] <= segment_embedding_table.shape[ 0]: logger.warning( f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}" ) if position_embedding_table.shape[ 0] <= segment_embedding_table.shape[0]: logger.warning( f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}" ) return True def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]: """Cast a graph input or node input to int32. Args: input_name (str): name of graph input or node input Returns: A tuple of casted input name and the cast node. int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node. input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32. """ input_cast_node = None graph_input = self.model.find_graph_input(input_name) if graph_input is not None: if graph_input.type.tensor_type.elem_type != TensorProto.INT32: int32_output, input_cast_node = self.utils.cast_input_to_int32( input_name) else: int32_output = input_name else: int32_output, input_cast_node = self.utils.cast_input_to_int32( input_name) return int32_output, input_cast_node def create_fused_node(self, input_ids: str, layernorm: NodeProto, word_embedding_gather: NodeProto, position_embedding_gather: NodeProto, segment_embedding_gather: Union[None, NodeProto]): """Create an EmbedLayerNormalization node. Note that segment embedding is optional. Args: input_ids (str): input_ids for word embeddings layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node. word_embedding_gather (NodeProto): the Gather node for word embedding position_embedding_gather (NodeProto): the Gather node for position embedding segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None. Returns: NodeProto: the EmbedLayerNormalization node created. """ nodes_to_add = [] input_ids, _ = self.cast_to_int32(input_ids) node_name = self.model.create_node_name('EmbedLayerNormalization') if layernorm.op_type == "LayerNormalization": gamma = layernorm.input[1] beta = layernorm.input[2] else: # SkipLayerNormalization gamma = layernorm.input[2] beta = layernorm.input[3] embed_node_inputs = None if segment_embedding_gather is not None: segment_ids, _ = self.cast_to_int32( segment_embedding_gather.input[1]) embed_node_inputs = [ input_ids, segment_ids, word_embedding_gather.input[0], position_embedding_gather.input[0], segment_embedding_gather.input[0], gamma, beta ] else: # no segment embedding embed_node_inputs = [ input_ids, '', word_embedding_gather.input[0], position_embedding_gather.input[0], '', gamma, beta ] embed_node = helper.make_node( 'EmbedLayerNormalization', embed_node_inputs, outputs=[node_name + "_output", node_name + "_dummy_mask_index"], name=node_name) embed_node.domain = "com.microsoft" # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization. for att in layernorm.attribute: if att.name == 'epsilon': embed_node.attribute.extend([att]) # Set default value to 1e-12 if no attribute is found. # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later. if len(embed_node.attribute) == 0: embed_node.attribute.extend( [helper.make_attribute("epsilon", 1.0E-12)]) # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add. nodes_to_add.append(embed_node) for node in nodes_to_add: self.node_name_to_graph_name[node.name] = self.this_graph_name self.nodes_to_add.extend(nodes_to_add) self.embed_node = embed_node return embed_node def finish_fusion(self, layernorm, embed_node): self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0]) # use prune graph to remove nodes that is not needed self.prune_graph = True def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node): """Fuse embedding layer for DistilBert Args: layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes """ # DistilBert has no segment embedding, subgraph pattern is like # input_ids # | \ # | (position_embedding_subgraph) # | | # Gather Gather # \ / # Add # | # LayerNormalization two_gather = self.match_two_gather(add_before_layernorm) if two_gather is None: return False word_embedding_gather, position_embedding_gather = two_gather input_ids = word_embedding_gather.input[1] if not self.check_attention_subgraph( layernorm, input_name_to_nodes, is_distil_bert=True): return False if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node): return False if not self.check_embedding(word_embedding_gather, None, position_embedding_gather): return False embed_node = self.create_fused_node(input_ids, layernorm, word_embedding_gather, position_embedding_gather, None) self.finish_fusion(layernorm, embed_node) return True def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node): """Fuse embedding layer for Bert Args: layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes """ add_2_gather = self.model.match_parent_path(add_before_layernorm, ['Add'], [0]) if add_2_gather is None: return False two_gather = self.match_two_gather(add_2_gather[0]) if two_gather is None: return False word_embedding_gather, segment_embedding_gather = two_gather input_ids = word_embedding_gather.input[1] if not self.check_attention_subgraph( layernorm, input_name_to_nodes, is_distil_bert=False): return False position_embedding_path = self.model.match_parent_path( add_before_layernorm, ['Gather'], [1]) if position_embedding_path is None: return False position_embedding_gather = position_embedding_path[0] if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node): if not self.match_position_embedding( segment_embedding_gather, input_ids, output_name_to_node): return False # position and segment are switched temp = segment_embedding_gather segment_embedding_gather = position_embedding_gather position_embedding_gather = temp if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather): return False embed_node = self.create_fused_node(input_ids, layernorm, word_embedding_gather, position_embedding_gather, segment_embedding_gather) self.finish_fusion(layernorm, embed_node) return True def fuse(self, node, input_name_to_nodes, output_name_to_node): if node.op_type == "LayerNormalization": first_add_path = self.model.match_parent_path(node, ['Add'], [0]) if first_add_path is None: return add_before_layernorm = first_add_path[0] else: # SkipLayerNormalization add_before_layernorm = node # Add is fused into SkipLayerNormalization if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node): return if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node): return
def auto_mixed_precision(onnx_model: OnnxModel, op_block_list: List[str] = [ 'Add', 'LayerNormalization', 'FastGelu' ]): """Convert GPT-2 model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: onnx_model (OnnxModel): optimized ONNX model op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu'] Returns: parameters(dict): a dictionary of parameters used in float16 conversion """ op_full_set = set([node.op_type for node in onnx_model.nodes()]) fp32_op_set = set(op_block_list) fp16_op_set = op_full_set.difference(fp32_op_set) logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}") # logits is the first output logits_output_name = onnx_model.graph().output[0].name # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training. is_weight_fp16_precision = False output_name_to_node = onnx_model.output_name_to_node() assert logits_output_name in output_name_to_node node = output_name_to_node[logits_output_name] last_matmul_node = None if node.op_type == "MatMul": last_matmul_node = node logger.info(f"Found last MatMul node for logits: {node.name}") initializer = None for input in node.input: initializer = onnx_model.get_initializer(input) if initializer is not None: break # when the max difference of value after converting float to float16 is lower than a threshold (1e-6), # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug( f"max diff of converting weights in last MatMul node {node.name}: {max_diff}" ) is_weight_fp16_precision = (max_diff < 1E-6) else: logger.warning( f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}" ) if is_weight_fp16_precision: keep_io_types = [] node_block_list = [] else: # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision. keep_io_types = [logits_output_name] node_block_list = [last_matmul_node.name] parameters = { "keep_io_types": keep_io_types, "op_block_list": op_block_list, "node_block_list": node_block_list, "force_fp16_initializers": is_weight_fp16_precision } logger.info(f"auto_mixed_precision parameters: {parameters}") onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) fusion_utils = FusionUtils(onnx_model) fusion_utils.remove_cascaded_cast_nodes() fusion_utils.remove_useless_cast_nodes() return parameters
class FusionEmbedLayerNoMask(Fusion): """ Embed Layer Normalization will fuse embeddings and mask processing into one node. The embeddings before conversion: (input_ids) --------> Gather ----------+ (segment_ids) | | | | v v +--> Shape --> Expand -> Gather---->Add Gather | ^ | | | | v v +---(optional graph) SkipLayerNormalization Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model. (input_ids) --> Gather -----+ Slice | | v v (segment_ids)--> Gather --->Add Reshape | | v v SkipLayerNormalization """ def __init__(self, model: OnnxModel, description='no mask'): super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description) self.utils = FusionUtils(model) def fuse(self, node, input_name_to_nodes, output_name_to_node): # already fused. Assumes that only one mebedding layer in a transformer model. if self.nodes_to_add: return if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None: return if self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) is None: # In case user disables attention fusion, check whether subgraph looks like Attention. if node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[node.output[0]] children_types = sorted([child.op_type for child in children]) if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']: return # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding normalize_node = node word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0]) if word_embedding_path is None: logger.info("Failed to find word embedding") return add_node, word_embedding_gather = word_embedding_path input_ids = word_embedding_gather.input[1] position_embedding_expand = None position_embedding_shape = None position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0]) if position_embedding_path is not None: _, position_embedding_weight_node = position_embedding_path else: position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path else: position_embedding_path = self.model.match_parent_path( add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path else: # Here we will not try to get exact match. Instead, we only try identify position embedding weights. position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand = position_embedding_path else: logger.info("Failed to find position embedding") return if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids: logger.info("position and word embedding is expected to be applied on same input") return segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1]) if segment_embedding_path is None: segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1]) if segment_embedding_path is None: logger.info("Failed to find segment embedding") return _, segment_embedding_gather = segment_embedding_path else: segment_embedding_gather = segment_embedding_path[0] segment_ids = segment_embedding_gather.input[1] if position_embedding_expand and position_embedding_shape: input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node) subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand, [input_parent] if input_parent else [], output_name_to_node) self.nodes_to_remove.extend(subgraph_nodes) self.nodes_to_remove.extend(word_embedding_path) self.nodes_to_remove.extend(position_embedding_path) self.nodes_to_remove.extend(segment_embedding_path) self.nodes_to_remove.extend([normalize_node]) # store inputs for further processing if self.model.find_graph_input(input_ids): self.model.bert_inputs = [input_ids, segment_ids ] if self.model.find_graph_input(segment_ids) else [input_ids] # Cast input_ids and segment_ids to int32. input_ids_cast_node = None if self.model.find_graph_input(input_ids): casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids) else: input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids) if self.model.find_graph_input(segment_ids): casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids) else: segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids) # Cast might be removed by OnnxRuntime. _, segment_id_path, _ = self.model.match_parent_paths( segment_ids_cast_node, [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]), (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node) if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]: logger.debug("Simplify semgent id path...") self.model.add_node( helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"])) self.model.add_node( helper.make_node('ConstantOfShape', inputs=["input_shape"], outputs=["zeros_for_input_shape"], value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1]))) segment_ids = "zeros_for_input_shape" embed_node = helper.make_node( 'EmbedLayerNormalization', inputs=[ input_ids, segment_ids, word_embedding_gather.input[0], position_embedding_weight_node.input[0], segment_embedding_gather.input[0], normalize_node.input[2], normalize_node.input[3] # gamma and beta ], outputs=["embed_output", "dummy_mask_index"], name="EmbedLayer") embed_node.domain = "com.microsoft" # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization. for att in normalize_node.attribute: if att.name == 'epsilon': embed_node.attribute.extend([att]) # Set default value to 1e-12 if no attribute is found. if len(embed_node.attribute) == 0: embed_node.attribute.extend([onnx.helper.make_attribute("epsilon", 1.0E-12)]) self.model.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output') self.nodes_to_add.append(embed_node)
class FusionGptAttention(Fusion): """ Fuse GPT-2 Attention with past state subgraph into one Attention node. This does not support attention_mask graph input right now. """ def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") # TODO: detect num_heads from graph like FusionAttention self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32 def create_attention_node(self, fc_weight, fc_bias, gemm_qkv, past, present, input, output, mask, is_unidirectional): attention_node_name = self.model.create_node_name('GptAttention') attention_node = helper.make_node( 'Attention', inputs=[input, fc_weight, fc_bias, mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" attention_node.attribute.extend([ helper.make_attribute("num_heads", self.num_heads), helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) ]) matmul_node = helper.make_node( 'MatMul', inputs=[attention_node_name + "_output", gemm_qkv.input[1]], outputs=[attention_node_name + "_matmul_output"], name=attention_node_name + "_matmul") add_node = helper.make_node( 'Add', inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], outputs=[output], name=attention_node_name + "_add") self.nodes_to_add.extend([attention_node, matmul_node, add_node]) self.node_name_to_graph_name[ attention_node.name] = self.this_graph_name self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name self.node_name_to_graph_name[add_node.name] = self.this_graph_name def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node): # Pattern 1: # {past} # / \ # / \ # Gather(axes=0, indices=0) Gather(indices=1) # | | # Transpose (perm=0,1,3,2) | # | | # Concat_k Concat_v # | / # Transpose (perm=0,1,3,2) / # | / # Unsqueeze Unsqueeze # \ / # \ / # Concat # | # {present} gather = self.model.get_parent(concat_v, 0, output_name_to_node) if gather.op_type != 'Gather': logger.debug("match_past_pattern_1: expect Gather for past") return None if not self.model.find_constant_input(gather, 1) == 1: logger.debug( "match_past_pattern_1: expect indices=1 for Gather of past") return None past = gather.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Transpose', 'Gather'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_1: failed match Transpose and Gather") return None gather_past_k = past_k_nodes[-1] if not self.model.find_constant_input(gather_past_k, 0) == 1: logger.debug( "match_past_pattern_1: expect indices=0 for Gather k of past") return None past_k = gather_past_k.input[0] if past != past_k: logger.debug("match_past_pattern_1: expect past to be same") return None return past def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): # Pattern 2: # Split (QKV) # / | | # / | +----------------------+ # | | # | {past} | # | | | # Reshape Split Reshape # | / \ | # Transpose_k Squeeze Squeeze Transpose_v # | | \ / # +------|---+ \ / # | | \ / # Concat_k Concat_v # | | # Unsqueeze Unsqueeze # \ / # Concat # | # {present} # squeeze = self.model.get_parent(concat_v, 0, output_name_to_node) if squeeze.op_type != 'Squeeze': logger.debug( "match_past_pattern_2: expect Squeeze as parent of concat_v") return None split = self.model.get_parent(squeeze, 0, output_name_to_node) if split.op_type != "Split": logger.debug("match_past_pattern_2: expect Split for past path") return None opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not FusionUtils.check_node_attribute(split, 'split', [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None else: if not self.utils.check_node_input_value(squeeze, 1, [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not self.utils.check_node_input_value(split, 1, [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None if not FusionUtils.check_node_attribute( split, 'axis', 0, default_value=0): logger.debug( "match_past_pattern_2: attribute axis of Split are not expected in past path" ) return None past = split.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Squeeze', 'Split'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_2: failed to match past_k_nodes path") return None past_k = past_k_nodes[-1].input[0] if past != past_k: logger.info("match_past_pattern_2: expect past to be same") return None return past def match_present(self, concat_v, input_name_to_nodes): unsqueeze_present_v = self.model.find_first_child_by_type( concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return None concat_present = self.model.find_first_child_by_type( unsqueeze_present_v, 'Concat', input_name_to_nodes, recursive=False) if not concat_present: logger.info("expect concat for present") return None present = concat_present.output[0] return present def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past = None present = None return_indice = [] qkv_nodes = self.model.match_parent_path( normalize_node, ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'], [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice ) # yapf: disable if qkv_nodes is None: return (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv, matmul_qkv) = qkv_nodes another_input = add_qkv.input[1 - return_indice[0]] v_nodes = self.model.match_parent_path( matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (concat_v, transpose_v, reshape_v, split_fc) = v_nodes fc_nodes = self.model.match_parent_path( split_fc, ['Reshape', 'Gemm', 'Reshape', 'LayerNormalization'], [0, 0, 0, 0], output_name_to_node) if fc_nodes is None: fc_nodes = self.model.match_parent_path( split_fc, ['Add', 'MatMul', 'LayerNormalization'], [0, None, 0], output_name_to_node) if fc_nodes is None: logger.debug("fuse_attention: failed to match fc path") return fc_weight = fc_nodes[1].input[1] i, _ = self.model.get_constant_input(fc_nodes[0]) fc_bias = fc_nodes[0].input[i] else: fc_weight = fc_nodes[1].input[1] fc_bias = fc_nodes[1].input[2] layernorm_before_attention = fc_nodes[-1] if not another_input in layernorm_before_attention.input: logger.debug( "Add and LayerNormalization shall have one same input") return is_unidirectional = True slice_mask = None input_mask_nodes = None concat_k_to_match = None qk_nodes = self.model.match_parent_path( matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( sub_qk, ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: logger.debug( "fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] slice_mask = mask_nodes[3] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. i, qk_nodes, _ = self.model.match_parent_paths( matmul_qkv, [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]), (['Softmax', 'Add', 'Where', 'Div', 'MatMul' ], [0, 0, None, 1, 0])], output_name_to_node) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk nodes") return where_qk = qk_nodes[-3] div_qk = qk_nodes[-2] matmul_qk = qk_nodes[-1] if i == 1: add_qk = qk_nodes[1] _, input_mask_nodes, _ = self.model.match_parent_paths( add_qk, [([ 'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [None, 0, 1, 0, 0, 0]), (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [None, 0, 1, 0, 0])], output_name_to_node) if input_mask_nodes is None: logger.debug( "fuse_attention: failed to match input attention mask path" ) return mask_nodes = self.model.match_parent_path( where_qk, ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape'], [ 0, 0, 0, 1, 0, 0, 0, 0], output_name_to_node) # yapf: disable if mask_nodes is None: # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep. logger.debug("fuse_attention: failed to match mask path") return slice_mask = mask_nodes[2] div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node) if div_or_concat.op_type == "Div": div_mask = div_or_concat if div_qk != div_mask: logger.debug( "fuse_attention: skip since div_qk != div_mask") return elif div_or_concat.op_type == "Concat": concat_k_to_match = div_or_concat else: logger.debug("fuse_attention: failed to match mask path") # Validate that the mask data is either lower triangular (unidirectional) or all ones mask_data = numpy_helper.to_array( self.model.get_initializer(slice_mask.input[0])) if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3]): logger.debug( "fuse_attention: skip since mask shape is not 1x1xWxW") return if np.allclose(mask_data, np.ones_like(mask_data)): is_unidirectional = False elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))): logger.debug( "fuse_attention: skip since mask is neither lower triangular nor ones" ) return q_nodes = self.model.match_parent_path( matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return (transpose_q, reshape_q, split_q) = q_nodes if split_fc != split_q: logger.debug("fuse_attention: skip since split_fc != split_q") return k_nodes = self.model.match_parent_path( matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) if k_nodes is None: # This pattern is from pytorch 1.7.1 and transformers 4.6.1 k_nodes = self.model.match_parent_path( matmul_qk, ['Transpose', 'Concat', 'Transpose', 'Reshape', 'Split'], [1, 0, 1, 0, 0]) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return else: (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes else: (concat_k, transpose_k, reshape_k, split_k) = k_nodes if split_fc != split_k: logger.debug("fuse_attention: skip since split_fc != split_k") return if concat_k_to_match and concat_k != concat_k_to_match: logger.debug( "fuse_attention: skip since concat_k != concat_k_to_match") return attention_mask_input_name = '' if input_mask_nodes is not None: input_name = input_mask_nodes[-1].input[0] if input_name in self.casted_attention_mask: attention_mask_input_name = self.casted_attention_mask[ input_name] elif self.model.find_graph_input(input_name): casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name else: attention_mask_input_name, cast_node = self.utils.cast_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name # Match past and present paths past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or \ self.match_past_pattern_2(concat_k, concat_v, output_name_to_node) if past is None: logger.info("fuse_attention: failed to match past path") return if not self.model.find_graph_input(past): logger.debug("past is not graph input.") # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input. present = self.match_present(concat_v, input_name_to_nodes) if present is None: logger.info("fuse_attention: failed to match present path") return if not self.model.find_graph_output(present): logger.info("expect present to be graph output") return self.create_attention_node(fc_weight, fc_bias, gemm_qkv, past, present, layernorm_before_attention.output[0], reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] self.prune_graph = True
class FusionGptAttention(Fusion): """ Fuse GPT-2 Attention with past state subgraph into one Attention node. This does not support attention_mask graph input right now. """ def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") # TODO: detect num_heads from graph like FusionAttention self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32 def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask, is_unidirectional): attention_node_name = self.model.create_node_name('GptAttention') attention_node = helper.make_node( 'Attention', inputs=[input, gemm.input[1], gemm.input[2], mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" attention_node.attribute.extend([ helper.make_attribute("num_heads", self.num_heads), helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) ]) matmul_node = helper.make_node( 'MatMul', inputs=[attention_node_name + "_output", gemm_qkv.input[1]], outputs=[attention_node_name + "_matmul_output"], name=attention_node_name + "_matmul") add_node = helper.make_node( 'Add', inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], outputs=[output], name=attention_node_name + "_add") self.nodes_to_add.extend([attention_node, matmul_node, add_node]) def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past = None present = None return_indice = [] qkv_nodes = self.model.match_parent_path( normalize_node, ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'], [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice ) # yapf: disable if qkv_nodes is None: return (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv, matmul_qkv) = qkv_nodes another_input = add_qkv.input[1 - return_indice[0]] v_nodes = self.model.match_parent_path( matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split', 'Reshape', 'Gemm', 'Reshape'], [1, 1, 0, 0, 0, 0, 0]) # yapf: disable if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (concat_v, transpose_v, reshape_v, split_v, reshape_after_gemm, gemm, reshape_before_gemm) = v_nodes # concat <-- Gather(indices=1) <-- past # | # unsqueeze # | # concat --> present gather_v = self.model.get_parent(concat_v, 0, output_name_to_node) if gather_v.op_type != 'Gather': logger.info("expect Gather for past") return if not self.model.find_constant_input(gather_v, 1) == 1: logger.info("expect indices=1 for Gather of past") return past = gather_v.input[0] if not self.model.find_graph_input(past): logger.info("expect past to be graph input") return unsqueeze_present_v = self.model.find_first_child_by_type( concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return concat_present = self.model.find_first_child_by_type( unsqueeze_present_v, 'Concat', input_name_to_nodes, recursive=False) if not concat_present: logger.info("expect concat for present") return present = concat_present.output[0] if not self.model.find_graph_output(present): logger.info("expect present to be graph input") return layernorm_before_attention = self.model.get_parent( reshape_before_gemm, 0, output_name_to_node) if layernorm_before_attention is None or layernorm_before_attention.op_type != 'LayerNormalization': logger.debug( f"failed to get layernorm before gemm. Got {layernorm_before_attention.op_type}" ) return if not another_input in layernorm_before_attention.input: logger.debug( "Add and LayerNormalization shall have one same input") return is_unidirectional = True slice_mask = None input_mask_nodes = None qk_nodes = self.model.match_parent_path( matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( sub_qk, ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: logger.debug( "fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] slice_mask = mask_nodes[3] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. i, qk_nodes, _ = self.model.match_parent_paths( matmul_qkv, [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]), (['Softmax', 'Add', 'Where', 'Div', 'MatMul' ], [0, 0, 0, 1, 0])], output_name_to_node) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk nodes") return where_qk = qk_nodes[-3] div_qk = qk_nodes[-2] matmul_qk = qk_nodes[-1] if i == 1: add_qk = qk_nodes[1] _, input_mask_nodes, _ = self.model.match_parent_paths( add_qk, [([ 'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [1, 0, 1, 0, 0, 0]), (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [1, 0, 1, 0, 0])], output_name_to_node) if input_mask_nodes is None: logger.debug( "fuse_attention: failed to match input attention mask path" ) return mask_nodes = self.model.match_parent_path( where_qk, ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [ 0, 0, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return div_mask = mask_nodes[-1] slice_mask = mask_nodes[2] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return # Validate that the mask data is either lower triangular (unidirectional) or all ones mask_data = numpy_helper.to_array( self.model.get_initializer(slice_mask.input[0])) if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3]): logger.debug( "fuse_attention: skip since mask shape is not 1x1xWxW") return if np.allclose(mask_data, np.ones_like(mask_data)): is_unidirectional = False elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))): logger.debug( "fuse_attention: skip since mask is neither lower triangular nor ones" ) return q_nodes = self.model.match_parent_path( matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return (transpose_q, reshape_q, split_q) = q_nodes if split_v != split_q: logger.debug("fuse_attention: skip since split_v != split_q") return k_nodes = self.model.match_parent_path( matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return (concat_k, transpose_k, reshape_k, split_k) = k_nodes if split_v != split_k: logger.debug("fuse_attention: skip since split_v != split_k") return # concat_k <-- Transpose (perm=0,1,3,2) <-- Gather(axes=0, indices=0) <-- past # | # Transpose (perm=0,1,3,2) # | # unsqueeze # | # concat --> present past_k_nodes = self.model.match_parent_path(concat_k, ['Transpose', 'Gather'], [0, 0]) if past_k_nodes is None: logger.debug("fuse_attention: failed to match past_k_nodes path") return gather_past_k = past_k_nodes[-1] if not self.model.find_constant_input(gather_past_k, 0) == 1: logger.info("expect indices=0 for Gather k of past") return past_k = gather_past_k.input[0] if past != past_k: logger.info("expect past to be same") return attention_mask_input_name = '' if input_mask_nodes is not None: input_name = input_mask_nodes[-1].input[0] if input_name in self.casted_attention_mask: attention_mask_input_name = self.casted_attention_mask[ input_name] elif self.model.find_graph_input(input_name): casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name else: attention_mask_input_name, cast_node = self.utils.cast_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0], reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] self.prune_graph = True
def __init__(self, model: OnnxModel): super().__init__(model, "Shape", "Concat") self.utils = FusionUtils(model) self.shape_infer = None self.shape_infer_done = False