Example #1
0
def create_dummy_inputs(onnx_model_path, batch_size, sequence_length, samples):
    from onnx import TensorProto
    from onnx_model import OnnxModel

    onnx_model = OnnxModel(onnx.load(onnx_model_path))
    dummy_inputs = {}
    for input in onnx_model.get_graph_inputs_excluding_initializers():
        shape = get_shape_from_type_proto(input.type)
        symbol_dims = []
        for i, dim in enumerate(shape):
            if type(dim) == str:
                symbol_dims.append(i)

        # allowed symbolic dimensions: batch_size and sequence_length
        if len(symbol_dims) > 2:
            return None
        if len(symbol_dims) > 0:
            shape[symbol_dims[0]] = batch_size
        if len(symbol_dims) > 1:
            shape[symbol_dims[1]] = sequence_length

        elem_type = input.type.tensor_type.elem_type
        assert elem_type in [
            TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64
        ]
        data_type = numpy.float32 if elem_type == TensorProto.FLOAT else (
            numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
        data = numpy.ones(shape, dtype=data_type)
        dummy_inputs[input.name] = data

    all_inputs = [dummy_inputs for _ in range(samples)]
    return all_inputs
Example #2
0
def create_gpt2_inputs(onnx_model_path, batch_size, sequence_length,
                       past_sequence_length, samples):
    from onnx import TensorProto
    from onnx_model import OnnxModel

    onnx_model = OnnxModel(onnx.load(onnx_model_path))
    # The symbolic name shall be same as those used in Gpt2Helper.export_onnx(...) function.
    symbols = {
        'batch_size': batch_size,
        'seq_len': sequence_length,
        'past_seq_len': past_sequence_length,
        'total_seq_len': sequence_length + past_sequence_length
    }

    dummy_inputs = {}
    for input in onnx_model.get_graph_inputs_excluding_initializers():
        shape = get_shape_from_type_proto(input.type)
        for i, dim in enumerate(shape):
            if type(dim) == str and dim not in symbols.keys():
                raise RuntimeError(f"symbol is not supported: {dim}")
            else:
                shape[i] = symbols[dim]

        elem_type = input.type.tensor_type.elem_type
        assert elem_type in [
            TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64
        ]
        data_type = numpy.float32 if elem_type == TensorProto.FLOAT else (
            numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
        data = numpy.ones(shape, dtype=data_type)
        dummy_inputs[input.name] = data

    all_inputs = [dummy_inputs for _ in range(samples)]
    return all_inputs
Example #3
0
def create_longformer_inputs(onnx_model_path, batch_size, sequence_length,
                             global_length, samples):
    from onnx import TensorProto
    from onnx_model import OnnxModel

    onnx_model = OnnxModel(onnx.load(onnx_model_path))
    symbols = {'batch_size': batch_size, 'sequence_length': sequence_length}

    dummy_inputs = {}
    for input in onnx_model.get_graph_inputs_excluding_initializers():
        shape = get_shape_from_type_proto(input.type)
        for i, dim in enumerate(shape):
            if type(dim) == str and dim not in symbols.keys():
                raise RuntimeError(f"symbol is not supported: {dim}")
            else:
                shape[i] = symbols[dim]

        elem_type = input.type.tensor_type.elem_type
        assert elem_type in [
            TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64
        ]
        data_type = numpy.float32 if elem_type == TensorProto.FLOAT else (
            numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)

        if "global" in input.name:
            data = numpy.zeros(shape, dtype=data_type)
            data[:, :global_length] = 1
        else:
            data = numpy.ones(shape, dtype=data_type)
        dummy_inputs[input.name] = data

    all_inputs = [dummy_inputs for _ in range(samples)]
    return all_inputs
Example #4
0
    def verify_fusion(self, optimized_model, expected_model_filename):
        optimized_model.topological_sort()

        expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
        expected_model = OnnxModel(onnx.load(expected_model_path))
        expected_model.topological_sort()

        self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph))
Example #5
0
def get_last_matmul_node_name(raw_onnx_model: str):
    model = onnx.load(raw_onnx_model)
    onnx_model = OnnxModel(model)
    output_name_to_node = onnx_model.output_name_to_node()

    assert model.graph.output[0].name in output_name_to_node
    node = output_name_to_node[model.graph.output[0].name]
    if node.op_type == "MatMul":
        logger.info(f"Found last MatMul node for logits: {node.name}")
        return node.name

    logger.warning(
        f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}"
    )
    return None
