コード例 #1
0
def test_forward_trace_functor():
    from nncf.dynamic_graph.patch_pytorch import ForwardTraceOnly
    from nncf.dynamic_graph.trace_tensor import TracedTensor, TensorMeta

    func = ForwardTraceOnly()
    shape1, shape2 = ([32, 1, 4, 8], [1, 8, 12, 16])
    meta1, meta2 = (TensorMeta(5, 1, shape1), TensorMeta(3, 8, shape2))
    input_tensor1 = TracedTensor.from_torch_tensor(torch.Tensor(shape1), meta1)
    input_tensor2 = TracedTensor.from_torch_tensor(torch.Tensor(shape2), meta2)

    # 1 -> 1
    output_tensor = func(torch.Tensor.view, input_tensor1, [-1])
    assert output_tensor.tensor_meta == input_tensor1.tensor_meta

    # 1 -> N
    outputs = func(torch.Tensor.chunk, input_tensor1, 3)
    for out in outputs:
        assert out.tensor_meta == input_tensor1.tensor_meta

    # N -> N (2 -> 2)
    outputs = func(lambda x: x + [5], [input_tensor1, input_tensor2])
    assert outputs[0].tensor_meta == input_tensor1.tensor_meta
    assert outputs[1].tensor_meta == input_tensor2.tensor_meta

    # M -> N (2 -> 3)
    with pytest.raises(RuntimeError):
        outputs = func(lambda x: x + [torch.Tensor(shape2)],
                       [input_tensor1, input_tensor2])

    # M -> N (2 -> 1)
    with pytest.raises(RuntimeError):
        outputs = func(lambda x: x[0], [input_tensor1, input_tensor2])
コード例 #2
0
    def __call__(self, operator: Callable, *args, **kwargs):
        """ This wrapper override will result in the operator not being added to graph,
        but the result will still have TracedTensors with parent IDs left the same as in input.
        Useful for operators which are not likely to be present in patterns considered for
        compression, but still have to be accounted for so that the NNCF internal graph representation
        does not become disjoint. """

        result = operator(*args, **kwargs)

        fargs = flatten_args(args, kwargs)
        input_traced_tensor_indices = [
            i for i in range(len(fargs)) if isinstance(fargs[i], TracedTensor)
        ]

        if isinstance(result, (list, tuple)):
            output_tensors_to_be_traced_indices = [
                i for i in range(len(result)) if isinstance(result[i], Tensor)
            ]

            was_tuple = isinstance(result, tuple)
            result = list(result)
            if len(input_traced_tensor_indices) == 1:
                # Broadcast one and the same creator ID of input to all outputs
                for out_idx in output_tensors_to_be_traced_indices:
                    forwarded_meta = deepcopy(
                        fargs[input_traced_tensor_indices[0]].tensor_meta)
                    forwarded_meta.shape = tuple(result[out_idx].shape)
                    result[out_idx] = TracedTensor.from_torch_tensor(
                        result[out_idx], forwarded_meta)
            elif len(input_traced_tensor_indices) != len(
                    output_tensors_to_be_traced_indices):
                raise RuntimeError(
                    "Unable to forward trace through operator {} - "
                    "input and output tensor count mismatch!".format(
                        operator.__name__))
            else:
                # Assume that output tensor order corresponds to input tensor order
                for in_idx, out_idx in zip(
                        input_traced_tensor_indices,
                        output_tensors_to_be_traced_indices):
                    forwarded_meta = deepcopy(fargs[in_idx].tensor_meta)
                    forwarded_meta.shape = tuple(result[out_idx].shape)
                    result[out_idx] = TracedTensor.from_torch_tensor(
                        result[out_idx], forwarded_meta)
            if was_tuple:
                result = tuple(result)
        elif len(input_traced_tensor_indices) > 1:
            raise RuntimeError(
                "Unable to forward trace through operator {} - "
                "input and output tensor count mismatch!".format(
                    operator.__name__))
        elif input_traced_tensor_indices:
            forwarded_meta = deepcopy(
                fargs[input_traced_tensor_indices[0]].tensor_meta)
            forwarded_meta.shape = tuple(result.shape)
            return TracedTensor.from_torch_tensor(result, forwarded_meta)
        # No traced tensors in input, return a usual torch.Tensor as well
        return result