Ejemplo n.º 1
0
 def export(cls, model, input_args, save_path, **export_kwargs):
     adapter = TracingAdapter(model, input_args)
     trace_and_save_torchscript(
         adapter, adapter.flattened_inputs, save_path, **export_kwargs
     )
     inputs_schema = dump_dataclass(adapter.inputs_schema)
     outputs_schema = dump_dataclass(adapter.outputs_schema)
     return {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
Ejemplo n.º 2
0
 def export(cls, model, input_args, save_path, **export_kwargs):
     trace_and_save_torchscript(
         model,
         input_args,
         save_path,
         mobile_optimization=MobileOptimizationConfig(),
         **export_kwargs,
     )
     return {}
Ejemplo n.º 3
0
Archivo: api.py Proyecto: zhiqwang/d2go
def standard_model_export(model, model_type, save_path, input_args, **kwargs):
    if model_type.startswith("torchscript"):
        from d2go.export.torchscript import trace_and_save_torchscript
        trace_and_save_torchscript(model, input_args, save_path, **kwargs)
    elif model_type == "caffe2":
        from d2go.export.caffe2 import export_caffe2
        # TODO: export_caffe2 depends on D2, need to make a copy of the implemetation
        # TODO: support specifying optimization pass via kwargs
        export_caffe2(model, input_args[0], save_path, **kwargs)
    else:
        raise NotImplementedError(
            "Incorrect model_type: {}".format(model_type))
Ejemplo n.º 4
0
    def export(cls, model, input_args, save_path, **export_kwargs):
        from d2go.export.torchscript import trace_and_save_torchscript

        trace_and_save_torchscript(model, input_args, save_path,
                                   **export_kwargs)
        return {}
Ejemplo n.º 5
0
 def export(cls, model, input_args, save_path, **export_kwargs):
     trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
     return {}