Exemple #1
0
 def __init__(self, model: OnnxModel):
     self.model = model
     # A lookup table with mask input as key, and mask index output as value
     self.mask_indice = {}
     # A lookup table with mask input as key, and cast (to int32) output as value
     self.mask_casted = {}
     self.utils = FusionUtils(model)
    def change_input_to_int32(self):
        original_opset_version = self.model.opset_import[0].version
        graph = self.graph()

        new_graph_inputs = []
        casted_bert_graph_inputs = self.get_graph_inputs_from_embed_nodes(casted=True)
        utils = FusionUtils(self)

        for input in graph.input:
            if input.name in casted_bert_graph_inputs:
                utils.remove_cast_int32(input.name)
                int32_input = helper.make_tensor_value_info(input.name, TensorProto.INT32,
                                                            self.tensor_shape_to_list(input.type.tensor_type))
                new_graph_inputs.append(int32_input)
            else:
                new_graph_inputs.append(input)

        graph_def = helper.make_graph(graph.node,
                                      'int32 inputs',
                                      new_graph_inputs,
                                      graph.output,
                                      initializer=graph.initializer,
                                      value_info=graph.value_info)

        self.model = helper.make_model(graph_def, producer_name='onnxruntime-tools')

        # restore opset version
        self.model.opset_import[0].version = original_opset_version
 def __init__(self, model: OnnxModel, num_heads: int):
     super().__init__(model, "Attention", "LayerNormalization", "with past")
     # TODO: detect num_heads from graph like FusionAttention
     self.num_heads = num_heads
     self.utils = FusionUtils(model)
     self.casted_attention_mask = {
     }  # map from name of attention mask to the name that casted to int32
class AttentionMask():
    """
    Fuse Attention subgraph into one Attention node.
    """
    def __init__(self, model: OnnxModel):
        self.model = model
        # A lookup table with mask input as key, and mask index output as value
        self.mask_indice = {}
        # A lookup table with mask input as key, and cast (to int32) output as value
        self.mask_casted = {}
        self.utils = FusionUtils(model)
        self.mask_format = AttentionMaskFormat.MaskIndexEnd

    def set_mask_format(self, mask_format: AttentionMaskFormat):
        self.mask_format = mask_format

    def set_mask_indice(self, mask, mask_index):
        if mask in self.mask_indice:
            assert mask_index == self.mask_indice[mask]
        self.mask_indice[mask] = mask_index

    def get_first_mask(self):
        assert len(self.mask_indice) > 0
        return next(iter(self.mask_indice))

    def process_mask(self, input):
        if input in self.mask_indice:
            return self.mask_indice[input]

        # Add cast to convert int64 to int32
        if self.model.find_graph_input(input):
            casted, input_name = self.utils.cast_graph_input_to_int32(input)
        else:
            input_name, cast_node = self.utils.cast_input_to_int32(input)
            casted = True

        if casted:
            self.mask_casted[input] = input_name

        # Attention supports int32 attention mask (2D) since 1.4.0
        if self.mask_format == AttentionMaskFormat.AttentionMask:
            self.mask_indice[input] = input_name
            return input_name

        # Add a mask processing node to convert attention mask to mask index (1D)
        output_name = self.model.create_node_name('mask_index')
        mask_index_node = helper.make_node('ReduceSum',
                                           inputs=[input_name],
                                           outputs=[output_name],
                                           name=self.model.create_node_name(
                                               'ReduceSum', 'MaskReduceSum'))
        mask_index_node.attribute.extend([
            helper.make_attribute("axes", [1]),
            helper.make_attribute("keepdims", 0)
        ])
        self.model.add_node(mask_index_node)

        self.mask_indice[input] = output_name
        return output_name
Exemple #5
0
    def fuse(self, concat_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]],
             output_name_to_node: Dict[str, NodeProto]):
        """
        Smplify subgraph like

                   (2d_input)
                    /       \
                Shape       shape
                /             \
            Gather(indices=0)  Gather(indices=1)
                |                |
            Unsqueeze(axes=0)   Unsqueeze(axes=0)
                   \          /
                      Concat 
                        |

        into  (2d_input) --> Shape -->
        """
        opset_version = self.model.get_opset_version()

        inputs = len(concat_node.input)
        root = None
        shape_output = None
        for i in range(inputs):
            path = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [i, 0, 0],
                                                output_name_to_node)
            if path is None:
                return

            unsqueeze, gather, shape = path
            if i == 0:
                shape_output = shape.output[0]
            if root is None:
                root = shape.input[0]
                if self.get_dimensions(root) != inputs:
                    return
            elif shape.input[0] != root:
                return

            if not FusionUtils.check_node_attribute(unsqueeze, 'axis', 0, default_value=0):
                return

            if opset_version < 13:
                if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]):
                    return
            else:
                if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
                    return

            value = self.model.get_constant_value(gather.input[1])
            from numpy import ndarray, array_equal
            if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
                return

        if self.model.find_graph_output(concat_node.output[0]) is None:
            self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
            self.fused_count += 1
            self.prune_graph = True
Exemple #6
0
 def __init__(self, model: OnnxModel, description: str = 'no mask'):
     super().__init__(model, "EmbedLayerNormalization",
                      ["LayerNormalization", "SkipLayerNormalization"],
                      description)
     self.utils = FusionUtils(model)
     self.shape_infer_helper = self.model.infer_runtime_shape({},
                                                              update=True)
     # The following will be reset in each fuse call of FusionEmbedLayerNormalization
     self.attention = None
     self.embed_node = None
    def __init__(self, model, num_heads, hidden_size):
        assert num_heads > 0
        assert hidden_size % num_heads == 0

        super().__init__(model)
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        self.attention_mask = AttentionMask(self)
        self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
        self.utils = FusionUtils(self)
Exemple #8
0
    def optimize(self, options: FusionOptions = None, add_dynamic_axes=False):
        if (options is None) or options.enable_layer_norm:
            self.fuse_layer_norm()

        if (options is None) or options.enable_gelu:
            self.fuse_gelu()

        self.preprocess()

        self.fuse_reshape()

        if (options is None) or options.enable_skip_layer_norm:
            self.fuse_skip_layer_norm()

        if (options is None) or options.enable_attention:
            if options is not None:
                self.attention_mask.set_mask_format(options.attention_mask_format)
            self.fuse_attention()

        if (options is None) or options.enable_embed_layer_norm:
            self.fuse_embed_layer()

        # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
        FusionUtils.remove_useless_reshape_nodes(self)

        self.postprocess()

        # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
        if (options is None) or options.enable_bias_gelu:
            # Fuse Gelu and Add Bias before it.
            self.fuse_bias_gelu(is_fastgelu=True)
            self.fuse_bias_gelu(is_fastgelu=False)

        if (options is None) or options.enable_bias_skip_layer_norm:
            # Fuse SkipLayerNormalization and Add Bias before it.
            self.fuse_add_bias_skip_layer_norm()

        if (options is not None and options.enable_gelu_approximation):
            self.gelu_approximation()

        self.remove_unused_constant()

        # Use symbolic batch dimension in input and output.
        if add_dynamic_axes:
            self.use_dynamic_axes()

        logger.info(f"opset verion: {self.model.opset_import[0].version}")
Exemple #9
0
    def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
        """Initialize BERT ONNX Model.
           
        Args:
            model (ModelProto): the ONNX model
            num_heads (int, optional): number of attentioin heads. Defaults to 0, and we will detect the parameter automatically.
            hidden_size (int, optional): hidden dimension. Defaults to 0, and we will detect the parameter automatically.
        """
        assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)

        super().__init__(model)
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        self.attention_mask = AttentionMask(self)
        self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
        self.utils = FusionUtils(self)
