def allow_non_deterministic(): prev_state = torch.are_deterministic_algorithms_enabled() try: torch.use_deterministic_algorithms(False) yield finally: torch.use_deterministic_algorithms(prev_state)
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """``torch.bincount`` currently does not support deterministic mode on GPU. This implementation fallback to a for-loop counting occurrences in that case. Args: x: tensor to count minlength: minimum length to count Returns: Number of occurrences for each unique element in x """ if x.is_cuda and torch.are_deterministic_algorithms_enabled(): if minlength is None: minlength = len(torch.unique(x)) output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() return output else: return torch.bincount(x, minlength=minlength)
def forward(self, *inputs, **kwargs): """Forward pass of the inference model ONNX model is exported the first time this method is executed. Next, we build an optimized inference graph with module_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession through the InferenceAgent. """ # Fallback to PyTorch due to failures *external* to forward(), # typically from initialization if self._fallback_manager.is_pending(): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) try: # Issue at most one warning message about fast path if ( self._first_skip_check_warning is True and self._skip_check.is_disabled() is False and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING ): self._first_skip_check_warning = False warnings.warn( f"Fast path enabled - skipping checks." f"rebuild gradient graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)}," f"execution agent recreation: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)}," f"device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", UserWarning, ) # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_graph = False if ( self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model ): # Exporting module to ONNX for the first time build_graph = self._export_model(*inputs, **kwargs) if build_graph: # If model was exported, then initialize the graph builder self._initialize_graph_builder(training=False) # Build the inference graph if build_graph: self._build_graph() # If creating the execution agent for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. create_execution_session = False if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent: module_device = _utils.get_device_from_module(self._original_module) create_execution_session = ( build_graph or self._device != module_device or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled() ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != module_device: self._device = module_device if create_execution_session: # Create execution session creates the inference_session self._create_execution_agent() if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) user_outputs, _ = InferenceManager.execution_session_run_forward( self._execution_agent, self._onnx_models.optimized_model, self._device, *_io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), inputs, kwargs, self._device, ), ) return _io.unflatten_user_output(self._module_output_schema, user_outputs) except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) except Exception as e: # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here self._fallback_manager.handle_exception( exception=e, log_level=self._debug_options.logging.log_level, override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD, ) # Fallback to PyTorch due to failures *during* forward(), # (e.g. export, model/input post-processing, forward, output processing, etc) if self._fallback_manager.is_pending(): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs)
from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed ################################################################################ # All global constant goes here, before ORTModule is imported ################## ################################################################################ ONNX_OPSET_VERSION = 12 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1' ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), 'torch_cpp_extensions') _FALLBACK_INIT_EXCEPTION = None ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE |\ _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA |\ _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL |\ _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL ORTMODULE_FALLBACK_RETRY = False ORTMODULE_IS_DETERMINISTIC = torch.are_deterministic_algorithms_enabled() ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr( ort_info, 'cuda_version') else '' ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr( ort_info, 'rocm_version') else '' # Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization try: import torch runtime_pytorch_version = version.parse(torch.__version__.split('+')[0]) minimum_runtime_pytorch_version = version.parse( MINIMUM_RUNTIME_PYTORCH_VERSION_STR) if runtime_pytorch_version < minimum_runtime_pytorch_version: raise wrap_exception( ORTModuleInitException,
def forward(self, *inputs, **kwargs): """Forward pass starts here and continues at `_ORTModuleFunction.forward` ONNX model is exported the first time this method is executed. Next, we build a full training graph with module_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession. """ # Fallback to PyTorch due to failures *external* to forward(), # typically from initialization if self._fallback_manager.is_pending(): return self._fallback_manager.fallback( self._debug_options.logging.log_level, *inputs, **kwargs) try: if (self._first_skip_check_warning is True and self._skip_check.is_disabled() is False and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING): # Only change this after the firs time a warning is issued. self._first_skip_check_warning = False warnings.warn( f"Fast path enabled - skipping checks." f" Rebuild graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)}," f" Execution agent: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)}," f" Device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", UserWarning, ) # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_gradient_graph = False if (self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or not self._onnx_models.exported_model): build_gradient_graph = self._export_model(*inputs, **kwargs) if build_gradient_graph: # If model was exported, then initialize the graph builder self._initialize_graph_builder(training=True) # since the schema was just extracted while trying to export the model and it was either # saved to self._input_info.schema or checked for equality with the self._input_info.schema # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. input_info = _io.parse_inputs_for_onnx_export( self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs) # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. # Order of or operation is important here because we always need to call # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. build_gradient_graph = self._reinitialize_graph_builder( input_info) or build_gradient_graph # Build the gradient graph if build_gradient_graph: self._build_graph() # If creating the execution agent for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. create_execution_session = False if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT ) is False or not self._execution_agent: device = _utils.get_device_from_module( self._original_module) or _utils.get_device_from_inputs( inputs, kwargs) create_execution_session = ( build_gradient_graph or self._device != device or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled()) _use_deterministic_algorithms( torch.are_deterministic_algorithms_enabled()) if self._device != device: self._device = device if create_execution_session: # Create execution session creates the training_session self._create_execution_agent() self._gradient_accumulation_manager.initialize( self._enable_grad_acc_optimization, self._flattened_module, self._graph_info) self._gradient_accumulation_manager.maybe_update_cache_before_run() return _io.unflatten_user_output( self._module_output_schema, self._forward_class.apply( *_io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), inputs, kwargs, self._device, )), ) except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here self._fallback_manager.handle_exception( exception=e, log_level=self._debug_options.logging.log_level) except Exception as e: # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here self._fallback_manager.handle_exception( exception=e, log_level=self._debug_options.logging.log_level, override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD, ) # Fallback to PyTorch due to failures *during* forward(), # (e.g. export, model/input post-processing, forward, output processing, etc) if self._fallback_manager.is_pending(): return self._fallback_manager.fallback( self._debug_options.logging.log_level, *inputs, **kwargs)