Ejemplo n.º 1
0
    def execution_session_run_forward(execution_session, onnx_model, device,
                                      gradient_accumulation_manager, *inputs):
        """Runs the forward graph on execution_session with given model inputs and device"""

        # TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
        #   especially the backward graph outputs.
        # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
        # have the need for passing IOBinding.
        state = C.PartialGraphExecutionState()
        forward_inputs = C.OrtValueVector()
        forward_inputs.reserve(len(inputs))
        for input in inputs:
            # TODO: Non-contiguous tensor input in execution_session_run_forward, need tensor copy.
            if not input.is_contiguous():
                input = input.contiguous()
            if input.device.type == "ort":
                forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input))
            else:
                valid_ort_tensor = _utils._torch_tensor_to_dlpack(input)
                forward_inputs.push_back(valid_ort_tensor,
                                         input.dtype == torch.bool)

        forward_outputs = C.OrtValueVector()
        # Run and return module outputs.
        execution_session.run_forward(forward_inputs, forward_outputs, state,
                                      gradient_accumulation_manager.cache)

        user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(
            forward_outputs, device)

        output_info = [(output.shape, output.device, output.dtype)
                       for output in user_outputs]
        run_info = _RunStateInfo(state, output_info)
        # Return user outputs and forward run information
        return user_outputs, run_info
Ejemplo n.º 2
0
def _ortvalue_from_torch_tensor(torch_tensor):
    # TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit.
    # https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41
    # We need to convert bool tensor to unit8 tensor to workaround this.
    # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
    # and PyTorch support bool type.
    is_bool_tensor = torch_tensor.dtype == torch.bool
    if is_bool_tensor and LooseVersion(
            torch.__version__) >= LooseVersion("1.10.0"):
        torch_tensor = torch_tensor.to(torch.uint8)
    if torch_tensor.device.type == "ort":
        return C.aten_ort_tensor_to_ort_value(torch_tensor)
    return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), is_bool_tensor)
Ejemplo n.º 3
0
    def execution_session_run_forward(execution_session, onnx_model, device,
                                      gradient_accumulation_manager, *inputs):
        """Runs the forward graph on execution_session with given model inputs and device"""

        # Clear all gradient functions, to avoid a deadlock issue.
        # Check the called function for more detailed comments.
        clear_all_grad_fns()

        # TODO: Try to reuse the output buffers as some of the output tensors are same sizes,
        #   especially the backward graph outputs.
        # REVIEW(codemzs): Consolidate Training Agent with InferenceAgent on C++ side to not
        # have the need for passing IOBinding.
        state = C.PartialGraphExecutionState()
        forward_inputs = C.OrtValueVector()
        forward_inputs.reserve(len(inputs))
        for input in inputs:
            if input.device.type == 'ort':
                forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input))
            else:
                valid_ort_tensor = _utils._torch_tensor_to_dlpack(input)
                forward_inputs.push_back(valid_ort_tensor,
                                         input.dtype == torch.bool)

        forward_outputs = C.OrtValueVector()
        # Run and return module outputs.
        execution_session.run_forward(forward_inputs, forward_outputs, state,
                                      gradient_accumulation_manager.cache)

        user_outputs = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache(
            forward_outputs, device)

        output_info = [(output.shape, output.device, output.dtype)
                       for output in user_outputs]
        run_info = _RunStateInfo(state, output_info)
        # Return user outputs and forward run information
        return user_outputs, run_info