Exemple #10
0
    def adjust_reshape_and_expand(self):
        # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
        FusionUtils.remove_useless_reshape_nodes(self)

        nodes_to_remove = []
        for node in self.nodes():
            if node.op_type == 'Reshape':
                # Clean up unneccessary reshape nodes.
                # Find reshape nodes with no actually data in "shape" attribute and remove.
                reshape_shape = self.get_constant_value(node.input[1])
                if reshape_shape is not None and reshape_shape.size == 0:
                    nodes_to_remove.extend([node])
                    self.replace_input_of_all_nodes(node.output[0],
                                                    node.input[0])
                    continue

                # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
                # changing current reshape's input to output of slice.
                reshape_path = self.match_parent_path(
                    node, ['Expand', 'Expand', 'Reshape', 'Slice'],
                    [0, 0, 0, 0], self.output_name_to_node())
                if reshape_path is not None:
                    expand_node = reshape_path[-3]
                    expand_shape_value = self.get_constant_value(
                        expand_node.input[1])

                    reshape_before_expand = reshape_path[-2]
                    shape_value = self.get_constant_value(
                        reshape_before_expand.input[1])

                    slice_node = reshape_path[-1]
                    if expand_shape_value is not None and shape_value is not None and len(
                            expand_shape_value) is 2 and len(
                                shape_value
                            ) is 1 and expand_shape_value[1] == shape_value[0]:
                        node.input[0] = slice_node.output[0]

        if nodes_to_remove:
            self.remove_nodes(nodes_to_remove)
            logger.info(
                f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
Exemple #11
0
    def change_input_to_int32(self):
        original_opset_version = self.model.opset_import[0].version
        graph = self.graph()

        batch_size, sequence_length = self.get_bert_input_shape()
        new_graph_inputs = []

        bert_inputs = self.get_bert_inputs()
        utils = FusionUtils(self)
        for input in graph.input:
            if input.name in bert_inputs:
                utils.remove_cast_int32(input.name)
                input_shape = [
                    batch_size if isinstance(batch_size, int) else 1,
                    sequence_length
                    if isinstance(sequence_length, int) else 128
                ]
                int32_input = helper.make_tensor_value_info(
                    input.name, TensorProto.INT32, input_shape)
                new_graph_inputs.append(int32_input)
            else:
                new_graph_inputs.append(input)

        graph_def = helper.make_graph(graph.node,
                                      'int32 inputs',
                                      new_graph_inputs,
                                      graph.output,
                                      initializer=graph.initializer,
                                      value_info=graph.value_info)

        self.model = helper.make_model(graph_def,
                                       producer_name='bert model optimizer')

        if isinstance(batch_size, str) or isinstance(sequence_length, str):
            self.use_dynamic_axes(
                batch_size if isinstance(batch_size, str) else None,
                sequence_length if isinstance(sequence_length, str) else None)

        # restore opset version
        self.model.opset_import[0].version = original_opset_version
Exemple #12
0
class BertOnnxModel(OnnxModel):
    def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
        """Initialize BERT ONNX Model.
           
        Args:
            model (ModelProto): the ONNX model
            num_heads (int, optional): number of attentioin heads. Defaults to 0, and we will detect the parameter automatically.
            hidden_size (int, optional): hidden dimension. Defaults to 0, and we will detect the parameter automatically.
        """
        assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)

        super().__init__(model)
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        self.attention_mask = AttentionMask(self)
        self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
        self.utils = FusionUtils(self)

    def fuse_attention(self):
        self.attention_fusion.apply()

    def fuse_gelu(self):
        fusion = FusionGelu(self)
        fusion.apply()
        fusion = FusionFastGelu(self)
        fusion.apply()

    def fuse_bias_gelu(self, is_fastgelu):
        fusion = FusionBiasGelu(self, is_fastgelu)
        fusion.apply()

    def gelu_approximation(self):
        fusion = FusionGeluApproximation(self)
        fusion.apply()

    def fuse_add_bias_skip_layer_norm(self):
        fusion = FusionBiasSkipLayerNormalization(self)
        fusion.apply()

    def fuse_reshape(self):
        fusion = FusionReshape(self)
        fusion.apply()

    def fuse_embed_layer(self):
        fusion = FusionEmbedLayerNormalization(self)
        fusion.apply()

    def fuse_layer_norm(self):
        fusion = FusionLayerNormalization(self)
        fusion.apply()

        fusion = FusionLayerNormalizationTF(self)
        fusion.apply()

    def fuse_skip_layer_norm(self):
        fusion = FusionSkipLayerNormalization(self)
        fusion.apply()

    def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool):
        """
        Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
        Returns a list of the graph input names based on the filter whether it is casted or not.
        """
        graph_inputs = []

        output_name_to_node = self.output_name_to_node()
        nodes = self.get_nodes_by_op_type(op_type)
        for node in nodes:
            bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
            for bert_input in bert_inputs:
                if self.find_graph_input(bert_input):
                    if not casted:
                        graph_inputs.append(bert_input)
                elif bert_input in output_name_to_node:
                    parent = output_name_to_node[bert_input]
                    if parent.op_type == 'Cast' and self.find_graph_input(parent.input[0]) is not None:
                        if casted:
                            graph_inputs.append(parent.input[0])
        return graph_inputs

    def get_graph_inputs_from_fused_nodes(self, casted: bool):
        inputs = self.get_graph_inputs_from_node_type('EmbedLayerNormalization', [0, 1, 7], casted)
        inputs += self.get_graph_inputs_from_node_type('Attention', [3], casted)
        return inputs

    def change_input_to_int32(self):
        original_opset_version = self.model.opset_import[0].version
        graph = self.graph()

        new_graph_inputs = []
        casted_bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(casted=True)

        for input in graph.input:
            if input.name in casted_bert_graph_inputs:
                self.utils.remove_cast_int32(input.name)
                int32_input = helper.make_tensor_value_info(input.name, TensorProto.INT32,
                                                            self.tensor_shape_to_list(input.type.tensor_type))
                new_graph_inputs.append(int32_input)
            else:
                new_graph_inputs.append(input)

        graph_def = helper.make_graph(graph.node,
                                      'int32 inputs',
                                      new_graph_inputs,
                                      graph.output,
                                      initializer=graph.initializer,
                                      value_info=graph.value_info)

        self.model = helper.make_model(graph_def, producer_name='onnxruntime-tools')

        # restore opset version
        self.model.opset_import[0].version = original_opset_version

    def use_dynamic_axes(self, dynamic_batch_dim='batch_size', dynamic_seq_len='max_seq_len'):
        """
        Update input and output shape to use dynamic axes.
        """
        bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
            casted=True) + self.get_graph_inputs_from_fused_nodes(casted=False)

        dynamic_batch_inputs = {}
        for input in self.model.graph.input:
            if input.name in bert_graph_inputs:
                dim_proto = input.type.tensor_type.shape.dim[0]
                dim_proto.dim_param = dynamic_batch_dim
                if dynamic_seq_len is not None:
                    dim_proto = input.type.tensor_type.shape.dim[1]
                    dim_proto.dim_param = dynamic_seq_len

        for output in self.model.graph.output:
            dim_proto = output.type.tensor_type.shape.dim[0]
            dim_proto.dim_param = dynamic_batch_dim

    def preprocess(self):
        self.adjust_reshape_and_expand()
        return

    def adjust_reshape_and_expand(self):
        nodes_to_remove = []
        for node in self.nodes():
            if node.op_type == 'Reshape':
                # Clean up unneccessary reshape nodes.
                # Find reshape nodes with no actually data in "shape" attribute and remove.
                reshape_shape = self.get_constant_value(node.input[1])
                if reshape_shape is not None and reshape_shape.size == 0:
                    nodes_to_remove.extend([node])
                    self.replace_input_of_all_nodes(node.output[0], node.input[0])
                    continue

                # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
                # changing current reshape's input to output of slice.
                reshape_path = self.match_parent_path(node, ['Expand', 'Expand', 'Reshape', 'Slice'], [0, 0, 0, 0],
                                                      self.output_name_to_node())
                if reshape_path is not None:
                    expand_node = reshape_path[-3]
                    expand_shape_value = self.get_constant_value(expand_node.input[1])

                    reshape_before_expand = reshape_path[-2]
                    shape_value = self.get_constant_value(reshape_before_expand.input[1])

                    slice_node = reshape_path[-1]
                    if expand_shape_value is not None and shape_value is not None and len(
                            expand_shape_value) is 2 and len(
                                shape_value) is 1 and expand_shape_value[1] == shape_value[0]:
                        node.input[0] = slice_node.output[0]
        self.remove_nodes(nodes_to_remove)
        logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")

    def clean_graph(self):
        output_name_to_node = self.output_name_to_node()
        nodes_to_add = []
        nodes_to_remove = []
        for node in self.nodes():
            # Before:
            #  input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
            #          |                                                     |
            #          |                                                     v
            #          +----> Shape --> Gather(indices=1) --> Unsqueeze--->  Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
            # After:
            #  input_ids --> Shape                                                  --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
            # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
            op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
            if node.op_type in op_input_id:
                i = op_input_id[node.op_type]
                parent_nodes = self.match_parent_path(
                    node, ['Cast', 'ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [i, 0, 0, 0, 0, 0],
                    output_name_to_node)
                if parent_nodes is not None:
                    cast, constantOfShape, concat, unsqueeze, gather, shape = parent_nodes
                    if shape.input[0] == self.graph().input[0].name:
                        constantOfShape.input[0] = shape.output[0]
                        output_name_to_node = self.output_name_to_node()

            if node.op_type == 'Attention':
                # Before:
                #   input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
                # After:
                #   remove this path, and remove the optional mask_index input of Attention node.
                parent_nodes = self.match_parent_path(node, ['ReduceSum', 'Cast', 'ConstantOfShape', 'Shape'],
                                                      [3, 0, 0, 0], output_name_to_node)
                if parent_nodes is not None:
                    if parent_nodes[-1].input[0] == self.graph().input[0].name:
                        attention_node = helper.make_node('Attention',
                                                          inputs=node.input[0:len(node.input) - 1],
                                                          outputs=node.output,
                                                          name=node.name + "_remove_mask")
                        attention_node.domain = "com.microsoft"
                        attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
                        nodes_to_add.append(attention_node)
                        nodes_to_remove.append(node)
        self.remove_nodes(nodes_to_remove)
        self.add_nodes(nodes_to_add)

    def postprocess(self):
        self.clean_graph()
        self.prune_graph()

    def optimize(self, options: BertOptimizationOptions = None, add_dynamic_axes=False):
        if (options is None) or options.enable_layer_norm:
            self.fuse_layer_norm()

        if (options is None) or options.enable_gelu:
            self.fuse_gelu()

        self.preprocess()

        self.fuse_reshape()

        if (options is None) or options.enable_skip_layer_norm:
            self.fuse_skip_layer_norm()

        if (options is None) or options.enable_attention:
            if options is not None:
                self.attention_mask.set_mask_format(options.attention_mask_format)
            self.fuse_attention()

        if (options is None) or options.enable_embed_layer_norm:
            self.fuse_embed_layer()

        # Post-processing like removing extra reshape nodes.
        self.postprocess()

        # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
        if (options is None) or options.enable_bias_gelu:
            # Fuse Gelu and Add Bias before it.
            self.fuse_bias_gelu(is_fastgelu=True)
            self.fuse_bias_gelu(is_fastgelu=False)

        if (options is None) or options.enable_bias_skip_layer_norm:
            # Fuse SkipLayerNormalization and Add Bias before it.
            self.fuse_add_bias_skip_layer_norm()

        if (options is not None and options.enable_gelu_approximation):
            self.gelu_approximation()

        self.remove_unused_constant()

        # Use symbolic batch dimension in input and output.
        if add_dynamic_axes:
            self.use_dynamic_axes()

        logger.info(f"opset verion: {self.model.opset_import[0].version}")

    def get_fused_operator_statistics(self):
        """
        Returns node count of fused operators.
        """
        op_count = {}
        ops = [
            'EmbedLayerNormalization', 'Attention', 'Gelu', 'FastGelu', 'BiasGelu', 'LayerNormalization',
            'SkipLayerNormalization'
        ]
        for op in ops:
            nodes = self.get_nodes_by_op_type(op)
            op_count[op] = len(nodes)
        logger.info(f"Optimized operators:{op_count}")
        return op_count

    def is_fully_optimized(self):
        """
        Returns True when the model is fully optimized.
        """
        op_count = self.get_fused_operator_statistics()
        embed = op_count['EmbedLayerNormalization']
        attention = op_count['Attention']
        gelu = op_count['Gelu'] + op_count['BiasGelu'] + op_count['FastGelu']
        layer_norm = op_count['LayerNormalization'] + op_count['SkipLayerNormalization']
        is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)

        if layer_norm == 0:
            logger.debug("Layer Normalization not fused")

        if gelu == 0:
            logger.debug("Gelu/FastGelu not fused")

        if embed == 0:
            logger.debug("Embed Layer not fused")

        if attention == 0:
            logger.warning("Attention not fused")

        return is_perfect
Exemple #13
0
 def __init__(self,
              model: OnnxModel,
              description='no mask'):
     super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
     self.utils = FusionUtils(model)