Example #6
0
def get_bert_inputs(
    onnx_file: str,
    input_ids_name: Optional[str] = None,
    segment_ids_name: Optional[str] = None,
    input_mask_name: Optional[str] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
    """Find graph inputs for BERT model.
    First, we will deduce inputs from EmbedLayerNormalization node.
    If not found, we will guess the meaning of graph inputs based on naming.

    Args:
        onnx_file (str): onnx model path
        input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
        segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
        input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.

    Returns:
        Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
                                                                                 segment_ids and input_mask
    """
    model = ModelProto()
    with open(onnx_file, "rb") as file:
        model.ParseFromString(file.read())

    onnx_model = OnnxModel(model)
    return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name,
                            input_mask_name)
Example #7
0
    def test_pytorch_model_0_gpu_onnxruntime(self):
        if 'CUDAExecutionProvider' not in onnxruntime.get_available_providers(
        ):
            print(
                "skip test_pytorch_model_0_gpu_onnxruntime since no gpu found")
            return

        input = _get_test_model_path('bert_pytorch_0')
        output = 'temp.onnx'
        optimize_by_onnxruntime(input,
                                use_gpu=True,
                                optimized_model_path=output)
        model = ModelProto()
        with open(output, "rb") as f:
            model.ParseFromString(f.read())
        os.remove(output)
        bert_model = OnnxModel(model)
        expected_node_count = {
            'EmbedLayerNormalization': 1,
            'Attention': 12,
            'SkipLayerNormalization': 24,
            'Gelu': 0,
            'FastGelu': 12,
            'BiasGelu': 0
        }
        self.verify_node_count(bert_model, expected_node_count,
                               'test_pytorch_model_0_gpu_onnxruntime')
Example #8
0
def main():
    args = parse_arguments()
    setup_logging(args.verbose)

    output_names = None if args.output_names is None else args.output_names.split(
        ";")

    model = ModelProto()
    with open(args.input, "rb") as input_file:
        model.ParseFromString(input_file.read())
    onnx_model = OnnxModel(model)

    optimizer = BertOnnxModelShapeOptimizer(onnx_model)

    optimizer.optimize(
        args.output,
        args.input_ids,
        args.segment_ids,
        args.input_mask,
        args.enable_shape_opt,
        args.enable_reshape_opt,
        output_names,
        args.batch_size,
        args.sequence_length,
        args.verbose,
    )
Example #9
0
def run(args):
    num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(
        logical=False)

    # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package.
    if "OMP_NUM_THREADS" not in os.environ:
        os.environ["OMP_NUM_THREADS"] = str(num_threads)

    from onnx import load
    from onnx_model import OnnxModel

    onnx_model = OnnxModel(load(args.model))

    all_inputs = None
    if args.dummy_inputs == "bert":
        all_inputs = create_bert_inputs(
            onnx_model,
            args.batch_size,
            args.sequence_length,
            args.samples,
            args.input_ids_name,
            args.segment_ids_name,
            args.input_mask_name,
        )
    elif args.dummy_inputs == "gpt2":
        all_inputs = create_gpt2_inputs(
            onnx_model,
            args.batch_size,
            args.sequence_length,
            args.past_sequence_length,
            args.samples,
        )
    elif args.dummy_inputs == "longformer":
        all_inputs = create_longformer_inputs(
            onnx_model,
            args.batch_size,
            args.sequence_length,
            args.global_length,
            args.samples,
        )
    else:  # default
        all_inputs = create_dummy_inputs(onnx_model, args.batch_size,
                                         args.sequence_length, args.samples)

    profile_file = run_profile(
        args.model,
        args.use_gpu,
        args.provider,
        args.basic_optimization,
        args.thread_num,
        all_inputs,
    )

    return profile_file
Example #10
0
 def test_gpt2_past_fp16(self):
     input_model_path = _get_test_model_path('gpt2_past')
     model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True))
     model.convert_model_float32_to_float16(cast_input_output=False, use_symbolic_shape_infer=False)
     for input in model.graph().input[1:]:
         self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16)
     for output in model.graph().output:
         self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)
