def infer_shapes(
        graph,  # type: ONNXGraph
        source_shapes=None,
        # type: typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None]
        custom_shapes=None,  # type: typing.Optional[typing.Dict[str, typing.Callable]]
):
    # type: (...)->None

    shape_functions = dict(_DefaultShapes)
    if custom_shapes:
        shape_functions.update(custom_shapes)

    graph.sort()

    shape_fixer.fix_input_shapes(graph, source_shapes)

    for op in graph.operations:
        # Shape prop
        assert op.name in shape_functions, "No shape function for {}".format(
            op.name)
        inferred_shapes, inferred_dtypes = shape_functions[op.name](op)
        assert not utils.has_le_0(inferred_shapes)
        assert len(inferred_shapes) == len(inferred_dtypes) == len(op.outputs)
        for new_shape, new_dtype, tensor in zip(inferred_shapes,
                                                inferred_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

    graph_utils.remove_unreachable(graph)
Beispiel #2
0
def fix_input_shapes(graph, source_shapes):
    # type: (BaseGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    def get_shape_for(name):
        if isinstance(source_shapes, dict) and name in source_shapes:
            return source_shapes[name]
        elif isinstance(source_shapes, list):
            return list(source_shapes)
        elif utils.is_anyint(source_shapes):
            return utils.anyint_to_int(source_shapes)
        return None

    placeholders = [
        tensor for tensor in graph.tensors if len(tensor.producers) == 0
        and not tensor.is_constant and not tensor.is_variable
    ]

    if source_shapes is None:
        if any(tensor.shape is None or -1 in tensor.shape
               for tensor in placeholders):
            for tensor in placeholders:
                print("Info: Input shape: {}: {}".format(
                    tensor.name, tensor.shape))

    for tensor in placeholders:  # type: BaseTensor
        shape_for_this = get_shape_for(tensor.name) if tensor.name else None
        if isinstance(shape_for_this, list):
            if not utils.compatible_shapes(tensor.shape, shape_for_this):
                raise utils.NNEFToolsException(
                    "The specified shape is incompatible with the original shape for {}. {} vs. {}"
                    .format(tensor.name, shape_for_this, tensor.shape))
            tensor.shape = shape_for_this
        elif shape_for_this is None or isinstance(shape_for_this, int):
            if tensor.shape is None:
                raise utils.NNEFToolsException(
                    "The full shape must be specified for {}, because it is unknown."
                    .format(tensor.name))
            elif -1 in tensor.shape:
                if shape_for_this is None:
                    shape_for_this = 1
                    print(
                        "Warning: Incomplete input shape is auto-fixed: {}. {} -> {}. "
                        "Use --input-shape if other shape is desired.".format(
                            tensor.name, tensor.shape, [
                                shape_for_this if dim == -1 else dim
                                for dim in tensor.shape
                            ]))
                tensor.shape = [
                    shape_for_this if dim == -1 else dim
                    for dim in tensor.shape
                ]
        else:
            assert False

        if tensor.dtype is None:
            raise utils.NNEFToolsException(
                "An input tensor has incomplete dtype, "
                "we have thought that this is impossible, "
                "please file a bug report to NNEF Tools.")
def evaluate_and_convert(tf_graph, source_shapes=None):
    # type: (TFGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    tf_graph.sort()

    if isinstance(source_shapes, dict):
        source_shapes = {(k + ':0' if ':' not in k else k): v
                         for k, v in six.iteritems(source_shapes)}
    shape_fixer.fix_input_shapes(tf_graph, source_shapes)

    const_value_by_tensor = {}

    for tensor in tf_graph.tensors:
        if tensor.is_constant:
            const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant(
                tensor)
        elif tensor.is_variable:
            const_value_by_tensor[tensor] = tensor.data

    for op in tf_graph.operations:
        # Shape prop
        if op.name not in tf_pb_shape_inference._DefaultPropagators:
            raise utils.NNEFToolsException(
                "Operation '{}' is not supported".format(op.name))
        propagated_shapes, propagated_dtypes = \
            tf_pb_shape_inference._DefaultPropagators[op.name](op, const_value_by_tensor)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

        # Evaluation
        if op.name in tf_pb_eval._DefaultOpEvaluators:
            tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor)

        # Conversion
        assert op.name in DefaultConverters, "No tf_pb_to_tf_py converter for {}".format(
            op.name)
        DefaultConverters[op.name](op, const_value_by_tensor)

    for tensor in tf_graph.tensors:
        tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None)

    for tensor in tf_graph.tensors:
        if tensor.is_variable:
            label = tensor.name
            if label is not None:
                if label.endswith(':0'):
                    label = label[:-2]
                label = label.replace(':', '_')
            tensor.label = label
