def torchscriptify(self, tensorizers, traced_model):
     """Using the traced model, create a ScriptModule which has a nicer API that
     includes generating tensors from simple data types, and returns classified
     values according to the output layer (eg. as a dict mapping class name to score)
     """
     script_tensorizer = tensorizers["tokens"].torchscriptify()
     if self.encoder.export_encoder:
         return ScriptPyTextEmbeddingModuleIndex(traced_model,
                                                 script_tensorizer,
                                                 index=0)
     else:
         if "right_dense" in tensorizers and "left_dense" in tensorizers:
             return ScriptPyTextTwoTowerModuleWithDense(
                 model=traced_model,
                 output_layer=self.output_layer.torchscript_predictions(),
                 tensorizer=script_tensorizer,
                 right_normalizer=tensorizers["right_dense"].normalizer,
                 left_normalizer=tensorizers["left_dense"].normalizer,
             )
         else:
             return ScriptPyTextModule(
                 model=traced_model,
                 output_layer=self.output_layer.torchscript_predictions(),
                 tensorizer=script_tensorizer,
             )
예제 #2
0
 def torchscriptify(self, tensorizers, traced_model):
     """Using the traced model, create a ScriptModule which has a nicer API that
     includes generating tensors from simple data types, and returns classified
     values according to the output layer (eg. as a dict mapping class name to score)
     """
     script_tensorizer = tensorizers["tokens"].torchscriptify()
     return ScriptPyTextModule(
         model=traced_model,
         output_layer=self.output_layer.torchscript_predictions(),
         tensorizer=script_tensorizer,
     )