示例#1
0
def export_onnx_model_from_pt(model_name, opset_version, use_external_data_format, model_type, model_class,
                              config_modifier, cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info,
                              validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics, fusion_options):

    config, model = load_pt_model(model_name, model_class, cache_dir, config_modifier)
    # config, model = load_pt_model_from_tf(model_name)
    model.cpu()

    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    max_input_size = tokenizer.max_model_input_sizes[
        model_name] if model_name in tokenizer.max_model_input_sizes else 1024

    example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")

    example_inputs = filter_inputs(example_inputs, input_names)

    example_outputs = model(**example_inputs)

    assert isinstance(example_outputs, (list, tuple)), f"type of output is not list or tuple: {type(example_outputs)}"

    # Flatten is needed for gpt2 and distilgpt2.
    example_outputs_flatten = flatten(example_outputs)
    example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])

    onnx_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, False,
                                         use_external_data_format)

    if overwrite or not os.path.exists(onnx_model_path):
        logger.info("Exporting ONNX model to {}".format(onnx_model_path))
        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)

        replace_torch_functions()
        torch_onnx_export(model=model,
                          args=tuple(example_inputs.values()),
                          f=onnx_model_path,
                          input_names=list(example_inputs.keys()),
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          do_constant_folding=True,
                          opset_version=opset_version,
                          use_external_data_format=use_external_data_format)
        restore_torch_functions()
    else:
        logger.info(f"Skip export since model existed: {onnx_model_path}")

    onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
        model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimizer_info,
        validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path,
        example_inputs, example_outputs_flatten, None, fusion_options)

    return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
示例#2
0
    def export_onnx(encoder: T5Encoder,
                    device: torch.device,
                    onnx_model_path: str,
                    verbose: bool = True,
                    use_external_data_format: bool = False):
        """Export encoder to ONNX

        Args:
            encoder (T5Encoder): encoder object
            device (torch.device): device of encoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
        """
        config = encoder.config
        encoder_inputs = T5EncoderInputs.create_dummy(
            batch_size=2,
            sequence_length=4,
            vocab_size=config.vocab_size,
            device=device)

        with torch.no_grad():
            outputs = encoder(encoder_inputs.input_ids,
                              encoder_inputs.attention_mask)

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
        torch_onnx_export(encoder,
                          args=tuple(encoder_inputs.to_list()),
                          f=onnx_model_path,
                          export_params=True,
                          input_names=['input_ids', 'attention_mask'],
                          output_names=['hidden_states'],
                          dynamic_axes={
                              'input_ids': {
                                  0: 'batch_size',
                                  1: 'sequence_length'
                              },
                              'attention_mask': {
                                  0: 'batch_size',
                                  1: 'sequence_length'
                              },
                              'hidden_states': {
                                  0: 'batch_size',
                                  1: 'sequence_length'
                              },
                          },
                          opset_version=12,
                          do_constant_folding=True,
                          use_external_data_format=use_external_data_format,
                          verbose=verbose)
