def test_forward_trace_functor(): from nncf.dynamic_graph.patch_pytorch import ForwardTraceOnly from nncf.dynamic_graph.trace_tensor import TracedTensor, TensorMeta func = ForwardTraceOnly() shape1, shape2 = ([32, 1, 4, 8], [1, 8, 12, 16]) meta1, meta2 = (TensorMeta(5, 1, shape1), TensorMeta(3, 8, shape2)) input_tensor1 = TracedTensor.from_torch_tensor(torch.Tensor(shape1), meta1) input_tensor2 = TracedTensor.from_torch_tensor(torch.Tensor(shape2), meta2) # 1 -> 1 output_tensor = func(torch.Tensor.view, input_tensor1, [-1]) assert output_tensor.tensor_meta == input_tensor1.tensor_meta # 1 -> N outputs = func(torch.Tensor.chunk, input_tensor1, 3) for out in outputs: assert out.tensor_meta == input_tensor1.tensor_meta # N -> N (2 -> 2) outputs = func(lambda x: x + [5], [input_tensor1, input_tensor2]) assert outputs[0].tensor_meta == input_tensor1.tensor_meta assert outputs[1].tensor_meta == input_tensor2.tensor_meta # M -> N (2 -> 3) with pytest.raises(RuntimeError): outputs = func(lambda x: x + [torch.Tensor(shape2)], [input_tensor1, input_tensor2]) # M -> N (2 -> 1) with pytest.raises(RuntimeError): outputs = func(lambda x: x[0], [input_tensor1, input_tensor2])
def __call__(self, operator: Callable, *args, **kwargs): """ This wrapper override will result in the operator not being added to graph, but the result will still have TracedTensors with parent IDs left the same as in input. Useful for operators which are not likely to be present in patterns considered for compression, but still have to be accounted for so that the NNCF internal graph representation does not become disjoint. """ result = operator(*args, **kwargs) fargs = flatten_args(args, kwargs) input_traced_tensor_indices = [ i for i in range(len(fargs)) if isinstance(fargs[i], TracedTensor) ] if isinstance(result, (list, tuple)): output_tensors_to_be_traced_indices = [ i for i in range(len(result)) if isinstance(result[i], Tensor) ] was_tuple = isinstance(result, tuple) result = list(result) if len(input_traced_tensor_indices) == 1: # Broadcast one and the same creator ID of input to all outputs for out_idx in output_tensors_to_be_traced_indices: forwarded_meta = deepcopy( fargs[input_traced_tensor_indices[0]].tensor_meta) forwarded_meta.shape = tuple(result[out_idx].shape) result[out_idx] = TracedTensor.from_torch_tensor( result[out_idx], forwarded_meta) elif len(input_traced_tensor_indices) != len( output_tensors_to_be_traced_indices): raise RuntimeError( "Unable to forward trace through operator {} - " "input and output tensor count mismatch!".format( operator.__name__)) else: # Assume that output tensor order corresponds to input tensor order for in_idx, out_idx in zip( input_traced_tensor_indices, output_tensors_to_be_traced_indices): forwarded_meta = deepcopy(fargs[in_idx].tensor_meta) forwarded_meta.shape = tuple(result[out_idx].shape) result[out_idx] = TracedTensor.from_torch_tensor( result[out_idx], forwarded_meta) if was_tuple: result = tuple(result) elif len(input_traced_tensor_indices) > 1: raise RuntimeError( "Unable to forward trace through operator {} - " "input and output tensor count mismatch!".format( operator.__name__)) elif input_traced_tensor_indices: forwarded_meta = deepcopy( fargs[input_traced_tensor_indices[0]].tensor_meta) forwarded_meta.shape = tuple(result.shape) return TracedTensor.from_torch_tensor(result, forwarded_meta) # No traced tensors in input, return a usual torch.Tensor as well return result