def torchscriptify(self, tensorizers, traced_model): """Using the traced model, create a ScriptModule which has a nicer API that includes generating tensors from simple data types, and returns classified values according to the output layer (eg. as a dict mapping class name to score) """ script_tensorizer = tensorizers["tokens"].torchscriptify() if self.encoder.export_encoder: return ScriptPyTextEmbeddingModuleIndex(traced_model, script_tensorizer, index=0) else: if "right_dense" in tensorizers and "left_dense" in tensorizers: return ScriptPyTextTwoTowerModuleWithDense( model=traced_model, output_layer=self.output_layer.torchscript_predictions(), tensorizer=script_tensorizer, right_normalizer=tensorizers["right_dense"].normalizer, left_normalizer=tensorizers["left_dense"].normalizer, ) else: return ScriptPyTextModule( model=traced_model, output_layer=self.output_layer.torchscript_predictions(), tensorizer=script_tensorizer, )
def torchscriptify(self, tensorizers, traced_model): """Using the traced model, create a ScriptModule which has a nicer API that includes generating tensors from simple data types, and returns classified values according to the output layer (eg. as a dict mapping class name to score) """ script_tensorizer = tensorizers["tokens"].torchscriptify() return ScriptPyTextModule( model=traced_model, output_layer=self.output_layer.torchscript_predictions(), tensorizer=script_tensorizer, )