def execution_session_run_forward(execution_session, onnx_model, device, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # Assert that the input and model device match _utils._check_same_device(device, "Input argument to forward", *inputs) # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not # have the need for passing IOBinding. state = C.PartialGraphExecutionState() forward_inputs = C.OrtValueVector() for input in inputs: forward_inputs.append(_utils._ortvalue_from_torch_tensor(input)) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state) user_outputs = tuple( _utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs) # Assert that the outputs and model device match _utils._check_same_device(device, "Output argument from forward", *user_outputs) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not # have the need for passing IOBinding. state = C.PartialGraphExecutionState() forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: # TODO: Non-contiguous tensor input in execution_session_run_forward, need tensor copy. if not input.is_contiguous(): input = input.contiguous() if input.device.type == "ort": forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input)) else: valid_ort_tensor = _utils._torch_tensor_to_dlpack(input) forward_inputs.push_back(valid_ort_tensor, input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache( forward_outputs, device) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info
def execution_session_run_forward(execution_session, onnx_model, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not # have the need for passing IOBinding. state = C.PartialGraphExecutionState() forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: forward_inputs.push_back(to_dlpack(input), input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state) user_outputs = tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info
def execution_session_run_forward(execution_session, onnx_model, device, gradient_accumulation_manager, *inputs): """Runs the forward graph on execution_session with given model inputs and device""" # Clear all gradient functions, to avoid a deadlock issue. # Check the called function for more detailed comments. clear_all_grad_fns() # TODO: Try to reuse the output buffers as some of the output tensors are same sizes, # especially the backward graph outputs. # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not # have the need for passing IOBinding. state = C.PartialGraphExecutionState() forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: if input.device.type == 'ort': forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input)) else: valid_ort_tensor = _utils._torch_tensor_to_dlpack(input) forward_inputs.push_back(valid_ort_tensor, input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache( forward_outputs, device) output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info
def extract_outputs_and_maybe_update_cache(self, forward_outputs, device): """Extract the user outputs from the forward outputs as torch tensor and update cache, if needed Args: forward_outputs (OrtValueVector): List of outputs returned by forward function """ if not self.enabled: return _utils._ortvalues_to_torch_tensor(forward_outputs, device) if self._update_cache: for i in range(self._cache_start, len(forward_outputs)): self.cache.insert(self._cached_node_arg_names[i - self._cache_start], forward_outputs[i]) self._update_cache = False ort_value_vector = C.OrtValueVector() ort_value_vector.reserve(self._cache_start) for i in range(self._cache_start): ort_value_vector.push_back(forward_outputs[i]) return _utils._ortvalues_to_torch_tensor(ort_value_vector, device) # pylint: disable=W0212
def _ortvalues_to_torch_tensor_list(self, device, tensor_type, new_impl): narrays = [ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.float32), ] vect = C.OrtValueVector() vect.reserve(len(narrays)) ptr = [] for a in narrays: ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(a, device.type if device.type != "ort" else "cpu") vect.push_back(ortvalue._ortvalue) ptr.append(ortvalue.data_ptr()) self.assertEqual(len(vect), 2) if new_impl: tensors = _utils._ortvalues_to_torch_tensor(vect, device) else: tensors = _ortvalues_to_torch_tensor(vect, device) self.assertEqual(len(tensors), len(vect)) self.assertEqual(ptr, [t.data_ptr() for t in tensors]) assert all(map(lambda v: isinstance(v, tensor_type), tensors))
def backward(ctx, *grad_outputs): '''Performs backward pass based on grad wrt module output''' assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) # Use IO binding # Push user output grads to ONNX backend. contiguous_grad_outputs = [] for idx, grad_output in enumerate(grad_outputs): if idx in self._graph_info.output_grad_indices_non_differentiable: assert grad_output is None, "ORT found the {}-th module output '{}' is " \ "non-differentiable according to the onnx graph. " \ "However, the gradient value is still provided by " \ "PyTorch's autograd engine." \ .format(idx, self._graph_info.user_output_names[idx]) continue if grad_output is None: shape, device, dtype = ctx.run_info.output_info[idx] if idx in self._graph_info.output_grad_indices_require_full_shape: grad_output = torch.zeros(shape, device=device, dtype=dtype) else: grad_output = torch.tensor(0., device=device, dtype=dtype) elif not grad_output.is_contiguous(): grad_output = grad_output.contiguous() contiguous_grad_outputs.append(grad_output) # Run and get results backward_inputs = C.OrtValueVector() for input in contiguous_grad_outputs: backward_inputs.append( _utils._ortvalue_from_torch_tensor(input)) backward_outputs = C.OrtValueVector() self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not # affect peak memory usage in a subsequent graph run. del ctx.run_info.state # Return input and initializer gradients num_user_input_grads = len(self._input_info.require_grad_names) results = [] require_grad_names_set = set( self._input_info.require_grad_names) require_grad_names_index = 0 for input_name in self._graph_info.user_input_names: # Append to the results the backward output for each input that required grad if input_name in require_grad_names_set: results.append( _utils._ortvalue_to_torch_tensor( backward_outputs[require_grad_names_index])) require_grad_names_index += 1 else: # input_name is not found in the self._input_info.require_grad_names list # Append None to results for each input that did not require grad results.append(None) # Append gradients of initializer to results # Go over each initializer, check if it required grad and append to results accordingly initializer_index = num_user_input_grads for initializer_name in self._graph_info.initializer_names: if initializer_name in self._graph_initializer_names_to_train: results.append( _utils._ortvalue_to_torch_tensor( backward_outputs[initializer_index])) initializer_index += 1 else: results.append(None) return tuple(results)
def backward(ctx, *grad_outputs): """Performs backward pass based on grad wrt module output""" assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()" if self._skip_check.is_set( _SkipCheck.SKIP_CHECK_DEVICE) is False: _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) # Unpack saved_tensor to trigger version detection that catches inplace corruption _ = ctx.saved_tensors # Use IO binding # Push user output grads to ONNX backend. backward_inputs = C.OrtValueVector() # Preallocate length of the vector. And then delete as required towards the end. backward_inputs.reserve(len(grad_outputs)) for idx, grad_output in enumerate(grad_outputs): if idx in self._graph_info.output_grad_indices_non_differentiable: assert grad_output is None, ( "ORT found the {}-th module output '{}' is " "non-differentiable according to the onnx graph. " "However, the gradient value is still provided by " "PyTorch's autograd engine.".format( idx, self._graph_info.user_output_names[idx])) continue if grad_output is None: shape, device, dtype = ctx.run_info.output_info[idx] if idx in self._graph_info.output_grad_indices_require_full_shape: grad_output = torch.zeros(shape, device=device, dtype=dtype) else: grad_output = torch.tensor(0.0, device=device, dtype=dtype) elif not grad_output.is_contiguous(): grad_output = grad_output.contiguous() if grad_output.device.type == "ort": backward_inputs.push_back( C.aten_ort_tensor_to_ort_value(grad_output)) else: backward_inputs.push_back( _utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype is torch.bool) backward_inputs.shrink_to_fit() # Run and get results backward_outputs = C.OrtValueVector() self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not # affect peak memory usage in a subsequent graph run. del ctx.run_info.state # Fast version: all backward_outputs are converted first. # This version only works if backward_outputs is an OrtValueVector. transfered_backward_outputs = _utils._ortvalues_to_torch_tensor( backward_outputs, self._device) return tuple( transfered_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)