Ejemplo n.º 1
0
def infer_shapes(
        graph,  # type: ONNXGraph
        source_shapes=None,
        # type: typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None]
        custom_shapes=None,  # type: typing.Optional[typing.Dict[str, typing.Callable]]
):
    # type: (...)->None

    shape_functions = dict(_DefaultShapes)
    if custom_shapes:
        shape_functions.update(custom_shapes)

    graph.sort()

    shape_fixer.fix_input_shapes(graph, source_shapes)

    for op in graph.operations:
        # Shape prop
        assert op.name in shape_functions, "No shape function for {}".format(
            op.name)
        inferred_shapes, inferred_dtypes = shape_functions[op.name](op)
        assert not utils.has_le_0(inferred_shapes)
        assert len(inferred_shapes) == len(inferred_dtypes) == len(op.outputs)
        for new_shape, new_dtype, tensor in zip(inferred_shapes,
                                                inferred_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

    graph_utils.remove_unreachable(graph)
Ejemplo n.º 2
0
def evaluate_and_convert(tf_graph, source_shapes=None):
    # type: (TFGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    tf_graph.sort()

    if isinstance(source_shapes, dict):
        source_shapes = {(k + ':0' if ':' not in k else k): v
                         for k, v in six.iteritems(source_shapes)}
    shape_fixer.fix_input_shapes(tf_graph, source_shapes)

    const_value_by_tensor = {}

    for tensor in tf_graph.tensors:
        if tensor.is_constant:
            const_value_by_tensor[tensor] = tf_pb_eval._evaluate_constant(
                tensor)
        elif tensor.is_variable:
            const_value_by_tensor[tensor] = tensor.data

    for op in tf_graph.operations:
        # Shape prop
        if op.name not in tf_pb_shape_inference._DefaultPropagators:
            raise utils.NNEFToolsException(
                "Operation '{}' is not supported".format(op.name))
        propagated_shapes, propagated_dtypes = \
            tf_pb_shape_inference._DefaultPropagators[op.name](op, const_value_by_tensor)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

        # Evaluation
        if op.name in tf_pb_eval._DefaultOpEvaluators:
            tf_pb_eval._DefaultOpEvaluators[op.name](op, const_value_by_tensor)

        # Conversion
        assert op.name in DefaultConverters, "No tf_pb_to_tf_py converter for {}".format(
            op.name)
        DefaultConverters[op.name](op, const_value_by_tensor)

    for tensor in tf_graph.tensors:
        tensor.dtype = _tf_py_dtype_by_tf_pb_dtype.get(tensor.dtype, None)

    for tensor in tf_graph.tensors:
        if tensor.is_variable:
            label = tensor.name
            if label is not None:
                if label.endswith(':0'):
                    label = label[:-2]
                label = label.replace(':', '_')
            tensor.label = label
Ejemplo n.º 3
0
def propagate(graph, source_shapes=None):
    # type: (ONNXGraph, typing.Union[typing.Dict[str, typing.List[int]], typing.List[int], int, None])->None

    graph.sort()

    shape_fixer.fix_input_shapes(graph, source_shapes)

    for op in graph.operations:
        # Shape prop
        assert op.name in _DefaultPropagators, "No shape propagator for {}".format(
            op.name)
        propagated_shapes, propagated_dtypes = _DefaultPropagators[op.name](op)
        assert not utils.has_le_0(propagated_shapes)
        assert len(propagated_shapes) == len(propagated_dtypes) == len(
            op.outputs)
        for new_shape, new_dtype, tensor in zip(propagated_shapes,
                                                propagated_dtypes, op.outputs):
            assert utils.compatible_shapes(tensor.shape, new_shape)
            tensor.shape = new_shape
            assert tensor.dtype is None or tensor.dtype == new_dtype
            tensor.dtype = new_dtype

    graph_utils.remove_unreachable(graph)