示例#3
0
    def export_onnx(
        decoder: Union[T5Decoder, T5DecoderInit],
        device: torch.device,
        onnx_model_path: str,
        verbose: bool = True,
        use_external_data_format: bool = False,
    ):
        """Export decoder to ONNX

        Args:
            decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
            device (torch.device): device of decoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
        """
        assert isinstance(decoder, (T5Decoder, T5DecoderInit))

        inputs = T5DecoderInputs.create_dummy(
            decoder.config,
            batch_size=2,
            encode_sequence_length=3,
            past_decode_sequence_length=5
            if isinstance(decoder, T5Decoder) else 0,
            device=device,
        )
        input_list = inputs.to_list()
        with torch.no_grad():
            outputs = decoder(*input_list)

        past_names = PastKeyValuesHelper.get_past_names(
            decoder.config.num_layers, present=False)
        present_names = PastKeyValuesHelper.get_past_names(
            decoder.config.num_layers, present=True)
        present_self_names = present_names[:2 * decoder.config.num_layers]

        input_past_names = past_names if isinstance(decoder, T5Decoder) else []
        output_present_names = present_self_names if isinstance(
            decoder, T5Decoder) else present_names
        output_names = ["logits"] + output_present_names

        # Shape of input tensors (sequence_length==1):
        #    input_ids: (batch_size, sequence_length)
        #    encoder_attention_mask: (batch_size, encode_sequence_length)
        #    encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length, hidden_size/num_heads)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, hidden_size/num_heads)

        # Shape of output tensors:
        #    logits: (batch_size, sequence_length, vocab_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, hidden_size/num_heads)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, hidden_size/num_heads)

        input_names = ["input_ids"]
        input_names.append("encoder_attention_mask")
        input_names.append("encoder_hidden_states")
        input_names.extend(input_past_names)

        dynamic_axes = {
            "input_ids": {
                0: "batch_size",
                # 1: 'sequence_length'
            },
            "encoder_attention_mask": {
                0: "batch_size",
                1: "encode_sequence_length"
            },
            "encoder_hidden_states": {
                0: "batch_size",
                1: "encode_sequence_length"
            },
            "logits": {
                0: "batch_size",
                # 1: 'sequence_length'
            },
        }

        for name in input_past_names:
            dynamic_axes[name] = {
                0:
                "batch_size",
                2:
                "past_decode_sequence_length"
                if "self" in name else "encode_sequence_length",
            }

        for name in output_present_names:
            if "cross" in name:
                dynamic_axes[name] = {
                    0: "batch_size",
                    2: "encode_sequence_length"
                }
            else:  # self attention past state
                if isinstance(decoder, T5Decoder):
                    dynamic_axes[name] = {
                        0: "batch_size",
                        2: "past_decode_sequence_length + 1",
                    }
                else:
                    dynamic_axes[name] = {
                        0: "batch_size",
                        # 2: 'sequence_length'
                    }

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
        torch_onnx_export(
            decoder,
            args=tuple(input_list),
            f=onnx_model_path,
            export_params=True,
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            opset_version=12,
            do_constant_folding=True,
            use_external_data_format=use_external_data_format,
            verbose=verbose,
        )
    def export_onnx(model: T5EncoderDecoderInit,
                    device: torch.device,
                    onnx_model_path: str,
                    use_decoder_input_ids: bool = True,
                    verbose: bool = True,
                    use_external_data_format: bool = False):
        """Export decoder to ONNX

        Args:
            model (T5EncoderDecoderInit): the model to export
            device (torch.device): device of decoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
        """
        assert isinstance(model, T5EncoderDecoderInit)

        inputs = T5EncoderDecoderInitInputs.create_dummy(model.config,
                                                         batch_size=2,
                                                         encode_sequence_length=3,
                                                         use_decoder_input_ids=use_decoder_input_ids,
                                                         device=device)
        input_list = inputs.to_list()
        outputs = model(*input_list)

        present_names = PastKeyValuesHelper.get_past_names(model.config.num_layers, present=True)

        output_names = ["logits", "encoder_hidden_states"] + present_names

        # Shape of input tensors (sequence_length==1):
        #    input_ids: (batch_size, sequence_length)
        #    encoder_attention_mask: (batch_size, encode_sequence_length)
        #    encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length, hidden_size/num_heads)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, hidden_size/num_heads)

        # Shape of output tensors:
        #    logits: (batch_size, sequence_length, vocab_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, hidden_size/num_heads)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, hidden_size/num_heads)

        input_names = ["encoder_input_ids", "encoder_attention_mask"]

        # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2'. Use more friendly string here.
        sequence_length = '1'
        num_heads = str(model.config.num_heads)
        hidden_size = str(model.config.d_model)
        head_size = str(model.config.d_model // model.config.num_heads)

        dynamic_axes = {
            'encoder_input_ids': {
                0: 'batch_size',
                1: 'encode_sequence_length'
            },
            'encoder_attention_mask': {
                0: 'batch_size',
                1: 'encode_sequence_length'
            },
            'encoder_hidden_states': {
                0: 'batch_size',
                1: 'encode_sequence_length',
                2: hidden_size
            },
            "logits": {
                0: 'batch_size',
                1: sequence_length
            }
        }

        if use_decoder_input_ids:
            input_names.append("decoder_input_ids")
            dynamic_axes["decoder_input_ids"] = {0: 'batch_size', 1: sequence_length}

        for name in present_names:
            if "cross" in name:
                dynamic_axes[name] = {0: 'batch_size', 1: num_heads, 2: 'encode_sequence_length', 3: head_size}

            else:  # self attention past state
                dynamic_axes[name] = {0: 'batch_size', 1: num_heads, 2: sequence_length, 3: head_size}

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
        torch_onnx_export(model,
                          args=tuple(input_list),
                          f=onnx_model_path,
                          export_params=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          opset_version=12,
                          do_constant_folding=True,
                          use_external_data_format=use_external_data_format,
                          verbose=verbose)
    def export_onnx(model,
                    device,
                    onnx_model_path: str,
                    verbose: bool = False,
                    use_external_data_format: bool = False,
                    has_position_ids: bool = True,
                    has_attention_mask: bool = True):
        """Export GPT-2 model with past state to ONNX model."""
        config: GPT2Config = model.config
        num_layer = config.n_layer
        dummy_inputs = Gpt2BeamSearchHelper.get_dummy_inputs(
            batch_size=1,
            past_sequence_length=1,
            sequence_length=2,
            num_attention_heads=config.num_attention_heads,
            hidden_size=config.hidden_size,
            num_layer=num_layer,
            vocab_size=config.vocab_size,
            device=device,
            float16=False,
            has_position_ids=has_position_ids,
            has_attention_mask=has_attention_mask)
        input_list = dummy_inputs.to_list()

        with torch.no_grad():
            # outputs = model(input_ids, position_id, attention_mask, beam_select_idx, past)
            outputs = model(*input_list)

        past_names = [f"past_{i}" for i in range(num_layer)]
        present_names = [f"present_{i}" for i in range(num_layer)]

        output_names = ["last_state"] + present_names

        if has_position_ids:
            output_names += [
                "output_selected_indices",
                "output_log_probs",
                "output_unfinished_sents",
                "current_step_results",
                "current_step_scores",
            ]
        else:
            output_names += [
                "output_selected_indices",
                "output_log_probs",
                "output_unfinished_sents",
                "current_step_scores",
            ]

        # Shape of input tensors:
        #    input_ids: (batch_size, seq_len)
        #    past_{i}:  (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
        #    attention_mask: (batch_size, past_seq_len + seq_len)
        # Shape of output tensors:
        #    last_state: (batch_size, seq_len, hidden_size)
        #      or logits: (batch_size, seq_len, vocab_size)
        #    present_{i}:  (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads)
        dynamic_axes = {
            "input_ids": {
                0: "batch_size",
                1: "seq_len"
            },
            output_names[0]: {
                0: "batch_size",
                1: "seq_len"
            },
        }
        for name in past_names:
            dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"}
        for name in present_names:
            dynamic_axes[name] = {1: "batch_size", 3: "cur_seq_len"}

        input_names = ["input_ids"]
        if has_position_ids:
            dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
            input_names.append("position_ids")
        if has_attention_mask:
            dynamic_axes["attention_mask"] = {
                0: "batch_size",
                1: "total_seq_len"
            }
            input_names.append("attention_mask")
        dynamic_axes["beam_select_idx"] = {1: "batch_size"}
        input_names.append("beam_select_idx")
        dynamic_axes["input_log_probs"] = {0: "batch_size", 1: "beam_size"}
        input_names.append("input_log_probs")
        dynamic_axes["input_unfinished_sents"] = {
            0: "batch_size",
            1: "beam_size"
        }
        input_names.append("input_unfinished_sents")
        if has_position_ids:
            dynamic_axes["prev_step_results"] = {
                0: "batch_size",
                1: "total_seq_len"
            }
            input_names.append("prev_step_results")
        dynamic_axes["prev_step_scores"] = {
            0: "batch_size",
            1: "total_seq_len"
        }
        input_names.append("prev_step_scores")
        input_names.extend(past_names)

        # add dynamic output axes
        present_axes = {1: 'batch_size', 3: 'cur_seq_len'}
        dynamic_axes["last_state"] = {0: 'batch_size', 1: 'beam_size'}
        for i in range(num_layer):
            dynamic_axes["present_" + str(i)] = present_axes

        dynamic_axes["output_selected_indices"] = {
            0: "batch_size",
            1: "'beam_size_or_1'"
        }
        dynamic_axes["output_log_probs"] = {0: "batch_size", 1: "'beam_size'"}
        dynamic_axes["output_unfinished_sents"] = {
            0: "batch_size",
            1: "'beam_size'"
        }
        dynamic_axes["current_step_results"] = {
            0: "beam_size_or_1",
            1: "total_seq_len"
        }
        dynamic_axes["current_step_scores"] = {
            0: "beam_size_or_1",
            1: "total_seq_len"
        }

        logger.info(
            f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
        )

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        torch_onnx_export(
            model,
            args=tuple(input_list),
            f=onnx_model_path,
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            opset_version=12,
            do_constant_folding=True,
            use_external_data_format=use_external_data_format,
            verbose=verbose,
        )
示例#6
0
    def export_onnx(model,
                    device,
                    onnx_model_path: str,
                    verbose: bool = False,
                    use_external_data_format: bool = False,
                    has_position_ids: bool = True,
                    has_attention_mask: bool = True,
                    input_ids_dtype: torch.dtype = torch.int32,
                    position_ids_dtype: torch.dtype = torch.int32,
                    attention_mask_dtype: torch.dtype = torch.int32):
        """ Export GPT-2 model with past state to ONNX model.
        """
        config: GPT2Config = model.config
        num_layer = config.n_layer
        dummy_inputs = Gpt2Helper.get_dummy_inputs(
            batch_size=1,
            past_sequence_length=1,
            sequence_length=1,
            num_attention_heads=config.num_attention_heads,
            hidden_size=config.hidden_size,
            num_layer=num_layer,
            vocab_size=config.vocab_size,
            device=device,
            float16=False,
            has_position_ids=has_position_ids,
            has_attention_mask=has_attention_mask,
            input_ids_dtype=input_ids_dtype,
            position_ids_dtype=position_ids_dtype,
            attention_mask_dtype=attention_mask_dtype)
        input_list = dummy_inputs.to_list()

        with torch.no_grad():
            outputs = model(*input_list)

        past_names = [f'past_{i}' for i in range(num_layer)]
        present_names = [f'present_{i}' for i in range(num_layer)]

        # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores)
        assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[
            2] == config.hidden_size
        output_names = [
            "logits"
            if outputs[0].shape[2] == config.vocab_size else "last_state"
        ] + present_names

        # Shape of input tensors:
        #    input_ids: (batch_size, seq_len)
        #    past_{i}:  (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
        #    attention_mask: (batch_size, past_seq_len + seq_len)
        # Shape of output tensors:
        #    last_state: (batch_size, seq_len, hidden_size)
        #      or logits: (batch_size, seq_len, vocab_size)
        #    present_{i}:  (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads)
        dynamic_axes = {
            'input_ids': {
                0: 'batch_size',
                1: 'seq_len'
            },
            output_names[0]: {
                0: 'batch_size',
                1: 'seq_len'
            }
        }
        for name in past_names:
            dynamic_axes[name] = {1: 'batch_size', 3: 'past_seq_len'}
        for name in present_names:
            dynamic_axes[name] = {1: 'batch_size', 3: 'total_seq_len'}

        input_names = ['input_ids']
        if has_position_ids:
            dynamic_axes['position_ids'] = {0: 'batch_size', 1: 'seq_len'}
            input_names.append('position_ids')
        if has_attention_mask:
            dynamic_axes['attention_mask'] = {
                0: 'batch_size',
                1: 'total_seq_len'
            }
            input_names.append('attention_mask')
        input_names.extend(past_names)

        assert len(outputs) == 2 and len(outputs[1]) == num_layer

        logger.info(
            f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
        )

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        torch_onnx_export(model,
                          args=tuple(input_list),
                          f=onnx_model_path,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          opset_version=11,
                          do_constant_folding=True,
                          use_external_data_format=use_external_data_format,
                          verbose=verbose)
示例#7
0
def export_longformer(model, onnx_model_path, export_padding):
    input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
        model.config, export_padding, device=torch.device("cpu"))

    example_outputs = model(
        input_ids,
        attention_mask=attention_mask,
        global_attention_mask=global_attention_mask,
    )

    if version.parse(transformers.__version__) < version.parse("4.0.0"):
        raise RuntimeError("This tool requires transformers 4.0.0 or later.")

    # Here we replace LongformerSelfAttention.forward using our implmentation for exporting ONNX model
    import inspect

    from transformers import LongformerSelfAttention

    key = " ".join(
        inspect.getfullargspec(LongformerSelfAttention.forward).args)
    args_to_func = {
        "self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions":
        my_longformer_self_attention_forward_4_3_2,
        "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions":
        my_longformer_self_attention_forward_4_3,
        "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn":
        my_longformer_self_attention_forward_4,
    }

    if key not in args_to_func:
        print(
            "Current arguments",
            inspect.getfullargspec(LongformerSelfAttention.forward).args,
        )
        raise RuntimeError(
            "LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
        )

    # Store for restoring later
    original_forward = LongformerSelfAttention.forward

    LongformerSelfAttention.forward = args_to_func[key]

    example_inputs = (input_ids, attention_mask, global_attention_mask)

    Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

    torch_onnx_export(
        model,
        example_inputs,
        onnx_model_path,
        opset_version=12,
        input_names=["input_ids", "attention_mask", "global_attention_mask"],
        output_names=["last_state", "pooler"],
        dynamic_axes={
            "input_ids": {
                0: "batch_size",
                1: "sequence_length"
            },
            "attention_mask": {
                0: "batch_size",
                1: "sequence_length"
            },
            "global_attention_mask": {
                0: "batch_size",
                1: "sequence_length"
            },
            "last_state": {
                0: "batch_size",
                1: "sequence_length"
            },
            "pooler": {
                0: "batch_size",
                1: "sequence_length"
            },
        },
        custom_opsets={"com.microsoft": 1},
    )
    print(f"ONNX model exported to {onnx_model_path}")

    # Restore original implementaiton:
    LongformerSelfAttention.forward = original_forward