Exemple #14
0
class FusionGptAttentionPastBase(Fusion):
    """Base class for GPT Attention Fusion with past state
    """
    def __init__(self, model: OnnxModel, num_heads: int):
        super().__init__(model, "Attention", "LayerNormalization", "with past")
        self.num_heads = num_heads
        self.utils = FusionUtils(model)
        self.casted_attention_mask = {
        }  # map from name of attention mask to the name that casted to int32

    def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
        # Pattern 1:
        #                      {past}
        #                    /        \
        #                   /          \
        #    Gather(axes=0, indices=0)  Gather(indices=1)
        #      |                          |
        #    Transpose (perm=0,1,3,2)     |
        #      |                          |
        #  Concat_k                     Concat_v
        #      |                        /
        #  Transpose (perm=0,1,3,2)    /
        #      |                      /
        #  Unsqueeze        Unsqueeze
        #        \        /
        #         \      /
        #           Concat
        #             |
        #         {present}
        gather = self.model.get_parent(concat_v, 0, output_name_to_node)
        if gather.op_type != 'Gather':
            logger.debug("match_past_pattern_1: expect Gather for past")
            return None

        if not self.model.find_constant_input(gather, 1) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=1 for Gather of past")
            return None
        past = gather.input[0]

        parent = self.model.get_parent(concat_k, 0, output_name_to_node)
        if parent.op_type == 'Gather':
            gather_past_k = parent
        else:
            past_k_nodes = self.model.match_parent_path(
                concat_k, ['Transpose', 'Gather'], [0, 0])
            if past_k_nodes is None:
                logger.debug(
                    "match_past_pattern_1: failed match Transpose and Gather")
                return None
            gather_past_k = past_k_nodes[-1]

        if not self.model.find_constant_input(gather_past_k, 0) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=0 for Gather k of past")
            return None
        past_k = gather_past_k.input[0]
        if past != past_k:
            logger.debug("match_past_pattern_1: expect past to be same")
            return None

        return past

    def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
        # Pattern 2:
        #      Split (QKV)
        #      / |   |
        #     /  |   +----------------------+
        #        |                          |
        #        |         {past}           |
        #        |           |              |
        #      Reshape     Split         Reshape
        #        |         /    \           |
        # Transpose_k  Squeeze  Squeeze  Transpose_v
        #        |      |        \        /
        #        +------|---+     \      /
        #               |   |      \    /
        #              Concat_k   Concat_v
        #               |            |
        #          Unsqueeze    Unsqueeze
        #                \       /
        #                 Concat
        #                   |
        #               {present}
        #
        squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
        if squeeze.op_type != 'Squeeze':
            logger.debug(
                "match_past_pattern_2: expect Squeeze as parent of concat_v")
            return None

        split = self.model.get_parent(squeeze, 0, output_name_to_node)
        if split.op_type != "Split":
            logger.debug("match_past_pattern_2: expect Split for past path")
            return None

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not FusionUtils.check_node_attribute(split, 'split', [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None
        else:
            if not self.utils.check_node_input_value(squeeze, 1, [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not self.utils.check_node_input_value(split, 1, [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None

        if not FusionUtils.check_node_attribute(
                split, 'axis', 0, default_value=0):
            logger.debug(
                "match_past_pattern_2: attribute axis of Split are not expected in past path"
            )
            return None
        past = split.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Squeeze', 'Split'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_2: failed to match past_k_nodes path")
            return None
        past_k = past_k_nodes[-1].input[0]

        if past != past_k:
            logger.info("match_past_pattern_2: expect past to be same")
            return None

        return past

    def match_present(self, concat_v, input_name_to_nodes):
        unsqueeze_present_v = self.model.find_first_child_by_type(
            concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False)
        if not unsqueeze_present_v:
            logger.info("expect unsqueeze for present")
            return None
        concat_present = self.model.find_first_child_by_type(
            unsqueeze_present_v,
            'Concat',
            input_name_to_nodes,
            recursive=False)
        if not concat_present:
            logger.info("expect concat for present")
            return None

        present = concat_present.output[0]
        return present

    def cast_attention_mask(self, input_name):
        if input_name in self.casted_attention_mask:
            attention_mask_input_name = self.casted_attention_mask[input_name]
        elif self.model.find_graph_input(input_name):
            casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(
                input_name)
            self.casted_attention_mask[input_name] = attention_mask_input_name
        else:
            attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(
                input_name)
            self.casted_attention_mask[input_name] = attention_mask_input_name
        return attention_mask_input_name
Exemple #15
0
    def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
        # Pattern 2:
        #      Split (QKV)
        #      / |   |
        #     /  |   +----------------------+
        #        |                          |
        #        |         {past}           |
        #        |           |              |
        #      Reshape     Split         Reshape
        #        |         /    \           |
        # Transpose_k  Squeeze  Squeeze  Transpose_v
        #        |      |        \        /
        #        +------|---+     \      /
        #               |   |      \    /
        #              Concat_k   Concat_v
        #               |            |
        #          Unsqueeze    Unsqueeze
        #                \       /
        #                 Concat
        #                   |
        #               {present}
        #
        squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
        if squeeze.op_type != 'Squeeze':
            logger.debug(
                "match_past_pattern_2: expect Squeeze as parent of concat_v")
            return None

        split = self.model.get_parent(squeeze, 0, output_name_to_node)
        if split.op_type != "Split":
            logger.debug("match_past_pattern_2: expect Split for past path")
            return None

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not FusionUtils.check_node_attribute(split, 'split', [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None
        else:
            if not self.utils.check_node_input_value(squeeze, 1, [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not self.utils.check_node_input_value(split, 1, [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None

        if not FusionUtils.check_node_attribute(
                split, 'axis', 0, default_value=0):
            logger.debug(
                "match_past_pattern_2: attribute axis of Split are not expected in past path"
            )
            return None
        past = split.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Squeeze', 'Split'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_2: failed to match past_k_nodes path")
            return None
        past_k = past_k_nodes[-1].input[0]

        if past != past_k:
            logger.info("match_past_pattern_2: expect past to be same")
            return None

        return past
Exemple #16
0
class BertOnnxModel(OnnxModel):
    def __init__(self,
                 model: ModelProto,
                 num_heads: int = 0,
                 hidden_size: int = 0):
        """Initialize BERT ONNX Model.

        Args:
            model (ModelProto): the ONNX model
            num_heads (int, optional): number of attentioin heads. Defaults to 0, and we will detect the parameter automatically.
            hidden_size (int, optional): hidden dimension. Defaults to 0, and we will detect the parameter automatically.
        """
        assert (num_heads == 0
                and hidden_size == 0) or (num_heads > 0
                                          and hidden_size % num_heads == 0)

        super().__init__(model)
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        self.attention_mask = AttentionMask(self)
        self.attention_fusion = FusionAttention(self, self.hidden_size,
                                                self.num_heads,
                                                self.attention_mask)
        self.utils = FusionUtils(self)

    def fuse_attention(self):
        self.attention_fusion.apply()

    def fuse_gelu(self):
        fusion = FusionGelu(self)
        fusion.apply()
        fusion = FusionFastGelu(self)
        fusion.apply()

    def fuse_bias_gelu(self, is_fastgelu):
        fusion = FusionBiasGelu(self, is_fastgelu)
        fusion.apply()

    def gelu_approximation(self):
        fusion = FusionGeluApproximation(self)
        fusion.apply()

    def fuse_add_bias_skip_layer_norm(self):
        fusion = FusionBiasSkipLayerNormalization(self)
        fusion.apply()

    def fuse_reshape(self):
        fusion = FusionReshape(self)
        fusion.apply()

    def fuse_shape(self):
        fusion = FusionShape(self)
        fusion.apply()

    def fuse_embed_layer(self):
        fusion = FusionEmbedLayerNormalization(self)
        fusion.apply()

    def fuse_layer_norm(self):
        fusion = FusionLayerNormalization(self)
        fusion.apply()

        fusion = FusionLayerNormalizationTF(self)
        fusion.apply()

    def fuse_skip_layer_norm(self):
        fusion = FusionSkipLayerNormalization(self)
        fusion.apply()

    def get_graph_inputs_from_node_type(self, op_type: str,
                                        input_indices: List[int],
                                        casted: bool):
        """
        Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
        Returns a list of the graph input names based on the filter whether it is casted or not.
        """
        graph_inputs = []

        output_name_to_node = self.output_name_to_node()
        nodes = self.get_nodes_by_op_type(op_type)
        for node in nodes:
            bert_inputs = [
                node.input[i] for i in input_indices if i < len(node.input)
            ]
            for bert_input in bert_inputs:
                if self.find_graph_input(bert_input):
                    if not casted:
                        graph_inputs.append(bert_input)
                elif bert_input in output_name_to_node:
                    parent = output_name_to_node[bert_input]
                    if parent.op_type == "Cast" and self.find_graph_input(
                            parent.input[0]) is not None:
                        if casted:
                            graph_inputs.append(parent.input[0])
        return graph_inputs

    def get_graph_inputs_from_fused_nodes(self, casted: bool):
        inputs = self.get_graph_inputs_from_node_type(
            "EmbedLayerNormalization", [0, 1, 7], casted)
        inputs += self.get_graph_inputs_from_node_type("Attention", [3],
                                                       casted)
        return inputs

    def change_graph_input_type(
        self,
        graph: GraphProto,
        graph_input: ValueInfoProto,
        new_type: int = TensorProto.INT32,
    ):
        """Change graph input type, and add Cast node if needed.

        Args:
            graph (GraphProto): graph
            graph_input (TensorProto): input of the graph
            new_type (int, optional): new data type. Defaults to TensorProto.INT32.

        Returns:
            NodeProto: a new Cast node that added. None if Cast node is not added.
            List[NodeProto]: Cast nodes that have been removed.
        """
        assert isinstance(graph, GraphProto)
        assert isinstance(graph_input, ValueInfoProto)
        assert self.find_graph_input(graph_input.name)

        if graph_input.type.tensor_type.elem_type == int(new_type):
            return None, []

        new_cast_node = None
        nodes_to_remove = []

        input_name_to_nodes = self.input_name_to_nodes()
        if graph_input.name in input_name_to_nodes:
            nodes = input_name_to_nodes[graph_input.name]

            # For children that is not Cast node, insert a Cast node to convert int32 to original data type.
            nodes_not_cast = [node for node in nodes if node.op_type != "Cast"]
            if nodes_not_cast:
                node_name = self.create_node_name("Cast")
                output_name = node_name + "_" + graph_input.name
                new_value_info = graph.value_info.add()
                new_value_info.CopyFrom(graph_input)
                new_value_info.name = output_name
                new_cast_node = helper.make_node(
                    "Cast",
                    [graph_input.name],
                    [output_name],
                    to=int(graph_input.type.tensor_type.elem_type),
                    name=node_name,
                )
                graph.node.extend([new_cast_node])

                for node in nodes_not_cast:
                    OnnxModel.replace_node_input(node, graph_input.name,
                                                 output_name)

            # For children that is Cast node, no need to insert Cast.
            # When the children is Cast to int32, we can remove that Cast node since input type is int32 now.
            nodes_cast = [node for node in nodes if node.op_type == "Cast"]
            for node in nodes_cast:
                if OnnxModel.get_node_attribute(node, "to") == int(new_type):
                    self.replace_input_of_all_nodes(node.output[0],
                                                    graph_input.name)
                if not self.find_graph_output(node.output[0]):
                    nodes_to_remove.append(node)
            if nodes_to_remove:
                self.remove_nodes(nodes_to_remove)

        graph_input.type.tensor_type.elem_type = int(new_type)
        return new_cast_node, nodes_to_remove

    def change_graph_inputs_to_int32(self):
        """Change data type of all graph inputs to int32 type, and add Cast node if needed."""
        graph = self.graph()
        add_cast_count = 0
        remove_cast_count = 0
        for graph_input in graph.input:
            new_node, removed_nodes = self.change_graph_input_type(
                graph, graph_input, TensorProto.INT32)
            if new_node:
                add_cast_count += 1
            remove_cast_count += len(removed_nodes)
        logger.info(
            f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
        )

    def use_dynamic_axes(self,
                         dynamic_batch_dim="batch_size",
                         dynamic_seq_len="max_seq_len"):
        """
        Update input and output shape to use dynamic axes.
        """
        bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
            casted=True) + self.get_graph_inputs_from_fused_nodes(casted=False)

        dynamic_batch_inputs = {}
        for input in self.model.graph.input:
            if input.name in bert_graph_inputs:
                dim_proto = input.type.tensor_type.shape.dim[0]
                dim_proto.dim_param = dynamic_batch_dim
                if dynamic_seq_len is not None:
                    dim_proto = input.type.tensor_type.shape.dim[1]
                    dim_proto.dim_param = dynamic_seq_len

        for output in self.model.graph.output:
            dim_proto = output.type.tensor_type.shape.dim[0]
            dim_proto.dim_param = dynamic_batch_dim

    def preprocess(self):
        self.adjust_reshape_and_expand()
        return

    def adjust_reshape_and_expand(self):
        nodes_to_remove = []
        for node in self.nodes():
            if node.op_type == "Reshape":
                # Clean up unneccessary reshape nodes.
                # Find reshape nodes with no actually data in "shape" attribute and remove.
                reshape_shape = self.get_constant_value(node.input[1])
                if reshape_shape is not None and reshape_shape.size == 0:
                    nodes_to_remove.extend([node])
                    self.replace_input_of_all_nodes(node.output[0],
                                                    node.input[0])
                    continue

                # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
                # changing current reshape's input to output of slice.
                reshape_path = self.match_parent_path(
                    node,
                    ["Expand", "Expand", "Reshape", "Slice"],
                    [0, 0, 0, 0],
                    self.output_name_to_node(),
                )
                if reshape_path is not None:
                    expand_node = reshape_path[-3]
                    expand_shape_value = self.get_constant_value(
                        expand_node.input[1])

                    reshape_before_expand = reshape_path[-2]
                    shape_value = self.get_constant_value(
                        reshape_before_expand.input[1])

                    slice_node = reshape_path[-1]
                    if (expand_shape_value is not None
                            and shape_value is not None
                            and len(expand_shape_value) == 2
                            and len(shape_value) == 1
                            and expand_shape_value[1] == shape_value[0]):
                        node.input[0] = slice_node.output[0]

        if nodes_to_remove:
            self.remove_nodes(nodes_to_remove)
            logger.info(
                f"Removed Reshape and Expand count: {len(nodes_to_remove)}")

    def clean_graph(self):
        output_name_to_node = self.output_name_to_node()
        nodes_to_remove = []
        for node in self.nodes():
            # Before:
            #  input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
            #          |                                                     |
            #          |                                                     v
            #          +----> Shape --> Gather(indices=1) --> Unsqueeze--->  Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
            # After:
            #  input_ids --> Shape                                                  --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
            # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
            op_input_id = {
                "EmbedLayerNormalization": 1,
                "ReduceSum": 0,
                "Attention": 3
            }
            if node.op_type in op_input_id:
                i = op_input_id[node.op_type]
                parent_nodes = self.match_parent_path(
                    node,
                    [
                        "Cast",
                        "ConstantOfShape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                    ],
                    [i, 0, 0, 0, 0, 0],
                    output_name_to_node,
                )
                if parent_nodes is not None:
                    (
                        cast,
                        constantOfShape,
                        concat,
                        unsqueeze,
                        gather,
                        shape,
                    ) = parent_nodes
                    if shape.input[0] == self.graph().input[0].name:
                        constantOfShape.input[0] = shape.output[0]
                        output_name_to_node = self.output_name_to_node()

            if node.op_type == "Attention":
                # Before:
                #   input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
                # After:
                #   remove this path, and remove the optional mask_index input of Attention node.
                parent_nodes = self.match_parent_path(
                    node,
                    ["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
                    [3, 0, 0, 0],
                    output_name_to_node,
                )
                if parent_nodes is not None:
                    if parent_nodes[-1].input[0] == self.graph().input[0].name:
                        attention_node = helper.make_node(
                            "Attention",
                            inputs=node.input[0:len(node.input) - 1],
                            outputs=node.output,
                            name=node.name + "_remove_mask",
                        )
                        attention_node.domain = "com.microsoft"
                        attention_node.attribute.extend([
                            helper.make_attribute("num_heads", self.num_heads)
                        ])
                        self.add_node(
                            attention_node,
                            self.get_graph_by_node(attention_node).name)
                        nodes_to_remove.append(node)
        self.remove_nodes(nodes_to_remove)

    def postprocess(self):
        self.clean_graph()
        self.prune_graph()

    def optimize(self, options: FusionOptions = None, add_dynamic_axes=False):
        # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
        self.utils.remove_useless_cast_nodes()

        if (options is None) or options.enable_layer_norm:
            self.fuse_layer_norm()

        if (options is None) or options.enable_gelu:
            self.fuse_gelu()

        self.preprocess()

        self.fuse_reshape()

        if (options is None) or options.enable_skip_layer_norm:
            self.fuse_skip_layer_norm()

        if (options is None) or options.enable_attention:
            if options is not None:
                self.attention_mask.set_mask_format(
                    options.attention_mask_format)
            self.fuse_attention()

        self.fuse_shape()

        if (options is None) or options.enable_embed_layer_norm:
            self.fuse_embed_layer()

        # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
        self.utils.remove_useless_reshape_nodes()

        self.postprocess()

        # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
        if (options is None) or options.enable_bias_gelu:
            # Fuse Gelu and Add Bias before it.
            self.fuse_bias_gelu(is_fastgelu=True)
            self.fuse_bias_gelu(is_fastgelu=False)

        if (options is None) or options.enable_bias_skip_layer_norm:
            # Fuse SkipLayerNormalization and Add Bias before it.
            self.fuse_add_bias_skip_layer_norm()

        if options is not None and options.enable_gelu_approximation:
            self.gelu_approximation()

        self.remove_unused_constant()

        # Use symbolic batch dimension in input and output.
        if add_dynamic_axes:
            self.use_dynamic_axes()

        logger.info(f"opset version: {self.get_opset_version()}")

    def get_fused_operator_statistics(self):
        """
        Returns node count of fused operators.
        """
        op_count = {}
        ops = [
            "EmbedLayerNormalization",
            "Attention",
            "Gelu",
            "FastGelu",
            "BiasGelu",
            "LayerNormalization",
            "SkipLayerNormalization",
        ]
        for op in ops:
            nodes = self.get_nodes_by_op_type(op)
            op_count[op] = len(nodes)
        logger.info(f"Optimized operators:{op_count}")
        return op_count

    def is_fully_optimized(self):
        """
        Returns True when the model is fully optimized.
        """
        op_count = self.get_fused_operator_statistics()
        embed = op_count["EmbedLayerNormalization"]
        attention = op_count["Attention"]
        gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
        layer_norm = op_count["LayerNormalization"] + op_count[
            "SkipLayerNormalization"]
        is_perfect = (embed > 0) and (attention > 0) and (
            attention == gelu) and (layer_norm >= 2 * attention)

        if layer_norm == 0:
            logger.debug("Layer Normalization not fused")

        if gelu == 0:
            logger.debug("Gelu/FastGelu not fused")

        if embed == 0:
            logger.debug("Embed Layer not fused")

        if attention == 0:
            logger.warning("Attention not fused")

        return is_perfect
 def __init__(self,
              model: OnnxModel,
              name: str = "EmbedLayerNormalization(no mask)",
              search_op_types="SkipLayerNormalization"):
     super().__init__(model, name, search_op_types)
     self.utils = FusionUtils(model)
Exemple #18
0
    def match_position_embedding_bert(self, position_embedding_gather,
                                      input_ids, output_name_to_node):
        """  Match position embedding path from input_ids to Gather for BERT.

        BERT Embedding Layer Pattern:       
                                    (input_ids)
                                   /         \
                                 /          Shape
                                /              |
                              /              Gather (indices=1)
                             /                  |
                            /                  Add (optional, B=0)
                           /                    |
                        Gather (segment_ids) Unsqueeze (axes=0)
                           \        |           |
                            \     Gather      Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
                              \    /            |
                                Add          Gather 
                                   \       /
                                      Add
                                       |
                                LayerNormalization
        """
        path = self.model.match_parent_path(position_embedding_gather,
                                            ['Slice', 'Unsqueeze'], [1, 2],
                                            output_name_to_node)
        if path is None:
            return False

        slice, unsqueeze = path
        slice_weight = self.model.get_constant_value(slice.input[0])
        if not (slice_weight is not None  and len(slice_weight.shape) == 2 and slice_weight.shape[0] == 1 \
                and self.utils.check_node_input_value(slice, 1, [0]) \
                and self.utils.check_node_input_value(slice, 3, [1]) \
                and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))):
            return False

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]):
                return False
        else:
            if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
                return False

        node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
        if node is None:
            return False
        if node.op_type == "Add":
            if not self.utils.check_node_input_value(node, 1, 0):
                return False
            gather = self.model.get_parent(node, 0, output_name_to_node)
        else:
            gather = node

        if gather is None or gather.op_type != "Gather":
            return False
        if not (self.utils.check_node_input_value(gather, 1, 1)):
            return False

        shape = self.model.get_parent(gather, 0, output_name_to_node)
        if shape is None or shape.op_type != "Shape":
            return False

        return input_ids == shape.input[0]
class FusionEmbedLayerNoMask(Fusion):
    """
     Embed Layer Normalization will fuse embeddings and mask processing into one node.
     The embeddings before conversion:

     (input_ids) -------->  Gather ----------+       (segment_ids)
        |                                    |            |
        |                                    v            v
        +--> Shape --> Expand -> Gather---->Add         Gather
        |                ^                   |            |
        |                |                   v            v
        +---(optional graph)               SkipLayerNormalization

      Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model.

      (input_ids) --> Gather -----+           Slice
                                  |            |
                                  v            v
     (segment_ids)--> Gather --->Add        Reshape
                                  |            |
                                  v            v
                              SkipLayerNormalization
    """
    def __init__(self, model: OnnxModel, description='no mask'):
        super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
        self.utils = FusionUtils(model)
        self.attention = None

    def match_segment_path(self, normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node):
        segment_ids = None
        segment_embedding_gather = None

        segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])

        if segment_embedding_path is None:
            segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
            if segment_embedding_path is None:
                logger.info("Segment embedding is not found. Embed layer cannot be fused.")
                return
            _, segment_embedding_gather = segment_embedding_path
        else:
            segment_embedding_gather = segment_embedding_path[0]

        segment_ids = segment_embedding_gather.input[1]

        self.nodes_to_remove.extend(segment_embedding_path)

        if self.model.find_graph_input(segment_ids):
            casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
        else:
            segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids)

             # Cast might be removed by OnnxRuntime.
            _, segment_id_path, _ = self.model.match_parent_paths(
                 segment_ids_cast_node,
                 [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]),
                  (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node)

            if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]:
                logger.debug("Simplify semgent id path...")
                self.model.add_node(
                    helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"]))
                self.model.add_node(
                    helper.make_node('ConstantOfShape',
                                      inputs=["input_shape"],
                                      outputs=["zeros_for_input_shape"],
                                      value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1])))
                segment_ids = "zeros_for_input_shape"

        return segment_ids, segment_embedding_gather

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        is_distill = False;

        if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None and self.model.match_parent_path(node, ['Gather'], [0]) is None:
            logger.debug("Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather")
            return

        self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False)
        if self.attention is None:
            # In case user disables attention fusion, check whether subgraph looks like Attention.
            if node.output[0] not in input_name_to_nodes:
                return
            children = input_name_to_nodes[node.output[0]]
            children_types = sorted([child.op_type for child in children])
            if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization'] and children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape', 'SkipLayerNormalization']:
                logger.debug("No Attention like subgraph in children of SkipLayerNormalization")
                return

        # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding
        normalize_node = node
        add_node = None
        word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
        if word_embedding_path is not None:
            add_node, word_embedding_gather = word_embedding_path
        else:
            word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
            if word_embedding_path is not None:
                word_embedding_gather = word_embedding_path[0]
                is_distill = True;
            else:
                logger.info("Word embedding path is not found. Embed layer cannot be fused.")
                return

        input_ids = word_embedding_gather.input[1]

        position_embedding_expand = None
        position_embedding_shape = None

        position_embedding_path = self.model.match_parent_path(normalize_node, ['Gather', 'Expand'], [1, 1]) # for distill-bert
        if position_embedding_path is not None:
            position_embedding_weight_node, position_embedding_expand = position_embedding_path
        else:
            position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0])
            if position_embedding_path is not None:
                _, position_embedding_weight_node = position_embedding_path
            else:
                position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
                if position_embedding_path is not None:
                    position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path
                else:
                    position_embedding_path = self.model.match_parent_path(
                        add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
                    if position_embedding_path is not None:
                        position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path
                    else:
                        # Here we will not try to get exact match. Instead, we only try identify position embedding weights.
                        position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1])
                        if position_embedding_path is not None:
                            position_embedding_weight_node, position_embedding_expand = position_embedding_path
                        else:
                            logger.info("Position embedding path is not found. Embed layer cannot be fused.")
                            return

                if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
                    logger.info("position and word embedding is expected to be applied on same input")
                    return

        if position_embedding_expand and position_embedding_shape:
            input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node)
            subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand,
                                                                  [input_parent] if input_parent else [],
                                                                  output_name_to_node)
            self.nodes_to_remove.extend(subgraph_nodes)

        self.nodes_to_remove.extend(word_embedding_path)
        self.nodes_to_remove.extend(position_embedding_path)

        self.nodes_to_remove.extend([normalize_node])

        # Cast input_ids and segment_ids to int32.
        input_ids_cast_node = None
        if self.model.find_graph_input(input_ids):
            casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
        else:
            input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids)

        node_name = self.model.create_node_name('EmbedLayerNormalization')
        output_name = node_name + "_output"

        embed_node_inputs = None
        if is_distill == False:
            segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node)
            if segment_path is None:
                return
            else:
                from packaging.version import Version
                import onnxruntime
                if Version(onnxruntime.__version__) <= Version("1.4.0"):
                    logger.warning('Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
                    return

                segment_ids, segment_embedding_gather = segment_path

                embed_node_inputs=[
                    input_ids,
                    segment_ids,
                    word_embedding_gather.input[0],
                    position_embedding_weight_node.input[0],
                    segment_embedding_gather.input[0],
                    normalize_node.input[2],
                    normalize_node.input[3]  # gamma and beta
                ]
        else:
            embed_node_inputs=[
                input_ids,
                '',
                word_embedding_gather.input[0],
                position_embedding_weight_node.input[0],
                '',
                normalize_node.input[2],
                normalize_node.input[3]  # gamma and beta
            ]

        embed_node = helper.make_node(
            'EmbedLayerNormalization',
            embed_node_inputs,
            outputs=[node_name + "_output", node_name + "_dummy_mask_index"],
            name=node_name)

        embed_node.domain = "com.microsoft"

        # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
        for att in normalize_node.attribute:
            if att.name == 'epsilon':
                embed_node.attribute.extend([att])
        # Set default value to 1e-12 if no attribute is found.
        # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
        if len(embed_node.attribute) == 0:
            embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)])

        self.model.replace_input_of_all_nodes(normalize_node.output[0], output_name)
        self.nodes_to_add.append(embed_node)
