def torchscriptify(self, tensorizers, traced_model):
     """Using the traced model, create a ScriptModule which has a nicer API that
     includes generating tensors from simple data types, and returns classified
     values according to the output layer (eg. as a dict mapping class name to score)
     """
     script_tensorizer = tensorizers["tokens"].torchscriptify()
     if self.encoder.export_encoder:
         return ScriptPyTextEmbeddingModuleIndex(traced_model,
                                                 script_tensorizer,
                                                 index=0)
     else:
         if "right_dense" in tensorizers and "left_dense" in tensorizers:
             return ScriptPyTextTwoTowerModuleWithDense(
                 model=traced_model,
                 output_layer=self.output_layer.torchscript_predictions(),
                 tensorizer=script_tensorizer,
                 right_normalizer=tensorizers["right_dense"].normalizer,
                 left_normalizer=tensorizers["left_dense"].normalizer,
             )
         else:
             return ScriptPyTextModule(
                 model=traced_model,
                 output_layer=self.output_layer.torchscript_predictions(),
                 tensorizer=script_tensorizer,
             )
Esempio n. 2
0
    def torchscriptify(self, tensorizers, traced_model, trace_both_encoders):
        if trace_both_encoders:

            class ScriptModel(torch.jit.ScriptModule):
                def __init__(self, model, tensorizer1, tensorizer2):
                    super().__init__()
                    self.model = model
                    self.tensorizer1 = tensorizer1
                    self.tensorizer2 = tensorizer2

                @torch.jit.script_method
                def forward(
                    self,
                    # first input
                    texts1: Optional[List[str]] = None,
                    tokens1: Optional[List[List[str]]] = None,
                    # second input
                    texts2: Optional[List[str]] = None,
                    tokens2: Optional[List[List[str]]] = None,
                ):
                    inputs1: ScriptBatchInput = ScriptBatchInput(
                        texts=squeeze_1d(texts1),
                        tokens=squeeze_2d(tokens1),
                        languages=None,
                    )
                    inputs2: ScriptBatchInput = ScriptBatchInput(
                        texts=squeeze_1d(texts2),
                        tokens=squeeze_2d(tokens2),
                        languages=None,
                    )
                    input_tensors1 = self.tensorizer1(inputs1)
                    input_tensors2 = self.tensorizer2(inputs2)
                    return self.model(input_tensors1, input_tensors2)

            tensorizer1 = tensorizers["tokens1"].torchscriptify()
            tensorizer2 = tensorizers["tokens2"].torchscriptify()
            return ScriptModel(traced_model, tensorizer1, tensorizer2)
        else:
            # optionally trace only one encoder
            script_tensorizer = tensorizers["tokens1"].torchscriptify()
            if "dense" in tensorizers:
                return ScriptPyTextEmbeddingModuleWithDenseIndex(
                    model=traced_model,
                    tensorizer=script_tensorizer,
                    normalizer=tensorizers["dense"].normalizer,
                    index=0,
                )
            else:
                return ScriptPyTextEmbeddingModuleIndex(
                    model=traced_model, tensorizer=script_tensorizer, index=0
                )