Esempio n. 1
0
def infer_shapes(model: Union[ModelProto, bytes],
                 check_type: bool = False,
                 strict_mode: bool = False,
                 data_prop: bool = False) -> ModelProto:
    """Apply shape inference to the provided ModelProto.

    Inferred shapes are added to the value_info field of the graph.

    If the inferred values conflict with values already provided in the
    graph, that means that the provided values are invalid (or there is a
    bug in shape inference), and the result is unspecified.

    Arguments:
        model (Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
        check_type (bool): Checks the type-equality for input and output
        strict_mode (bool): Stricter shape inference, it will throw errors if any;
            Otherwise, simply stop if any error
        data_prop (bool): Enables data propagation for limited operators to perform shape computation

    Returns:
        (ModelProto) model with inferred shape information
    """
    if isinstance(model, (ModelProto, bytes)):
        model_str = model if isinstance(model,
                                        bytes) else model.SerializeToString()
        inferred_model_str = C.infer_shapes(model_str, check_type, strict_mode,
                                            data_prop)
        return onnx.load_from_string(inferred_model_str)
    elif isinstance(model, str):
        raise TypeError(
            'infer_shapes only accepts ModelProto or bytes,'
            'you can use infer_shapes_path for the model path (String).')
    else:
        raise TypeError('infer_shapes only accepts ModelProto or bytes, '
                        'incorrect type: {}'.format(type(model)))
Esempio n. 2
0
def infer_shapes(model, check_type=False):  # type: (ModelProto,bool) -> ModelProto
    if not isinstance(model, ModelProto):
        raise ValueError('Shape inference only accepts ModelProto, '
                         'incorrect type: {}'.format(type(model)))
    model_str = model.SerializeToString()
    inferred_model_str = C.infer_shapes(model_str, check_type)
    return onnx.load_from_string(inferred_model_str)
Esempio n. 3
0
def infer_shapes(model):  # type: (ModelProto) -> ModelProto
    if not isinstance(model, ModelProto):
        raise ValueError('Shape inference only accepts ModelProto, '
                         'incorrect type: {}'.format(type(model)))

    model_str = model.SerializeToString()
    inferred_model_str = C.infer_shapes(model_str)
    return onnx.load_from_string(inferred_model_str)
Esempio n. 4
0
def infer_shapes(model, check_type=False):  # type: (ModelProto, bool) -> ModelProto
    if isinstance(model, ModelProto):
        model_str = model.SerializeToString()
        inferred_model_str = C.infer_shapes(model_str, check_type)
        return onnx.load_from_string(inferred_model_str)
    elif isinstance(model, string_types):
        raise TypeError('infer_shapes only accepts ModelProto,'
                        'you can use infer_shapes_path for the model path (String).')
    else:
        raise TypeError('infer_shapes only accepts ModelProto, '
                         'incorrect type: {}'.format(type(model)))
Esempio n. 5
0
def infer_shapes(model: Union[ModelProto, bytes],
                 check_type: bool = False,
                 strict_mode: bool = False,
                 data_prop: bool = False) -> ModelProto:
    if isinstance(model, (ModelProto, bytes)):
        model_str = model if isinstance(model,
                                        bytes) else model.SerializeToString()
        inferred_model_str = C.infer_shapes(model_str, check_type, strict_mode,
                                            data_prop)
        return onnx.load_from_string(inferred_model_str)
    elif isinstance(model, str):
        raise TypeError(
            'infer_shapes only accepts ModelProto or bytes,'
            'you can use infer_shapes_path for the model path (String).')
    else:
        raise TypeError('infer_shapes only accepts ModelProto or bytes, '
                        'incorrect type: {}'.format(type(model)))