コード例 #1
0
 def export(cls, model, input_args, save_path, **export_kwargs):
     adapter = TracingAdapter(model, input_args)
     trace_and_save_torchscript(
         adapter, adapter.flattened_inputs, save_path, **export_kwargs
     )
     inputs_schema = dump_dataclass(adapter.inputs_schema)
     outputs_schema = dump_dataclass(adapter.outputs_schema)
     return {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
コード例 #2
0
    def new_f(cls, model, input_args, save_path, export_method,
              **export_kwargs):
        force_disable_tracing_adapter = export_kwargs.pop(
            "force_disable_tracing_adapter", False)
        is_trace_mode = export_kwargs.get("jit_mode", "trace") == "trace"
        if force_disable_tracing_adapter or not is_trace_mode:
            logger.info("Not trace mode, export normally")
            return old_f(cls, model, input_args, save_path, export_method,
                         **export_kwargs)

        if _is_data_flattened_tensors(input_args):
            logger.info(
                "Dry run the model to check if TracingAdapter is needed ...")
            outputs = model(*input_args)
            if _is_data_flattened_tensors(outputs):
                logger.info(
                    "Both inputs and outputs are flattened tensors, export the model as is."
                )
                load_kwargs = old_f(cls, model, input_args, save_path,
                                    export_method, **export_kwargs)
                assert "tracing_adapted" not in load_kwargs
                load_kwargs.update({"tracing_adapted": False})
                return load_kwargs
            else:
                logger.info(
                    "The outputs are not flattened tensors, can't trace normally."
                )
        else:
            logger.info(
                "The inputs are not flattened tensors, can't trace normally.")

        logger.warning(
            "Wrap the model with TracingAdapter to handle non-flattened inputs/outputs,"
            " please be aware that the exported model will have different input/output data structure."
        )
        adapter = TracingAdapter(model, input_args)
        load_kwargs = old_f(
            cls,
            adapter,
            adapter.flattened_inputs,
            save_path,
            export_method,
            **export_kwargs,
        )
        inputs_schema = dump_dataclass(adapter.inputs_schema)
        outputs_schema = dump_dataclass(adapter.outputs_schema)
        assert "tracing_adapted" not in load_kwargs
        assert "inputs_schema" not in load_kwargs
        assert "outputs_schema" not in load_kwargs
        load_kwargs.update({
            "tracing_adapted": True,
            "inputs_schema": inputs_schema,
            "outputs_schema": outputs_schema,
        })
        return load_kwargs
コード例 #3
0
 def new_f(cls, model, input_args, *args, **kwargs):
     adapter = TracingAdapter(model, input_args)
     load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args,
                         **kwargs)
     inputs_schema = dump_dataclass(adapter.inputs_schema)
     outputs_schema = dump_dataclass(adapter.outputs_schema)
     assert "inputs_schema" not in load_kwargs
     assert "outputs_schema" not in load_kwargs
     load_kwargs.update({
         "inputs_schema": inputs_schema,
         "outputs_schema": outputs_schema
     })
     return load_kwargs
コード例 #4
0
    def _check_schema(self, schema):
        dumped_schema = dump_dataclass(schema)
        # Check that the schema is json-serializable
        # Although in reality you might want to use yaml because it often has many levels
        json.dumps(dumped_schema)

        # Check that the schema can be deserialized
        new_schema = instantiate(dumped_schema)
        self.assertEqual(schema, new_schema)