示例#1
0
class AttentionMask():
    """
    Fuse Attention subgraph into one Attention node.
    """
    def __init__(self, model: OnnxModel):
        self.model = model
        # A lookup table with mask input as key, and mask index output as value
        self.mask_indice = {}
        # A lookup table with mask input as key, and cast (to int32) output as value
        self.mask_casted = {}
        self.utils = FusionUtils(model)
        self.mask_format = AttentionMaskFormat.MaskIndexEnd

    def set_mask_format(self, mask_format: AttentionMaskFormat):
        self.mask_format = mask_format

    def set_mask_indice(self, mask, mask_index):
        if mask in self.mask_indice:
            assert mask_index == self.mask_indice[mask]
        self.mask_indice[mask] = mask_index

    def get_first_mask(self):
        assert len(self.mask_indice) > 0
        return next(iter(self.mask_indice))

    def process_mask(self, input):
        if input in self.mask_indice:
            return self.mask_indice[input]

        # Add cast to convert int64 to int32
        if self.model.find_graph_input(input):
            casted, input_name = self.utils.cast_graph_input_to_int32(input)
        else:
            input_name, cast_node = self.utils.cast_input_to_int32(input)
            casted = True

        if casted:
            self.mask_casted[input] = input_name

        # Attention supports int32 attention mask (2D) since 1.4.0
        if self.mask_format == AttentionMaskFormat.AttentionMask:
            self.mask_indice[input] = input_name
            return input_name

        # Add a mask processing node to convert attention mask to mask index (1D)
        output_name = self.model.create_node_name('mask_index')
        mask_index_node = helper.make_node('ReduceSum',
                                           inputs=[input_name],
                                           outputs=[output_name],
                                           name=self.model.create_node_name(
                                               'ReduceSum', 'MaskReduceSum'))
        mask_index_node.attribute.extend([
            helper.make_attribute("axes", [1]),
            helper.make_attribute("keepdims", 0)
        ])
        self.model.add_node(mask_index_node)

        self.mask_indice[input] = output_name
        return output_name
