예제 #1
0
    def generate_and_download_framework(
            self, metadata: NetworkMetadata,
            workspace: NNFolderWorkspace) -> NetworkModels:

        cache_variant = False
        if metadata.other.kv_cache:
            cache_variant = True

        trt_gpt2_config = self.config
        metadata_serialized = trt_gpt2_config.get_metadata_string(metadata)
        workspace_dir = workspace.get_path()

        pytorch_model_dir = os.path.join(workspace_dir, metadata_serialized)
        # We keep track of the generated torch location for cleanup later
        self.torch_gpt2_dir = pytorch_model_dir

        model = None
        tfm_config = GPT2Config(use_cache=cache_variant)

        if not os.path.exists(pytorch_model_dir):
            # Generate the pre-trained weights
            model = GPT2LMHeadModel(tfm_config).from_pretrained(
                metadata.variant)
            model.save_pretrained(pytorch_model_dir)
            print("Pytorch Model saved to {}".format(pytorch_model_dir))
        else:
            print(
                "Frameworks file already exists, skipping generation and loading from file instead."
            )
            model = GPT2LMHeadModel(tfm_config).from_pretrained(
                pytorch_model_dir)

        root_onnx_model_name = "{}.onnx".format(metadata_serialized)
        root_onnx_model_fpath = os.path.join(os.getcwd(), workspace_dir,
                                             root_onnx_model_name)
        onnx_model_fpath = root_onnx_model_fpath

        gpt2 = GPT2TorchFile(model, metadata)
        self.onnx_gpt2 = gpt2.as_onnx_model(onnx_model_fpath,
                                            force_overwrite=False)

        onnx_models = [
            NetworkModel(
                name=GPT2ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                fpath=self.onnx_gpt2.fpath,
            )
        ]
        torch_models = [
            NetworkModel(
                name=GPT2ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                fpath=pytorch_model_dir,
            )
        ]

        return NetworkModels(torch=torch_models, onnx=onnx_models, trt=None)
예제 #2
0
    def execute_inference(
        self,
        metadata: NetworkMetadata,
        onnx_fpaths: Dict[str, NetworkModel],
        inference_input: str,
        timing_profile: TimingProfile,
    ) -> NetworkResult:

        tokenizer = GPT2Tokenizer.from_pretrained(metadata.variant)
        input_ids = tokenizer(inference_input, return_tensors="pt").input_ids

        # get single decoder iteration inference timing profile
        _, decoder_e2e_median_time = gpt2_inference(self.gpt2_trt, input_ids,
                                                    timing_profile)

        # get complete decoder inference result and its timing profile
        sample_output, full_e2e_median_runtime = full_inference_greedy(
            self.gpt2_trt,
            input_ids,
            timing_profile,
            max_length=GPT2ModelTRTConfig.MAX_SEQUENCE_LENGTH[
                metadata.variant])

        semantic_outputs = []
        for i, sample_output in enumerate(sample_output):
            semantic_outputs.append(
                tokenizer.decode(sample_output, skip_special_tokens=True))

        return NetworkResult(
            input=inference_input,
            output_tensor=sample_output,
            semantic_output=semantic_outputs,
            median_runtime=[
                NetworkRuntime(
                    name=GPT2ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                    runtime=decoder_e2e_median_time,
                ),
                NetworkRuntime(
                    name=GPT2ModelTRTConfig.NETWORK_FULL_NAME,
                    runtime=full_e2e_median_runtime,
                ),
            ],
            models=NetworkModels(
                torch=None,
                onnx=list(onnx_fpaths.values()),
                trt=[
                    NetworkModel(
                        name=GPT2ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                        fpath=self.gpt2_engine.fpath,
                    ),
                ],
            ),
        )
예제 #3
0
    def args_to_network_models(self, args) -> List[NetworkModel]:
        gpt2_fpath_check = args.onnx_fpath is None

        network_models = None
        if gpt2_fpath_check:
            network_models = tuple()
        else:
            onnx_decoder = NetworkModel(
                name=GPT2ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                fpath=args.onnx_fpath,
            )
            network_models = (onnx_decoder)

        return network_models
예제 #4
0
    def args_to_network_models(self, args) -> List[NetworkModel]:
        # Check if both flags are given otherwise error out
        decoder_fpath_check = args.onnx_decoder_fpath is None
        encoder_fpath_check = args.onnx_encoder_fpath is None

        network_models = None
        if decoder_fpath_check and encoder_fpath_check:
            network_models = tuple()
        elif decoder_fpath_check or encoder_fpath_check:
            raise self._parser.error(
                "Both --onnx-decoder-fpath and --onnx-encoder-fpath must be given. Otherwise neither should be provided for script to download them."
            )
        else:
            onnx_decoder = NetworkModel(
                name=T5ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                fpath=args.onnx_decoder_fpath,
            )
            onnx_encoder = NetworkModel(
                name=T5ModelTRTConfig.NETWORK_ENCODER_SEGMENT_NAME,
                fpath=args.onnx_encoder_fpath,
            )
            network_models = (onnx_decoder, onnx_encoder)

        return network_models
