class AttentionMask(): """ Fuse Attention subgraph into one Attention node. """ def __init__(self, model: OnnxModel): self.model = model # A lookup table with mask input as key, and mask index output as value self.mask_indice = {} # A lookup table with mask input as key, and cast (to int32) output as value self.mask_casted = {} self.utils = FusionUtils(model) self.mask_format = AttentionMaskFormat.MaskIndexEnd def set_mask_format(self, mask_format: AttentionMaskFormat): self.mask_format = mask_format def set_mask_indice(self, mask, mask_index): if mask in self.mask_indice: assert mask_index == self.mask_indice[mask] self.mask_indice[mask] = mask_index def get_first_mask(self): assert len(self.mask_indice) > 0 return next(iter(self.mask_indice)) def process_mask(self, input): if input in self.mask_indice: return self.mask_indice[input] # Add cast to convert int64 to int32 if self.model.find_graph_input(input): casted, input_name = self.utils.cast_graph_input_to_int32(input) else: input_name, cast_node = self.utils.cast_input_to_int32(input) casted = True if casted: self.mask_casted[input] = input_name # Attention supports int32 attention mask (2D) since 1.4.0 if self.mask_format == AttentionMaskFormat.AttentionMask: self.mask_indice[input] = input_name return input_name # Add a mask processing node to convert attention mask to mask index (1D) output_name = self.model.create_node_name('mask_index') mask_index_node = helper.make_node('ReduceSum', inputs=[input_name], outputs=[output_name], name=self.model.create_node_name( 'ReduceSum', 'MaskReduceSum')) mask_index_node.attribute.extend([ helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0) ]) self.model.add_node(mask_index_node) self.mask_indice[input] = output_name return output_name
class FusionGptAttentionPastBase(Fusion): """Base class for GPT Attention Fusion with past state """ def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32 def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node): # Pattern 1: # {past} # / \ # / \ # Gather(axes=0, indices=0) Gather(indices=1) # | | # Transpose (perm=0,1,3,2) | # | | # Concat_k Concat_v # | / # Transpose (perm=0,1,3,2) / # | / # Unsqueeze Unsqueeze # \ / # \ / # Concat # | # {present} gather = self.model.get_parent(concat_v, 0, output_name_to_node) if gather.op_type != 'Gather': logger.debug("match_past_pattern_1: expect Gather for past") return None if not self.model.find_constant_input(gather, 1) == 1: logger.debug( "match_past_pattern_1: expect indices=1 for Gather of past") return None past = gather.input[0] parent = self.model.get_parent(concat_k, 0, output_name_to_node) if parent.op_type == 'Gather': gather_past_k = parent else: past_k_nodes = self.model.match_parent_path( concat_k, ['Transpose', 'Gather'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_1: failed match Transpose and Gather") return None gather_past_k = past_k_nodes[-1] if not self.model.find_constant_input(gather_past_k, 0) == 1: logger.debug( "match_past_pattern_1: expect indices=0 for Gather k of past") return None past_k = gather_past_k.input[0] if past != past_k: logger.debug("match_past_pattern_1: expect past to be same") return None return past def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): # Pattern 2: # Split (QKV) # / | | # / | +----------------------+ # | | # | {past} | # | | | # Reshape Split Reshape # | / \ | # Transpose_k Squeeze Squeeze Transpose_v # | | \ / # +------|---+ \ / # | | \ / # Concat_k Concat_v # | | # Unsqueeze Unsqueeze # \ / # Concat # | # {present} # squeeze = self.model.get_parent(concat_v, 0, output_name_to_node) if squeeze.op_type != 'Squeeze': logger.debug( "match_past_pattern_2: expect Squeeze as parent of concat_v") return None split = self.model.get_parent(squeeze, 0, output_name_to_node) if split.op_type != "Split": logger.debug("match_past_pattern_2: expect Split for past path") return None opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not FusionUtils.check_node_attribute(split, 'split', [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None else: if not self.utils.check_node_input_value(squeeze, 1, [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not self.utils.check_node_input_value(split, 1, [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None if not FusionUtils.check_node_attribute( split, 'axis', 0, default_value=0): logger.debug( "match_past_pattern_2: attribute axis of Split are not expected in past path" ) return None past = split.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Squeeze', 'Split'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_2: failed to match past_k_nodes path") return None past_k = past_k_nodes[-1].input[0] if past != past_k: logger.info("match_past_pattern_2: expect past to be same") return None return past def match_present(self, concat_v, input_name_to_nodes): unsqueeze_present_v = self.model.find_first_child_by_type( concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return None concat_present = self.model.find_first_child_by_type( unsqueeze_present_v, 'Concat', input_name_to_nodes, recursive=False) if not concat_present: logger.info("expect concat for present") return None present = concat_present.output[0] return present def cast_attention_mask(self, input_name): if input_name in self.casted_attention_mask: attention_mask_input_name = self.casted_attention_mask[input_name] elif self.model.find_graph_input(input_name): casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32( input_name) self.casted_attention_mask[input_name] = attention_mask_input_name else: attention_mask_input_name, cast_node = self.utils.cast_input_to_int32( input_name) self.casted_attention_mask[input_name] = attention_mask_input_name return attention_mask_input_name
class FusionEmbedLayerNoMask(Fusion): """ Embed Layer Normalization will fuse embeddings and mask processing into one node. The embeddings before conversion: (input_ids) --------> Gather ----------+ (segment_ids) | | | | v v +--> Shape --> Expand -> Gather---->Add Gather | ^ | | | | v v +---(optional graph) SkipLayerNormalization Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model. (input_ids) --> Gather -----+ Slice | | v v (segment_ids)--> Gather --->Add Reshape | | v v SkipLayerNormalization """ def __init__(self, model: OnnxModel, description='no mask'): super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description) self.utils = FusionUtils(model) def fuse(self, node, input_name_to_nodes, output_name_to_node): # already fused. Assumes that only one mebedding layer in a transformer model. if self.nodes_to_add: return if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None: return if self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) is None: # In case user disables attention fusion, check whether subgraph looks like Attention. if node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[node.output[0]] children_types = sorted([child.op_type for child in children]) if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']: return # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding normalize_node = node word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0]) if word_embedding_path is None: logger.info("Failed to find word embedding") return add_node, word_embedding_gather = word_embedding_path input_ids = word_embedding_gather.input[1] position_embedding_expand = None position_embedding_shape = None position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0]) if position_embedding_path is not None: _, position_embedding_weight_node = position_embedding_path else: position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path else: position_embedding_path = self.model.match_parent_path( add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path else: # Here we will not try to get exact match. Instead, we only try identify position embedding weights. position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand = position_embedding_path else: logger.info("Failed to find position embedding") return if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids: logger.info("position and word embedding is expected to be applied on same input") return segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1]) if segment_embedding_path is None: segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1]) if segment_embedding_path is None: logger.info("Failed to find segment embedding") return _, segment_embedding_gather = segment_embedding_path else: segment_embedding_gather = segment_embedding_path[0] segment_ids = segment_embedding_gather.input[1] if position_embedding_expand and position_embedding_shape: input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node) subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand, [input_parent] if input_parent else [], output_name_to_node) self.nodes_to_remove.extend(subgraph_nodes) self.nodes_to_remove.extend(word_embedding_path) self.nodes_to_remove.extend(position_embedding_path) self.nodes_to_remove.extend(segment_embedding_path) self.nodes_to_remove.extend([normalize_node]) # store inputs for further processing if self.model.find_graph_input(input_ids): self.model.bert_inputs = [input_ids, segment_ids ] if self.model.find_graph_input(segment_ids) else [input_ids] # Cast input_ids and segment_ids to int32. input_ids_cast_node = None if self.model.find_graph_input(input_ids): casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids) else: input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids) if self.model.find_graph_input(segment_ids): casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids) else: segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids) # Cast might be removed by OnnxRuntime. _, segment_id_path, _ = self.model.match_parent_paths( segment_ids_cast_node, [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]), (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node) if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]: logger.debug("Simplify semgent id path...") self.model.add_node( helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"])) self.model.add_node( helper.make_node('ConstantOfShape', inputs=["input_shape"], outputs=["zeros_for_input_shape"], value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1]))) segment_ids = "zeros_for_input_shape" embed_node = helper.make_node( 'EmbedLayerNormalization', inputs=[ input_ids, segment_ids, word_embedding_gather.input[0], position_embedding_weight_node.input[0], segment_embedding_gather.input[0], normalize_node.input[2], normalize_node.input[3] # gamma and beta ], outputs=["embed_output", "dummy_mask_index"], name="EmbedLayer") embed_node.domain = "com.microsoft" # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization. for att in normalize_node.attribute: if att.name == 'epsilon': embed_node.attribute.extend([att]) # Set default value to 1e-12 if no attribute is found. if len(embed_node.attribute) == 0: embed_node.attribute.extend([onnx.helper.make_attribute("epsilon", 1.0E-12)]) self.model.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output') self.nodes_to_add.append(embed_node)
class FusionGptAttention(Fusion): """ Fuse GPT-2 Attention with past state subgraph into one Attention node. This does not support attention_mask graph input right now. """ def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") # TODO: detect num_heads from graph like FusionAttention self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32 def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask, is_unidirectional): attention_node_name = self.model.create_node_name('GptAttention') attention_node = helper.make_node( 'Attention', inputs=[input, gemm.input[1], gemm.input[2], mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" attention_node.attribute.extend([ helper.make_attribute("num_heads", self.num_heads), helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) ]) matmul_node = helper.make_node( 'MatMul', inputs=[attention_node_name + "_output", gemm_qkv.input[1]], outputs=[attention_node_name + "_matmul_output"], name=attention_node_name + "_matmul") add_node = helper.make_node( 'Add', inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], outputs=[output], name=attention_node_name + "_add") self.nodes_to_add.extend([attention_node, matmul_node, add_node]) def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past = None present = None return_indice = [] qkv_nodes = self.model.match_parent_path( normalize_node, ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'], [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice ) # yapf: disable if qkv_nodes is None: return (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv, matmul_qkv) = qkv_nodes another_input = add_qkv.input[1 - return_indice[0]] v_nodes = self.model.match_parent_path( matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split', 'Reshape', 'Gemm', 'Reshape'], [1, 1, 0, 0, 0, 0, 0]) # yapf: disable if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (concat_v, transpose_v, reshape_v, split_v, reshape_after_gemm, gemm, reshape_before_gemm) = v_nodes # concat <-- Gather(indices=1) <-- past # | # unsqueeze # | # concat --> present gather_v = self.model.get_parent(concat_v, 0, output_name_to_node) if gather_v.op_type != 'Gather': logger.info("expect Gather for past") return if not self.model.find_constant_input(gather_v, 1) == 1: logger.info("expect indices=1 for Gather of past") return past = gather_v.input[0] if not self.model.find_graph_input(past): logger.info("expect past to be graph input") return unsqueeze_present_v = self.model.find_first_child_by_type( concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return concat_present = self.model.find_first_child_by_type( unsqueeze_present_v, 'Concat', input_name_to_nodes, recursive=False) if not concat_present: logger.info("expect concat for present") return present = concat_present.output[0] if not self.model.find_graph_output(present): logger.info("expect present to be graph input") return layernorm_before_attention = self.model.get_parent( reshape_before_gemm, 0, output_name_to_node) if layernorm_before_attention is None or layernorm_before_attention.op_type != 'LayerNormalization': logger.debug( f"failed to get layernorm before gemm. Got {layernorm_before_attention.op_type}" ) return if not another_input in layernorm_before_attention.input: logger.debug( "Add and LayerNormalization shall have one same input") return is_unidirectional = True slice_mask = None input_mask_nodes = None qk_nodes = self.model.match_parent_path( matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( sub_qk, ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: logger.debug( "fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] slice_mask = mask_nodes[3] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. i, qk_nodes, _ = self.model.match_parent_paths( matmul_qkv, [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]), (['Softmax', 'Add', 'Where', 'Div', 'MatMul' ], [0, 0, 0, 1, 0])], output_name_to_node) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk nodes") return where_qk = qk_nodes[-3] div_qk = qk_nodes[-2] matmul_qk = qk_nodes[-1] if i == 1: add_qk = qk_nodes[1] _, input_mask_nodes, _ = self.model.match_parent_paths( add_qk, [([ 'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [1, 0, 1, 0, 0, 0]), (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [1, 0, 1, 0, 0])], output_name_to_node) if input_mask_nodes is None: logger.debug( "fuse_attention: failed to match input attention mask path" ) return mask_nodes = self.model.match_parent_path( where_qk, ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [ 0, 0, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return div_mask = mask_nodes[-1] slice_mask = mask_nodes[2] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return # Validate that the mask data is either lower triangular (unidirectional) or all ones mask_data = numpy_helper.to_array( self.model.get_initializer(slice_mask.input[0])) if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3]): logger.debug( "fuse_attention: skip since mask shape is not 1x1xWxW") return if np.allclose(mask_data, np.ones_like(mask_data)): is_unidirectional = False elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))): logger.debug( "fuse_attention: skip since mask is neither lower triangular nor ones" ) return q_nodes = self.model.match_parent_path( matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return (transpose_q, reshape_q, split_q) = q_nodes if split_v != split_q: logger.debug("fuse_attention: skip since split_v != split_q") return k_nodes = self.model.match_parent_path( matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return (concat_k, transpose_k, reshape_k, split_k) = k_nodes if split_v != split_k: logger.debug("fuse_attention: skip since split_v != split_k") return # concat_k <-- Transpose (perm=0,1,3,2) <-- Gather(axes=0, indices=0) <-- past # | # Transpose (perm=0,1,3,2) # | # unsqueeze # | # concat --> present past_k_nodes = self.model.match_parent_path(concat_k, ['Transpose', 'Gather'], [0, 0]) if past_k_nodes is None: logger.debug("fuse_attention: failed to match past_k_nodes path") return gather_past_k = past_k_nodes[-1] if not self.model.find_constant_input(gather_past_k, 0) == 1: logger.info("expect indices=0 for Gather k of past") return past_k = gather_past_k.input[0] if past != past_k: logger.info("expect past to be same") return attention_mask_input_name = '' if input_mask_nodes is not None: input_name = input_mask_nodes[-1].input[0] if input_name in self.casted_attention_mask: attention_mask_input_name = self.casted_attention_mask[ input_name] elif self.model.find_graph_input(input_name): casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name else: attention_mask_input_name, cast_node = self.utils.cast_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0], reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] self.prune_graph = True
class FusionGptAttention(Fusion): """ Fuse GPT-2 Attention with past state subgraph into one Attention node. This does not support attention_mask graph input right now. """ def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") # TODO: detect num_heads from graph like FusionAttention self.num_heads = num_heads self.utils = FusionUtils(model) self.casted_attention_mask = { } # map from name of attention mask to the name that casted to int32 def create_attention_node(self, fc_weight, fc_bias, gemm_qkv, past, present, input, output, mask, is_unidirectional): attention_node_name = self.model.create_node_name('GptAttention') attention_node = helper.make_node( 'Attention', inputs=[input, fc_weight, fc_bias, mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" attention_node.attribute.extend([ helper.make_attribute("num_heads", self.num_heads), helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) ]) matmul_node = helper.make_node( 'MatMul', inputs=[attention_node_name + "_output", gemm_qkv.input[1]], outputs=[attention_node_name + "_matmul_output"], name=attention_node_name + "_matmul") add_node = helper.make_node( 'Add', inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], outputs=[output], name=attention_node_name + "_add") self.nodes_to_add.extend([attention_node, matmul_node, add_node]) self.node_name_to_graph_name[ attention_node.name] = self.this_graph_name self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name self.node_name_to_graph_name[add_node.name] = self.this_graph_name def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node): # Pattern 1: # {past} # / \ # / \ # Gather(axes=0, indices=0) Gather(indices=1) # | | # Transpose (perm=0,1,3,2) | # | | # Concat_k Concat_v # | / # Transpose (perm=0,1,3,2) / # | / # Unsqueeze Unsqueeze # \ / # \ / # Concat # | # {present} gather = self.model.get_parent(concat_v, 0, output_name_to_node) if gather.op_type != 'Gather': logger.debug("match_past_pattern_1: expect Gather for past") return None if not self.model.find_constant_input(gather, 1) == 1: logger.debug( "match_past_pattern_1: expect indices=1 for Gather of past") return None past = gather.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Transpose', 'Gather'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_1: failed match Transpose and Gather") return None gather_past_k = past_k_nodes[-1] if not self.model.find_constant_input(gather_past_k, 0) == 1: logger.debug( "match_past_pattern_1: expect indices=0 for Gather k of past") return None past_k = gather_past_k.input[0] if past != past_k: logger.debug("match_past_pattern_1: expect past to be same") return None return past def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): # Pattern 2: # Split (QKV) # / | | # / | +----------------------+ # | | # | {past} | # | | | # Reshape Split Reshape # | / \ | # Transpose_k Squeeze Squeeze Transpose_v # | | \ / # +------|---+ \ / # | | \ / # Concat_k Concat_v # | | # Unsqueeze Unsqueeze # \ / # Concat # | # {present} # squeeze = self.model.get_parent(concat_v, 0, output_name_to_node) if squeeze.op_type != 'Squeeze': logger.debug( "match_past_pattern_2: expect Squeeze as parent of concat_v") return None split = self.model.get_parent(squeeze, 0, output_name_to_node) if split.op_type != "Split": logger.debug("match_past_pattern_2: expect Split for past path") return None opset_version = self.model.get_opset_version() if opset_version < 13: if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not FusionUtils.check_node_attribute(split, 'split', [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None else: if not self.utils.check_node_input_value(squeeze, 1, [0]): logger.debug( "match_past_pattern_2: axes != [0] for Squeeze in past path" ) return None if not self.utils.check_node_input_value(split, 1, [1, 1]): logger.debug( "match_past_pattern_2: split != [1, 1] for Split in past path" ) return None if not FusionUtils.check_node_attribute( split, 'axis', 0, default_value=0): logger.debug( "match_past_pattern_2: attribute axis of Split are not expected in past path" ) return None past = split.input[0] past_k_nodes = self.model.match_parent_path(concat_k, ['Squeeze', 'Split'], [0, 0]) if past_k_nodes is None: logger.debug( "match_past_pattern_2: failed to match past_k_nodes path") return None past_k = past_k_nodes[-1].input[0] if past != past_k: logger.info("match_past_pattern_2: expect past to be same") return None return past def match_present(self, concat_v, input_name_to_nodes): unsqueeze_present_v = self.model.find_first_child_by_type( concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return None concat_present = self.model.find_first_child_by_type( unsqueeze_present_v, 'Concat', input_name_to_nodes, recursive=False) if not concat_present: logger.info("expect concat for present") return None present = concat_present.output[0] return present def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past = None present = None return_indice = [] qkv_nodes = self.model.match_parent_path( normalize_node, ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'], [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice ) # yapf: disable if qkv_nodes is None: return (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv, matmul_qkv) = qkv_nodes another_input = add_qkv.input[1 - return_indice[0]] v_nodes = self.model.match_parent_path( matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (concat_v, transpose_v, reshape_v, split_fc) = v_nodes fc_nodes = self.model.match_parent_path( split_fc, ['Reshape', 'Gemm', 'Reshape', 'LayerNormalization'], [0, 0, 0, 0], output_name_to_node) if fc_nodes is None: fc_nodes = self.model.match_parent_path( split_fc, ['Add', 'MatMul', 'LayerNormalization'], [0, None, 0], output_name_to_node) if fc_nodes is None: logger.debug("fuse_attention: failed to match fc path") return fc_weight = fc_nodes[1].input[1] i, _ = self.model.get_constant_input(fc_nodes[0]) fc_bias = fc_nodes[0].input[i] else: fc_weight = fc_nodes[1].input[1] fc_bias = fc_nodes[1].input[2] layernorm_before_attention = fc_nodes[-1] if not another_input in layernorm_before_attention.input: logger.debug( "Add and LayerNormalization shall have one same input") return is_unidirectional = True slice_mask = None input_mask_nodes = None concat_k_to_match = None qk_nodes = self.model.match_parent_path( matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( sub_qk, ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: logger.debug( "fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] slice_mask = mask_nodes[3] if div_qk != div_mask: logger.debug("fuse_attention: skip since div_qk != div_mask") return else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. i, qk_nodes, _ = self.model.match_parent_paths( matmul_qkv, [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]), (['Softmax', 'Add', 'Where', 'Div', 'MatMul' ], [0, 0, None, 1, 0])], output_name_to_node) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk nodes") return where_qk = qk_nodes[-3] div_qk = qk_nodes[-2] matmul_qk = qk_nodes[-1] if i == 1: add_qk = qk_nodes[1] _, input_mask_nodes, _ = self.model.match_parent_paths( add_qk, [([ 'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [None, 0, 1, 0, 0, 0]), (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape' ], [None, 0, 1, 0, 0])], output_name_to_node) if input_mask_nodes is None: logger.debug( "fuse_attention: failed to match input attention mask path" ) return mask_nodes = self.model.match_parent_path( where_qk, ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape'], [ 0, 0, 0, 1, 0, 0, 0, 0], output_name_to_node) # yapf: disable if mask_nodes is None: # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep. logger.debug("fuse_attention: failed to match mask path") return slice_mask = mask_nodes[2] div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node) if div_or_concat.op_type == "Div": div_mask = div_or_concat if div_qk != div_mask: logger.debug( "fuse_attention: skip since div_qk != div_mask") return elif div_or_concat.op_type == "Concat": concat_k_to_match = div_or_concat else: logger.debug("fuse_attention: failed to match mask path") # Validate that the mask data is either lower triangular (unidirectional) or all ones mask_data = numpy_helper.to_array( self.model.get_initializer(slice_mask.input[0])) if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3]): logger.debug( "fuse_attention: skip since mask shape is not 1x1xWxW") return if np.allclose(mask_data, np.ones_like(mask_data)): is_unidirectional = False elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))): logger.debug( "fuse_attention: skip since mask is neither lower triangular nor ones" ) return q_nodes = self.model.match_parent_path( matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return (transpose_q, reshape_q, split_q) = q_nodes if split_fc != split_q: logger.debug("fuse_attention: skip since split_fc != split_q") return k_nodes = self.model.match_parent_path( matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) if k_nodes is None: # This pattern is from pytorch 1.7.1 and transformers 4.6.1 k_nodes = self.model.match_parent_path( matmul_qk, ['Transpose', 'Concat', 'Transpose', 'Reshape', 'Split'], [1, 0, 1, 0, 0]) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return else: (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes else: (concat_k, transpose_k, reshape_k, split_k) = k_nodes if split_fc != split_k: logger.debug("fuse_attention: skip since split_fc != split_k") return if concat_k_to_match and concat_k != concat_k_to_match: logger.debug( "fuse_attention: skip since concat_k != concat_k_to_match") return attention_mask_input_name = '' if input_mask_nodes is not None: input_name = input_mask_nodes[-1].input[0] if input_name in self.casted_attention_mask: attention_mask_input_name = self.casted_attention_mask[ input_name] elif self.model.find_graph_input(input_name): casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name else: attention_mask_input_name, cast_node = self.utils.cast_input_to_int32( input_name) self.casted_attention_mask[ input_name] = attention_mask_input_name # Match past and present paths past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or \ self.match_past_pattern_2(concat_k, concat_v, output_name_to_node) if past is None: logger.info("fuse_attention: failed to match past path") return if not self.model.find_graph_input(past): logger.debug("past is not graph input.") # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input. present = self.match_present(concat_v, input_name_to_nodes) if present is None: logger.info("fuse_attention: failed to match present path") return if not self.model.find_graph_output(present): logger.info("expect present to be graph output") return self.create_attention_node(fc_weight, fc_bias, gemm_qkv, past, present, layernorm_before_attention.output[0], reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] self.prune_graph = True
class FusionEmbedLayerNoMask(Fusion): """ Embed Layer Normalization will fuse embeddings and mask processing into one node. The embeddings before conversion: (input_ids) --------> Gather ----------+ (segment_ids) | | | | v v +--> Shape --> Expand -> Gather---->Add Gather | ^ | | | | v v +---(optional graph) SkipLayerNormalization Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model. (input_ids) --> Gather -----+ Slice | | v v (segment_ids)--> Gather --->Add Reshape | | v v SkipLayerNormalization """ def __init__(self, model: OnnxModel, description='no mask'): super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description) self.utils = FusionUtils(model) self.attention = None def match_segment_path(self, normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node): segment_ids = None segment_embedding_gather = None segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1]) if segment_embedding_path is None: segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1]) if segment_embedding_path is None: logger.info("Segment embedding is not found. Embed layer cannot be fused.") return _, segment_embedding_gather = segment_embedding_path else: segment_embedding_gather = segment_embedding_path[0] segment_ids = segment_embedding_gather.input[1] self.nodes_to_remove.extend(segment_embedding_path) if self.model.find_graph_input(segment_ids): casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids) else: segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids) # Cast might be removed by OnnxRuntime. _, segment_id_path, _ = self.model.match_parent_paths( segment_ids_cast_node, [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]), (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node) if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]: logger.debug("Simplify semgent id path...") self.model.add_node( helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"])) self.model.add_node( helper.make_node('ConstantOfShape', inputs=["input_shape"], outputs=["zeros_for_input_shape"], value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1]))) segment_ids = "zeros_for_input_shape" return segment_ids, segment_embedding_gather def fuse(self, node, input_name_to_nodes, output_name_to_node): is_distill = False; if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None and self.model.match_parent_path(node, ['Gather'], [0]) is None: logger.debug("Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather") return self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) if self.attention is None: # In case user disables attention fusion, check whether subgraph looks like Attention. if node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[node.output[0]] children_types = sorted([child.op_type for child in children]) if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization'] and children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape', 'SkipLayerNormalization']: logger.debug("No Attention like subgraph in children of SkipLayerNormalization") return # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding normalize_node = node add_node = None word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0]) if word_embedding_path is not None: add_node, word_embedding_gather = word_embedding_path else: word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0]) if word_embedding_path is not None: word_embedding_gather = word_embedding_path[0] is_distill = True; else: logger.info("Word embedding path is not found. Embed layer cannot be fused.") return input_ids = word_embedding_gather.input[1] position_embedding_expand = None position_embedding_shape = None position_embedding_path = self.model.match_parent_path(normalize_node, ['Gather', 'Expand'], [1, 1]) # for distill-bert if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand = position_embedding_path else: position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0]) if position_embedding_path is not None: _, position_embedding_weight_node = position_embedding_path else: position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path else: position_embedding_path = self.model.match_parent_path( add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path else: # Here we will not try to get exact match. Instead, we only try identify position embedding weights. position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1]) if position_embedding_path is not None: position_embedding_weight_node, position_embedding_expand = position_embedding_path else: logger.info("Position embedding path is not found. Embed layer cannot be fused.") return if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids: logger.info("position and word embedding is expected to be applied on same input") return if position_embedding_expand and position_embedding_shape: input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node) subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand, [input_parent] if input_parent else [], output_name_to_node) self.nodes_to_remove.extend(subgraph_nodes) self.nodes_to_remove.extend(word_embedding_path) self.nodes_to_remove.extend(position_embedding_path) self.nodes_to_remove.extend([normalize_node]) # Cast input_ids and segment_ids to int32. input_ids_cast_node = None if self.model.find_graph_input(input_ids): casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids) else: input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids) node_name = self.model.create_node_name('EmbedLayerNormalization') output_name = node_name + "_output" embed_node_inputs = None if is_distill == False: segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node) if segment_path is None: return else: from packaging.version import Version import onnxruntime if Version(onnxruntime.__version__) <= Version("1.4.0"): logger.warning('Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert') return segment_ids, segment_embedding_gather = segment_path embed_node_inputs=[ input_ids, segment_ids, word_embedding_gather.input[0], position_embedding_weight_node.input[0], segment_embedding_gather.input[0], normalize_node.input[2], normalize_node.input[3] # gamma and beta ] else: embed_node_inputs=[ input_ids, '', word_embedding_gather.input[0], position_embedding_weight_node.input[0], '', normalize_node.input[2], normalize_node.input[3] # gamma and beta ] embed_node = helper.make_node( 'EmbedLayerNormalization', embed_node_inputs, outputs=[node_name + "_output", node_name + "_dummy_mask_index"], name=node_name) embed_node.domain = "com.microsoft" # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization. for att in normalize_node.attribute: if att.name == 'epsilon': embed_node.attribute.extend([att]) # Set default value to 1e-12 if no attribute is found. # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later. if len(embed_node.attribute) == 0: embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)]) self.model.replace_input_of_all_nodes(normalize_node.output[0], output_name) self.nodes_to_add.append(embed_node)