Exemple #20
0
class FusionShape(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "Shape", "Concat")
        self.utils = FusionUtils(model)
        self.shape_infer = None
        self.shape_infer_done = False

    def get_dimensions_from_tensor_proto(
            self, tensor_proto: TensorProto) -> Union[int, None]:
        if tensor_proto.type.tensor_type.HasField("shape"):
            return len(tensor_proto.type.tensor_type.shape.dim)
        else:
            return None

    def get_dimensions(self, input_name: str) -> Union[int, None]:
        graph_input = self.model.find_graph_input(input_name)
        if graph_input:
            return self.get_dimensions_from_tensor_proto(graph_input)

        if not self.shape_infer_done:
            self.shape_infer = self.model.infer_runtime_shape({}, update=True)
            self.shape_infer_done = True

        if self.shape_infer is not None:
            return self.get_dimensions_from_tensor_proto(
                self.shape_infer.known_vi_[input_name])

        return None

    def fuse(
        self,
        concat_node: NodeProto,
        input_name_to_nodes: Dict[str, List[NodeProto]],
        output_name_to_node: Dict[str, NodeProto],
    ):
        """
        Smplify subgraph like

                   (2d_input)
                    /       \
                Shape       shape
                /             \
            Gather(indices=0)  Gather(indices=1)
                |                |
            Unsqueeze(axes=0)   Unsqueeze(axes=0)
                   \          /
                      Concat 
                        |

        into  (2d_input) --> Shape -->
        """
        opset_version = self.model.get_opset_version()

        inputs = len(concat_node.input)
        root = None
        shape_output = None
        for i in range(inputs):
            path = self.model.match_parent_path(
                concat_node,
                ["Unsqueeze", "Gather", "Shape"],
                [i, 0, 0],
                output_name_to_node,
            )
            if path is None:
                return

            unsqueeze, gather, shape = path
            if i == 0:
                shape_output = shape.output[0]
            if root is None:
                root = shape.input[0]
                if self.get_dimensions(root) != inputs:
                    return
            elif shape.input[0] != root:
                return

            if not FusionUtils.check_node_attribute(
                    unsqueeze, "axis", 0, default_value=0):
                return

            if opset_version < 13:
                if not FusionUtils.check_node_attribute(
                        unsqueeze, "axes", [0]):
                    return
            else:
                if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
                    return

            value = self.model.get_constant_value(gather.input[1])
            from numpy import array_equal, ndarray

            if not (isinstance(value, ndarray) and value.size == 1
                    and value.item() == i):
                return

        if self.model.find_graph_output(concat_node.output[0]) is None:
            self.model.replace_input_of_all_nodes(concat_node.output[0],
                                                  shape_output)
            self.fused_count += 1
            self.prune_graph = True