Ejemplo n.º 4
0
            def backward(ctx, *grad_outputs):
                """Performs backward pass based on grad wrt module output"""

                assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()"
                if self._skip_check.is_set(
                        _SkipCheck.SKIP_CHECK_DEVICE) is False:
                    _utils._check_same_device(self._device,
                                              "Input argument to backward",
                                              *grad_outputs)

                # Unpack saved_tensor to trigger version detection that catches inplace corruption
                _ = ctx.saved_tensors

                # Use IO binding
                # Push user output grads to ONNX backend.
                backward_inputs = C.OrtValueVector()
                # Preallocate length of the vector. And then delete as required towards the end.
                backward_inputs.reserve(len(grad_outputs))
                for idx, grad_output in enumerate(grad_outputs):
                    if idx in self._graph_info.output_grad_indices_non_differentiable:
                        assert grad_output is None, (
                            "ORT found the {}-th module output '{}' is "
                            "non-differentiable according to the onnx graph. "
                            "However, the gradient value is still provided by "
                            "PyTorch's autograd engine.".format(
                                idx, self._graph_info.user_output_names[idx]))
                        continue

                    if grad_output is None:
                        shape, device, dtype = ctx.run_info.output_info[idx]
                        if idx in self._graph_info.output_grad_indices_require_full_shape:
                            grad_output = torch.zeros(shape,
                                                      device=device,
                                                      dtype=dtype)
                        else:
                            grad_output = torch.tensor(0.0,
                                                       device=device,
                                                       dtype=dtype)
                    elif not grad_output.is_contiguous():
                        grad_output = grad_output.contiguous()
                    if grad_output.device.type == "ort":
                        backward_inputs.push_back(
                            C.aten_ort_tensor_to_ort_value(grad_output))
                    else:
                        backward_inputs.push_back(
                            _utils._torch_tensor_to_dlpack(grad_output),
                            grad_output.dtype is torch.bool)
                backward_inputs.shrink_to_fit()

                # Run and get results
                backward_outputs = C.OrtValueVector()
                self._execution_agent.run_backward(backward_inputs,
                                                   backward_outputs,
                                                   ctx.run_info.state)
                # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
                # affect peak memory usage in a subsequent graph run.
                del ctx.run_info.state

                # Fast version: all backward_outputs are converted first.
                # This version only works if backward_outputs is an OrtValueVector.
                transfered_backward_outputs = _utils._ortvalues_to_torch_tensor(
                    backward_outputs, self._device)
                return tuple(
                    transfered_backward_outputs[idx] if idx != -1 else None
                    for idx in self._gradient_map)
Ejemplo n.º 5
0
            def backward(ctx, *grad_outputs):
                '''Performs backward pass based on grad wrt module output'''

                assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()'
                if self._skip_check.is_set(
                        _SkipCheck.SKIP_CHECK_DEVICE) is False:
                    _utils._check_same_device(self._device,
                                              "Input argument to backward",
                                              *grad_outputs)

                # Unpack saved_tensor to trigger version detection that catches inplace corruption
                _ = ctx.saved_tensors

                # Use IO binding
                # Push user output grads to ONNX backend.
                backward_inputs = C.OrtValueVector()
                # Preallocate length of the vector. And then delete as required towards the end.
                backward_inputs.reserve(len(grad_outputs))
                for idx, grad_output in enumerate(grad_outputs):
                    if idx in self._graph_info.output_grad_indices_non_differentiable:
                        assert grad_output is None, "ORT found the {}-th module output '{}' is " \
                                                    "non-differentiable according to the onnx graph. " \
                                                    "However, the gradient value is still provided by " \
                                                    "PyTorch's autograd engine." \
                                                    .format(idx, self._graph_info.user_output_names[idx])
                        continue

                    if grad_output is None:
                        shape, device, dtype = ctx.run_info.output_info[idx]
                        if idx in self._graph_info.output_grad_indices_require_full_shape:
                            grad_output = torch.zeros(shape,
                                                      device=device,
                                                      dtype=dtype)
                        else:
                            grad_output = torch.tensor(0.,
                                                       device=device,
                                                       dtype=dtype)
                    elif not grad_output.is_contiguous():
                        grad_output = grad_output.contiguous()
                    if grad_output.device.type == 'ort':
                        backward_inputs.push_back(
                            C.aten_ort_tensor_to_ort_value(grad_output))
                    else:
                        backward_inputs.push_back(
                            _utils._torch_tensor_to_dlpack(grad_output),
                            grad_output.dtype is torch.bool)
                backward_inputs.shrink_to_fit()

                # Run and get results
                backward_outputs = C.OrtValueVector()
                self._execution_agent.run_backward(backward_inputs,
                                                   backward_outputs,
                                                   ctx.run_info.state)
                # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
                # affect peak memory usage in a subsequent graph run.
                del ctx.run_info.state
                # Return input and initializer gradients
                num_user_input_grads = len(self._input_info.require_grad_names)
                results = []
                require_grad_names_set = set(
                    self._input_info.require_grad_names)
                require_grad_names_index = 0
                for input_name in self._graph_info.user_input_names:
                    # Append to the results the backward output for each input that required grad
                    if input_name in require_grad_names_set:
                        results.append(
                            _utils._ortvalue_to_torch_tensor(
                                backward_outputs[require_grad_names_index],
                                self._device))
                        require_grad_names_index += 1
                    else:
                        # input_name is not found in the self._input_info.require_grad_names list
                        # Append None to results for each input that did not require grad
                        results.append(None)

                # Append gradients of initializer to results
                # Go over each initializer, check if it required grad and append to results accordingly
                initializer_index = num_user_input_grads
                for initializer_name in self._graph_info.initializer_names:
                    if initializer_name in self._graph_initializer_names_to_train:
                        results.append(
                            _utils._ortvalue_to_torch_tensor(
                                backward_outputs[initializer_index],
                                self._device))
                        initializer_index += 1
                    else:
                        results.append(None)

                return tuple(results)