def export(cls, model, input_args, save_path, **export_kwargs): adapter = TracingAdapter(model, input_args) trace_and_save_torchscript( adapter, adapter.flattened_inputs, save_path, **export_kwargs ) inputs_schema = dump_dataclass(adapter.inputs_schema) outputs_schema = dump_dataclass(adapter.outputs_schema) return {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
def new_f(cls, model, input_args, save_path, export_method, **export_kwargs): force_disable_tracing_adapter = export_kwargs.pop( "force_disable_tracing_adapter", False) is_trace_mode = export_kwargs.get("jit_mode", "trace") == "trace" if force_disable_tracing_adapter or not is_trace_mode: logger.info("Not trace mode, export normally") return old_f(cls, model, input_args, save_path, export_method, **export_kwargs) if _is_data_flattened_tensors(input_args): logger.info( "Dry run the model to check if TracingAdapter is needed ...") outputs = model(*input_args) if _is_data_flattened_tensors(outputs): logger.info( "Both inputs and outputs are flattened tensors, export the model as is." ) load_kwargs = old_f(cls, model, input_args, save_path, export_method, **export_kwargs) assert "tracing_adapted" not in load_kwargs load_kwargs.update({"tracing_adapted": False}) return load_kwargs else: logger.info( "The outputs are not flattened tensors, can't trace normally." ) else: logger.info( "The inputs are not flattened tensors, can't trace normally.") logger.warning( "Wrap the model with TracingAdapter to handle non-flattened inputs/outputs," " please be aware that the exported model will have different input/output data structure." ) adapter = TracingAdapter(model, input_args) load_kwargs = old_f( cls, adapter, adapter.flattened_inputs, save_path, export_method, **export_kwargs, ) inputs_schema = dump_dataclass(adapter.inputs_schema) outputs_schema = dump_dataclass(adapter.outputs_schema) assert "tracing_adapted" not in load_kwargs assert "inputs_schema" not in load_kwargs assert "outputs_schema" not in load_kwargs load_kwargs.update({ "tracing_adapted": True, "inputs_schema": inputs_schema, "outputs_schema": outputs_schema, }) return load_kwargs
def new_f(cls, model, input_args, *args, **kwargs): adapter = TracingAdapter(model, input_args) load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs) inputs_schema = dump_dataclass(adapter.inputs_schema) outputs_schema = dump_dataclass(adapter.outputs_schema) assert "inputs_schema" not in load_kwargs assert "outputs_schema" not in load_kwargs load_kwargs.update({ "inputs_schema": inputs_schema, "outputs_schema": outputs_schema }) return load_kwargs
def _check_schema(self, schema): dumped_schema = dump_dataclass(schema) # Check that the schema is json-serializable # Although in reality you might want to use yaml because it often has many levels json.dumps(dumped_schema) # Check that the schema can be deserialized new_schema = instantiate(dumped_schema) self.assertEqual(schema, new_schema)