Example #11
0
def run(args):
    num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(
        logical=False)

    # Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package.
    if "OMP_NUM_THREADS" not in os.environ:
        os.environ["OMP_NUM_THREADS"] = str(num_threads)

    from onnx import load
    from onnx_model import OnnxModel
    onnx_model = OnnxModel(load(args.model))

    all_inputs = None
    if args.dummy_inputs == 'bert':
        all_inputs = create_bert_inputs(onnx_model, args.batch_size,
                                        args.sequence_length, args.samples,
                                        args.input_ids_name,
                                        args.segment_ids_name,
                                        args.input_mask_name)
    elif args.dummy_inputs == 'gpt2':
        all_inputs = create_gpt2_inputs(onnx_model, args.batch_size,
                                        args.sequence_length,
                                        args.past_sequence_length,
                                        args.samples)
    elif args.dummy_inputs == 'longformer':
        all_inputs = create_longformer_inputs(onnx_model, args.batch_size,
                                              args.sequence_length,
                                              args.global_length, args.samples)
    else:  # default
        all_inputs = create_dummy_inputs(onnx_model, args.batch_size,
                                         args.sequence_length, args.samples)

    profile_file = run_profile(args.model, args.use_gpu,
                               args.basic_optimization, args.thread_num,
                               all_inputs)

    profile_records = load_profile_json(profile_file)

    lines = parse_profile_results(profile_records, args.kernel_time_only,
                                  args.threshold)

    lines.append("\nGrouped by operator type:")
    lines.append("-" * 64)
    lines += group_profile_results(profile_records, args.kernel_time_only,
                                   args.use_gpu)

    return lines
Example #12
0
 def test_pytorch_model_0_cpu_onnxruntime(self):
     input = _get_test_model_path('bert_pytorch_0')
     output = 'temp.onnx'
     optimize_by_onnxruntime(input, use_gpu=False, optimized_model_path=output)
     model = ModelProto()
     with open(output, "rb") as f:
         model.ParseFromString(f.read())
     os.remove(output)
     bert_model = OnnxModel(model)
     expected_node_count = {
         'EmbedLayerNormalization': 1,
         'Attention': 12,
         'SkipLayerNormalization': 24,
         'Gelu': 0,
         'FastGelu': 0,
         'BiasGelu': 12
     }
     self.verify_node_count(bert_model, expected_node_count, 'test_pytorch_model_0_cpu_onnxruntime')
Example #13
0
def get_bert_inputs(onnx_file,
                    input_ids_name=None,
                    segment_ids_name=None,
                    input_mask_name=None):
    """Find graph inputs for BERT model.
    First, we will deduce from EmbedLayerNormalization node. If not found, we will guess based on naming.

    Args:
        onnx_file (str): onnx model path
        input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
        segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
        input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
    """
    model = ModelProto()
    with open(onnx_file, "rb") as f:
        model.ParseFromString(f.read())

    onnx_model = OnnxModel(model)
    return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name,
                            input_mask_name)
Example #14
0
    def remove_useless_reshape_nodes(model: OnnxModel):
        """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape
        """
        shape_infer = model.infer_runtime_shape(update=True)
        if shape_infer is None:
            return

        nodes_to_remove = []
        for node in model.nodes():
            if node.op_type == 'Reshape':
                input_shape = shape_infer.get_edge_shape(node.input[0])
                output_shape = shape_infer.get_edge_shape(node.output[0])
                if input_shape and output_shape and input_shape == output_shape:
                    logger.info(
                        f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
                    )
                    nodes_to_remove.append(node)

        if nodes_to_remove:
            for node in nodes_to_remove:
                model.replace_input_of_all_nodes(node.output[0], node.input[0])
                model.remove_node(node)
            model.prune_graph()
