Ejemplo n.º 1
0
def reinitialize_ortmodule(ortmodule):
    # Re-register contrib OPs
    pytorch_export_contrib_ops.register()
    CustomOpSymbolicRegistry.register_all()
    CustomGradientRegistry.register_all()

    # Re-initialize the ORTModule forward method
    patch_ortmodule_forward_method(ortmodule)

    # Re-bind users custom methods to ORTModule
    check_for_name_collisions_and_bind_methods_to_ortmodule(
        ortmodule, ortmodule.module)
 def setUp(self):
     torch.manual_seed(0)
     pytorch_export_contrib_ops.register()
Ejemplo n.º 3
0
    def __init__(self, module, debug_options=None):

        # NOTE: torch.nn.Modules that call setattr on their internal attributes regularly
        #       (for example PyTorch Lightning), will trigger regular re-exports. This is
        #       because ORTModule auto detects such setattrs on the original module and
        #       marks the model as stale. This is a known limitation. To disable repeated
        #       re-export checks when not required, please set the environment variable
        #       ORTMODULE_SKIPCHECK_POLICY to SKIP_CHECK_BUILD_GRADIENT|SKIP_CHECK_EXECUTION_AGENT

        # Set _is_initialized attribute first which starts off as False.
        # This variable will be used for comparing strings in __setattr__ and __getattr__
        # NOTE: Do not rename/move.
        self._is_initialized = False
        # Python default arguments are evaluated on function definition
        # and not on function invocation. So, if debug_options is not provided,
        # instantiate it inside the function.
        if not debug_options:
            debug_options = DebugOptions()

        # Fallback settings
        self._fallback_manager = _FallbackManager(
            pytorch_module=module,
            policy=ORTMODULE_FALLBACK_POLICY,
            retry=ORTMODULE_FALLBACK_RETRY)

        try:
            # Read ORTModule module initialization status
            if _FALLBACK_INIT_EXCEPTION:
                raise _FALLBACK_INIT_EXCEPTION

            super(ORTModule, self).__init__()

            self._torch_module = TorchModuleFactory()(module, debug_options,
                                                      self._fallback_manager)

            # Create forward dynamically, so each ORTModule instance will have its own copy.
            # This is needed to be able to copy the forward signatures from the original PyTorch models
            # and possibly have different signatures for different instances.
            def _forward(self, *inputs, **kwargs):
                '''Forward pass starts here and continues at `_ORTModuleFunction.forward`

                ONNX model is exported the first time this method is executed.
                Next, we build a full training graph with module_gradient_graph_builder.
                Finally, we instantiate the ONNX Runtime InferenceSession.
                '''

                return self._torch_module.forward(*inputs, **kwargs)

            # Bind the forward method.
            self.forward = _forward.__get__(self)
            # Copy the forward signature from the _torch_module's forward signature.
            functools.update_wrapper(self.forward.__func__,
                                     self._torch_module.forward.__func__)

            # Support contrib OPs
            pytorch_export_contrib_ops.register()
            CustomOpSymbolicRegistry.register_all()
            CustomGradientRegistry.register_all()

            # Warn user if there are name collisions between user model's and ORTModule attributes
            # And if there are custom methods defined on the user's model, copy and bind them to
            # ORTModule.
            _utils.check_for_name_collisions_and_bind_methods_to_ortmodule(
                self, module)

        except ORTModuleFallbackException as e:
            # Although backend is switched to PyTorch here,
            # it is up to _FallbackManager to actually terminate execution or fallback
            _utils.switch_backend_to_pytorch(self, module)

            # Exceptions subject to fallback are handled here
            self._fallback_manager.handle_exception(
                exception=e, log_level=debug_options.logging.log_level)
        except Exception as e:
            # Although backend is switched to PyTorch here,
            # it is up to _FallbackManager to actually terminate execution or fallback
            _utils.switch_backend_to_pytorch(self, module)

            # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here
            self._fallback_manager.handle_exception(
                exception=e,
                log_level=debug_options.logging.log_level,
                override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD)

        # Finally, ORTModule initialization is complete.
        # Assign self._is_initialized to True after all the ORTModule class attributes have been assigned
        # else, they will be assigned to self._torch_module.original_module instead.
        self._is_initialized = True