示例#2
0
class FusionGptAttentionPastBase(Fusion):
    """Base class for GPT Attention Fusion with past state
    """
    def __init__(self, model: OnnxModel, num_heads: int):
        super().__init__(model, "Attention", "LayerNormalization", "with past")
        self.num_heads = num_heads
        self.utils = FusionUtils(model)
        self.casted_attention_mask = {
        }  # map from name of attention mask to the name that casted to int32

    def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
        # Pattern 1:
        #                      {past}
        #                    /        \
        #                   /          \
        #    Gather(axes=0, indices=0)  Gather(indices=1)
        #      |                          |
        #    Transpose (perm=0,1,3,2)     |
        #      |                          |
        #  Concat_k                     Concat_v
        #      |                        /
        #  Transpose (perm=0,1,3,2)    /
        #      |                      /
        #  Unsqueeze        Unsqueeze
        #        \        /
        #         \      /
        #           Concat
        #             |
        #         {present}
        gather = self.model.get_parent(concat_v, 0, output_name_to_node)
        if gather.op_type != 'Gather':
            logger.debug("match_past_pattern_1: expect Gather for past")
            return None

        if not self.model.find_constant_input(gather, 1) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=1 for Gather of past")
            return None
        past = gather.input[0]

        parent = self.model.get_parent(concat_k, 0, output_name_to_node)
        if parent.op_type == 'Gather':
            gather_past_k = parent
        else:
            past_k_nodes = self.model.match_parent_path(
                concat_k, ['Transpose', 'Gather'], [0, 0])
            if past_k_nodes is None:
                logger.debug(
                    "match_past_pattern_1: failed match Transpose and Gather")
                return None
            gather_past_k = past_k_nodes[-1]

        if not self.model.find_constant_input(gather_past_k, 0) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=0 for Gather k of past")
            return None
        past_k = gather_past_k.input[0]
        if past != past_k:
            logger.debug("match_past_pattern_1: expect past to be same")
            return None

        return past

    def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
        # Pattern 2:
        #      Split (QKV)
        #      / |   |
        #     /  |   +----------------------+
        #        |                          |
        #        |         {past}           |
        #        |           |              |
        #      Reshape     Split         Reshape
        #        |         /    \           |
        # Transpose_k  Squeeze  Squeeze  Transpose_v
        #        |      |        \        /
        #        +------|---+     \      /
        #               |   |      \    /
        #              Concat_k   Concat_v
        #               |            |
        #          Unsqueeze    Unsqueeze
        #                \       /
        #                 Concat
        #                   |
        #               {present}
        #
        squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
        if squeeze.op_type != 'Squeeze':
            logger.debug(
                "match_past_pattern_2: expect Squeeze as parent of concat_v")
            return None

        split = self.model.get_parent(squeeze, 0, output_name_to_node)
        if split.op_type != "Split":
            logger.debug("match_past_pattern_2: expect Split for past path")
            return None

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not FusionUtils.check_node_attribute(split, 'split', [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None
        else:
            if not self.utils.check_node_input_value(squeeze, 1, [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not self.utils.check_node_input_value(split, 1, [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None

        if not FusionUtils.check_node_attribute(
                split, 'axis', 0, default_value=0):
            logger.debug(
                "match_past_pattern_2: attribute axis of Split are not expected in past path"
            )
            return None
        past = split.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Squeeze', 'Split'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_2: failed to match past_k_nodes path")
            return None
        past_k = past_k_nodes[-1].input[0]

        if past != past_k:
            logger.info("match_past_pattern_2: expect past to be same")
            return None

        return past

    def match_present(self, concat_v, input_name_to_nodes):
        unsqueeze_present_v = self.model.find_first_child_by_type(
            concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False)
        if not unsqueeze_present_v:
            logger.info("expect unsqueeze for present")
            return None
        concat_present = self.model.find_first_child_by_type(
            unsqueeze_present_v,
            'Concat',
            input_name_to_nodes,
            recursive=False)
        if not concat_present:
            logger.info("expect concat for present")
            return None

        present = concat_present.output[0]
        return present

    def cast_attention_mask(self, input_name):
        if input_name in self.casted_attention_mask:
            attention_mask_input_name = self.casted_attention_mask[input_name]
        elif self.model.find_graph_input(input_name):
            casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(
                input_name)
            self.casted_attention_mask[input_name] = attention_mask_input_name
        else:
            attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(
                input_name)
            self.casted_attention_mask[input_name] = attention_mask_input_name
        return attention_mask_input_name
示例#3
0
class FusionEmbedLayerNoMask(Fusion):
    """
     Embed Layer Normalization will fuse embeddings and mask processing into one node.
     The embeddings before conversion:

     (input_ids) -------->  Gather ----------+       (segment_ids)
        |                                    |            |
        |                                    v            v
        +--> Shape --> Expand -> Gather---->Add         Gather
        |                ^                   |            |
        |                |                   v            v
        +---(optional graph)               SkipLayerNormalization

      Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model.

      (input_ids) --> Gather -----+           Slice
                                  |            |
                                  v            v
     (segment_ids)--> Gather --->Add        Reshape
                                  |            |
                                  v            v
                              SkipLayerNormalization
    """
    def __init__(self,
                 model: OnnxModel,
                 description='no mask'):
        super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
        self.utils = FusionUtils(model)

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        # already fused. Assumes that only one mebedding layer in a transformer model.
        if self.nodes_to_add:
            return

        if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None:
            return

        if self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) is None:
            # In case user disables attention fusion, check whether subgraph looks like Attention.
            if node.output[0] not in input_name_to_nodes:
                return
            children = input_name_to_nodes[node.output[0]]
            children_types = sorted([child.op_type for child in children])
            if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']:
                return

        # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding
        normalize_node = node
        word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
        if word_embedding_path is None:
            logger.info("Failed to find word embedding")
            return
        add_node, word_embedding_gather = word_embedding_path
        input_ids = word_embedding_gather.input[1]

        position_embedding_expand = None
        position_embedding_shape = None

        position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0])
        if position_embedding_path is not None:
            _, position_embedding_weight_node = position_embedding_path
        else:
            position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
            if position_embedding_path is not None:
                position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path
            else:
                position_embedding_path = self.model.match_parent_path(
                    add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
                if position_embedding_path is not None:
                    position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path
                else:
                    # Here we will not try to get exact match. Instead, we only try identify position embedding weights.
                    position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1])
                    if position_embedding_path is not None:
                        position_embedding_weight_node, position_embedding_expand = position_embedding_path
                    else:
                        logger.info("Failed to find position embedding")
                        return

            if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
                logger.info("position and word embedding is expected to be applied on same input")
                return

        segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
        if segment_embedding_path is None:
            segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
            if segment_embedding_path is None:
                logger.info("Failed to find segment embedding")
                return
            _, segment_embedding_gather = segment_embedding_path
        else:
            segment_embedding_gather = segment_embedding_path[0]

        segment_ids = segment_embedding_gather.input[1]

        if position_embedding_expand and position_embedding_shape:
            input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node)
            subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand,
                                                                  [input_parent] if input_parent else [],
                                                                  output_name_to_node)
            self.nodes_to_remove.extend(subgraph_nodes)

        self.nodes_to_remove.extend(word_embedding_path)
        self.nodes_to_remove.extend(position_embedding_path)
        self.nodes_to_remove.extend(segment_embedding_path)

        self.nodes_to_remove.extend([normalize_node])

        # store inputs for further processing
        if self.model.find_graph_input(input_ids):
            self.model.bert_inputs = [input_ids, segment_ids
                                      ] if self.model.find_graph_input(segment_ids) else [input_ids]

        # Cast input_ids and segment_ids to int32.
        input_ids_cast_node = None
        if self.model.find_graph_input(input_ids):
            casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
        else:
            input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids)

        if self.model.find_graph_input(segment_ids):
            casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
        else:
            segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids)

            # Cast might be removed by OnnxRuntime.
            _, segment_id_path, _ = self.model.match_parent_paths(
                segment_ids_cast_node,
                [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]),
                 (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node)

            if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]:
                logger.debug("Simplify semgent id path...")
                self.model.add_node(
                    helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"]))
                self.model.add_node(
                    helper.make_node('ConstantOfShape',
                                     inputs=["input_shape"],
                                     outputs=["zeros_for_input_shape"],
                                     value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1])))
                segment_ids = "zeros_for_input_shape"

        embed_node = helper.make_node(
            'EmbedLayerNormalization',
            inputs=[
                input_ids,
                segment_ids,
                word_embedding_gather.input[0],
                position_embedding_weight_node.input[0],
                segment_embedding_gather.input[0],
                normalize_node.input[2],
                normalize_node.input[3]  # gamma and beta
            ],
            outputs=["embed_output", "dummy_mask_index"],
            name="EmbedLayer")

        embed_node.domain = "com.microsoft"

        # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
        for att in normalize_node.attribute:
            if att.name == 'epsilon':
                embed_node.attribute.extend([att])
        # Set default value to 1e-12 if no attribute is found.
        if len(embed_node.attribute) == 0:
            embed_node.attribute.extend([onnx.helper.make_attribute("epsilon", 1.0E-12)])

        self.model.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output')
        self.nodes_to_add.append(embed_node)
