示例#1
0
def infer_shapes(
        graph,  # type: ONNXGraph
        source_shapes=None,
        # type: typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None]
        custom_shapes=None,  # type: typing.Optional[typing.Dict[str, typing.Callable]]
):
    # type: (...)->None

    shape_functions = dict(_DefaultShapes)
    if custom_shapes:
        shape_functions.update(custom_shapes)

    graph.sort()

    shape_fixer.fix_input_shapes(graph, source_shapes)

    for op in graph.operations:
        # Shape prop
        assert op.name in shape_functions, "No shape function for {}".format(
            op.name)
        inferred_shapes, inferred_dtypes = shape_functions[op.name](op)
        assert not utils.has_le_0(inferred_shapes)
        assert len(inferred_shapes) == len(inferred_dtypes) == len(op.outputs)
        for new_shape, new_dtype, tensor in zip(inferred_shapes,
                                                inferred_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

    graph_utils.remove_unreachable(graph)
示例#2
0
def evaluate_and_convert(tf_graph, source_shapes=None):
    # type: (TFGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    tf_graph.sort()

    if isinstance(source_shapes, dict):
        source_shapes = {(k + ':0' if ':' not in k else k): v
                         for k, v in six.iteritems(source_shapes)}
    shape_fixer.fix_input_shapes(tf_graph, source_shapes)

    const_value_by_tensor = {}

    for tensor in tf_graph.tensors:
        if tensor.is_constant:
            const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant(
                tensor)
        elif tensor.is_variable:
            const_value_by_tensor[tensor] = tensor.data

    for op in tf_graph.operations:
        # Shape prop
        if op.name not in tf_pb_shape_inference._DefaultPropagators:
            raise utils.NNEFToolsException(
                "Operation '{}' is not supported".format(op.name))
        propagated_shapes, propagated_dtypes = \
            tf_pb_shape_inference._DefaultPropagators[op.name](op, const_value_by_tensor)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

        # Evaluation
        if op.name in tf_pb_eval._DefaultOpEvaluators:
            tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor)

        # Conversion
        assert op.name in DefaultConverters, "No tf_pb_to_tf_py converter for {}".format(
            op.name)
        DefaultConverters[op.name](op, const_value_by_tensor)

    for tensor in tf_graph.tensors:
        tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None)

    for tensor in tf_graph.tensors:
        if tensor.is_variable:
            label = tensor.name
            if label is not None:
                if label.endswith(':0'):
                    label = label[:-2]
                label = label.replace(':', '_')
            tensor.label = label
示例#3
0
def propagate(graph, source_shapes=None, source_dtypes=None):
    # type: (ONNXGraph, typing.Dict[str, typing.List[int]], typing.Dict[str, str])->None

    if source_shapes is None:
        source_shapes = {}
    if source_dtypes is None:
        source_dtypes = {}

    graph.sort()

    for tensor in graph.tensors:
        if tensor.name and tensor.name in source_shapes:
            if not _is_compatible(tensor.shape, source_shapes[tensor.name]):
                raise utils.NNEFToolsException(
                    "The specified shape is incompatible with the original shape for {}. {} vs. {}"
                    .format(tensor.name, source_shapes[tensor.name],
                            tensor.shape))
            tensor.shape = source_shapes[tensor.name]
        if tensor.name and tensor.name in source_dtypes:
            if not (tensor.dtype is None
                    or tensor.dtype == source_dtypes[tensor.name]):
                raise utils.NNEFToolsException(
                    "The specified dtype is incompatible with the original dtype for {}. {} vs. {}"
                    .format(tensor.name, source_dtypes[tensor.name],
                            tensor.dtype))

            tensor.dtype = source_dtypes[tensor.name]
        if len(tensor.producers) == 0 and (tensor.shape is None
                                           or -1 in tensor.shape
                                           or tensor.dtype is None):
            raise utils.NNEFToolsException(
                "Source tensor '{}' has incomplete dtype or shape: {} {}\n"
                "Please specify it in --input-shape or through the corresponding API."
                .format(tensor.name, tensor.dtype, tensor.shape))

    for op in graph.operations:
        # Shape prop
        assert op.name in _DefaultPropagators, "No shape propagator for {}".format(
            op.name)
        propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert _is_compatible(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype
示例#4
0
def propagate(graph, source_shapes=None):
    # type: (ONNXGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    graph.sort()

    shape_fixer.fix_input_shapes(graph, source_shapes)

    for op in graph.operations:
        # Shape prop
        assert op.name in _DefaultPropagators, "No shape propagator for {}".format(
            op.name)
        propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

    graph_utils.remove_unreachable(graph)
示例#5
0
def evaluate_and_convert(tf_graph, source_shapes=None, source_dtypes=None):
    # type: (TFGraph, typing.Dict[str, typing.List[int]], typing.Dict[str, str])->None
    if source_shapes is None:
        source_shapes = {}
    if source_dtypes is None:
        source_dtypes = {}

    source_shapes = {(k + ':0' if ':' not in k else k): v
                     for k, v in six.iteritems(source_shapes)}
    source_dtypes = {(k + ':0' if ':' not in k else k): v
                     for k, v in six.iteritems(source_dtypes)}

    tf_graph.sort()

    for tensor in tf_graph.tensors:
        tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None)

        if tensor.name and tensor.name in source_shapes:
            if not _is_compatible(tensor.shape, source_shapes[tensor.name]):
                raise utils.NNEFToolsException(
                    "The specified shape is incompatible with the original shape for {}. {} vs. {}"
                    .format(tensor.name, source_shapes[tensor.name],
                            tensor.shape))
            tensor.shape = source_shapes[tensor.name]
        if tensor.name and tensor.name in source_dtypes:
            if not (tensor.dtype is None
                    or tensor.dtype == source_dtypes[tensor.name]):
                raise utils.NNEFToolsException(
                    "The specified dtype is incompatible with the original dtype for {}. {} vs. {}"
                    .format(tensor.name, source_dtypes[tensor.name],
                            tensor.dtype))

            tensor.dtype = source_dtypes[tensor.name]
        if len(tensor.producers) == 0 and (tensor.shape is None
                                           or -1 in tensor.shape
                                           or tensor.dtype is None):
            raise utils.NNEFToolsException(
                "Source tensor '{}' has incomplete dtype or shape: {} {}\n"
                "Please specify it in --input-shape or through the corresponding API."
                .format(tensor.name, tensor.dtype, tensor.shape))

    const_value_by_tensor = {}

    for tensor in tf_graph.tensors:
        if tensor.is_constant:
            const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant(
                tensor)
        elif tensor.is_variable:
            const_value_by_tensor[tensor] = tensor.data

    for op in tf_graph.operations:
        # Shape prop
        assert op.name in tf_pb_shape_inference._DefaultPropagators, "No shape propagator for {}".format(
            op.name)
        propagated_shapes, propagated_dtypes = tf_pb_shape_inference._DefaultPropagators[
            op.name](op, const_value_by_tensor)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert _is_compatible(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype

        # Evaluation
        if op.name in tf_pb_eval._DefaultOpEvaluators:
            tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor)

        # Conversion
        assert op.name in _DefaultConverters, "No tf_pb_to_tf_py converter for {}".format(
            op.name)
        _DefaultConverters[op.name](op, const_value_by_tensor)

    for tensor in tf_graph.tensors:
        if tensor.is_variable:
            label = tensor.name
            if label is not None:
                if label.endswith(':0'):
                    label = label[:-2]
                label = label.replace(':', '_')
            tensor.label = label