Example #1
0
    def test_ensure_valid_input(self):
        """
        Validate parameters are correctly exported
        GPT2 has "past" parameter in the middle of input_ids, token_type_ids and attention_mask.
        ONNX doesn't support export with a dictionary, only a tuple. Thus we need to ensure we remove
        token_type_ids and attention_mask for now to not having a None tensor in the middle
        """
        # All generated args are valid
        input_names = ["input_ids", "attention_mask", "token_type_ids"]
        tokens = {
            "input_ids": [1, 2, 3, 4],
            "attention_mask": [0, 0, 0, 0],
            "token_type_ids": [1, 1, 1, 1]
        }
        inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens,
                                         input_names)

        # Should have exactly the same number of args (all are valid)
        self.assertEqual(len(inputs_args), 3)

        # Parameter should be reordered according to their respective place in the function:
        # (input_ids, token_type_ids, attention_mask)
        self.assertEqual(inputs_args,
                         (tokens["input_ids"], tokens["token_type_ids"],
                          tokens["attention_mask"]))

        # Generated args are interleaved with another args (for instance parameter "past" in GPT2)
        inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens,
                                         input_names)

        # Should have exactly the one arg (all before the one not provided "some_other_args")
        self.assertEqual(len(inputs_args), 1)

        # Should have only "input_ids"
        self.assertEqual(inputs_args[0], tokens["input_ids"])
def export():
    span = args.input

    model_name = "rinna/japanese-gpt2-" + args.model_name
    model_pth = Path(f"../japanese-gpt2-" + args.model_name + ".onnx")

    pipeline_name = "text-generation"

    model_pth.parent.mkdir(exist_ok=True, parents=True)

    nlp = transformers.pipeline(pipeline_name,
                                model=model_name,
                                tokenizer=model_name)
    tokenizer = nlp.tokenizer
    model = nlp.model

    with torch.no_grad():
        (
            input_names,
            output_names,
            dynamic_axes,
            tokens,
        ) = convert_graph_to_onnx.infer_shapes(nlp, "pt")
        ordered_input_names, model_args = convert_graph_to_onnx.ensure_valid_input(
            nlp.model, tokens, input_names)

    class GPT2Sent(transformers.GPT2LMHeadModel):
        def __init__(self, config):
            super().__init__(config)
            self.sentence_embedding = torch.nn.Identity()

        def forward(self, input_ids, attention_mask):
            return self.sentence_embedding(super().forward(
                input_ids, attention_mask=attention_mask).logits)

    # Create the new model based on the config of the original pipeline
    model = GPT2Sent(config=nlp.model.config).from_pretrained(model_name)

    encoding = nlp.tokenizer([span], return_tensors="pt")
    print(encoding)

    if not model_pth.exists():
        inputs = ['input_ids', 'attention_mask']
        outputs = ['3062']
        dynamic_axes = {
            'input_ids': {
                1: 'len'
            },
            'attention_mask': {
                1: 'len'
            },
            '3062': {
                1: 'len'
            }
        }

        torch.onnx.export(model,
                          (encoding["input_ids"], encoding["attention_mask"]),
                          f=model_pth.as_posix(),
                          input_names=input_names,
                          do_constant_folding=True,
                          use_external_data_format=False,
                          enable_onnx_checker=True,
                          opset_version=11,
                          dynamic_axes=dynamic_axes)

    output = generate_text(nlp.tokenizer, model, args.input,
                           int(args.outlength))
    print(output)