Example #1
0
    def execution_session_run_forward(execution_session, onnx_model, device,
                                      *inputs):
        """Runs the forward graph on execution_session with given model inputs and device"""

        # Assert that the input and model device match
        _utils._check_same_device(device, "Input argument to forward", *inputs)

        # TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
        #   especially the backward graph outputs.
        # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
        # have the need for passing IOBinding.
        state = C.PartialGraphExecutionState()
        forward_inputs = C.OrtValueVector()
        for input in inputs:
            forward_inputs.append(_utils._ortvalue_from_torch_tensor(input))

        forward_outputs = C.OrtValueVector()
        # Run and return module outputs.
        execution_session.run_forward(forward_inputs, forward_outputs, state)
        user_outputs = tuple(
            _utils._ortvalue_to_torch_tensor(forward_output)
            for forward_output in forward_outputs)

        # Assert that the outputs and model device match
        _utils._check_same_device(device, "Output argument from forward",
                                  *user_outputs)

        output_info = [(output.shape, output.device, output.dtype)
                       for output in user_outputs]
        run_info = RunStateInfo(state, output_info)
        # Return user outputs and forward run information
        return user_outputs, run_info
Example #2
0
    def execution_session_run_forward(execution_session, onnx_model, device,
                                      gradient_accumulation_manager, *inputs):
        """Runs the forward graph on execution_session with given model inputs and device"""

        # TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
        #   especially the backward graph outputs.
        # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
        # have the need for passing IOBinding.
        state = C.PartialGraphExecutionState()
        forward_inputs = C.OrtValueVector()
        forward_inputs.reserve(len(inputs))
        for input in inputs:
            # TODO: Non-contiguous tensor input in execution_session_run_forward, need tensor copy.
            if not input.is_contiguous():
                input = input.contiguous()
            if input.device.type == "ort":
                forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input))
            else:
                valid_ort_tensor = _utils._torch_tensor_to_dlpack(input)
                forward_inputs.push_back(valid_ort_tensor,
                                         input.dtype == torch.bool)

        forward_outputs = C.OrtValueVector()
        # Run and return module outputs.
        execution_session.run_forward(forward_inputs, forward_outputs, state,
                                      gradient_accumulation_manager.cache)

        user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(
            forward_outputs, device)

        output_info = [(output.shape, output.device, output.dtype)
                       for output in user_outputs]
        run_info = _RunStateInfo(state, output_info)
        # Return user outputs and forward run information
        return user_outputs, run_info
Example #3
0
    def execution_session_run_forward(execution_session, onnx_model, *inputs):
        """Runs the forward graph on execution_session with given model inputs and device"""

        # TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
        #   especially the backward graph outputs.
        # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
        # have the need for passing IOBinding.
        state = C.PartialGraphExecutionState()
        forward_inputs = C.OrtValueVector()
        forward_inputs.reserve(len(inputs))
        for input in inputs:
            forward_inputs.push_back(to_dlpack(input), input.dtype == torch.bool)

        forward_outputs = C.OrtValueVector()
        # Run and return module outputs.
        execution_session.run_forward(forward_inputs, forward_outputs, state)
        user_outputs = tuple(_utils._ortvalue_to_torch_tensor(forward_output) for forward_output in forward_outputs)

        output_info = [(output.shape, output.device, output.dtype) for output in user_outputs]
        run_info = RunStateInfo(state, output_info)
        # Return user outputs and forward run information
        return user_outputs, run_info