Example #15
0
    def export_onnx(
        encoder: T5Encoder,
        device: torch.device,
        onnx_model_path: str,
        verbose: bool = True,
        use_external_data_format: bool = False,
        use_int32_inputs: bool = False,
    ):
        """Export encoder to ONNX

        Args:
            encoder (T5Encoder): encoder object
            device (torch.device): device of encoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
        """
        config = encoder.config
        encoder_inputs = T5EncoderInputs.create_dummy(
            batch_size=2,
            sequence_length=4,
            vocab_size=config.vocab_size,
            device=device,
            use_int32_inputs=use_int32_inputs,
        )

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        with tempfile.TemporaryDirectory() as tmp_dir_name:
            temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
            Path(temp_onnx_model_path).parent.mkdir(parents=True,
                                                    exist_ok=True)
            torch_onnx_export(
                encoder,
                args=tuple(encoder_inputs.to_list()),
                f=temp_onnx_model_path
                if use_external_data_format else onnx_model_path,
                export_params=True,
                input_names=["input_ids", "attention_mask"],
                output_names=["hidden_states"],
                dynamic_axes={
                    "input_ids": {
                        0: "batch_size",
                        1: "sequence_length"
                    },
                    "attention_mask": {
                        0: "batch_size",
                        1: "sequence_length"
                    },
                    "hidden_states": {
                        0: "batch_size",
                        1: "sequence_length"
                    },
                },
                opset_version=12,
                do_constant_folding=True,
                use_external_data_format=use_external_data_format,
                verbose=verbose,
            )

            if use_external_data_format:
                model = onnx.load_model(temp_onnx_model_path,
                                        load_external_data=True)
                OnnxModel.save(
                    model,
                    onnx_model_path,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                )
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)

    nodes_to_remove = onnx_model.nodes()
    node_to_add = onnx.helper.make_node(
        "LayerNormalization",
        ["input", "layer_norm.weight", "layer_norm.bias"], ["output"],
        "layer_norm",
        epsilon=epsilon)

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_node(node_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)
Example #17
0
def get_bert_inputs(onnx_file,
                    input_ids_name=None,
                    segment_ids_name=None,
                    input_mask_name=None):
    """
    Get graph inputs for bert model.
    First, we will deduce from EmbedLayerNormalization node. If not found, we will guess based on naming.
    """
    model = ModelProto()
    with open(onnx_file, "rb") as f:
        model.ParseFromString(f.read())

    onnx_model = OnnxModel(model)
    graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()

    if input_ids_name is not None:
        input_ids = onnx_model.find_graph_input(input_ids_name)
        if input_ids is None:
            raise ValueError(
                f"Graph does not have input named {input_ids_name}")

        segment_ids = None
        if segment_ids_name:
            segment_ids = onnx_model.find_graph_input(segment_ids_name)
            if segment_ids is None:
                raise ValueError(
                    f"Graph does not have input named {segment_ids_name}")

        input_mask = None
        if input_mask_name:
            input_mask = onnx_model.find_graph_input(input_mask_name)
            if input_mask is None:
                raise ValueError(
                    f"Graph does not have input named {input_mask_name}")

        expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else
                                                           0)
        if len(graph_inputs) != expected_inputs:
            raise ValueError(
                f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}"
            )

        return input_ids, segment_ids, input_mask

    if len(graph_inputs) != 3:
        raise ValueError("Expect the graph to have 3 inputs. Got {}".format(
            len(graph_inputs)))

    embed_nodes = onnx_model.get_nodes_by_op_type('EmbedLayerNormalization')
    if len(embed_nodes) == 1:
        embed_node = embed_nodes[0]
        input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
        segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node,
                                                      1)
        input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)
        return input_ids, segment_ids, input_mask

    # Try guess the inputs based on naming.
    input_ids = None
    segment_ids = None
    input_mask = None
    for input in graph_inputs:
        input_name_lower = input.name.lower()
        if "mask" in input_name_lower:  # matches input with name like "attention_mask" or "input_mask"
            input_mask = input
        elif "token" in input_name_lower or "segment" in input_name_lower:  # matches input with name like "segment_ids" or "token_type_ids"
            segment_ids = input
        else:
            input_ids = input

    if input_ids and segment_ids and input_mask:
        return input_ids, segment_ids, input_mask

    raise ValueError(
        "Fail to assign 3 inputs. You might try rename the graph inputs.")
