def torchscriptify(self, tensorizers, traced_model): """Using the traced model, create a ScriptModule which has a nicer API that includes generating tensors from simple data types, and returns classified values according to the output layer (eg. as a dict mapping class name to score) """ script_tensorizer = tensorizers["tokens"].torchscriptify() if self.encoder.export_encoder: return ScriptPyTextEmbeddingModuleIndex(traced_model, script_tensorizer, index=0) else: if "right_dense" in tensorizers and "left_dense" in tensorizers: return ScriptPyTextTwoTowerModuleWithDense( model=traced_model, output_layer=self.output_layer.torchscript_predictions(), tensorizer=script_tensorizer, right_normalizer=tensorizers["right_dense"].normalizer, left_normalizer=tensorizers["left_dense"].normalizer, ) else: return ScriptPyTextModule( model=traced_model, output_layer=self.output_layer.torchscript_predictions(), tensorizer=script_tensorizer, )
def torchscriptify(self, tensorizers, traced_model, trace_both_encoders): if trace_both_encoders: class ScriptModel(torch.jit.ScriptModule): def __init__(self, model, tensorizer1, tensorizer2): super().__init__() self.model = model self.tensorizer1 = tensorizer1 self.tensorizer2 = tensorizer2 @torch.jit.script_method def forward( self, # first input texts1: Optional[List[str]] = None, tokens1: Optional[List[List[str]]] = None, # second input texts2: Optional[List[str]] = None, tokens2: Optional[List[List[str]]] = None, ): inputs1: ScriptBatchInput = ScriptBatchInput( texts=squeeze_1d(texts1), tokens=squeeze_2d(tokens1), languages=None, ) inputs2: ScriptBatchInput = ScriptBatchInput( texts=squeeze_1d(texts2), tokens=squeeze_2d(tokens2), languages=None, ) input_tensors1 = self.tensorizer1(inputs1) input_tensors2 = self.tensorizer2(inputs2) return self.model(input_tensors1, input_tensors2) tensorizer1 = tensorizers["tokens1"].torchscriptify() tensorizer2 = tensorizers["tokens2"].torchscriptify() return ScriptModel(traced_model, tensorizer1, tensorizer2) else: # optionally trace only one encoder script_tensorizer = tensorizers["tokens1"].torchscriptify() if "dense" in tensorizers: return ScriptPyTextEmbeddingModuleWithDenseIndex( model=traced_model, tensorizer=script_tensorizer, normalizer=tensorizers["dense"].normalizer, index=0, ) else: return ScriptPyTextEmbeddingModuleIndex( model=traced_model, tensorizer=script_tensorizer, index=0 )