Example #4
0
    def execution_session_run_forward(execution_session, onnx_model, device,
                                      gradient_accumulation_manager, *inputs):
        """Runs the forward graph on execution_session with given model inputs and device"""

        # Clear all gradient functions, to avoid a deadlock issue.
        # Check the called function for more detailed comments.
        clear_all_grad_fns()

        # TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
        #   especially the backward graph outputs.
        # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
        # have the need for passing IOBinding.
        state = C.PartialGraphExecutionState()
        forward_inputs = C.OrtValueVector()
        forward_inputs.reserve(len(inputs))
        for input in inputs:
            if input.device.type == 'ort':
                forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input))
            else:
                valid_ort_tensor = _utils._torch_tensor_to_dlpack(input)
                forward_inputs.push_back(valid_ort_tensor,
                                         input.dtype == torch.bool)

        forward_outputs = C.OrtValueVector()
        # Run and return module outputs.
        execution_session.run_forward(forward_inputs, forward_outputs, state,
                                      gradient_accumulation_manager.cache)

        user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(
            forward_outputs, device)

        output_info = [(output.shape, output.device, output.dtype)
                       for output in user_outputs]
        run_info = _RunStateInfo(state, output_info)
        # Return user outputs and forward run information
        return user_outputs, run_info
Example #5
0
    def extract_outputs_and_maybe_update_cache(self, forward_outputs, device):
        """Extract the user outputs from the forward outputs as torch tensor and update cache, if needed

        Args:
            forward_outputs (OrtValueVector): List of outputs returned by forward function
        """
        if not self.enabled:
            return _utils._ortvalues_to_torch_tensor(forward_outputs, device)
        if self._update_cache:
            for i in range(self._cache_start, len(forward_outputs)):
                self.cache.insert(self._cached_node_arg_names[i - self._cache_start], forward_outputs[i])
            self._update_cache = False
        ort_value_vector = C.OrtValueVector()
        ort_value_vector.reserve(self._cache_start)
        for i in range(self._cache_start):
            ort_value_vector.push_back(forward_outputs[i])
        return _utils._ortvalues_to_torch_tensor(ort_value_vector, device)  # pylint: disable=W0212
Example #6
0
 def _ortvalues_to_torch_tensor_list(self, device, tensor_type, new_impl):
     narrays = [
         np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32),
         np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.float32),
     ]
     vect = C.OrtValueVector()
     vect.reserve(len(narrays))
     ptr = []
     for a in narrays:
         ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(a, device.type if device.type != "ort" else "cpu")
         vect.push_back(ortvalue._ortvalue)
         ptr.append(ortvalue.data_ptr())
     self.assertEqual(len(vect), 2)
     if new_impl:
         tensors = _utils._ortvalues_to_torch_tensor(vect, device)
     else:
         tensors = _ortvalues_to_torch_tensor(vect, device)
     self.assertEqual(len(tensors), len(vect))
     self.assertEqual(ptr, [t.data_ptr() for t in tensors])
     assert all(map(lambda v: isinstance(v, tensor_type), tensors))
Example #7
0
            def backward(ctx, *grad_outputs):
                '''Performs backward pass based on grad wrt module output'''

                assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
                _utils._check_same_device(self._device,
                                          "Input argument to backward",
                                          *grad_outputs)

                # Use IO binding
                # Push user output grads to ONNX backend.
                contiguous_grad_outputs = []
                for idx, grad_output in enumerate(grad_outputs):
                    if idx in self._graph_info.output_grad_indices_non_differentiable:
                        assert grad_output is None, "ORT found the {}-th module output '{}' is " \
                                                    "non-differentiable according to the onnx graph. " \
                                                    "However, the gradient value is still provided by " \
                                                    "PyTorch's autograd engine." \
                                                    .format(idx, self._graph_info.user_output_names[idx])
                        continue

                    if grad_output is None:
                        shape, device, dtype = ctx.run_info.output_info[idx]
                        if idx in self._graph_info.output_grad_indices_require_full_shape:
                            grad_output = torch.zeros(shape,
                                                      device=device,
                                                      dtype=dtype)
                        else:
                            grad_output = torch.tensor(0.,
                                                       device=device,
                                                       dtype=dtype)
                    elif not grad_output.is_contiguous():
                        grad_output = grad_output.contiguous()
                    contiguous_grad_outputs.append(grad_output)

                # Run and get results
                backward_inputs = C.OrtValueVector()
                for input in contiguous_grad_outputs:
                    backward_inputs.append(
                        _utils._ortvalue_from_torch_tensor(input))

                backward_outputs = C.OrtValueVector()
                self._execution_agent.run_backward(backward_inputs,
                                                   backward_outputs,
                                                   ctx.run_info.state)
                # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
                # affect peak memory usage in a subsequent graph run.
                del ctx.run_info.state
                # Return input and initializer gradients
                num_user_input_grads = len(self._input_info.require_grad_names)
                results = []
                require_grad_names_set = set(
                    self._input_info.require_grad_names)
                require_grad_names_index = 0
                for input_name in self._graph_info.user_input_names:
                    # Append to the results the backward output for each input that required grad
                    if input_name in require_grad_names_set:
                        results.append(
                            _utils._ortvalue_to_torch_tensor(
                                backward_outputs[require_grad_names_index]))
                        require_grad_names_index += 1
                    else:
                        # input_name is not found in the self._input_info.require_grad_names list
                        # Append None to results for each input that did not require grad
                        results.append(None)

                # Append gradients of initializer to results
                # Go over each initializer, check if it required grad and append to results accordingly
                initializer_index = num_user_input_grads
                for initializer_name in self._graph_info.initializer_names:
                    if initializer_name in self._graph_initializer_names_to_train:
                        results.append(
                            _utils._ortvalue_to_torch_tensor(
                                backward_outputs[initializer_index]))
                        initializer_index += 1
                    else:
                        results.append(None)

                return tuple(results)