Example #18
0
    def export_onnx(
        model,
        device,
        onnx_model_path: str,
        verbose: bool = False,
        use_external_data_format: bool = False,
        has_position_ids: bool = True,
        has_attention_mask: bool = True,
        input_ids_dtype: torch.dtype = torch.int32,
        position_ids_dtype: torch.dtype = torch.int32,
        attention_mask_dtype: torch.dtype = torch.int32,
    ):
        """Export GPT-2 model with past state to ONNX model."""
        config: GPT2Config = model.config
        num_layer = config.n_layer
        dummy_inputs = Gpt2Helper.get_dummy_inputs(
            batch_size=1,
            past_sequence_length=1,
            sequence_length=1,
            num_attention_heads=config.num_attention_heads,
            hidden_size=config.hidden_size,
            num_layer=num_layer,
            vocab_size=config.vocab_size,
            device=device,
            float16=False,
            has_position_ids=has_position_ids,
            has_attention_mask=has_attention_mask,
            input_ids_dtype=input_ids_dtype,
            position_ids_dtype=position_ids_dtype,
            attention_mask_dtype=attention_mask_dtype,
        )
        input_list = dummy_inputs.to_list()

        with torch.no_grad():
            outputs = model(*input_list)

        past_names = [f"past_{i}" for i in range(num_layer)]
        present_names = [f"present_{i}" for i in range(num_layer)]

        # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores)
        assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[
            2] == config.hidden_size
        output_names = [
            "logits"
            if outputs[0].shape[2] == config.vocab_size else "last_state"
        ] + present_names

        # Shape of input tensors:
        #    input_ids: (batch_size, seq_len)
        #    past_{i}:  (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
        #    attention_mask: (batch_size, past_seq_len + seq_len)
        # Shape of output tensors:
        #    last_state: (batch_size, seq_len, hidden_size)
        #      or logits: (batch_size, seq_len, vocab_size)
        #    present_{i}:  (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads)
        dynamic_axes = {
            "input_ids": {
                0: "batch_size",
                1: "seq_len"
            },
            output_names[0]: {
                0: "batch_size",
                1: "seq_len"
            },
        }
        for name in past_names:
            dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"}
        for name in present_names:
            dynamic_axes[name] = {1: "batch_size", 3: "total_seq_len"}

        input_names = ["input_ids"]
        if has_position_ids:
            dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
            input_names.append("position_ids")
        if has_attention_mask:
            dynamic_axes["attention_mask"] = {
                0: "batch_size",
                1: "total_seq_len"
            }
            input_names.append("attention_mask")
        input_names.extend(past_names)

        assert len(outputs) == 2 and len(outputs[1]) == num_layer

        logger.info(
            f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
        )

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        if use_external_data_format:
            # We let PyTorch export onnx to a temp directory first, then convert external data to one file.
            with tempfile.TemporaryDirectory() as tmp_dir_name:
                temp_onnx_model_path = os.path.join(tmp_dir_name, "gpt2.onnx")
                Path(temp_onnx_model_path).parent.mkdir(parents=True,
                                                        exist_ok=True)

                torch_onnx_export(
                    model,
                    args=tuple(input_list),
                    f=temp_onnx_model_path,
                    export_params=True,
                    input_names=input_names,
                    output_names=output_names,
                    dynamic_axes=dynamic_axes,
                    opset_version=11,
                    do_constant_folding=True,
                    use_external_data_format=True,
                    verbose=verbose,
                )

                model = onnx.load_model(temp_onnx_model_path,
                                        load_external_data=True)
                OnnxModel.save(
                    model,
                    onnx_model_path,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                )
        else:
            torch_onnx_export(
                model,
                args=tuple(input_list),
                f=onnx_model_path,
                export_params=True,
                input_names=input_names,
                output_names=output_names,
                dynamic_axes=dynamic_axes,
                opset_version=11,
                do_constant_folding=True,
                use_external_data_format=False,
                verbose=verbose,
            )
