Пример #1
0
def allow_non_deterministic():
    prev_state = torch.are_deterministic_algorithms_enabled()
    try:
        torch.use_deterministic_algorithms(False)
        yield
    finally:
        torch.use_deterministic_algorithms(prev_state)
Пример #2
0
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
    """``torch.bincount`` currently does not support deterministic mode on GPU.

    This implementation fallback to a for-loop counting occurrences in that case.

    Args:
        x: tensor to count
        minlength: minimum length to count

    Returns:
        Number of occurrences for each unique element in x
    """
    if x.is_cuda and torch.are_deterministic_algorithms_enabled():
        if minlength is None:
            minlength = len(torch.unique(x))
        output = torch.zeros(minlength, device=x.device, dtype=torch.long)
        for i in range(minlength):
            output[i] = (x == i).sum()
        return output
    else:
        return torch.bincount(x, minlength=minlength)
Пример #3
0
    def forward(self, *inputs, **kwargs):
        """Forward pass of the inference model

        ONNX model is exported the first time this method is executed.
        Next, we build an optimized inference graph with module_graph_builder.
        Finally, we instantiate the ONNX Runtime InferenceSession through the InferenceAgent.
        """

        # Fallback to PyTorch due to failures *external* to forward(),
        #  typically from initialization
        if self._fallback_manager.is_pending():
            return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)

        try:
            # Issue at most one warning message about fast path
            if (
                self._first_skip_check_warning is True
                and self._skip_check.is_disabled() is False
                and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING
            ):
                self._first_skip_check_warning = False
                warnings.warn(
                    f"Fast path enabled - skipping checks."
                    f"rebuild gradient graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)},"
                    f"execution agent recreation: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)},"
                    f"device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}",
                    UserWarning,
                )

            # If exporting module to ONNX for the first time, this skip check will not take effect.
            # It will only take effect on subsequent forward calls.
            build_graph = False
            if (
                self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
                or not self._onnx_models.exported_model
            ):
                # Exporting module to ONNX for the first time
                build_graph = self._export_model(*inputs, **kwargs)
                if build_graph:
                    # If model was exported, then initialize the graph builder
                    self._initialize_graph_builder(training=False)

                # Build the inference graph
                if build_graph:
                    self._build_graph()

            # If creating the execution agent for the first time, this skip check will not take effect.
            # It will only take effect on subsequent forward calls.
            create_execution_session = False
            if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent:
                module_device = _utils.get_device_from_module(self._original_module)

                create_execution_session = (
                    build_graph
                    or self._device != module_device
                    or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled()
                )
                _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled())

                if self._device != module_device:
                    self._device = module_device

            if create_execution_session:
                # Create execution session creates the inference_session
                self._create_execution_agent()

            if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
                # Assert that the input and model device match
                _utils._check_same_device(self._device, "Input argument to forward", *inputs)

            user_outputs, _ = InferenceManager.execution_session_run_forward(
                self._execution_agent,
                self._onnx_models.optimized_model,
                self._device,
                *_io._combine_input_buffers_initializers(
                    self._graph_initializers,
                    self._graph_info.user_input_names,
                    self._input_info,
                    self._flattened_module.named_buffers(),
                    inputs,
                    kwargs,
                    self._device,
                ),
            )

            return _io.unflatten_user_output(self._module_output_schema, user_outputs)
        except ORTModuleFallbackException as e:
            # Exceptions subject to fallback are handled here
            self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level)
        except Exception as e:
            # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here
            self._fallback_manager.handle_exception(
                exception=e,
                log_level=self._debug_options.logging.log_level,
                override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD,
            )
        # Fallback to PyTorch due to failures *during* forward(),
        #  (e.g. export, model/input post-processing, forward, output processing, etc)
        if self._fallback_manager.is_pending():
            return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
Пример #4
0
from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed

################################################################################
# All global constant goes here, before ORTModule is imported ##################
################################################################################
ONNX_OPSET_VERSION = 12
MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1'
ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__),
                                       'torch_cpp_extensions')
_FALLBACK_INIT_EXCEPTION = None
ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE |\
                            _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA |\
                            _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL |\
                            _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL
ORTMODULE_FALLBACK_RETRY = False
ORTMODULE_IS_DETERMINISTIC = torch.are_deterministic_algorithms_enabled()

ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr(
    ort_info, 'cuda_version') else ''
ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr(
    ort_info, 'rocm_version') else ''

# Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization
try:
    import torch
    runtime_pytorch_version = version.parse(torch.__version__.split('+')[0])
    minimum_runtime_pytorch_version = version.parse(
        MINIMUM_RUNTIME_PYTORCH_VERSION_STR)
    if runtime_pytorch_version < minimum_runtime_pytorch_version:
        raise wrap_exception(
            ORTModuleInitException,
Пример #5
0
    def forward(self, *inputs, **kwargs):
        """Forward pass starts here and continues at `_ORTModuleFunction.forward`

        ONNX model is exported the first time this method is executed.
        Next, we build a full training graph with module_graph_builder.
        Finally, we instantiate the ONNX Runtime InferenceSession.
        """

        # Fallback to PyTorch due to failures *external* to forward(),
        #  typically from initialization
        if self._fallback_manager.is_pending():
            return self._fallback_manager.fallback(
                self._debug_options.logging.log_level, *inputs, **kwargs)

        try:
            if (self._first_skip_check_warning is True
                    and self._skip_check.is_disabled() is False
                    and self._debug_options.logging.log_level <=
                    _logger.LogLevel.WARNING):
                # Only change this after the firs time a warning is issued.
                self._first_skip_check_warning = False
                warnings.warn(
                    f"Fast path enabled - skipping checks."
                    f" Rebuild graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)},"
                    f" Execution agent: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)},"
                    f" Device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}",
                    UserWarning,
                )

            # If exporting module to ONNX for the first time, this skip check will not take effect.
            # It will only take effect on subsequent forward calls.
            build_gradient_graph = False
            if (self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)
                    is False or not self._onnx_models.exported_model):
                build_gradient_graph = self._export_model(*inputs, **kwargs)
                if build_gradient_graph:
                    # If model was exported, then initialize the graph builder
                    self._initialize_graph_builder(training=True)

                # since the schema was just extracted while trying to export the model and it was either
                # saved to self._input_info.schema or checked for equality with the self._input_info.schema
                # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export.
                input_info = _io.parse_inputs_for_onnx_export(
                    self._module_parameters, self._onnx_models.exported_model,
                    self._input_info.schema, inputs, kwargs)

                # Reinitialize graph builder if the inputs or initializers requiring gradient have changed.
                # Order of or operation is important here because we always need to call
                # _reinitialize_graph_builder irrespective of the value of build_gradient_graph.
                build_gradient_graph = self._reinitialize_graph_builder(
                    input_info) or build_gradient_graph

                # Build the gradient graph
                if build_gradient_graph:
                    self._build_graph()

            # If creating the execution agent for the first time, this skip check will not take effect.
            # It will only take effect on subsequent forward calls.
            create_execution_session = False
            if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT
                                       ) is False or not self._execution_agent:
                device = _utils.get_device_from_module(
                    self._original_module) or _utils.get_device_from_inputs(
                        inputs, kwargs)
                create_execution_session = (
                    build_gradient_graph or self._device != device
                    or torch.are_deterministic_algorithms_enabled()
                    is not _are_deterministic_algorithms_enabled())
                _use_deterministic_algorithms(
                    torch.are_deterministic_algorithms_enabled())
                if self._device != device:
                    self._device = device

            if create_execution_session:
                # Create execution session creates the training_session
                self._create_execution_agent()

                self._gradient_accumulation_manager.initialize(
                    self._enable_grad_acc_optimization, self._flattened_module,
                    self._graph_info)

            self._gradient_accumulation_manager.maybe_update_cache_before_run()

            return _io.unflatten_user_output(
                self._module_output_schema,
                self._forward_class.apply(
                    *_io._combine_input_buffers_initializers(
                        self._graph_initializers,
                        self._graph_info.user_input_names,
                        self._input_info,
                        self._flattened_module.named_buffers(),
                        inputs,
                        kwargs,
                        self._device,
                    )),
            )
        except ORTModuleFallbackException as e:
            # Exceptions subject to fallback are handled here
            self._fallback_manager.handle_exception(
                exception=e, log_level=self._debug_options.logging.log_level)
        except Exception as e:
            # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here
            self._fallback_manager.handle_exception(
                exception=e,
                log_level=self._debug_options.logging.log_level,
                override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD,
            )

        # Fallback to PyTorch due to failures *during* forward(),
        #  (e.g. export, model/input post-processing, forward, output processing, etc)
        if self._fallback_manager.is_pending():
            return self._fallback_manager.fallback(
                self._debug_options.logging.log_level, *inputs, **kwargs)