Example #8
0
            def backward(ctx, *grad_outputs):
                """Performs backward pass based on grad wrt module output"""

                assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()"
                if self._skip_check.is_set(
                        _SkipCheck.SKIP_CHECK_DEVICE) is False:
                    _utils._check_same_device(self._device,
                                              "Input argument to backward",
                                              *grad_outputs)

                # Unpack saved_tensor to trigger version detection that catches inplace corruption
                _ = ctx.saved_tensors

                # Use IO binding
                # Push user output grads to ONNX backend.
                backward_inputs = C.OrtValueVector()
                # Preallocate length of the vector. And then delete as required towards the end.
                backward_inputs.reserve(len(grad_outputs))
                for idx, grad_output in enumerate(grad_outputs):
                    if idx in self._graph_info.output_grad_indices_non_differentiable:
                        assert grad_output is None, (
                            "ORT found the {}-th module output '{}' is "
                            "non-differentiable according to the onnx graph. "
                            "However, the gradient value is still provided by "
                            "PyTorch's autograd engine.".format(
                                idx, self._graph_info.user_output_names[idx]))
                        continue

                    if grad_output is None:
                        shape, device, dtype = ctx.run_info.output_info[idx]
                        if idx in self._graph_info.output_grad_indices_require_full_shape:
                            grad_output = torch.zeros(shape,
                                                      device=device,
                                                      dtype=dtype)
                        else:
                            grad_output = torch.tensor(0.0,
                                                       device=device,
                                                       dtype=dtype)
                    elif not grad_output.is_contiguous():
                        grad_output = grad_output.contiguous()
                    if grad_output.device.type == "ort":
                        backward_inputs.push_back(
                            C.aten_ort_tensor_to_ort_value(grad_output))
                    else:
                        backward_inputs.push_back(
                            _utils._torch_tensor_to_dlpack(grad_output),
                            grad_output.dtype is torch.bool)
                backward_inputs.shrink_to_fit()

                # Run and get results
                backward_outputs = C.OrtValueVector()
                self._execution_agent.run_backward(backward_inputs,
                                                   backward_outputs,
                                                   ctx.run_info.state)
                # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
                # affect peak memory usage in a subsequent graph run.
                del ctx.run_info.state

                # Fast version: all backward_outputs are converted first.
                # This version only works if backward_outputs is an OrtValueVector.
                transfered_backward_outputs = _utils._ortvalues_to_torch_tensor(
                    backward_outputs, self._device)
                return tuple(
                    transfered_backward_outputs[idx] if idx != -1 else None
                    for idx in self._gradient_map)