Example #19
0
def get_longformer_inputs(onnx_file,
                          input_ids_name=None,
                          input_mask_name=None,
                          global_mask_name=None):
    """
    Get graph inputs for longformer model.
    """
    model = ModelProto()
    with open(onnx_file, "rb") as f:
        model.ParseFromString(f.read())

    onnx_model = OnnxModel(model)
    graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()

    if input_ids_name is not None:
        input_ids = onnx_model.find_graph_input(input_ids_name)
        if input_ids is None:
            raise ValueError(
                f"Graph does not have input named {input_ids_name}")

        input_mask = None
        if input_mask_name:
            input_mask = onnx_model.find_graph_input(input_mask_name)
            if input_mask is None:
                raise ValueError(
                    f"Graph does not have input named {input_mask_name}")

        global_mask = None
        if global_mask_name:
            global_mask = onnx_model.find_graph_input(global_mask_name)
            if global_mask is None:
                raise ValueError(
                    f"Graph does not have input named {global_mask_name}")

        expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else
                                                          0)
        if len(graph_inputs) != expected_inputs:
            raise ValueError(
                f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}"
            )

        return input_ids, input_mask, global_mask

    if len(graph_inputs) != 3:
        raise ValueError("Expect the graph to have 3 inputs. Got {}".format(
            len(graph_inputs)))

    # Try guess the inputs based on naming.
    input_ids = None
    input_mask = None
    global_mask = None
    for input in graph_inputs:
        input_name_lower = input.name.lower()
        if "global" in input_name_lower:
            global_mask = input
        elif "mask" in input_name_lower:
            input_mask = input
        else:
            input_ids = input

    if input_ids and input_mask and global_mask:
        return input_ids, input_mask, global_mask

    raise ValueError(
        "Fail to assign 3 inputs. You might try rename the graph inputs.")
Example #20
0
    def change_graph_input_type(
        self,
        graph: GraphProto,
        graph_input: ValueInfoProto,
        new_type: int = TensorProto.INT32,
    ):
        """Change graph input type, and add Cast node if needed.

        Args:
            graph (GraphProto): graph
            graph_input (TensorProto): input of the graph
            new_type (int, optional): new data type. Defaults to TensorProto.INT32.

        Returns:
            NodeProto: a new Cast node that added. None if Cast node is not added.
            List[NodeProto]: Cast nodes that have been removed.
        """
        assert isinstance(graph, GraphProto)
        assert isinstance(graph_input, ValueInfoProto)
        assert self.find_graph_input(graph_input.name)

        if graph_input.type.tensor_type.elem_type == int(new_type):
            return None, []

        new_cast_node = None
        nodes_to_remove = []

        input_name_to_nodes = self.input_name_to_nodes()
        if graph_input.name in input_name_to_nodes:
            nodes = input_name_to_nodes[graph_input.name]

            # For children that is not Cast node, insert a Cast node to convert int32 to original data type.
            nodes_not_cast = [node for node in nodes if node.op_type != "Cast"]
            if nodes_not_cast:
                node_name = self.create_node_name("Cast")
                output_name = node_name + "_" + graph_input.name
                new_value_info = graph.value_info.add()
                new_value_info.CopyFrom(graph_input)
                new_value_info.name = output_name
                new_cast_node = helper.make_node(
                    "Cast",
                    [graph_input.name],
                    [output_name],
                    to=int(graph_input.type.tensor_type.elem_type),
                    name=node_name,
                )
                graph.node.extend([new_cast_node])

                for node in nodes_not_cast:
                    OnnxModel.replace_node_input(node, graph_input.name,
                                                 output_name)

            # For children that is Cast node, no need to insert Cast.
            # When the children is Cast to int32, we can remove that Cast node since input type is int32 now.
            nodes_cast = [node for node in nodes if node.op_type == "Cast"]
            for node in nodes_cast:
                if OnnxModel.get_node_attribute(node, "to") == int(new_type):
                    self.replace_input_of_all_nodes(node.output[0],
                                                    graph_input.name)
                if not self.find_graph_output(node.output[0]):
                    nodes_to_remove.append(node)
            if nodes_to_remove:
                self.remove_nodes(nodes_to_remove)

        graph_input.type.tensor_type.elem_type = int(new_type)
        return new_cast_node, nodes_to_remove
