Example #1
0
    def _export_onnx_graph(self, input_names_path: Path):
        # if graph exists, but we are here then it means something went wrong in previous load
        # so delete old graph
        if self.graph_path.exists():
            self.graph_path.unlink()
        if input_names_path.exists():
            input_names_path.unlink()

        # create parent dir
        if not self.graph_path.parent.exists():
            os.makedirs(self.graph_path.parent.as_posix())

        logger.info(f"Saving onnx graph at { self.graph_path.as_posix()}")

        if self.framework == "pt":
            convert_pytorch(self,
                            opset=12,
                            output=self.graph_path,
                            use_external_format=False)
        else:
            convert_tensorflow(self, opset=12, output=self.graph_path)

        # save input names
        self.input_names = infer_shapes(self, "pt")[0]
        with open(input_names_path, "w") as f:
            json.dump(self.input_names, f)
Example #2
0
    def _test_infer_dynamic_axis(self, model, tokenizer, framework):
        nlp = FeatureExtractionPipeline(model, tokenizer)

        variable_names = ["input_ids", "token_type_ids", "attention_mask", "output_0", "output_1"]
        input_vars, output_vars, shapes, tokens = infer_shapes(nlp, framework)

        # Assert all variables are present
        self.assertEqual(len(shapes), len(variable_names))
        self.assertTrue(all([var_name in shapes for var_name in variable_names]))
        self.assertSequenceEqual(variable_names[:3], input_vars)
        self.assertSequenceEqual(variable_names[3:], output_vars)

        # Assert inputs are {0: batch, 1: sequence}
        for var_name in ["input_ids", "token_type_ids", "attention_mask"]:
            self.assertDictEqual(shapes[var_name], {0: "batch", 1: "sequence"})

        # Assert outputs are {0: batch, 1: sequence} and {0: batch}
        self.assertDictEqual(shapes["output_0"], {0: "batch", 1: "sequence"})
        self.assertDictEqual(shapes["output_1"], {0: "batch"})
def export():
    span = args.input

    model_name = "rinna/japanese-gpt2-" + args.model_name
    model_pth = Path(f"../japanese-gpt2-" + args.model_name + ".onnx")

    pipeline_name = "text-generation"

    model_pth.parent.mkdir(exist_ok=True, parents=True)

    nlp = transformers.pipeline(pipeline_name,
                                model=model_name,
                                tokenizer=model_name)
    tokenizer = nlp.tokenizer
    model = nlp.model

    with torch.no_grad():
        (
            input_names,
            output_names,
            dynamic_axes,
            tokens,
        ) = convert_graph_to_onnx.infer_shapes(nlp, "pt")
        ordered_input_names, model_args = convert_graph_to_onnx.ensure_valid_input(
            nlp.model, tokens, input_names)

    class GPT2Sent(transformers.GPT2LMHeadModel):
        def __init__(self, config):
            super().__init__(config)
            self.sentence_embedding = torch.nn.Identity()

        def forward(self, input_ids, attention_mask):
            return self.sentence_embedding(super().forward(
                input_ids, attention_mask=attention_mask).logits)

    # Create the new model based on the config of the original pipeline
    model = GPT2Sent(config=nlp.model.config).from_pretrained(model_name)

    encoding = nlp.tokenizer([span], return_tensors="pt")
    print(encoding)

    if not model_pth.exists():
        inputs = ['input_ids', 'attention_mask']
        outputs = ['3062']
        dynamic_axes = {
            'input_ids': {
                1: 'len'
            },
            'attention_mask': {
                1: 'len'
            },
            '3062': {
                1: 'len'
            }
        }

        torch.onnx.export(model,
                          (encoding["input_ids"], encoding["attention_mask"]),
                          f=model_pth.as_posix(),
                          input_names=input_names,
                          do_constant_folding=True,
                          use_external_data_format=False,
                          enable_onnx_checker=True,
                          opset_version=11,
                          dynamic_axes=dynamic_axes)

    output = generate_text(nlp.tokenizer, model, args.input,
                           int(args.outlength))
    print(output)