Exemple #21
0
class FusionEmbedLayerNoMask(Fusion):
    """
     Fuse embedding layer into one node (EmbedLayerNormalization).
     It supports the following model types: BERT, DistilBert, ALBert.
    """
    def __init__(self, model: OnnxModel, description: str = 'no mask'):
        super().__init__(model, "EmbedLayerNormalization",
                         ["LayerNormalization", "SkipLayerNormalization"],
                         description)
        self.utils = FusionUtils(model)
        self.shape_infer_helper = self.model.infer_runtime_shape({},
                                                                 update=True)
        # The following will be reset in each fuse call of FusionEmbedLayerNormalization
        self.attention = None
        self.embed_node = None

    def match_two_gather(
            self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]:
        gather_0_path = self.model.match_parent_path(add, ['Gather'], [0])
        if gather_0_path is None:
            return None

        gather_1_path = self.model.match_parent_path(add, ['Gather'], [1])
        if gather_1_path is None:
            return None

        return gather_0_path[0], gather_1_path[0]

    def check_attention_subgraph(self, layernorm: NodeProto,
                                 input_name_to_nodes: Dict[str,
                                                           List[NodeProto]],
                                 is_distil_bert: bool) -> bool:
        """Check that LayerNormalization has a child of Attention node or subgraph like Attention.

        Args:
            layernorm (NodeProto): LayerNormalization node
            input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
            is_distil_bert (bool): whether it is DistilBert or not

        Returns:
            bool: whether there is Attention node or subgraph like Attention
        """
        self.attention = self.model.find_first_child_by_type(
            layernorm, 'Attention', input_name_to_nodes, recursive=False)
        if self.attention is None:
            # In case user disables attention fusion, check whether subgraph looks like Attention.
            if layernorm.output[0] not in input_name_to_nodes:
                return False
            children = input_name_to_nodes[layernorm.output[0]]

            # For Albert, there is MatMul+Add after embedding layer before attention.
            if len(children
                   ) == 1 and children[0].op_type == "MatMul" and children[
                       0].output[0] in input_name_to_nodes:
                grandchildren = input_name_to_nodes[children[0].output[0]]
                if len(grandchildren) == 1 and grandchildren[
                        0].op_type == "Add" and grandchildren[0].output[
                            0] in input_name_to_nodes:
                    nodes = input_name_to_nodes[grandchildren[0].output[0]]
                    for node in nodes:
                        if node.op_type == "Attention":
                            self.attention = node
                            return True
                    children_types = sorted([child.op_type for child in nodes])
            else:
                children_types = sorted([child.op_type for child in children])

            # Two Shape nodes might be merged by ORT
            if is_distil_bert:
                # SkipLayerNormailization might exist when model has been optimized by ORT first.
                if children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'SkipLayerNormalization'] and \
                   children_types != ['Add', 'MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape'] and \
                   children_types != ['Add', 'MatMul', 'MatMul', 'MatMul', 'Shape']:
                    logger.debug(
                        "No Attention like subgraph in children of LayerNormalization"
                    )
                    return False
            else:
                if children_types != ['Add', 'MatMul', 'MatMul', 'MatMul'] and \
                   children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']:
                    logger.debug(
                        "No Attention like subgraph in children of LayerNormalization"
                    )
                    return False
        return True

    def match_position_embedding_distilbert(self, position_embedding_gather,
                                            input_ids, output_name_to_node):
        """  Match position embedding path from input_ids to Gather for DistilBert.

        Pattern is like the following:
                 (input_ids)
                      |
                     Shape
                       |   \
                       |    Gather (indices=1)
                       |       |
                       |      Cast (optional)
                       |       |
                       |      Range (start=0, end=*, delta=1)
                       |       |
                       |    Unsqueeze
                       |    /
                      Expand
                        |
                      Gather
        """
        path1 = self.model.match_parent_path(position_embedding_gather,
                                             ['Expand', 'Shape'], [1, 1])
        if path1 is None:
            return False

        expand, shape = path1
        if shape.input[0] != input_ids:
            return False

        _, path2, _ = self.model.match_parent_paths(expand, [(['Unsqueeze', 'Range', 'Cast', 'Gather', 'Shape'], [0, 0, 1, 0, 0]), \
                                                             (['Unsqueeze', 'Range', 'Gather', 'Shape'], [0, 0, 1, 0])], output_name_to_node)
        if path2 is None:
            return False

        range_node = path2[1]
        if not (self.utils.check_node_input_value(range_node, 0, 0)
                and self.utils.check_node_input_value(range_node, 2, 1)):
            return False

        gather_node = path2[-2]
        if not (self.utils.check_node_input_value(gather_node, 1, 1)):
            return False

        shape_node = path2[-1]
        if shape_node.input[0] != input_ids:
            return False

        return True

    def match_position_embedding_roberta(self, position_embedding_gather,
                                         input_ids, output_name_to_node):
        """  Match position embedding path from input_ids to Gather for Roberta.

        Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):       
          (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
                                                |                              ^
                                                V                              |
                                                +------------------------------+  

        Roberta new pattern from transformers v4.9:
           (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
                                                |                                           ^
                                                V                                           |
                                                +-------------------------------------------+  

        start_node = position_embedding_gather
        start_index = 1

        # match optional Cast node.
        parent = self.model.get_parent(start_node, start_index, output_name_to_node)
        if parent is None:
            return
        if parent.op_type == "Cast":
            if OnnxModel.get_node_attribute(parent, "to") != 7:
                return
            start_node = parent
            start_index = 0

        i, path, return_indices = self.model.match_parent_paths(
            start_node,
            [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
              (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
            output_name_to_node)

        if path is not None:
            # constant input of Add shall be 1.
            i, value = self.model.get_constant_input(path[0])
            if value != 1:
                return False

            _, self.padding_word_id = self.model.get_constant_input(path[-1])

            return input_ids == path[-1].input[0]
        """

        return False

    def match_position_embedding_bert(self, position_embedding_gather,
                                      input_ids, output_name_to_node):
        """  Match position embedding path from input_ids to Gather for BERT.

        BERT Embedding Layer Pattern:       
                                    (input_ids)
                                   /         \
                                 /          Shape
                                /              |
                              /              Gather (indices=1)
                             /                  |
                            /                  Add (optional, B=0)
                           /                    |
                        Gather (segment_ids) Unsqueeze (axes=0)
                           \        |           |
                            \     Gather      Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
                              \    /            |
                                Add          Gather 
                                   \       /
                                      Add
                                       |
                                LayerNormalization
        """
        path = self.model.match_parent_path(position_embedding_gather,
                                            ['Slice', 'Unsqueeze'], [1, 2],
                                            output_name_to_node)
        if path is None:
            return False

        slice, unsqueeze = path
        slice_weight = self.model.get_constant_value(slice.input[0])
        if not (slice_weight is not None  and len(slice_weight.shape) == 2 and slice_weight.shape[0] == 1 \
                and self.utils.check_node_input_value(slice, 1, [0]) \
                and self.utils.check_node_input_value(slice, 3, [1]) \
                and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))):
            return False

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]):
                return False
        else:
            if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
                return False

        node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
        if node is None:
            return False
        if node.op_type == "Add":
            if not self.utils.check_node_input_value(node, 1, 0):
                return False
            gather = self.model.get_parent(node, 0, output_name_to_node)
        else:
            gather = node

        if gather is None or gather.op_type != "Gather":
            return False
        if not (self.utils.check_node_input_value(gather, 1, 1)):
            return False

        shape = self.model.get_parent(gather, 0, output_name_to_node)
        if shape is None or shape.op_type != "Shape":
            return False

        return input_ids == shape.input[0]

    def match_position_embedding(self, position_embedding_gather, input_ids,
                                 output_name_to_node):
        if self.match_position_embedding_bert(position_embedding_gather,
                                              input_ids, output_name_to_node):
            return True

        # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
        #       related: https://github.com/huggingface/transformers/issues/10736
        #if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
        #    return True

        if self.match_position_embedding_distilbert(position_embedding_gather,
                                                    input_ids,
                                                    output_name_to_node):
            return True

        return False

    def check_embedding(self, word_embedding_gather, segment_embedding_gather,
                        position_embedding_gather):
        """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs.
        """
        input_ids = word_embedding_gather.input[1]
        segment_ids = segment_embedding_gather.input[
            1] if segment_embedding_gather else None
        position_ids = position_embedding_gather.input[1]

        if self.shape_infer_helper is not None:
            input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids)
            position_ids_shape = self.shape_infer_helper.get_edge_shape(
                position_ids)
            assert input_ids_shape and position_ids_shape
            if not (len(input_ids_shape) == 2 and len(position_ids_shape) == 2
                    and input_ids_shape[1] == position_ids_shape[1]):
                logger.info(
                    "Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}"
                    .format(input_ids_shape, position_ids_shape))
                return False

            if segment_ids and not self.shape_infer_helper.compare_shape(
                    input_ids, segment_ids):
                logger.info(
                    "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}"
                    .format(
                        input_ids_shape,
                        self.shape_infer_helper.get_edge_shape(segment_ids)))
                return False

        word_embedding_table = self.model.get_constant_value(
            word_embedding_gather.input[0])
        if word_embedding_table is None or len(
                word_embedding_table.shape) != 2:
            logger.info(
                "Cannot fuse EmbedLayerNormalization: word embedding table is not expected"
            )
            return False

        position_embedding_table = self.model.get_constant_value(
            position_embedding_gather.input[0])
        if position_embedding_table is None or len(
                position_embedding_table.shape) != 2 or (
                    word_embedding_table.shape[1] !=
                    position_embedding_table.shape[1]):
            logger.info(
                "Cannot fuse EmbedLayerNormalization: position embedding table is not expected"
            )
            return False

        if segment_ids:
            segment_embedding_table = self.model.get_constant_value(
                segment_embedding_gather.input[0])
            if segment_embedding_table is None or len(
                    segment_embedding_table.shape) != 2 or (
                        word_embedding_table.shape[1] !=
                        segment_embedding_table.shape[1]):
                logger.info(
                    "Cannot fuse EmbedLayerNormalization: segment embedding table is not expected"
                )
                return False

        # In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between.
        # TODO: use other information (like initializer names) to identify different embedding weights automatically.
        if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
            logger.warning(
                f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
            )

        if segment_ids:
            if word_embedding_table.shape[0] <= segment_embedding_table.shape[
                    0]:
                logger.warning(
                    f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
                )

            if position_embedding_table.shape[
                    0] <= segment_embedding_table.shape[0]:
                logger.warning(
                    f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
                )

        return True

    def cast_to_int32(self,
                      input_name: str) -> Tuple[str, Union[None, NodeProto]]:
        """Cast a graph input or node input to int32.

        Args:
            input_name (str): name of graph input or node input

        Returns:
            A tuple of casted input name and the cast node.
            int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
            input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
        """
        input_cast_node = None
        graph_input = self.model.find_graph_input(input_name)
        if graph_input is not None:
            if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
                int32_output, input_cast_node = self.utils.cast_input_to_int32(
                    input_name)
            else:
                int32_output = input_name
        else:
            int32_output, input_cast_node = self.utils.cast_input_to_int32(
                input_name)

        return int32_output, input_cast_node

    def create_fused_node(self, input_ids: str, layernorm: NodeProto,
                          word_embedding_gather: NodeProto,
                          position_embedding_gather: NodeProto,
                          segment_embedding_gather: Union[None, NodeProto]):
        """Create an EmbedLayerNormalization node. Note that segment embedding is optional.

        Args:
            input_ids (str): input_ids for word embeddings
            layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
            word_embedding_gather (NodeProto): the Gather node for word embedding
            position_embedding_gather (NodeProto): the Gather node for position embedding
            segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.

        Returns:
            NodeProto: the EmbedLayerNormalization node created.
        """
        nodes_to_add = []
        input_ids, _ = self.cast_to_int32(input_ids)

        node_name = self.model.create_node_name('EmbedLayerNormalization')

        if layernorm.op_type == "LayerNormalization":
            gamma = layernorm.input[1]
            beta = layernorm.input[2]
        else:  # SkipLayerNormalization
            gamma = layernorm.input[2]
            beta = layernorm.input[3]

        embed_node_inputs = None
        if segment_embedding_gather is not None:
            segment_ids, _ = self.cast_to_int32(
                segment_embedding_gather.input[1])

            embed_node_inputs = [
                input_ids, segment_ids, word_embedding_gather.input[0],
                position_embedding_gather.input[0],
                segment_embedding_gather.input[0], gamma, beta
            ]
        else:  # no segment embedding
            embed_node_inputs = [
                input_ids, '', word_embedding_gather.input[0],
                position_embedding_gather.input[0], '', gamma, beta
            ]

        embed_node = helper.make_node(
            'EmbedLayerNormalization',
            embed_node_inputs,
            outputs=[node_name + "_output", node_name + "_dummy_mask_index"],
            name=node_name)

        embed_node.domain = "com.microsoft"

        # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
        for att in layernorm.attribute:
            if att.name == 'epsilon':
                embed_node.attribute.extend([att])
        # Set default value to 1e-12 if no attribute is found.
        # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
        if len(embed_node.attribute) == 0:
            embed_node.attribute.extend(
                [helper.make_attribute("epsilon", 1.0E-12)])

        # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
        nodes_to_add.append(embed_node)
        for node in nodes_to_add:
            self.node_name_to_graph_name[node.name] = self.this_graph_name
        self.nodes_to_add.extend(nodes_to_add)

        self.embed_node = embed_node
        return embed_node

    def finish_fusion(self, layernorm, embed_node):
        self.model.replace_input_of_all_nodes(layernorm.output[0],
                                              embed_node.output[0])
        # use prune graph to remove nodes that is not needed
        self.prune_graph = True

    def fuse_distilbert(self, layernorm, add_before_layernorm,
                        input_name_to_nodes, output_name_to_node):
        """Fuse embedding layer for DistilBert
        Args:
            layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
            add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
            input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
            output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
        """

        # DistilBert has no segment embedding, subgraph pattern is like
        #       input_ids
        #        |      \
        #        |     (position_embedding_subgraph)
        #        |        |
        #     Gather    Gather
        #          \   /
        #           Add
        #            |
        #    LayerNormalization
        two_gather = self.match_two_gather(add_before_layernorm)
        if two_gather is None:
            return False

        word_embedding_gather, position_embedding_gather = two_gather
        input_ids = word_embedding_gather.input[1]

        if not self.check_attention_subgraph(
                layernorm, input_name_to_nodes, is_distil_bert=True):
            return False

        if not self.match_position_embedding(position_embedding_gather,
                                             input_ids, output_name_to_node):
            return False

        if not self.check_embedding(word_embedding_gather, None,
                                    position_embedding_gather):
            return False

        embed_node = self.create_fused_node(input_ids, layernorm,
                                            word_embedding_gather,
                                            position_embedding_gather, None)
        self.finish_fusion(layernorm, embed_node)
        return True

    def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes,
                  output_name_to_node):
        """Fuse embedding layer for Bert
        Args:
            layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
            add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
            input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
            output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
        """

        add_2_gather = self.model.match_parent_path(add_before_layernorm,
                                                    ['Add'], [0])
        if add_2_gather is None:
            return False

        two_gather = self.match_two_gather(add_2_gather[0])
        if two_gather is None:
            return False

        word_embedding_gather, segment_embedding_gather = two_gather

        input_ids = word_embedding_gather.input[1]

        if not self.check_attention_subgraph(
                layernorm, input_name_to_nodes, is_distil_bert=False):
            return False

        position_embedding_path = self.model.match_parent_path(
            add_before_layernorm, ['Gather'], [1])
        if position_embedding_path is None:
            return False

        position_embedding_gather = position_embedding_path[0]
        if not self.match_position_embedding(position_embedding_gather,
                                             input_ids, output_name_to_node):
            if not self.match_position_embedding(
                    segment_embedding_gather, input_ids, output_name_to_node):
                return False
            # position and segment are switched
            temp = segment_embedding_gather
            segment_embedding_gather = position_embedding_gather
            position_embedding_gather = temp

        if not self.check_embedding(word_embedding_gather,
                                    segment_embedding_gather,
                                    position_embedding_gather):
            return False

        embed_node = self.create_fused_node(input_ids, layernorm,
                                            word_embedding_gather,
                                            position_embedding_gather,
                                            segment_embedding_gather)
        self.finish_fusion(layernorm, embed_node)
        return True

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        if node.op_type == "LayerNormalization":
            first_add_path = self.model.match_parent_path(node, ['Add'], [0])
            if first_add_path is None:
                return
            add_before_layernorm = first_add_path[0]
        else:  # SkipLayerNormalization
            add_before_layernorm = node  # Add is fused into SkipLayerNormalization

        if self.fuse_distilbert(node, add_before_layernorm,
                                input_name_to_nodes, output_name_to_node):
            return

        if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes,
                          output_name_to_node):
            return