Example #21
0
    def auto_mixed_precision(onnx_model: OnnxModel,
                             op_block_list: List[str] = [
                                 'Add', 'LayerNormalization', 'FastGelu'
                             ]):
        """Convert GPT-2 model to mixed precision.
           It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
        Args:
            onnx_model (OnnxModel): optimized ONNX model
            op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu']
        Returns:
            parameters(dict): a dictionary of parameters used in float16 conversion
        """
        op_full_set = set([node.op_type for node in onnx_model.nodes()])
        fp32_op_set = set(op_block_list)
        fp16_op_set = op_full_set.difference(fp32_op_set)
        logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")

        # logits is the first output
        logits_output_name = onnx_model.graph().output[0].name

        # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
        is_weight_fp16_precision = False
        output_name_to_node = onnx_model.output_name_to_node()
        assert logits_output_name in output_name_to_node
        node = output_name_to_node[logits_output_name]
        last_matmul_node = None
        if node.op_type == "MatMul":
            last_matmul_node = node
            logger.info(f"Found last MatMul node for logits: {node.name}")
            initializer = None
            for input in node.input:
                initializer = onnx_model.get_initializer(input)
                if initializer is not None:
                    break

            # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
            # we can deduce that the weights are stored in float16 precision.
            max_diff = float_to_float16_max_diff(initializer)
            logger.debug(
                f"max diff of converting weights in last MatMul node {node.name}: {max_diff}"
            )
            is_weight_fp16_precision = (max_diff < 1E-6)
        else:
            logger.warning(
                f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}"
            )

        if is_weight_fp16_precision:
            keep_io_types = []
            node_block_list = []
        else:
            # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
            keep_io_types = [logits_output_name]
            node_block_list = [last_matmul_node.name]

        parameters = {
            "keep_io_types": keep_io_types,
            "op_block_list": op_block_list,
            "node_block_list": node_block_list,
            "force_fp16_initializers": is_weight_fp16_precision
        }

        logger.info(f"auto_mixed_precision parameters: {parameters}")
        onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True,
                                            **parameters)

        fusion_utils = FusionUtils(onnx_model)
        fusion_utils.remove_cascaded_cast_nodes()
        fusion_utils.remove_useless_cast_nodes()

        return parameters
Example #22
0
    def export_onnx(
        decoder: Union[T5Decoder, T5DecoderInit],
        device: torch.device,
        onnx_model_path: str,
        verbose: bool = True,
        use_external_data_format: bool = False,
        use_int32_inputs: bool = False,
    ):
        """Export decoder to ONNX

        Args:
            decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
            device (torch.device): device of decoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
            use_int32_inputs (bool, optional): use int32 inputs
        """
        assert isinstance(decoder, (T5Decoder, T5DecoderInit))

        inputs = T5DecoderInputs.create_dummy(
            decoder.config,
            batch_size=2,
            encode_sequence_length=3,
            past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
            device=device,
            use_int32_inputs=use_int32_inputs,
        )
        input_list = inputs.to_list()

        past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False)
        present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True)
        present_self_names = present_names[: 2 * decoder.config.num_layers]

        input_past_names = past_names if isinstance(decoder, T5Decoder) else []
        output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
        output_names = ["logits"] + output_present_names

        # Shape of input tensors (sequence_length==1):
        #    input_ids: (batch_size, sequence_length)
        #    encoder_attention_mask: (batch_size, encode_sequence_length)
        #    encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)

        # Shape of output tensors:
        #    logits: (batch_size, sequence_length, vocab_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)

        input_names = ["input_ids"]
        input_names.append("encoder_attention_mask")
        input_names.append("encoder_hidden_states")
        input_names.extend(input_past_names)

        dynamic_axes = {
            "input_ids": {
                0: "batch_size",
                # 1: 'sequence_length'
            },
            "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
            "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
            "logits": {
                0: "batch_size",
                # 1: 'sequence_length'
            },
        }

        for name in input_past_names:
            dynamic_axes[name] = {
                0: "batch_size",
                2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
            }

        for name in output_present_names:
            if "cross" in name:
                dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
            else:  # self attention past state
                if isinstance(decoder, T5Decoder):
                    dynamic_axes[name] = {
                        0: "batch_size",
                        2: "past_decode_sequence_length + 1",
                    }
                else:
                    dynamic_axes[name] = {
                        0: "batch_size",
                        # 2: 'sequence_length'
                    }

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        with tempfile.TemporaryDirectory() as tmp_dir_name:
            temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
            Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
            torch_onnx_export(
                decoder,
                args=tuple(input_list),
                f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
                export_params=True,
                input_names=input_names,
                output_names=output_names,
                dynamic_axes=dynamic_axes,
                opset_version=12,
                do_constant_folding=True,
                use_external_data_format=use_external_data_format,
                verbose=verbose,
            )

            if use_external_data_format:
                model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
                OnnxModel.save(
                    model,
                    onnx_model_path,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                )