예제 #5
0
    def generate_and_download_framework(
            self, metadata: NetworkMetadata,
            workspace: NNFolderWorkspace) -> NetworkModels:

        cache_variant = False
        if metadata.other.kv_cache:
            cache_variant = True

        trt_t5_config = self.config
        metadata_serialized = trt_t5_config.get_metadata_string(metadata)
        workspace_dir = workspace.get_path()

        pytorch_model_dir = os.path.join(workspace_dir, metadata_serialized)
        # We keep track of the generated torch location for cleanup later
        self.torch_t5_dir = pytorch_model_dir

        model = None
        tfm_config = T5Config(
            use_cache=cache_variant,
            num_layers=T5ModelTRTConfig.NUMBER_OF_LAYERS[metadata.variant],
        )
        if not os.path.exists(pytorch_model_dir):
            # Generate the pre-trained weights
            model = T5ForConditionalGeneration(tfm_config).from_pretrained(
                metadata.variant)
            model.save_pretrained(pytorch_model_dir)
            print("Pytorch Model saved to {}".format(pytorch_model_dir))
        else:
            print(
                "Frameworks file already exists, skipping generation and loading from file instead."
            )
            model = T5ForConditionalGeneration(tfm_config).from_pretrained(
                pytorch_model_dir)

        # These ONNX models can be converted using special encoder and decoder classes.
        root_onnx_model_name = "{}.onnx".format(metadata_serialized)
        root_onnx_model_fpath = os.path.join(os.getcwd(), workspace_dir,
                                             root_onnx_model_name)
        encoder_onnx_model_fpath = root_onnx_model_fpath + "-encoder.onnx"
        decoder_onnx_model_fpath = root_onnx_model_fpath + "-decoder-with-lm-head.onnx"

        t5_encoder = T5EncoderTorchFile(model, metadata)
        t5_decoder = T5DecoderTorchFile(model, metadata)
        self.onnx_t5_encoder = t5_encoder.as_onnx_model(
            encoder_onnx_model_fpath, force_overwrite=False)
        self.onnx_t5_decoder = t5_decoder.as_onnx_model(
            decoder_onnx_model_fpath, force_overwrite=False)

        onnx_models = [
            NetworkModel(
                name=T5ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                fpath=self.onnx_t5_decoder.fpath,
            ),
            NetworkModel(
                name=T5ModelTRTConfig.NETWORK_ENCODER_SEGMENT_NAME,
                fpath=self.onnx_t5_encoder.fpath,
            ),
        ]
        torch_models = [
            NetworkModel(name=T5ModelTRTConfig.NETWORK_FULL_NAME,
                         fpath=pytorch_model_dir)
        ]

        return NetworkModels(torch=torch_models, onnx=onnx_models, trt=None)
예제 #6
0
    def execute_inference(
        self,
        metadata: NetworkMetadata,
        onnx_fpaths: Dict[str, NetworkModel],
        inference_input: str,
        timing_profile: TimingProfile,
    ) -> NetworkResult:

        tokenizer = T5Tokenizer.from_pretrained(metadata.variant)
        input_ids = tokenizer(inference_input, return_tensors="pt").input_ids
        encoder_last_hidden_state, encoder_e2e_median_time = encoder_inference(
            self.t5_trt_encoder, input_ids, timing_profile
        )
        _, decoder_e2e_median_time = decoder_inference(
            self.t5_trt_decoder,
            input_ids,
            encoder_last_hidden_state,
            timing_profile,
            use_cuda=False,
        )
        decoder_output_greedy, full_e2e_median_runtime = full_inference_greedy(
            self.t5_trt_encoder,
            self.t5_trt_decoder,
            input_ids,
            tokenizer,
            timing_profile,
            max_length=T5ModelTRTConfig.MAX_SEQUENCE_LENGTH[metadata.variant],
            use_cuda=False,
        )

        # Remove the padding and end tokens.
        semantic_outputs = tokenizer.convert_ids_to_tokens(
            decoder_output_greedy.tolist()[0]
        )[1:-1]
        remove_underscore = "".join(
            [s.replace("\u2581", " ") for s in semantic_outputs]
        )

        return NetworkResult(
            input=inference_input,
            output_tensor=encoder_last_hidden_state,
            semantic_output=remove_underscore.strip(),
            median_runtime=[
                NetworkRuntime(
                    name=T5ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                    runtime=decoder_e2e_median_time,
                ),
                NetworkRuntime(
                    name=T5ModelTRTConfig.NETWORK_ENCODER_SEGMENT_NAME,
                    runtime=encoder_e2e_median_time,
                ),
                NetworkRuntime(
                    name=T5ModelTRTConfig.NETWORK_FULL_NAME,
                    runtime=full_e2e_median_runtime,
                ),
            ],
            models=NetworkModels(
                torch=None,
                onnx=list(onnx_fpaths.values()),
                trt=[
                    NetworkModel(
                        name=T5ModelTRTConfig.NETWORK_DECODER_SEGMENT_NAME,
                        fpath=self.t5_trt_decoder_engine.fpath,
                    ),
                    NetworkModel(
                        name=T5ModelTRTConfig.NETWORK_ENCODER_SEGMENT_NAME,
                        fpath=self.t5_trt_encoder_engine.fpath,
                    ),
                ],
            ),
        )