Exemple #22
0
    def auto_mixed_precision(onnx_model: OnnxModel,
                             op_block_list: List[str] = [
                                 'Add', 'LayerNormalization', 'FastGelu'
                             ]):
        """Convert GPT-2 model to mixed precision.
           It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
        Args:
            onnx_model (OnnxModel): optimized ONNX model
            op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu']
        Returns:
            parameters(dict): a dictionary of parameters used in float16 conversion
        """
        op_full_set = set([node.op_type for node in onnx_model.nodes()])
        fp32_op_set = set(op_block_list)
        fp16_op_set = op_full_set.difference(fp32_op_set)
        logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")

        # logits is the first output
        logits_output_name = onnx_model.graph().output[0].name

        # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
        is_weight_fp16_precision = False
        output_name_to_node = onnx_model.output_name_to_node()
        assert logits_output_name in output_name_to_node
        node = output_name_to_node[logits_output_name]
        last_matmul_node = None
        if node.op_type == "MatMul":
            last_matmul_node = node
            logger.info(f"Found last MatMul node for logits: {node.name}")
            initializer = None
            for input in node.input:
                initializer = onnx_model.get_initializer(input)
                if initializer is not None:
                    break

            # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
            # we can deduce that the weights are stored in float16 precision.
            max_diff = float_to_float16_max_diff(initializer)
            logger.debug(
                f"max diff of converting weights in last MatMul node {node.name}: {max_diff}"
            )
            is_weight_fp16_precision = (max_diff < 1E-6)
        else:
            logger.warning(
                f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}"
            )

        if is_weight_fp16_precision:
            keep_io_types = []
            node_block_list = []
        else:
            # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
            keep_io_types = [logits_output_name]
            node_block_list = [last_matmul_node.name]

        parameters = {
            "keep_io_types": keep_io_types,
            "op_block_list": op_block_list,
            "node_block_list": node_block_list,
            "force_fp16_initializers": is_weight_fp16_precision
        }

        logger.info(f"auto_mixed_precision parameters: {parameters}")
        onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True,
                                            **parameters)

        fusion_utils = FusionUtils(onnx_model)
        fusion_utils.remove_cascaded_cast_nodes()
        fusion_utils.remove_useless_cast_nodes()

        return parameters