def evaluate_shape_of_operation(op, const_value_by_tensor):
    # type: (TFOperation, typing.Dict[TFTensor, np.ndarray])->None

    if all(output.shape is not None and all(s is not None for s in output.shape) for output in op.outputs):
        return

    old_shapes = [output.shape for output in op.outputs]

    if all(output in const_value_by_tensor for output in op.outputs):
        for output in op.outputs:
            output.shape = list(np.shape(const_value_by_tensor[output]))
    elif op.name in _DefaultOpShapeEvaluators:
        _DefaultOpShapeEvaluators[op.name](op)

    for old_shape, tensor in zip(old_shapes, op.outputs):
        assert tensor.shape is not None and all(s is not None for s in tensor.shape)
        assert utils.compatible_shapes(old_shape, tensor.shape), \
            "{}: Evaluated shape ({}) not compatible with original shape ({})".format(tensor, tensor.shape, old_shape)
Beispiel #5
0
def propagate(graph, source_shapes=None):
    # type: (ONNXGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    graph.sort()

    shape_fixer.fix_input_shapes(graph, source_shapes)

    for op in graph.operations:
        # Shape prop
        assert op.name in _DefaultPropagators, "No shape propagator for {}".format(
            op.name)
        propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

    graph_utils.remove_unreachable(graph)
def tf_get_input_shapes(input_shape=None):
    def get_shape_for(name):
        if isinstance(input_shape, dict) and name in input_shape:
            return input_shape[name]
        elif isinstance(input_shape, list):
            return list(input_shape)
        elif utils.is_anyint(input_shape):
            return utils.anyint_to_int(input_shape)
        return None

    if isinstance(input_shape, dict):
        input_shape = {(k + ':0' if ':' not in k else k): v
                       for k, v in six.iteritems(input_shape)}

    placeholders = tf_get_placeholders()
    new_input_shapes = {}

    if input_shape is None:
        if any(
                tf_shape_normalize(tensor.shape) is None
                or -1 in tf_shape_normalize(tensor.shape)
                for tensor in placeholders):
            for tensor in placeholders:
                print("Info: Input shape: {}: {}".format(
                    tensor.name, tf_shape_normalize(tensor.shape)))

    for tensor in placeholders:
        tensor_shape = tf_shape_normalize(tensor.shape)
        shape_for_this = get_shape_for(tensor.name) if tensor.name else None
        if isinstance(shape_for_this, list):
            if not utils.compatible_shapes(tensor_shape, shape_for_this):
                raise utils.NNEFToolsException(
                    "The specified shape is incompatible with the original shape for {}. {} vs. {}"
                    .format(tensor.name, shape_for_this, tensor_shape))
            tensor_shape = shape_for_this
        elif shape_for_this is None or isinstance(shape_for_this, int):
            if tensor_shape is None:
                raise utils.NNEFToolsException(
                    "The full shape must be specified for {}, because it is unknown."
                    .format(tensor.name))
            elif -1 in tensor_shape:
                if shape_for_this is None:
                    shape_for_this = 1
                    print(
                        "Warning: Incomplete input shape is auto-fixed: {}. {} -> {}. "
                        "Use --input-shape if other shape is desired.".format(
                            tensor.name, tensor_shape, [
                                shape_for_this if dim == -1 else dim
                                for dim in tensor_shape
                            ]))
                tensor_shape = [
                    shape_for_this if dim == -1 else dim
                    for dim in tensor_shape
                ]
        else:
            assert False

        if tensor.dtype is None:
            raise utils.NNEFToolsException(
                "An input tensor has incomplete dtype, "
                "we have thought that this is impossible, "
                "please file a bug report to NNEF Tools.")

        new_input_shapes[tensor.name] = (tensor.dtype.name, tensor_shape)

    return new_input_shapes