def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not # have the need for passing IOBinding. state = C.PartialGraphExecutionState() forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: # TODO: Non-contiguous tensor input in execution_session_run_forward, need tensor copy. if not input.is_contiguous(): input = input.contiguous() if input.device.type == "ort": forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input)) else: valid_ort_tensor = _utils._torch_tensor_to_dlpack(input) forward_inputs.push_back(valid_ort_tensor, input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache( forward_outputs, device) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info
def _ortvalue_from_torch_tensor(torch_tensor): # TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit. # https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41 # We need to convert bool tensor to unit8 tensor to workaround this. # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack # and PyTorch support bool type. is_bool_tensor = torch_tensor.dtype == torch.bool if is_bool_tensor and LooseVersion( torch.__version__) >= LooseVersion("1.10.0"): torch_tensor = torch_tensor.to(torch.uint8) if torch_tensor.device.type == "ort": return C.aten_ort_tensor_to_ort_value(torch_tensor) return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), is_bool_tensor)
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # Clear all gradient functions, to avoid a deadlock issue. # Check the called function for more detailed comments. clear_all_grad_fns() # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not # have the need for passing IOBinding. state = C.PartialGraphExecutionState() forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: if input.device.type == 'ort': forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input)) else: valid_ort_tensor = _utils._torch_tensor_to_dlpack(input) forward_inputs.push_back(valid_ort_tensor, input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache( forward_outputs, device) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info
def backward(ctx, *grad_outputs): """Performs backward pass based on grad wrt module output""" assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()" if self._skip_check.is_set( _SkipCheck.SKIP_CHECK_DEVICE) is False: _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) # Unpack saved_tensor to trigger version detection that catches inplace corruption _ = ctx.saved_tensors # Use IO binding # Push user output grads to ONNX backend. backward_inputs = C.OrtValueVector() # Preallocate length of the vector. And then delete as required towards the end. backward_inputs.reserve(len(grad_outputs)) for idx, grad_output in enumerate(grad_outputs): if idx in self._graph_info.output_grad_indices_non_differentiable: assert grad_output is None, ( "ORT found the {}-th module output '{}' is " "non-differentiable according to the onnx graph. " "However, the gradient value is still provided by " "PyTorch's autograd engine.".format( idx, self._graph_info.user_output_names[idx])) continue if grad_output is None: shape, device, dtype = ctx.run_info.output_info[idx] if idx in self._graph_info.output_grad_indices_require_full_shape: grad_output = torch.zeros(shape, device=device, dtype=dtype) else: grad_output = torch.tensor(0.0, device=device, dtype=dtype) elif not grad_output.is_contiguous(): grad_output = grad_output.contiguous() if grad_output.device.type == "ort": backward_inputs.push_back( C.aten_ort_tensor_to_ort_value(grad_output)) else: backward_inputs.push_back( _utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype is torch.bool) backward_inputs.shrink_to_fit() # Run and get results backward_outputs = C.OrtValueVector() self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not # affect peak memory usage in a subsequent graph run. del ctx.run_info.state # Fast version: all backward_outputs are converted first. # This version only works if backward_outputs is an OrtValueVector. transfered_backward_outputs = _utils._ortvalues_to_torch_tensor( backward_outputs, self._device) return tuple( transfered_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
def backward(ctx, *grad_outputs): '''Performs backward pass based on grad wrt module output''' assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' if self._skip_check.is_set( _SkipCheck.SKIP_CHECK_DEVICE) is False: _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) # Unpack saved_tensor to trigger version detection that catches inplace corruption _ = ctx.saved_tensors # Use IO binding # Push user output grads to ONNX backend. backward_inputs = C.OrtValueVector() # Preallocate length of the vector. And then delete as required towards the end. backward_inputs.reserve(len(grad_outputs)) for idx, grad_output in enumerate(grad_outputs): if idx in self._graph_info.output_grad_indices_non_differentiable: assert grad_output is None, "ORT found the {}-th module output '{}' is " \ "non-differentiable according to the onnx graph. " \ "However, the gradient value is still provided by " \ "PyTorch's autograd engine." \ .format(idx, self._graph_info.user_output_names[idx]) continue if grad_output is None: shape, device, dtype = ctx.run_info.output_info[idx] if idx in self._graph_info.output_grad_indices_require_full_shape: grad_output = torch.zeros(shape, device=device, dtype=dtype) else: grad_output = torch.tensor(0., device=device, dtype=dtype) elif not grad_output.is_contiguous(): grad_output = grad_output.contiguous() if grad_output.device.type == 'ort': backward_inputs.push_back( C.aten_ort_tensor_to_ort_value(grad_output)) else: backward_inputs.push_back( _utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype is torch.bool) backward_inputs.shrink_to_fit() # Run and get results backward_outputs = C.OrtValueVector() self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not # affect peak memory usage in a subsequent graph run. del ctx.run_info.state # Return input and initializer gradients num_user_input_grads = len(self._input_info.require_grad_names) results = [] require_grad_names_set = set( self._input_info.require_grad_names) require_grad_names_index = 0 for input_name in self._graph_info.user_input_names: # Append to the results the backward output for each input that required grad if input_name in require_grad_names_set: results.append( _utils._ortvalue_to_torch_tensor( backward_outputs[require_grad_names_index], self._device)) require_grad_names_index += 1 else: # input_name is not found in the self._input_info.require_grad_names list # Append None to results for each input that did not require grad results.append(None) # Append gradients of initializer to results # Go over each initializer, check if it required grad and append to results accordingly initializer_index = num_user_input_grads for initializer_name in self._graph_info.initializer_names: if initializer_name in self._graph_initializer_names_to_train: results.append( _utils._ortvalue_to_torch_tensor( backward_outputs[initializer_index], self._device)) initializer_index += 1 else: results.append(None) return tuple(results)