示例#4
0
class FusionGptAttention(Fusion):
    """
    Fuse GPT-2 Attention with past state subgraph into one Attention node.
    This does not support attention_mask graph input right now.
    """
    def __init__(self, model: OnnxModel, num_heads: int):
        super().__init__(model, "Attention", "LayerNormalization", "with past")
        # TODO: detect num_heads from graph like FusionAttention
        self.num_heads = num_heads
        self.utils = FusionUtils(model)
        self.casted_attention_mask = {
        }  # map from name of attention mask to the name that casted to int32

    def create_attention_node(self, gemm, gemm_qkv, past, present, input,
                              output, mask, is_unidirectional):
        attention_node_name = self.model.create_node_name('GptAttention')
        attention_node = helper.make_node(
            'Attention',
            inputs=[input, gemm.input[1], gemm.input[2], mask, past],
            outputs=[attention_node_name + "_output", present],
            name=attention_node_name)
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([
            helper.make_attribute("num_heads", self.num_heads),
            helper.make_attribute("unidirectional",
                                  1 if is_unidirectional else 0)
        ])

        matmul_node = helper.make_node(
            'MatMul',
            inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
            outputs=[attention_node_name + "_matmul_output"],
            name=attention_node_name + "_matmul")

        add_node = helper.make_node(
            'Add',
            inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
            outputs=[output],
            name=attention_node_name + "_add")
        self.nodes_to_add.extend([attention_node, matmul_node, add_node])

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        past = None
        present = None
        return_indice = []
        qkv_nodes = self.model.match_parent_path(
            normalize_node,
            ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'],
            [0,      None,      0,     0,          0,         0,           0],
            output_name_to_node=output_name_to_node,
            return_indice=return_indice
            ) # yapf: disable
        if qkv_nodes is None:
            return
        (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv,
         matmul_qkv) = qkv_nodes

        another_input = add_qkv.input[1 - return_indice[0]]

        v_nodes = self.model.match_parent_path(
            matmul_qkv,
            ['Concat', 'Transpose', 'Reshape', 'Split', 'Reshape', 'Gemm', 'Reshape'],
            [1,        1,            0,         0,       0,         0,      0]) # yapf: disable
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return
        (concat_v, transpose_v, reshape_v, split_v, reshape_after_gemm, gemm,
         reshape_before_gemm) = v_nodes

        #      concat <-- Gather(indices=1) <-- past
        #        |
        #      unsqueeze
        #        |
        #    concat  -->  present
        gather_v = self.model.get_parent(concat_v, 0, output_name_to_node)
        if gather_v.op_type != 'Gather':
            logger.info("expect Gather for past")
            return
        if not self.model.find_constant_input(gather_v, 1) == 1:
            logger.info("expect indices=1 for Gather of past")
            return
        past = gather_v.input[0]
        if not self.model.find_graph_input(past):
            logger.info("expect past to be graph input")
            return
        unsqueeze_present_v = self.model.find_first_child_by_type(
            concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False)
        if not unsqueeze_present_v:
            logger.info("expect unsqueeze for present")
            return
        concat_present = self.model.find_first_child_by_type(
            unsqueeze_present_v,
            'Concat',
            input_name_to_nodes,
            recursive=False)
        if not concat_present:
            logger.info("expect concat for present")
            return
        present = concat_present.output[0]
        if not self.model.find_graph_output(present):
            logger.info("expect present to be graph input")
            return

        layernorm_before_attention = self.model.get_parent(
            reshape_before_gemm, 0, output_name_to_node)
        if layernorm_before_attention is None or layernorm_before_attention.op_type != 'LayerNormalization':
            logger.debug(
                f"failed to get layernorm before gemm. Got {layernorm_before_attention.op_type}"
            )
            return

        if not another_input in layernorm_before_attention.input:
            logger.debug(
                "Add and LayerNormalization shall have one same input")
            return

        is_unidirectional = True
        slice_mask = None
        input_mask_nodes = None
        qk_nodes = self.model.match_parent_path(
            matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'],
            [0, 0, 0, 0, 0])
        if qk_nodes is not None:
            (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
            mask_nodes = self.model.match_parent_path(
                sub_qk,
                ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
                [1,      0,     1,       0,       1,           0,     0,         0,       0,       0])  # yapf: disable
            if mask_nodes is None:
                logger.debug(
                    "fuse_attention: failed to match unidirectional mask path")
                return
            div_mask = mask_nodes[-1]
            slice_mask = mask_nodes[3]

            if div_qk != div_mask:
                logger.debug("fuse_attention: skip since div_qk != div_mask")
                return
        else:
            # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
            i, qk_nodes, _ = self.model.match_parent_paths(
                matmul_qkv,
                [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]),
                 (['Softmax', 'Add', 'Where', 'Div', 'MatMul'
                   ], [0, 0, 0, 1, 0])], output_name_to_node)
            if qk_nodes is None:
                logger.debug("fuse_attention: failed to match qk nodes")
                return

            where_qk = qk_nodes[-3]
            div_qk = qk_nodes[-2]
            matmul_qk = qk_nodes[-1]

            if i == 1:
                add_qk = qk_nodes[1]
                _, input_mask_nodes, _ = self.model.match_parent_paths(
                    add_qk,
                    [([
                        'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze',
                        'Reshape'
                    ], [1, 0, 1, 0, 0, 0]),
                     (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'
                       ], [1, 0, 1, 0, 0])], output_name_to_node)
                if input_mask_nodes is None:
                    logger.debug(
                        "fuse_attention: failed to match input attention mask path"
                    )
                    return

            mask_nodes = self.model.match_parent_path(
                where_qk,
                ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
                [ 0,     0,       0,       1,           0,     0,         0,       0,       0])  # yapf: disable
            if mask_nodes is None:
                logger.debug("fuse_attention: failed to match mask path")
                return
            div_mask = mask_nodes[-1]
            slice_mask = mask_nodes[2]

            if div_qk != div_mask:
                logger.debug("fuse_attention: skip since div_qk != div_mask")
                return

        # Validate that the mask data is either lower triangular (unidirectional) or all ones
        mask_data = numpy_helper.to_array(
            self.model.get_initializer(slice_mask.input[0]))
        if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1)
                and mask_data.shape[2] == mask_data.shape[3]):
            logger.debug(
                "fuse_attention: skip since mask shape is not 1x1xWxW")
            return
        if np.allclose(mask_data, np.ones_like(mask_data)):
            is_unidirectional = False
        elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
            logger.debug(
                "fuse_attention: skip since mask is neither lower triangular nor ones"
            )
            return

        q_nodes = self.model.match_parent_path(
            matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return
        (transpose_q, reshape_q, split_q) = q_nodes
        if split_v != split_q:
            logger.debug("fuse_attention: skip since split_v != split_q")
            return

        k_nodes = self.model.match_parent_path(
            matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'],
            [1, 1, 0, 0])
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match k path")
            return
        (concat_k, transpose_k, reshape_k, split_k) = k_nodes
        if split_v != split_k:
            logger.debug("fuse_attention: skip since split_v != split_k")
            return

        #     concat_k <-- Transpose (perm=0,1,3,2) <-- Gather(axes=0, indices=0) <-- past
        #        |
        #     Transpose (perm=0,1,3,2)
        #        |
        #      unsqueeze
        #        |
        #    concat  -->  present
        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Transpose', 'Gather'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug("fuse_attention: failed to match past_k_nodes path")
            return

        gather_past_k = past_k_nodes[-1]
        if not self.model.find_constant_input(gather_past_k, 0) == 1:
            logger.info("expect indices=0 for Gather k of past")
            return
        past_k = gather_past_k.input[0]
        if past != past_k:
            logger.info("expect past to be same")
            return

        attention_mask_input_name = ''
        if input_mask_nodes is not None:
            input_name = input_mask_nodes[-1].input[0]
            if input_name in self.casted_attention_mask:
                attention_mask_input_name = self.casted_attention_mask[
                    input_name]
            elif self.model.find_graph_input(input_name):
                casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name
            else:
                attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name

        self.create_attention_node(gemm, gemm_qkv, past, present,
                                   layernorm_before_attention.output[0],
                                   reshape_qkv.output[0],
                                   attention_mask_input_name,
                                   is_unidirectional)

        # we rely on prune_graph() to clean old subgraph nodes:
        # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
        self.prune_graph = True
示例#5
0
class FusionGptAttention(Fusion):
    """
    Fuse GPT-2 Attention with past state subgraph into one Attention node.
    This does not support attention_mask graph input right now.
    """
    def __init__(self, model: OnnxModel, num_heads: int):
        super().__init__(model, "Attention", "LayerNormalization", "with past")
        # TODO: detect num_heads from graph like FusionAttention
        self.num_heads = num_heads
        self.utils = FusionUtils(model)
        self.casted_attention_mask = {
        }  # map from name of attention mask to the name that casted to int32

    def create_attention_node(self, fc_weight, fc_bias, gemm_qkv, past,
                              present, input, output, mask, is_unidirectional):
        attention_node_name = self.model.create_node_name('GptAttention')
        attention_node = helper.make_node(
            'Attention',
            inputs=[input, fc_weight, fc_bias, mask, past],
            outputs=[attention_node_name + "_output", present],
            name=attention_node_name)
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([
            helper.make_attribute("num_heads", self.num_heads),
            helper.make_attribute("unidirectional",
                                  1 if is_unidirectional else 0)
        ])

        matmul_node = helper.make_node(
            'MatMul',
            inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
            outputs=[attention_node_name + "_matmul_output"],
            name=attention_node_name + "_matmul")

        add_node = helper.make_node(
            'Add',
            inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
            outputs=[output],
            name=attention_node_name + "_add")
        self.nodes_to_add.extend([attention_node, matmul_node, add_node])
        self.node_name_to_graph_name[
            attention_node.name] = self.this_graph_name
        self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
        self.node_name_to_graph_name[add_node.name] = self.this_graph_name

    def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
        # Pattern 1:
        #                      {past}
        #                    /        \
        #                   /          \
        #    Gather(axes=0, indices=0)  Gather(indices=1)
        #      |                          |
        #    Transpose (perm=0,1,3,2)     |
        #      |                          |
        #  Concat_k                     Concat_v
        #      |                        /
        #  Transpose (perm=0,1,3,2)    /
        #      |                      /
        #  Unsqueeze        Unsqueeze
        #        \        /
        #         \      /
        #           Concat
        #             |
        #         {present}
        gather = self.model.get_parent(concat_v, 0, output_name_to_node)
        if gather.op_type != 'Gather':
            logger.debug("match_past_pattern_1: expect Gather for past")
            return None

        if not self.model.find_constant_input(gather, 1) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=1 for Gather of past")
            return None
        past = gather.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Transpose', 'Gather'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_1: failed match Transpose and Gather")
            return None

        gather_past_k = past_k_nodes[-1]
        if not self.model.find_constant_input(gather_past_k, 0) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=0 for Gather k of past")
            return None
        past_k = gather_past_k.input[0]
        if past != past_k:
            logger.debug("match_past_pattern_1: expect past to be same")
            return None

        return past

    def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
        # Pattern 2:
        #      Split (QKV)
        #      / |   |
        #     /  |   +----------------------+
        #        |                          |
        #        |         {past}           |
        #        |           |              |
        #      Reshape     Split         Reshape
        #        |         /    \           |
        # Transpose_k  Squeeze  Squeeze  Transpose_v
        #        |      |        \        /
        #        +------|---+     \      /
        #               |   |      \    /
        #              Concat_k   Concat_v
        #               |            |
        #          Unsqueeze    Unsqueeze
        #                \       /
        #                 Concat
        #                   |
        #               {present}
        #
        squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
        if squeeze.op_type != 'Squeeze':
            logger.debug(
                "match_past_pattern_2: expect Squeeze as parent of concat_v")
            return None

        split = self.model.get_parent(squeeze, 0, output_name_to_node)
        if split.op_type != "Split":
            logger.debug("match_past_pattern_2: expect Split for past path")
            return None

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not FusionUtils.check_node_attribute(split, 'split', [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None
        else:
            if not self.utils.check_node_input_value(squeeze, 1, [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not self.utils.check_node_input_value(split, 1, [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None

        if not FusionUtils.check_node_attribute(
                split, 'axis', 0, default_value=0):
            logger.debug(
                "match_past_pattern_2: attribute axis of Split are not expected in past path"
            )
            return None
        past = split.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Squeeze', 'Split'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_2: failed to match past_k_nodes path")
            return None
        past_k = past_k_nodes[-1].input[0]

        if past != past_k:
            logger.info("match_past_pattern_2: expect past to be same")
            return None

        return past

    def match_present(self, concat_v, input_name_to_nodes):
        unsqueeze_present_v = self.model.find_first_child_by_type(
            concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False)
        if not unsqueeze_present_v:
            logger.info("expect unsqueeze for present")
            return None
        concat_present = self.model.find_first_child_by_type(
            unsqueeze_present_v,
            'Concat',
            input_name_to_nodes,
            recursive=False)
        if not concat_present:
            logger.info("expect concat for present")
            return None

        present = concat_present.output[0]
        return present

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        past = None
        present = None
        return_indice = []
        qkv_nodes = self.model.match_parent_path(
            normalize_node,
            ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'],
            [0,      None,      0,     0,          0,         0,           0],
            output_name_to_node=output_name_to_node,
            return_indice=return_indice
            ) # yapf: disable
        if qkv_nodes is None:
            return
        (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv,
         matmul_qkv) = qkv_nodes

        another_input = add_qkv.input[1 - return_indice[0]]

        v_nodes = self.model.match_parent_path(
            matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split'],
            [1, 1, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return
        (concat_v, transpose_v, reshape_v, split_fc) = v_nodes

        fc_nodes = self.model.match_parent_path(
            split_fc, ['Reshape', 'Gemm', 'Reshape', 'LayerNormalization'],
            [0, 0, 0, 0], output_name_to_node)
        if fc_nodes is None:
            fc_nodes = self.model.match_parent_path(
                split_fc, ['Add', 'MatMul', 'LayerNormalization'],
                [0, None, 0], output_name_to_node)
            if fc_nodes is None:
                logger.debug("fuse_attention: failed to match fc path")
                return
            fc_weight = fc_nodes[1].input[1]
            i, _ = self.model.get_constant_input(fc_nodes[0])
            fc_bias = fc_nodes[0].input[i]
        else:
            fc_weight = fc_nodes[1].input[1]
            fc_bias = fc_nodes[1].input[2]

        layernorm_before_attention = fc_nodes[-1]

        if not another_input in layernorm_before_attention.input:
            logger.debug(
                "Add and LayerNormalization shall have one same input")
            return

        is_unidirectional = True
        slice_mask = None
        input_mask_nodes = None
        concat_k_to_match = None
        qk_nodes = self.model.match_parent_path(
            matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'],
            [0, 0, 0, 0, 0])
        if qk_nodes is not None:
            (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
            mask_nodes = self.model.match_parent_path(
                sub_qk,
                ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
                [1,      0,     1,       0,       1,           0,     0,         0,       0,       0])  # yapf: disable
            if mask_nodes is None:
                logger.debug(
                    "fuse_attention: failed to match unidirectional mask path")
                return
            div_mask = mask_nodes[-1]
            slice_mask = mask_nodes[3]

            if div_qk != div_mask:
                logger.debug("fuse_attention: skip since div_qk != div_mask")
                return
        else:
            # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
            i, qk_nodes, _ = self.model.match_parent_paths(
                matmul_qkv,
                [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]),
                 (['Softmax', 'Add', 'Where', 'Div', 'MatMul'
                   ], [0, 0, None, 1, 0])], output_name_to_node)
            if qk_nodes is None:
                logger.debug("fuse_attention: failed to match qk nodes")
                return

            where_qk = qk_nodes[-3]
            div_qk = qk_nodes[-2]
            matmul_qk = qk_nodes[-1]

            if i == 1:
                add_qk = qk_nodes[1]
                _, input_mask_nodes, _ = self.model.match_parent_paths(
                    add_qk,
                    [([
                        'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze',
                        'Reshape'
                    ], [None, 0, 1, 0, 0, 0]),
                     (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'
                       ], [None, 0, 1, 0, 0])], output_name_to_node)
                if input_mask_nodes is None:
                    logger.debug(
                        "fuse_attention: failed to match input attention mask path"
                    )
                    return

            mask_nodes = self.model.match_parent_path(
                where_qk,
                ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape'],
                [ 0,     0,       0,       1,           0,     0,         0,       0],
                output_name_to_node)  # yapf: disable
            if mask_nodes is None:
                # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
                logger.debug("fuse_attention: failed to match mask path")
                return

            slice_mask = mask_nodes[2]

            div_or_concat = self.model.get_parent(mask_nodes[-1], 0,
                                                  output_name_to_node)
            if div_or_concat.op_type == "Div":
                div_mask = div_or_concat
                if div_qk != div_mask:
                    logger.debug(
                        "fuse_attention: skip since div_qk != div_mask")
                    return
            elif div_or_concat.op_type == "Concat":
                concat_k_to_match = div_or_concat
            else:
                logger.debug("fuse_attention: failed to match mask path")

        # Validate that the mask data is either lower triangular (unidirectional) or all ones
        mask_data = numpy_helper.to_array(
            self.model.get_initializer(slice_mask.input[0]))
        if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1)
                and mask_data.shape[2] == mask_data.shape[3]):
            logger.debug(
                "fuse_attention: skip since mask shape is not 1x1xWxW")
            return
        if np.allclose(mask_data, np.ones_like(mask_data)):
            is_unidirectional = False
        elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
            logger.debug(
                "fuse_attention: skip since mask is neither lower triangular nor ones"
            )
            return

        q_nodes = self.model.match_parent_path(
            matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return
        (transpose_q, reshape_q, split_q) = q_nodes
        if split_fc != split_q:
            logger.debug("fuse_attention: skip since split_fc != split_q")
            return

        k_nodes = self.model.match_parent_path(
            matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'],
            [1, 1, 0, 0])
        if k_nodes is None:
            # This pattern is from pytorch 1.7.1 and transformers 4.6.1
            k_nodes = self.model.match_parent_path(
                matmul_qk,
                ['Transpose', 'Concat', 'Transpose', 'Reshape', 'Split'],
                [1, 0, 1, 0, 0])
            if k_nodes is None:
                logger.debug("fuse_attention: failed to match k path")
                return
            else:
                (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
        else:
            (concat_k, transpose_k, reshape_k, split_k) = k_nodes
        if split_fc != split_k:
            logger.debug("fuse_attention: skip since split_fc != split_k")
            return

        if concat_k_to_match and concat_k != concat_k_to_match:
            logger.debug(
                "fuse_attention: skip since concat_k != concat_k_to_match")
            return

        attention_mask_input_name = ''
        if input_mask_nodes is not None:
            input_name = input_mask_nodes[-1].input[0]
            if input_name in self.casted_attention_mask:
                attention_mask_input_name = self.casted_attention_mask[
                    input_name]
            elif self.model.find_graph_input(input_name):
                casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name
            else:
                attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name

        # Match past and present paths
        past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or \
               self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
        if past is None:
            logger.info("fuse_attention: failed to match past path")
            return
        if not self.model.find_graph_input(past):
            logger.debug("past is not graph input.")
            # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.

        present = self.match_present(concat_v, input_name_to_nodes)
        if present is None:
            logger.info("fuse_attention: failed to match present path")
            return
        if not self.model.find_graph_output(present):
            logger.info("expect present to be graph output")
            return

        self.create_attention_node(fc_weight, fc_bias, gemm_qkv, past, present,
                                   layernorm_before_attention.output[0],
                                   reshape_qkv.output[0],
                                   attention_mask_input_name,
                                   is_unidirectional)

        # we rely on prune_graph() to clean old subgraph nodes:
        # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
        self.prune_graph = True
class FusionEmbedLayerNoMask(Fusion):
    """
     Embed Layer Normalization will fuse embeddings and mask processing into one node.
     The embeddings before conversion:

     (input_ids) -------->  Gather ----------+       (segment_ids)
        |                                    |            |
        |                                    v            v
        +--> Shape --> Expand -> Gather---->Add         Gather
        |                ^                   |            |
        |                |                   v            v
        +---(optional graph)               SkipLayerNormalization

      Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model.

      (input_ids) --> Gather -----+           Slice
                                  |            |
                                  v            v
     (segment_ids)--> Gather --->Add        Reshape
                                  |            |
                                  v            v
                              SkipLayerNormalization
    """
    def __init__(self, model: OnnxModel, description='no mask'):
        super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
        self.utils = FusionUtils(model)
        self.attention = None

    def match_segment_path(self, normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node):
        segment_ids = None
        segment_embedding_gather = None

        segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])

        if segment_embedding_path is None:
            segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
            if segment_embedding_path is None:
                logger.info("Segment embedding is not found. Embed layer cannot be fused.")
                return
            _, segment_embedding_gather = segment_embedding_path
        else:
            segment_embedding_gather = segment_embedding_path[0]

        segment_ids = segment_embedding_gather.input[1]

        self.nodes_to_remove.extend(segment_embedding_path)

        if self.model.find_graph_input(segment_ids):
            casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
        else:
            segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids)

             # Cast might be removed by OnnxRuntime.
            _, segment_id_path, _ = self.model.match_parent_paths(
                 segment_ids_cast_node,
                 [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]),
                  (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node)

            if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]:
                logger.debug("Simplify semgent id path...")
                self.model.add_node(
                    helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"]))
                self.model.add_node(
                    helper.make_node('ConstantOfShape',
                                      inputs=["input_shape"],
                                      outputs=["zeros_for_input_shape"],
                                      value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1])))
                segment_ids = "zeros_for_input_shape"

        return segment_ids, segment_embedding_gather

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        is_distill = False;

        if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None and self.model.match_parent_path(node, ['Gather'], [0]) is None:
            logger.debug("Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather")
            return

        self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False)
        if self.attention is None:
            # In case user disables attention fusion, check whether subgraph looks like Attention.
            if node.output[0] not in input_name_to_nodes:
                return
            children = input_name_to_nodes[node.output[0]]
            children_types = sorted([child.op_type for child in children])
            if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization'] and children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape', 'SkipLayerNormalization']:
                logger.debug("No Attention like subgraph in children of SkipLayerNormalization")
                return

        # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding
        normalize_node = node
        add_node = None
        word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
        if word_embedding_path is not None:
            add_node, word_embedding_gather = word_embedding_path
        else:
            word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
            if word_embedding_path is not None:
                word_embedding_gather = word_embedding_path[0]
                is_distill = True;
            else:
                logger.info("Word embedding path is not found. Embed layer cannot be fused.")
                return

        input_ids = word_embedding_gather.input[1]

        position_embedding_expand = None
        position_embedding_shape = None

        position_embedding_path = self.model.match_parent_path(normalize_node, ['Gather', 'Expand'], [1, 1]) # for distill-bert
        if position_embedding_path is not None:
            position_embedding_weight_node, position_embedding_expand = position_embedding_path
        else:
            position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0])
            if position_embedding_path is not None:
                _, position_embedding_weight_node = position_embedding_path
            else:
                position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
                if position_embedding_path is not None:
                    position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path
                else:
                    position_embedding_path = self.model.match_parent_path(
                        add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
                    if position_embedding_path is not None:
                        position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path
                    else:
                        # Here we will not try to get exact match. Instead, we only try identify position embedding weights.
                        position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1])
                        if position_embedding_path is not None:
                            position_embedding_weight_node, position_embedding_expand = position_embedding_path
                        else:
                            logger.info("Position embedding path is not found. Embed layer cannot be fused.")
                            return

                if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
                    logger.info("position and word embedding is expected to be applied on same input")
                    return

        if position_embedding_expand and position_embedding_shape:
            input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node)
            subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand,
                                                                  [input_parent] if input_parent else [],
                                                                  output_name_to_node)
            self.nodes_to_remove.extend(subgraph_nodes)

        self.nodes_to_remove.extend(word_embedding_path)
        self.nodes_to_remove.extend(position_embedding_path)

        self.nodes_to_remove.extend([normalize_node])

        # Cast input_ids and segment_ids to int32.
        input_ids_cast_node = None
        if self.model.find_graph_input(input_ids):
            casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
        else:
            input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids)

        node_name = self.model.create_node_name('EmbedLayerNormalization')
        output_name = node_name + "_output"

        embed_node_inputs = None
        if is_distill == False:
            segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node)
            if segment_path is None:
                return
            else:
                from packaging.version import Version
                import onnxruntime
                if Version(onnxruntime.__version__) <= Version("1.4.0"):
                    logger.warning('Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
                    return

                segment_ids, segment_embedding_gather = segment_path

                embed_node_inputs=[
                    input_ids,
                    segment_ids,
                    word_embedding_gather.input[0],
                    position_embedding_weight_node.input[0],
                    segment_embedding_gather.input[0],
                    normalize_node.input[2],
                    normalize_node.input[3]  # gamma and beta
                ]
        else:
            embed_node_inputs=[
                input_ids,
                '',
                word_embedding_gather.input[0],
                position_embedding_weight_node.input[0],
                '',
                normalize_node.input[2],
                normalize_node.input[3]  # gamma and beta
            ]

        embed_node = helper.make_node(
            'EmbedLayerNormalization',
            embed_node_inputs,
            outputs=[node_name + "_output", node_name + "_dummy_mask_index"],
            name=node_name)

        embed_node.domain = "com.microsoft"

        # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
        for att in normalize_node.attribute:
            if att.name == 'epsilon':
                embed_node.attribute.extend([att])
        # Set default value to 1e-12 if no attribute is found.
        # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
        if len(embed_node.attribute) == 0:
            embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)])

        self.model.replace_input_of_all_nodes(normalize_node.output[0], output_name)
        self.nodes_to_add.append(embed_node)