Exemple #23
0
class FusionEmbedLayerNoMask(Fusion):
    """
     Embed Layer Normalization will fuse embeddings and mask processing into one node.
     The embeddings before conversion:

     (input_ids) -------->  Gather ----------+       (segment_ids)
        |                                    |            |
        |                                    v            v
        +--> Shape --> Expand -> Gather---->Add         Gather
        |                ^                   |            |
        |                |                   v            v
        +---(optional graph)               SkipLayerNormalization

      Optional graph is used to generate position list (0, 1, ...) per batch. It can be a constant in some model.

      (input_ids) --> Gather -----+           Slice
                                  |            |
                                  v            v
     (segment_ids)--> Gather --->Add        Reshape
                                  |            |
                                  v            v
                              SkipLayerNormalization
    """
    def __init__(self,
                 model: OnnxModel,
                 description='no mask'):
        super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
        self.utils = FusionUtils(model)

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        # already fused. Assumes that only one mebedding layer in a transformer model.
        if self.nodes_to_add:
            return

        if self.model.match_parent_path(node, ['Add', 'Gather'], [0, 0]) is None:
            return

        if self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False) is None:
            # In case user disables attention fusion, check whether subgraph looks like Attention.
            if node.output[0] not in input_name_to_nodes:
                return
            children = input_name_to_nodes[node.output[0]]
            children_types = sorted([child.op_type for child in children])
            if children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']:
                return

        # Assume the order of embeddings are word_embedding + position_embedding + segment_embedding
        normalize_node = node
        word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 0])
        if word_embedding_path is None:
            logger.info("Failed to find word embedding")
            return
        add_node, word_embedding_gather = word_embedding_path
        input_ids = word_embedding_gather.input[1]

        position_embedding_expand = None
        position_embedding_shape = None

        position_embedding_path = self.model.match_parent_path(normalize_node, ['Reshape', 'Slice'], [1, 0])
        if position_embedding_path is not None:
            _, position_embedding_weight_node = position_embedding_path
        else:
            position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand', 'Shape'], [1, 1, 1])
            if position_embedding_path is not None:
                position_embedding_weight_node, position_embedding_expand, position_embedding_shape = position_embedding_path
            else:
                position_embedding_path = self.model.match_parent_path(
                    add_node, ['Gather', 'Expand', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [1, 1, 1, 1, 0, 0])
                if position_embedding_path is not None:
                    position_embedding_weight_node, position_embedding_expand, _, _, _, position_embedding_shape = position_embedding_path
                else:
                    # Here we will not try to get exact match. Instead, we only try identify position embedding weights.
                    position_embedding_path = self.model.match_parent_path(add_node, ['Gather', 'Expand'], [1, 1])
                    if position_embedding_path is not None:
                        position_embedding_weight_node, position_embedding_expand = position_embedding_path
                    else:
                        logger.info("Failed to find position embedding")
                        return

            if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
                logger.info("position and word embedding is expected to be applied on same input")
                return

        segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
        if segment_embedding_path is None:
            segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
            if segment_embedding_path is None:
                logger.info("Failed to find segment embedding")
                return
            _, segment_embedding_gather = segment_embedding_path
        else:
            segment_embedding_gather = segment_embedding_path[0]

        segment_ids = segment_embedding_gather.input[1]

        if position_embedding_expand and position_embedding_shape:
            input_parent = self.model.get_parent(position_embedding_shape, 0, output_name_to_node)
            subgraph_nodes = self.model.get_parent_subgraph_nodes(position_embedding_expand,
                                                                  [input_parent] if input_parent else [],
                                                                  output_name_to_node)
            self.nodes_to_remove.extend(subgraph_nodes)

        self.nodes_to_remove.extend(word_embedding_path)
        self.nodes_to_remove.extend(position_embedding_path)
        self.nodes_to_remove.extend(segment_embedding_path)

        self.nodes_to_remove.extend([normalize_node])

        # store inputs for further processing
        if self.model.find_graph_input(input_ids):
            self.model.bert_inputs = [input_ids, segment_ids
                                      ] if self.model.find_graph_input(segment_ids) else [input_ids]

        # Cast input_ids and segment_ids to int32.
        input_ids_cast_node = None
        if self.model.find_graph_input(input_ids):
            casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
        else:
            input_ids, input_ids_cast_node = self.utils.cast_input_to_int32(input_ids)

        if self.model.find_graph_input(segment_ids):
            casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
        else:
            segment_ids, segment_ids_cast_node = self.utils.cast_input_to_int32(segment_ids)

            # Cast might be removed by OnnxRuntime.
            _, segment_id_path, _ = self.model.match_parent_paths(
                segment_ids_cast_node,
                [(['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape', 'Cast'], [0, 0, 1, 0, 0, 0]),
                 (['ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [0, 0, 1, 0, 0])], output_name_to_node)

            if segment_id_path and input_ids_cast_node and input_ids_cast_node.input[0] == segment_id_path[-1].input[0]:
                logger.debug("Simplify semgent id path...")
                self.model.add_node(
                    helper.make_node('Shape', inputs=[input_ids_cast_node.input[0]], outputs=["input_shape"]))
                self.model.add_node(
                    helper.make_node('ConstantOfShape',
                                     inputs=["input_shape"],
                                     outputs=["zeros_for_input_shape"],
                                     value=helper.make_tensor("value", onnx.TensorProto.INT32, [1], [1])))
                segment_ids = "zeros_for_input_shape"

        embed_node = helper.make_node(
            'EmbedLayerNormalization',
            inputs=[
                input_ids,
                segment_ids,
                word_embedding_gather.input[0],
                position_embedding_weight_node.input[0],
                segment_embedding_gather.input[0],
                normalize_node.input[2],
                normalize_node.input[3]  # gamma and beta
            ],
            outputs=["embed_output", "dummy_mask_index"],
            name="EmbedLayer")

        embed_node.domain = "com.microsoft"

        # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
        for att in normalize_node.attribute:
            if att.name == 'epsilon':
                embed_node.attribute.extend([att])
        # Set default value to 1e-12 if no attribute is found.
        if len(embed_node.attribute) == 0:
            embed_node.attribute.extend([onnx.helper.make_attribute("epsilon", 1.0E-12)])

        self.model.replace_input_of_all_nodes(normalize_node.output[0], 'embed_output')
        self.nodes_to_add.append(embed_node)
Exemple #24
0
class FusionGptAttention(Fusion):
    """
    Fuse GPT-2 Attention with past state subgraph into one Attention node.
    This does not support attention_mask graph input right now.
    """
    def __init__(self, model: OnnxModel, num_heads: int):
        super().__init__(model, "Attention", "LayerNormalization", "with past")
        # TODO: detect num_heads from graph like FusionAttention
        self.num_heads = num_heads
        self.utils = FusionUtils(model)
        self.casted_attention_mask = {
        }  # map from name of attention mask to the name that casted to int32

    def create_attention_node(self, fc_weight, fc_bias, gemm_qkv, past,
                              present, input, output, mask, is_unidirectional):
        attention_node_name = self.model.create_node_name('GptAttention')
        attention_node = helper.make_node(
            'Attention',
            inputs=[input, fc_weight, fc_bias, mask, past],
            outputs=[attention_node_name + "_output", present],
            name=attention_node_name)
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([
            helper.make_attribute("num_heads", self.num_heads),
            helper.make_attribute("unidirectional",
                                  1 if is_unidirectional else 0)
        ])

        matmul_node = helper.make_node(
            'MatMul',
            inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
            outputs=[attention_node_name + "_matmul_output"],
            name=attention_node_name + "_matmul")

        add_node = helper.make_node(
            'Add',
            inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
            outputs=[output],
            name=attention_node_name + "_add")
        self.nodes_to_add.extend([attention_node, matmul_node, add_node])
        self.node_name_to_graph_name[
            attention_node.name] = self.this_graph_name
        self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
        self.node_name_to_graph_name[add_node.name] = self.this_graph_name

    def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
        # Pattern 1:
        #                      {past}
        #                    /        \
        #                   /          \
        #    Gather(axes=0, indices=0)  Gather(indices=1)
        #      |                          |
        #    Transpose (perm=0,1,3,2)     |
        #      |                          |
        #  Concat_k                     Concat_v
        #      |                        /
        #  Transpose (perm=0,1,3,2)    /
        #      |                      /
        #  Unsqueeze        Unsqueeze
        #        \        /
        #         \      /
        #           Concat
        #             |
        #         {present}
        gather = self.model.get_parent(concat_v, 0, output_name_to_node)
        if gather.op_type != 'Gather':
            logger.debug("match_past_pattern_1: expect Gather for past")
            return None

        if not self.model.find_constant_input(gather, 1) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=1 for Gather of past")
            return None
        past = gather.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Transpose', 'Gather'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_1: failed match Transpose and Gather")
            return None

        gather_past_k = past_k_nodes[-1]
        if not self.model.find_constant_input(gather_past_k, 0) == 1:
            logger.debug(
                "match_past_pattern_1: expect indices=0 for Gather k of past")
            return None
        past_k = gather_past_k.input[0]
        if past != past_k:
            logger.debug("match_past_pattern_1: expect past to be same")
            return None

        return past

    def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
        # Pattern 2:
        #      Split (QKV)
        #      / |   |
        #     /  |   +----------------------+
        #        |                          |
        #        |         {past}           |
        #        |           |              |
        #      Reshape     Split         Reshape
        #        |         /    \           |
        # Transpose_k  Squeeze  Squeeze  Transpose_v
        #        |      |        \        /
        #        +------|---+     \      /
        #               |   |      \    /
        #              Concat_k   Concat_v
        #               |            |
        #          Unsqueeze    Unsqueeze
        #                \       /
        #                 Concat
        #                   |
        #               {present}
        #
        squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
        if squeeze.op_type != 'Squeeze':
            logger.debug(
                "match_past_pattern_2: expect Squeeze as parent of concat_v")
            return None

        split = self.model.get_parent(squeeze, 0, output_name_to_node)
        if split.op_type != "Split":
            logger.debug("match_past_pattern_2: expect Split for past path")
            return None

        opset_version = self.model.get_opset_version()
        if opset_version < 13:
            if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not FusionUtils.check_node_attribute(split, 'split', [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None
        else:
            if not self.utils.check_node_input_value(squeeze, 1, [0]):
                logger.debug(
                    "match_past_pattern_2: axes != [0] for Squeeze in past path"
                )
                return None

            if not self.utils.check_node_input_value(split, 1, [1, 1]):
                logger.debug(
                    "match_past_pattern_2: split != [1, 1] for Split in past path"
                )
                return None

        if not FusionUtils.check_node_attribute(
                split, 'axis', 0, default_value=0):
            logger.debug(
                "match_past_pattern_2: attribute axis of Split are not expected in past path"
            )
            return None
        past = split.input[0]

        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Squeeze', 'Split'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug(
                "match_past_pattern_2: failed to match past_k_nodes path")
            return None
        past_k = past_k_nodes[-1].input[0]

        if past != past_k:
            logger.info("match_past_pattern_2: expect past to be same")
            return None

        return past

    def match_present(self, concat_v, input_name_to_nodes):
        unsqueeze_present_v = self.model.find_first_child_by_type(
            concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False)
        if not unsqueeze_present_v:
            logger.info("expect unsqueeze for present")
            return None
        concat_present = self.model.find_first_child_by_type(
            unsqueeze_present_v,
            'Concat',
            input_name_to_nodes,
            recursive=False)
        if not concat_present:
            logger.info("expect concat for present")
            return None

        present = concat_present.output[0]
        return present

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        past = None
        present = None
        return_indice = []
        qkv_nodes = self.model.match_parent_path(
            normalize_node,
            ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'],
            [0,      None,      0,     0,          0,         0,           0],
            output_name_to_node=output_name_to_node,
            return_indice=return_indice
            ) # yapf: disable
        if qkv_nodes is None:
            return
        (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv,
         matmul_qkv) = qkv_nodes

        another_input = add_qkv.input[1 - return_indice[0]]

        v_nodes = self.model.match_parent_path(
            matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split'],
            [1, 1, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return
        (concat_v, transpose_v, reshape_v, split_fc) = v_nodes

        fc_nodes = self.model.match_parent_path(
            split_fc, ['Reshape', 'Gemm', 'Reshape', 'LayerNormalization'],
            [0, 0, 0, 0], output_name_to_node)
        if fc_nodes is None:
            fc_nodes = self.model.match_parent_path(
                split_fc, ['Add', 'MatMul', 'LayerNormalization'],
                [0, None, 0], output_name_to_node)
            if fc_nodes is None:
                logger.debug("fuse_attention: failed to match fc path")
                return
            fc_weight = fc_nodes[1].input[1]
            i, _ = self.model.get_constant_input(fc_nodes[0])
            fc_bias = fc_nodes[0].input[i]
        else:
            fc_weight = fc_nodes[1].input[1]
            fc_bias = fc_nodes[1].input[2]

        layernorm_before_attention = fc_nodes[-1]

        if not another_input in layernorm_before_attention.input:
            logger.debug(
                "Add and LayerNormalization shall have one same input")
            return

        is_unidirectional = True
        slice_mask = None
        input_mask_nodes = None
        concat_k_to_match = None
        qk_nodes = self.model.match_parent_path(
            matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'],
            [0, 0, 0, 0, 0])
        if qk_nodes is not None:
            (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
            mask_nodes = self.model.match_parent_path(
                sub_qk,
                ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
                [1,      0,     1,       0,       1,           0,     0,         0,       0,       0])  # yapf: disable
            if mask_nodes is None:
                logger.debug(
                    "fuse_attention: failed to match unidirectional mask path")
                return
            div_mask = mask_nodes[-1]
            slice_mask = mask_nodes[3]

            if div_qk != div_mask:
                logger.debug("fuse_attention: skip since div_qk != div_mask")
                return
        else:
            # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
            i, qk_nodes, _ = self.model.match_parent_paths(
                matmul_qkv,
                [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]),
                 (['Softmax', 'Add', 'Where', 'Div', 'MatMul'
                   ], [0, 0, None, 1, 0])], output_name_to_node)
            if qk_nodes is None:
                logger.debug("fuse_attention: failed to match qk nodes")
                return

            where_qk = qk_nodes[-3]
            div_qk = qk_nodes[-2]
            matmul_qk = qk_nodes[-1]

            if i == 1:
                add_qk = qk_nodes[1]
                _, input_mask_nodes, _ = self.model.match_parent_paths(
                    add_qk,
                    [([
                        'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze',
                        'Reshape'
                    ], [None, 0, 1, 0, 0, 0]),
                     (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'
                       ], [None, 0, 1, 0, 0])], output_name_to_node)
                if input_mask_nodes is None:
                    logger.debug(
                        "fuse_attention: failed to match input attention mask path"
                    )
                    return

            mask_nodes = self.model.match_parent_path(
                where_qk,
                ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape'],
                [ 0,     0,       0,       1,           0,     0,         0,       0],
                output_name_to_node)  # yapf: disable
            if mask_nodes is None:
                # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
                logger.debug("fuse_attention: failed to match mask path")
                return

            slice_mask = mask_nodes[2]

            div_or_concat = self.model.get_parent(mask_nodes[-1], 0,
                                                  output_name_to_node)
            if div_or_concat.op_type == "Div":
                div_mask = div_or_concat
                if div_qk != div_mask:
                    logger.debug(
                        "fuse_attention: skip since div_qk != div_mask")
                    return
            elif div_or_concat.op_type == "Concat":
                concat_k_to_match = div_or_concat
            else:
                logger.debug("fuse_attention: failed to match mask path")

        # Validate that the mask data is either lower triangular (unidirectional) or all ones
        mask_data = numpy_helper.to_array(
            self.model.get_initializer(slice_mask.input[0]))
        if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1)
                and mask_data.shape[2] == mask_data.shape[3]):
            logger.debug(
                "fuse_attention: skip since mask shape is not 1x1xWxW")
            return
        if np.allclose(mask_data, np.ones_like(mask_data)):
            is_unidirectional = False
        elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
            logger.debug(
                "fuse_attention: skip since mask is neither lower triangular nor ones"
            )
            return

        q_nodes = self.model.match_parent_path(
            matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return
        (transpose_q, reshape_q, split_q) = q_nodes
        if split_fc != split_q:
            logger.debug("fuse_attention: skip since split_fc != split_q")
            return

        k_nodes = self.model.match_parent_path(
            matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'],
            [1, 1, 0, 0])
        if k_nodes is None:
            # This pattern is from pytorch 1.7.1 and transformers 4.6.1
            k_nodes = self.model.match_parent_path(
                matmul_qk,
                ['Transpose', 'Concat', 'Transpose', 'Reshape', 'Split'],
                [1, 0, 1, 0, 0])
            if k_nodes is None:
                logger.debug("fuse_attention: failed to match k path")
                return
            else:
                (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
        else:
            (concat_k, transpose_k, reshape_k, split_k) = k_nodes
        if split_fc != split_k:
            logger.debug("fuse_attention: skip since split_fc != split_k")
            return

        if concat_k_to_match and concat_k != concat_k_to_match:
            logger.debug(
                "fuse_attention: skip since concat_k != concat_k_to_match")
            return

        attention_mask_input_name = ''
        if input_mask_nodes is not None:
            input_name = input_mask_nodes[-1].input[0]
            if input_name in self.casted_attention_mask:
                attention_mask_input_name = self.casted_attention_mask[
                    input_name]
            elif self.model.find_graph_input(input_name):
                casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name
            else:
                attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name

        # Match past and present paths
        past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or \
               self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
        if past is None:
            logger.info("fuse_attention: failed to match past path")
            return
        if not self.model.find_graph_input(past):
            logger.debug("past is not graph input.")
            # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.

        present = self.match_present(concat_v, input_name_to_nodes)
        if present is None:
            logger.info("fuse_attention: failed to match present path")
            return
        if not self.model.find_graph_output(present):
            logger.info("expect present to be graph output")
            return

        self.create_attention_node(fc_weight, fc_bias, gemm_qkv, past, present,
                                   layernorm_before_attention.output[0],
                                   reshape_qkv.output[0],
                                   attention_mask_input_name,
                                   is_unidirectional)

        # we rely on prune_graph() to clean old subgraph nodes:
        # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
        self.prune_graph = True
class FusionGptAttention(Fusion):
    """
    Fuse GPT-2 Attention with past state subgraph into one Attention node.
    This does not support attention_mask graph input right now.
    """
    def __init__(self, model: OnnxModel, num_heads: int):
        super().__init__(model, "Attention", "LayerNormalization", "with past")
        # TODO: detect num_heads from graph like FusionAttention
        self.num_heads = num_heads
        self.utils = FusionUtils(model)
        self.casted_attention_mask = {
        }  # map from name of attention mask to the name that casted to int32

    def create_attention_node(self, gemm, gemm_qkv, past, present, input,
                              output, mask, is_unidirectional):
        attention_node_name = self.model.create_node_name('GptAttention')
        attention_node = helper.make_node(
            'Attention',
            inputs=[input, gemm.input[1], gemm.input[2], mask, past],
            outputs=[attention_node_name + "_output", present],
            name=attention_node_name)
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([
            helper.make_attribute("num_heads", self.num_heads),
            helper.make_attribute("unidirectional",
                                  1 if is_unidirectional else 0)
        ])

        matmul_node = helper.make_node(
            'MatMul',
            inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
            outputs=[attention_node_name + "_matmul_output"],
            name=attention_node_name + "_matmul")

        add_node = helper.make_node(
            'Add',
            inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
            outputs=[output],
            name=attention_node_name + "_add")
        self.nodes_to_add.extend([attention_node, matmul_node, add_node])

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        past = None
        present = None
        return_indice = []
        qkv_nodes = self.model.match_parent_path(
            normalize_node,
            ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'],
            [0,      None,      0,     0,          0,         0,           0],
            output_name_to_node=output_name_to_node,
            return_indice=return_indice
            ) # yapf: disable
        if qkv_nodes is None:
            return
        (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv,
         matmul_qkv) = qkv_nodes

        another_input = add_qkv.input[1 - return_indice[0]]

        v_nodes = self.model.match_parent_path(
            matmul_qkv,
            ['Concat', 'Transpose', 'Reshape', 'Split', 'Reshape', 'Gemm', 'Reshape'],
            [1,        1,            0,         0,       0,         0,      0]) # yapf: disable
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return
        (concat_v, transpose_v, reshape_v, split_v, reshape_after_gemm, gemm,
         reshape_before_gemm) = v_nodes

        #      concat <-- Gather(indices=1) <-- past
        #        |
        #      unsqueeze
        #        |
        #    concat  -->  present
        gather_v = self.model.get_parent(concat_v, 0, output_name_to_node)
        if gather_v.op_type != 'Gather':
            logger.info("expect Gather for past")
            return
        if not self.model.find_constant_input(gather_v, 1) == 1:
            logger.info("expect indices=1 for Gather of past")
            return
        past = gather_v.input[0]
        if not self.model.find_graph_input(past):
            logger.info("expect past to be graph input")
            return
        unsqueeze_present_v = self.model.find_first_child_by_type(
            concat_v, 'Unsqueeze', input_name_to_nodes, recursive=False)
        if not unsqueeze_present_v:
            logger.info("expect unsqueeze for present")
            return
        concat_present = self.model.find_first_child_by_type(
            unsqueeze_present_v,
            'Concat',
            input_name_to_nodes,
            recursive=False)
        if not concat_present:
            logger.info("expect concat for present")
            return
        present = concat_present.output[0]
        if not self.model.find_graph_output(present):
            logger.info("expect present to be graph input")
            return

        layernorm_before_attention = self.model.get_parent(
            reshape_before_gemm, 0, output_name_to_node)
        if layernorm_before_attention is None or layernorm_before_attention.op_type != 'LayerNormalization':
            logger.debug(
                f"failed to get layernorm before gemm. Got {layernorm_before_attention.op_type}"
            )
            return

        if not another_input in layernorm_before_attention.input:
            logger.debug(
                "Add and LayerNormalization shall have one same input")
            return

        is_unidirectional = True
        slice_mask = None
        input_mask_nodes = None
        qk_nodes = self.model.match_parent_path(
            matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'],
            [0, 0, 0, 0, 0])
        if qk_nodes is not None:
            (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
            mask_nodes = self.model.match_parent_path(
                sub_qk,
                ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
                [1,      0,     1,       0,       1,           0,     0,         0,       0,       0])  # yapf: disable
            if mask_nodes is None:
                logger.debug(
                    "fuse_attention: failed to match unidirectional mask path")
                return
            div_mask = mask_nodes[-1]
            slice_mask = mask_nodes[3]

            if div_qk != div_mask:
                logger.debug("fuse_attention: skip since div_qk != div_mask")
                return
        else:
            # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
            i, qk_nodes, _ = self.model.match_parent_paths(
                matmul_qkv,
                [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]),
                 (['Softmax', 'Add', 'Where', 'Div', 'MatMul'
                   ], [0, 0, 0, 1, 0])], output_name_to_node)
            if qk_nodes is None:
                logger.debug("fuse_attention: failed to match qk nodes")
                return

            where_qk = qk_nodes[-3]
            div_qk = qk_nodes[-2]
            matmul_qk = qk_nodes[-1]

            if i == 1:
                add_qk = qk_nodes[1]
                _, input_mask_nodes, _ = self.model.match_parent_paths(
                    add_qk,
                    [([
                        'Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze',
                        'Reshape'
                    ], [1, 0, 1, 0, 0, 0]),
                     (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'
                       ], [1, 0, 1, 0, 0])], output_name_to_node)
                if input_mask_nodes is None:
                    logger.debug(
                        "fuse_attention: failed to match input attention mask path"
                    )
                    return

            mask_nodes = self.model.match_parent_path(
                where_qk,
                ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'],
                [ 0,     0,       0,       1,           0,     0,         0,       0,       0])  # yapf: disable
            if mask_nodes is None:
                logger.debug("fuse_attention: failed to match mask path")
                return
            div_mask = mask_nodes[-1]
            slice_mask = mask_nodes[2]

            if div_qk != div_mask:
                logger.debug("fuse_attention: skip since div_qk != div_mask")
                return

        # Validate that the mask data is either lower triangular (unidirectional) or all ones
        mask_data = numpy_helper.to_array(
            self.model.get_initializer(slice_mask.input[0]))
        if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1)
                and mask_data.shape[2] == mask_data.shape[3]):
            logger.debug(
                "fuse_attention: skip since mask shape is not 1x1xWxW")
            return
        if np.allclose(mask_data, np.ones_like(mask_data)):
            is_unidirectional = False
        elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
            logger.debug(
                "fuse_attention: skip since mask is neither lower triangular nor ones"
            )
            return

        q_nodes = self.model.match_parent_path(
            matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return
        (transpose_q, reshape_q, split_q) = q_nodes
        if split_v != split_q:
            logger.debug("fuse_attention: skip since split_v != split_q")
            return

        k_nodes = self.model.match_parent_path(
            matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'],
            [1, 1, 0, 0])
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match k path")
            return
        (concat_k, transpose_k, reshape_k, split_k) = k_nodes
        if split_v != split_k:
            logger.debug("fuse_attention: skip since split_v != split_k")
            return

        #     concat_k <-- Transpose (perm=0,1,3,2) <-- Gather(axes=0, indices=0) <-- past
        #        |
        #     Transpose (perm=0,1,3,2)
        #        |
        #      unsqueeze
        #        |
        #    concat  -->  present
        past_k_nodes = self.model.match_parent_path(concat_k,
                                                    ['Transpose', 'Gather'],
                                                    [0, 0])
        if past_k_nodes is None:
            logger.debug("fuse_attention: failed to match past_k_nodes path")
            return

        gather_past_k = past_k_nodes[-1]
        if not self.model.find_constant_input(gather_past_k, 0) == 1:
            logger.info("expect indices=0 for Gather k of past")
            return
        past_k = gather_past_k.input[0]
        if past != past_k:
            logger.info("expect past to be same")
            return

        attention_mask_input_name = ''
        if input_mask_nodes is not None:
            input_name = input_mask_nodes[-1].input[0]
            if input_name in self.casted_attention_mask:
                attention_mask_input_name = self.casted_attention_mask[
                    input_name]
            elif self.model.find_graph_input(input_name):
                casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name
            else:
                attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(
                    input_name)
                self.casted_attention_mask[
                    input_name] = attention_mask_input_name

        self.create_attention_node(gemm, gemm_qkv, past, present,
                                   layernorm_before_attention.output[0],
                                   reshape_qkv.output[0],
                                   attention_mask_input_name,
                                   is_unidirectional)

        # we rely on prune_graph() to clean old subgraph nodes:
        # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
        self.prune_graph = True
Exemple #26
0
 def __init__(self, model: OnnxModel):
     super().__init__(model, "Shape", "Concat")
     self.utils = FusionUtils(model)
     self.shape_infer = None
     self.shape_infer_done = False