Example #23
0
def find_bert_inputs(
    onnx_model: OnnxModel,
    input_ids_name: Optional[str] = None,
    segment_ids_name: Optional[str] = None,
    input_mask_name: Optional[str] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
    """Find graph inputs for BERT model.
    First, we will deduce inputs from EmbedLayerNormalization node.
    If not found, we will guess the meaning of graph inputs based on naming.

    Args:
        onnx_model (OnnxModel): onnx model object
        input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
        segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
        input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.

    Raises:
        ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
        ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
                    and input_mask_name

    Returns:
        Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
                                                                                 segment_ids and input_mask
    """

    graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()

    if input_ids_name is not None:
        input_ids = onnx_model.find_graph_input(input_ids_name)
        if input_ids is None:
            raise ValueError(
                f"Graph does not have input named {input_ids_name}")

        segment_ids = None
        if segment_ids_name:
            segment_ids = onnx_model.find_graph_input(segment_ids_name)
            if segment_ids is None:
                raise ValueError(
                    f"Graph does not have input named {segment_ids_name}")

        input_mask = None
        if input_mask_name:
            input_mask = onnx_model.find_graph_input(input_mask_name)
            if input_mask is None:
                raise ValueError(
                    f"Graph does not have input named {input_mask_name}")

        expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else
                                                           0)
        if len(graph_inputs) != expected_inputs:
            raise ValueError(
                f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}"
            )

        return input_ids, segment_ids, input_mask

    if len(graph_inputs) != 3:
        raise ValueError("Expect the graph to have 3 inputs. Got {}".format(
            len(graph_inputs)))

    embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization")
    if len(embed_nodes) == 1:
        embed_node = embed_nodes[0]
        input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
        segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node,
                                                      1)
        input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)

        if input_mask is None:
            for input in graph_inputs:
                input_name_lower = input.name.lower()
                if "mask" in input_name_lower:
                    input_mask = input
        if input_mask is None:
            raise ValueError(f"Failed to find attention mask input")

        return input_ids, segment_ids, input_mask

    # Try guess the inputs based on naming.
    input_ids = None
    segment_ids = None
    input_mask = None
    for input in graph_inputs:
        input_name_lower = input.name.lower()
        if "mask" in input_name_lower:  # matches input with name like "attention_mask" or "input_mask"
            input_mask = input
        elif (
                "token" in input_name_lower or "segment" in input_name_lower
        ):  # matches input with name like "segment_ids" or "token_type_ids"
            segment_ids = input
        else:
            input_ids = input

    if input_ids and segment_ids and input_mask:
        return input_ids, segment_ids, input_mask

    raise ValueError(
        "Fail to assign 3 inputs. You might try rename the graph inputs.")
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)

    weight_name = get_weight(onnx_model)
    bias_name = get_bias(onnx_model)
    nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name]

    nodes_to_remove = onnx_model.nodes()
    node_to_add = onnx.helper.make_node("LayerNormalization", ["input", weight_name, bias_name], ["output"],
                                        "layer_norm",
                                        epsilon=epsilon)

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_node(node_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)
def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path, epsilon):
    m = onnx.load(input_onnx_path)
    onnx_model = OnnxModel(m)
    weight_name = get_weight(onnx_model)
    bias_name = get_bias(onnx_model)
    nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name]
    nodes_to_add = [
        onnx.helper.make_node("Cast", ["input"], ["fp32_input"], "cast_input", to=1),
        onnx.helper.make_node("Cast", [weight_name], ["fp32_layer_norm.weight"], "cast_weight", to=1),
        onnx.helper.make_node("Cast", [bias_name], ["fp32_layer_norm.bias"], "cast_bias", to=1),
        onnx.helper.make_node("LayerNormalization", ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"],
                              ["fp32_output"],
                              "layer_norm",
                              epsilon=epsilon),  # use fp32 epsilon
        onnx.helper.make_node("Cast", ["fp32_output"], ["output"], "cast_output", to=10)
    ]

    onnx_model.remove_nodes(nodes_to_remove)
    onnx_model.add_nodes(nodes_to_add)
    onnx_model.prune_graph()
    onnx_model.save_model_to_file(optimized_onnx_path)