예제 #1
0
파일: utils.py 프로젝트: theo-char/armnn
def create_network(model_file: str, backends: list):
    """
    Creates a network based on the model file and a list of backends.

    Args:
        model_file: User-specified model file.
        backends: List of backends to optimize network.

    Returns:
        net_id: Unique ID of the network to run.
        runtime: Runtime context for executing inference.
        input_binding_info: Contains essential information about the model input.
        output_binding_info: Used to map output tensor and its memory.
    """
    if not os.path.exists(model_file):
        raise FileNotFoundError(f'Model file not found for: {model_file}')

    # Determine which parser to create based on model file extension
    parser = None
    _, ext = os.path.splitext(model_file)
    if ext == '.tflite':
        parser = ann.ITfLiteParser()
    elif ext == '.pb':
        parser = ann.ITfParser()
    elif ext == '.onnx':
        parser = ann.IOnnxParser()
    assert (parser is not None)
    network = parser.CreateNetworkFromBinaryFile(model_file)

    # Specify backends to optimize network
    preferred_backends = []
    for b in backends:
        preferred_backends.append(ann.BackendId(b))

    # Select appropriate device context and optimize the network for that device
    options = ann.CreationOptions()
    runtime = ann.IRuntime(options)
    opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
                                  ann.OptimizerOptions())
    print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
          f'Optimization warnings: {messages}')

    # Load the optimized network onto the Runtime device
    net_id, _ = runtime.LoadNetwork(opt_network)

    # Get input and output binding information
    graph_id = parser.GetSubgraphCount() - 1
    input_names = parser.GetSubgraphInputTensorNames(graph_id)
    input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
    output_binding_info = []
    for output_name in output_names:
        outBindInfo = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
        output_binding_info.append(outBindInfo)
    return net_id, runtime, input_binding_info, output_binding_info
예제 #2
0
def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
    """Creates a network from an onnx model file.

    Args:
        model_file (str): Path of the model file.
        backends (list): List of backends to use when running inference.

    Returns:
        int: Network ID.
        IOnnxParser: ONNX parser instance.
        IRuntime: Runtime object instance.
    """
    return __create_network(model_file, backends, ann.IOnnxParser())
예제 #3
0
def test_onnx_filenotfound_exception(shared_data_folder):
    parser = ann.IOnnxParser()

    # path to model
    path_to_model = os.path.join(shared_data_folder, 'some_unknown_model.onnx')

    # parse onnx binary & create network

    with pytest.raises(RuntimeError) as err:
        parser.CreateNetworkFromBinaryFile(path_to_model)

    # Only check for part of the exception since the exception returns
    # absolute path which will change on different machines.
    assert 'Invalid (null) filename' in str(err.value)
예제 #4
0
def parser(shared_data_folder):
    """
    Parse and setup the test network to be used for the tests below
    """

    # create onnx parser
    parser = ann.IOnnxParser()

    # path to model
    path_to_model = os.path.join(shared_data_folder, 'mock_model.onnx')

    # parse onnx binary & create network
    parser.CreateNetworkFromBinaryFile(path_to_model)

    yield parser
예제 #5
0
def __create_network(model_file: str, backends: list, parser=None):
    """Creates a network based on a file and parser type.

    Args:
        model_file (str): Path of the model file.
        backends (list): List of backends to use when running inference.
        parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)

    Returns:
        int: Network ID.
        int: Graph ID.
        IParser: TF Lite parser instance.
        IRuntime: Runtime object instance.
    """
    args = parse_command_line()
    options = ann.CreationOptions()
    runtime = ann.IRuntime(options)

    if parser is None:
        # try to determine what parser to create based on model extension
        _, ext = os.path.splitext(model_file)
        if ext == ".onnx":
            parser = ann.IOnnxParser()
        elif ext == ".tflite":
            parser = ann.ITfLiteParser()
    assert (parser is not None)

    network = parser.CreateNetworkFromBinaryFile(model_file)

    preferred_backends = []
    for b in backends:
        preferred_backends.append(ann.BackendId(b))

    opt_network, messages = ann.Optimize(network, preferred_backends,
                                         runtime.GetDeviceSpec(),
                                         ann.OptimizerOptions())
    if args.verbose:
        for m in messages:
            warnings.warn(m)

    net_id, w = runtime.LoadNetwork(opt_network)
    if args.verbose and w:
        warnings.warn(w)

    return net_id, parser, runtime
예제 #6
0
def test_onnx_parser_end_to_end(shared_data_folder):
    parser = ann.IOnnxParser = ann.IOnnxParser()

    network = parser.CreateNetworkFromBinaryFile(
        os.path.join(shared_data_folder, 'mock_model.onnx'))

    # load test image data stored in input_onnx.npy
    input_binding_info = parser.GetNetworkInputBindingInfo("input")
    input_tensor_data = np.load(
        os.path.join(shared_data_folder,
                     'onnx_parser/input_onnx.npy')).astype(np.float32)

    options = ann.CreationOptions()
    runtime = ann.IRuntime(options)

    preferred_backends = [ann.BackendId('CpuAcc'), ann.BackendId('CpuRef')]
    opt_network, messages = ann.Optimize(network, preferred_backends,
                                         runtime.GetDeviceSpec(),
                                         ann.OptimizerOptions())

    assert 0 == len(messages)

    net_id, messages = runtime.LoadNetwork(opt_network)

    assert "" == messages

    input_tensors = ann.make_input_tensors([input_binding_info],
                                           [input_tensor_data])
    output_tensors = ann.make_output_tensors(
        [parser.GetNetworkOutputBindingInfo("output")])

    runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)

    output = ann.workload_tensors_to_ndarray(output_tensors)

    # Load golden output file for result comparison.
    golden_output = np.load(
        os.path.join(shared_data_folder, 'onnx_parser/golden_output_onnx.npy'))

    # Check that output matches golden output to 4 decimal places (there are slight rounding differences after this)
    np.testing.assert_almost_equal(output[0], golden_output, decimal=4)