示例#8
0
    def export_onnx(
        model,
        device,
        onnx_model_path: str,
        verbose: bool = False,
        use_external_data_format: bool = False,
        has_position_ids: bool = True,
        has_attention_mask: bool = True,
        input_ids_dtype: torch.dtype = torch.int32,
        position_ids_dtype: torch.dtype = torch.int32,
        attention_mask_dtype: torch.dtype = torch.int32,
    ):
        """Export GPT-2 model with past state to ONNX model."""
        config: GPT2Config = model.config
        num_layer = config.n_layer
        dummy_inputs = Gpt2Helper.get_dummy_inputs(
            batch_size=1,
            past_sequence_length=1,
            sequence_length=1,
            num_attention_heads=config.num_attention_heads,
            hidden_size=config.hidden_size,
            num_layer=num_layer,
            vocab_size=config.vocab_size,
            device=device,
            float16=False,
            has_position_ids=has_position_ids,
            has_attention_mask=has_attention_mask,
            input_ids_dtype=input_ids_dtype,
            position_ids_dtype=position_ids_dtype,
            attention_mask_dtype=attention_mask_dtype,
        )
        input_list = dummy_inputs.to_list()

        with torch.no_grad():
            outputs = model(*input_list)

        past_names = [f"past_{i}" for i in range(num_layer)]
        present_names = [f"present_{i}" for i in range(num_layer)]

        # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores)
        assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[
            2] == config.hidden_size
        output_names = [
            "logits"
            if outputs[0].shape[2] == config.vocab_size else "last_state"
        ] + present_names

        # Shape of input tensors:
        #    input_ids: (batch_size, seq_len)
        #    past_{i}:  (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
        #    attention_mask: (batch_size, past_seq_len + seq_len)
        # Shape of output tensors:
        #    last_state: (batch_size, seq_len, hidden_size)
        #      or logits: (batch_size, seq_len, vocab_size)
        #    present_{i}:  (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads)
        dynamic_axes = {
            "input_ids": {
                0: "batch_size",
                1: "seq_len"
            },
            output_names[0]: {
                0: "batch_size",
                1: "seq_len"
            },
        }
        for name in past_names:
            dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"}
        for name in present_names:
            dynamic_axes[name] = {1: "batch_size", 3: "total_seq_len"}

        input_names = ["input_ids"]
        if has_position_ids:
            dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
            input_names.append("position_ids")
        if has_attention_mask:
            dynamic_axes["attention_mask"] = {
                0: "batch_size",
                1: "total_seq_len"
            }
            input_names.append("attention_mask")
        input_names.extend(past_names)

        assert len(outputs) == 2 and len(outputs[1]) == num_layer

        logger.info(
            f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
        )

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        if use_external_data_format:
            # We let PyTorch export onnx to a temp directory first, then convert external data to one file.
            with tempfile.TemporaryDirectory() as tmp_dir_name:
                temp_onnx_model_path = os.path.join(tmp_dir_name, "gpt2.onnx")
                Path(temp_onnx_model_path).parent.mkdir(parents=True,
                                                        exist_ok=True)

                torch_onnx_export(
                    model,
                    args=tuple(input_list),
                    f=temp_onnx_model_path,
                    export_params=True,
                    input_names=input_names,
                    output_names=output_names,
                    dynamic_axes=dynamic_axes,
                    opset_version=11,
                    do_constant_folding=True,
                    use_external_data_format=True,
                    verbose=verbose,
                )

                model = onnx.load_model(temp_onnx_model_path,
                                        load_external_data=True)
                OnnxModel.save(
                    model,
                    onnx_model_path,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                )
        else:
            torch_onnx_export(
                model,
                args=tuple(input_list),
                f=onnx_model_path,
                export_params=True,
                input_names=input_names,
                output_names=output_names,
                dynamic_axes=dynamic_axes,
                opset_version=11,
                do_constant_folding=True,
                use_external_data_format=False,
                verbose=verbose,
            )
示例#9
0
    def export_onnx(
        encoder: T5Encoder,
        device: torch.device,
        onnx_model_path: str,
        verbose: bool = True,
        use_external_data_format: bool = False,
        use_int32_inputs: bool = False,
    ):
        """Export encoder to ONNX

        Args:
            encoder (T5Encoder): encoder object
            device (torch.device): device of encoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
        """
        config = encoder.config
        encoder_inputs = T5EncoderInputs.create_dummy(
            batch_size=2,
            sequence_length=4,
            vocab_size=config.vocab_size,
            device=device,
            use_int32_inputs=use_int32_inputs,
        )

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        with tempfile.TemporaryDirectory() as tmp_dir_name:
            temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
            Path(temp_onnx_model_path).parent.mkdir(parents=True,
                                                    exist_ok=True)
            torch_onnx_export(
                encoder,
                args=tuple(encoder_inputs.to_list()),
                f=temp_onnx_model_path
                if use_external_data_format else onnx_model_path,
                export_params=True,
                input_names=["input_ids", "attention_mask"],
                output_names=["hidden_states"],
                dynamic_axes={
                    "input_ids": {
                        0: "batch_size",
                        1: "sequence_length"
                    },
                    "attention_mask": {
                        0: "batch_size",
                        1: "sequence_length"
                    },
                    "hidden_states": {
                        0: "batch_size",
                        1: "sequence_length"
                    },
                },
                opset_version=12,
                do_constant_folding=True,
                use_external_data_format=use_external_data_format,
                verbose=verbose,
            )

            if use_external_data_format:
                model = onnx.load_model(temp_onnx_model_path,
                                        load_external_data=True)
                OnnxModel.save(
                    model,
                    onnx_model_path,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                )